├── .codecov.yml ├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .github └── FUNDING.yml ├── .gitignore ├── .pylintrc ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── MANIFEST.in ├── Makefile ├── README.md ├── azure-pipelines.yml ├── make.bat ├── nbconfig.py ├── notebooks ├── GP-Kernels.ipynb ├── Variational_API_Quickstart.ipynb ├── __init__.py ├── baseball.ipynb ├── basic-usage.ipynb ├── context_design │ ├── README.md │ ├── compare_speed_logps.ipynb │ ├── developer guide.ipynb │ ├── example_models.ipynb │ ├── hmc.ipynb │ ├── pymc3_samplers.ipynb │ └── tfp_samplers.ipynb ├── data │ ├── efron-morris-75-data.tsv │ ├── radon.csv │ └── rugby.csv ├── discrete_distributions_sampling.ipynb ├── eight_schools.ipynb ├── gaussian_process.ipynb ├── pymc4_design_guide.ipynb ├── radon_hierarchical.ipynb ├── rugby_analytics.ipynb └── utils.py ├── pymc4 ├── __init__.py ├── coroutine_model.py ├── distributions │ ├── __init__.py │ ├── batchstack.py │ ├── continuous.py │ ├── discrete.py │ ├── distribution.py │ ├── mixture.py │ ├── multivariate.py │ ├── state_functions.py │ ├── timeseries.py │ └── transforms.py ├── flow │ ├── __init__.py │ ├── executor.py │ ├── meta_executor.py │ ├── posterior_predictive_executor.py │ └── transformed_executor.py ├── forward_sampling.py ├── gp │ ├── __init__.py │ ├── _kernel.py │ ├── cov.py │ ├── gp.py │ ├── mean.py │ └── util.py ├── inference │ ├── __init__.py │ └── sampling.py ├── mcmc │ ├── __init__.py │ ├── samplers.py │ ├── tf_support.py │ └── utils.py ├── plots │ ├── __init__.py │ └── gp_plots.py ├── scopes.py ├── utils.py └── variational │ ├── __init__.py │ ├── approximations.py │ ├── updates.py │ └── util.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts ├── Dockerfile ├── README.md ├── container.bat ├── container.sh ├── create_testenv.sh └── lint.sh ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── pytest.ini ├── test_8schools.py ├── test_compound.py ├── test_discrete.py ├── test_distributions.py ├── test_executor.py ├── test_forward_sampling.py ├── test_gp.py ├── test_gp_cov.py ├── test_gp_mean.py ├── test_gp_util.py ├── test_mixture.py ├── test_plots.py ├── test_sampling.py ├── test_utils.py └── test_variational.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | 9 | status: 10 | project: yes 11 | patch: yes 12 | changes: no 13 | 14 | comment: 15 | layout: "reach, diff, flags, files" 16 | behavior: default 17 | require_changes: false # if true: only post the comment if coverage changes 18 | require_base: no # [yes :: must have a base report to post] 19 | require_head: yes # [yes :: must have a head report to post] 20 | branches: null # branch names that can post comment 21 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. See https://go.microsoft.com/fwlink/?linkid=2090316 for license information. 4 | # Original Dockerfile available at: https://github.com/microsoft/vscode-dev-containers/blob/master/containers/python-3-miniconda/.devcontainer/Dockerfile 5 | #------------------------------------------------------------------------------------------------------------- 6 | 7 | FROM continuumio/miniconda3 8 | 9 | # Avoid warnings by switching to noninteractive 10 | ENV DEBIAN_FRONTEND=noninteractive 11 | 12 | # This Dockerfile adds a non-root user with sudo access. Use the "remoteUser" 13 | # property in devcontainer.json to use it. On Linux, the container user's GID/UIDs 14 | # will be updated to match your local UID/GID (when using the dockerFile property). 15 | # See https://aka.ms/vscode-remote/containers/non-root-user for details. 16 | ARG USERNAME=vscode 17 | ARG USER_UID=1000 18 | ARG USER_GID=$USER_UID 19 | 20 | # Copy requirements.txt and requirements-dev.txt 21 | COPY requirements.txt requirements-dev.txt /tmp/conda-tmp/ 22 | 23 | # Configure apt and install packages 24 | RUN apt-get update \ 25 | && apt-get -y install --no-install-recommends apt-utils dialog 2>&1 \ 26 | # 27 | # Verify git, process tools, lsb-release (common in install instructions for CLIs) installed 28 | && apt-get -y install git openssh-client less iproute2 procps iproute2 lsb-release \ 29 | # 30 | # Install pylint 31 | && /opt/conda/bin/pip install pylint \ 32 | # 33 | # Update environment. 34 | && /opt/conda/bin/pip install -r /tmp/conda-tmp/requirements.txt \ 35 | && /opt/conda/bin/pip install -r /tmp/conda-tmp/requirements-dev.txt \ 36 | # 37 | # Create a non-root user to use if preferred - see https://aka.ms/vscode-remote/containers/non-root-user. 38 | && groupadd --gid $USER_GID $USERNAME \ 39 | && useradd -s /bin/bash --uid $USER_UID --gid $USER_GID -m $USERNAME \ 40 | # [Optional] Add sudo support for the non-root user 41 | && apt-get install -y sudo \ 42 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME\ 43 | && chmod 0440 /etc/sudoers.d/$USERNAME \ 44 | # [Additional Customization] 45 | && apt-get install -y nano vim emacs \ 46 | # Clean up 47 | && apt-get autoremove -y \ 48 | && apt-get clean -y \ 49 | && rm -rf /var/lib/apt/lists/* 50 | 51 | # Switch back to dialog for any ad-hoc use of apt-get 52 | ENV DEBIAN_FRONTEND=dialog 53 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/vscode-remote/devcontainer.json or this file's README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.117.1/containers/python-3-miniconda 3 | { 4 | "name": "PyMC4 Development Container", 5 | "context": "..", 6 | "image": "registry.hub.docker.com/pymc/pymc4:devcontainer", 7 | // Set *default* container specific settings.json values on container create. 8 | "settings": { 9 | "terminal.integrated.shell.linux": "/bin/bash", 10 | "python.pythonPath": "/opt/conda/bin/python", 11 | "python.linting.enabled": true, 12 | "python.linting.pylintEnabled": true, 13 | "python.linting.pylintPath": "/opt/conda/bin/pylint" 14 | }, 15 | // Add the IDs of extensions you want installed when the container is created. 16 | "extensions": [ 17 | "ms-python.python" 18 | ], 19 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 20 | // "forwardPorts": [], 21 | // Use 'postCreateCommand' to run commands after the container is created. 22 | "postCreateCommand": "python setup.py develop" 23 | // Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/containers/non-root. 24 | // "remoteUser": "vscode" 25 | } 26 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: ['https://numfocus.org/donate'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | environment.yml 3 | testing-report.html 4 | pip-wheel-metadata/* 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # Unit test / coverage reports 34 | htmlcov/ 35 | .tox/ 36 | .coverage 37 | .coverage.* 38 | .cache 39 | nosetests.xml 40 | coverage.xml 41 | *.cover 42 | .hypothesis/ 43 | .pytest_cache/ 44 | 45 | # Sphinx documentation 46 | docs/_build/ 47 | 48 | # Jupyter Notebook 49 | .ipynb_checkpoints 50 | 51 | # pyenv 52 | .python-version 53 | 54 | # Environments 55 | .env 56 | .venv 57 | env/ 58 | pymc4-env/ 59 | venv/ 60 | pymc4-venv/ 61 | ENV/ 62 | env.bak/ 63 | venv.bak/ 64 | 65 | # mkdocs documentation 66 | /site 67 | 68 | # mypy 69 | .mypy_cache/ 70 | 71 | # Merge tool 72 | *.orig 73 | 74 | # VSCode 75 | .vscode/ 76 | 77 | # IntelliJ IDE 78 | .idea 79 | *.iml 80 | 81 | # Vim 82 | *.swp 83 | 84 | # OS generated files 85 | .DS_Store 86 | .DS_Store? 87 | ._* 88 | .Spotlight-V100 89 | .Trashes 90 | ehthumbs.db 91 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # PyMC4 Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting PyMC4 developer Christopher Fonnesbeck via email 59 | (chris.fonnesbeck@vanderbilt.edu) or phone (615-955-0380). Alternatively, you 60 | may also contact NumFOCUS Executive Director Leah Silen (512-222-5449), as PyMC4 61 | is a member of NumFOCUS and subscribes to their code of conduct as a 62 | precondition for continued membership. All complaints will be reviewed and 63 | investigated and will result in a response that is deemed necessary and 64 | appropriate to the circumstances. The project team is obligated to maintain 65 | confidentiality with regard to the reporter of an incident. Further details of 66 | specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PyMC4 2 | 3 | As a scientific community-driven software project, PyMC4 welcomes contributions 4 | from users. This document describes how users can contribute to the PyMC4 5 | project, and what workflow to follow to contribute as quickly and seamlessly as 6 | possible. 7 | 8 | There are four main ways of contributing to PyMC4 (in descending order of 9 | difficulty or scope): 10 | 11 | 1. **Adding new or improved functionality** to the codebase: these contributions 12 | directly extend PyMC4's functionality. 13 | 2. **Fixing outstanding issues or bugs** with the codebase: these range from 14 | low-level software bugs to high-level design problems. 15 | 3. **Contributing to the documentation or examples**: improving the 16 | documentation is _just as important_ as improving the codebase itself. 17 | 4. **Submitting bug reports or feature requests** via the GitHub issue tracker: 18 | even something as simple as leaving a "thumbs up" reaction to issues that are 19 | relevant to you! 20 | 21 | The first three types of contributions involve [opening a pull 22 | request](#opening-a-pull-request), whereas the fourth involves [creating an 23 | issue](#creating-an-issue). 24 | 25 | Finally, it also helps us if you spread the word: reference the project from 26 | your blog and articles, link to it from your website, or simply star it in 27 | GitHub to say "I use it"! 28 | 29 | ## Creating an Issue 30 | 31 | > Creating your first GitHub issue? Check out [the official GitHub 32 | > documentation](https://help.github.com/articles/creating-an-issue/) on how to 33 | > do that! 34 | 35 | We appreciate being notified of problems with the existing PyMC4 codebase. We 36 | prefer that issues be filed the on [GitHub issue 37 | tracker](https://github.com/pymc-devs/pymc4/issues), rather than on social media 38 | or by direct email to the developers. 39 | 40 | Please check that your issue is not being currently addressed by other issues or 41 | pull requests by using the GitHub search tool. 42 | 43 | ## Opening a Pull Request 44 | 45 | While reporting issues is valuable, we welcome and encourage users to submit 46 | patches for new or existing issues via pull requests (a.k.a. "PRs"). This is 47 | especially the case for simple fixes, such as fixing typos or tweaking 48 | documentation, which do not require a heavy investment of time and attention. 49 | 50 | The preferred workflow for contributing to PyMC4 is to fork the [GitHub 51 | repository](https://github.com/pymc-devs/pymc4/), clone it to your local 52 | machine, and develop on a feature branch. 53 | 54 | ### Step-by-step instructions 55 | 56 | 1. Fork the [project repository](https://github.com/pymc-devs/pymc4/) by 57 | clicking on the `Fork` button near the top right of the main repository page. 58 | This creates a copy of the code under your GitHub user account. 59 | 60 | 2. Clone your fork of the PyMC4 repo from your GitHub account to your local 61 | computer, and add the base repository as an upstream remote. 62 | 63 | ```bash 64 | $ git clone git@github.com:/pymc4.git 65 | $ cd pymc4 66 | $ git remote add upstream git@github.com:pymc-devs/pymc4.git 67 | ``` 68 | 69 | 3. Check out a `feature` branch to contain your edits. 70 | 71 | ```bash 72 | $ git checkout -b my-feature 73 | ``` 74 | 75 | Always create a new `feature` branch. It's best practice to never work on the 76 | `master` branch of any repository. 77 | 78 | 4. To set up your development environment, you can use the `make` command line 79 | utility. Depending on whether you want to develop using a [Python virtual 80 | environment](https://docs.python.org/3/library/venv.html), [conda 81 | environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) 82 | or [Docker 83 | image](https://docs.docker.com/develop/develop-images/image_management/), you 84 | can run one of `make venv`, `make conda` or `make docker`, respectively, and 85 | follow the resulting instructions (in blue) after the setup is finished. 86 | 87 | 5. Develop the feature on your feature branch. This is the fun part! 88 | 89 | 6. Once you are done developing, run `make black` and `make check` from the root 90 | `pymc4/` directory to blackify, lint and test the codebase. If you like, you 91 | can run `make lint` and `make test` to lint and test separately. Work through 92 | and fix any lint errors or failing tests. Don't hesitate to reach out to us 93 | through the [GitHub issue tracker](https://github.com/pymc-devs/pymc4/issues) 94 | if you run into problems! 95 | 96 | 7. Add changed files using `git add` and then `git commit`: 97 | 98 | ```bash 99 | $ git add your_modified_file.py 100 | $ git commit 101 | ``` 102 | 103 | to record your changes locally. 104 | 105 | Then push the changes to your GitHub account with: 106 | 107 | ```bash 108 | $ git push -u origin my-feature 109 | ``` 110 | 111 | 8. Go to the GitHub web page of your fork of the PyMC4 repo. Click the `Pull 112 | request` button to send your changes to the project's maintainers for review. 113 | This will notify the PyMC4 developers. 114 | 115 | ## Code of Conduct 116 | 117 | The PyMC4 project abides by the [Contributor 118 | Covenant](https://www.contributor-covenant.org/). You can find our code of 119 | conduct 120 | [here](https://github.com/pymc-devs/pymc4/blob/master/CODE_OF_CONDUCT.md). 121 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude *.md 2 | exclude *.toml 3 | exclude *.yml 4 | exclude Makefile 5 | recursive-exclude notebooks *.ipynb 6 | recursive-exclude scripts * 7 | recursive-exclude tests * 8 | recursive-exclude notebooks/context_design * 9 | 10 | 11 | recursive-include notebooks *.tsv 12 | recursive-include notebooks *.csv 13 | include *.txt 14 | include *.md 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help venv conda docker docstyle format style types black test lint check notebooks 2 | .DEFAULT_GOAL = help 3 | 4 | PYTHON = python 5 | PIP = pip 6 | CONDA = conda 7 | SHELL = bash 8 | 9 | help: 10 | @printf "Usage:\n" 11 | @grep -E '^[a-zA-Z_-]+:.*?# .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?# "}; {printf "\033[1;34mmake %-10s\033[0m%s\n", $$1, $$2}' 12 | 13 | conda: # Set up a conda environment for development. 14 | @printf "Creating conda environment...\n" 15 | ${CONDA} create --yes --name pymc4-env python=3.6 16 | ( \ 17 | ${CONDA} activate pymc4-env; \ 18 | ${PIP} install -U pip; \ 19 | ${PIP} install -r requirements.txt; \ 20 | ${PIP} install -r requirements-dev.txt; \ 21 | ${CONDA} deactivate; \ 22 | ) 23 | @printf "\n\nConda environment created! \033[1;34mRun \`conda activate pymc4-env\` to activate it.\033[0m\n\n\n" 24 | 25 | venv: # Set up a Python virtual environment for development. 26 | @printf "Creating Python virtual environment...\n" 27 | rm -rf pymc4-venv 28 | ${PYTHON} -m venv pymc4-venv 29 | ( \ 30 | source pymc4-venv/bin/activate; \ 31 | ${PIP} install -U pip; \ 32 | ${PIP} install -r requirements.txt; \ 33 | ${PIP} install -r requirements-dev.txt; \ 34 | deactivate; \ 35 | ) 36 | @printf "\n\nVirtual environment created! \033[1;34mRun \`source pymc4-venv/bin/activate\` to activate it.\033[0m\n\n\n" 37 | 38 | docker: # Set up a Docker image for development. 39 | @printf "Creating Docker image...\n" 40 | ${SHELL} ./scripts/container.sh --build 41 | 42 | docstyle: 43 | @printf "Checking documentation with pydocstyle...\n" 44 | pydocstyle pymc4/ 45 | @printf "\033[1;34mPydocstyle passes!\033[0m\n\n" 46 | 47 | format: 48 | @printf "Checking code style with black...\n" 49 | black --check --diff pymc4 tests 50 | @printf "\033[1;34mBlack passes!\033[0m\n\n" 51 | 52 | style: 53 | @printf "Checking code style with pylint...\n" 54 | pylint pymc4/ 55 | @printf "\033[1;34mPylint passes!\033[0m\n\n" 56 | 57 | types: 58 | @printf "Checking code type signatures with mypy...\n" 59 | python -m mypy --ignore-missing-imports pymc4/ 60 | @printf "\033[1;34mMypy passes!\033[0m\n\n" 61 | 62 | black: # Format code in-place using black. 63 | black pymc4/ tests/ 64 | 65 | notebooks: notebooks/* 66 | jupyter nbconvert --config nbconfig.py --execute --ExecutePreprocessor.kernel_name="pymc4-dev" --ExecutePreprocessor.timeout=1200 --to html 67 | rm notebooks/*.html 68 | 69 | test: # Test code using pytest. 70 | pytest -v pymc4 tests --doctest-modules --html=testing-report.html --self-contained-html 71 | 72 | lint: docstyle format style types # Lint code using pydocstyle, black, pylint and mypy. 73 | 74 | check: lint test # Both lint and test code. Runs `make lint` followed by `make test`. 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NOTICE: Official development of this project has ceased, and it is no longer intended to become the next major version of PyMC. Ongoing development will continue on the [main PyMC repository](https://github.com/pymc-devs/pymc). 2 | 3 | See [the announcement](https://pymc-devs.medium.com/the-future-of-pymc3-or-theano-is-dead-long-live-theano-d8005f8a0e9b) for more details on the future of PyMC and Theano. 4 | 5 | [![Build Status](https://dev.azure.com/pymc-devs/pymc4/_apis/build/status/pymc-devs.pymc4?branchName=master)](https://dev.azure.com/pymc-devs/pymc4/_build/latest?definitionId=1&branchName=master) 6 | [![Coverage Status](https://codecov.io/gh/pymc-devs/pymc4/branch/master/graph/badge.svg)](https://codecov.io/gh/pymc-devs/pymc4) 7 | 8 | High-level interface to TensorFlow Probability. Do not use for anything serious. 9 | 10 | What works? 11 | 12 | * Build most models you could build with PyMC3 13 | * Sample using NUTS, all in TF, fully vectorized across chains (multiple chains basically become free) 14 | * Automatic transforms of model to the real line 15 | * Prior and posterior predictive sampling 16 | * Deterministic variables 17 | * Trace that can be passed to ArviZ 18 | 19 | However, expect things to break or change without warning. 20 | 21 | See here for an example: https://github.com/pymc-devs/pymc4/blob/master/notebooks/radon_hierarchical.ipynb 22 | See here for the design document: https://github.com/pymc-devs/pymc4/blob/master/notebooks/pymc4_design_guide.ipynb 23 | 24 | ## Develop 25 | 26 | One easy way of developing on PyMC4 is to take advantage of the development containers! 27 | Using pre-built development environments allows you to develop on PyMC4 without needing to set up locally. 28 | 29 | To use the dev containers, you will need to have Docker and VSCode running locally on your machine, 30 | and will need the VSCode Remote extension (`ms-vscode-remote.vscode-remote-extensionpack`). 31 | 32 | Once you have done that, to develop on PyMC4, on GitHub: 33 | 34 | 1. Make a fork of the repository 35 | 2. Create a new branch inside your fork 36 | 3. Copy the branch URL 37 | 38 | Now, in VSCode: 39 | 40 | 1. In the command palette, search for "Remote-Containers: Open Repository in Container...". 41 | 2. Paste in the branch URL 42 | 3. If prompted, create it in a "Unique Volume". 43 | 44 | Happy hacking away! 45 | Because the repo will be cloned into an ephemeral repo, 46 | **don't forget to commit your changes and push them to your branch!** 47 | Then follow the usual pull request workflow back into PyMC4. 48 | 49 | We hope you enjoy the time saved on setting up your development environment! 50 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | - master 3 | 4 | pool: 5 | vmImage: "ubuntu-latest" 6 | strategy: 7 | matrix: 8 | Python36: 9 | python.version: "3.6" 10 | Python37: 11 | python.version: "3.7" 12 | 13 | steps: 14 | - task: UsePythonVersion@0 15 | inputs: 16 | versionSpec: "$(python.version)" 17 | displayName: "Use Python $(python.version)" 18 | 19 | - script: | 20 | python -m pip install --upgrade pip 21 | python -m pip install . 22 | python -m pip install -r requirements-dev.txt 23 | python -m pip freeze 24 | displayName: "Install dependencies" 25 | 26 | # - script: | 27 | # python -m pydocstyle --convention=numpy pymc4/ tests/ 28 | # displayName: "pydocstyle" 29 | 30 | - script: | 31 | python -m black --check pymc4/ tests/ 32 | displayName: "black" 33 | 34 | - script: | 35 | python -m mypy --ignore-missing-imports pymc4/ 36 | displayName: "mypy" 37 | 38 | - script: | 39 | python -m pylint pymc4/ tests/ 40 | displayName: "pylint" 41 | 42 | - script: | 43 | python -m pip install pytest-azurepipelines 44 | python -m pytest -xv --cov pymc4 --junitxml=junit/test-results.xml --cov-report xml --cov-report term --cov-report html . 45 | displayName: "pytest" 46 | 47 | - script: | 48 | python -m pip install ipykernel nbconvert 49 | python -m ipykernel install --name pymc4-dev --user 50 | make notebooks 51 | displayName: "notebooks" 52 | 53 | - script: | 54 | bash <(curl -s https://codecov.io/bash) -n "$(NAME)" -C $BUILD_SOURCEVERSION 55 | displayName: "Publish test results to codecov for Python $(python.version)" 56 | condition: succeededOrFailed() 57 | -------------------------------------------------------------------------------- /make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file to centralize test tasks 4 | 5 | SET ROOT_DIR="%~dp0" 6 | CALL :joinpath %ROOT_DIR% pymc4 7 | SET PACKAGE_DIR=%RESULT% 8 | CALL :joinpath %ROOT_DIR% tests 9 | SET TESTS_DIR=%RESULT% 10 | SET PYTHON=python 11 | SET PIP=pip 12 | SET CONDA=conda 13 | 14 | if "%1" == "" ( 15 | call :help 16 | exit /B 0 17 | ) 18 | 19 | call :%1 20 | EXIT /B %ERRORLEVEL% 21 | 22 | 23 | :help 24 | echo.Usage: 25 | echo. make help: display help 26 | echo. make venv: create python virtual environment 27 | echo. make conda: create conda environment 28 | echo. make docker: create docker image from application 29 | echo. make docstyle: check package documentation with pydocstyle 30 | echo. make format: check package formating with black 31 | echo. make style: check package style with pylint 32 | echo. make types: check typehints with mypy 33 | echo. make black: apply black formating to the entire package and tests 34 | echo. make test: run tests with pytest 35 | echo. make lint: run all docstyle, format, style and docscheck checks 36 | echo. make check: run lint, test 37 | echo. make notebooks: execute jupyter notebooks 38 | EXIT /B 0 39 | 40 | 41 | :joinpath 42 | set Path1=%~1 43 | set Path2=%~2 44 | if {%Path1:~-1,1%}=={\} (set Result="%Path1%%Path2%") else (set Result="%Path1%\%Path2%") 45 | EXIT /B %ERRORLEVEL% 46 | 47 | 48 | :conda 49 | echo.Creating conda environment... 50 | %CONDA% create --yes --name pymc4-env python=3.6 51 | %CONDA% activate pymc4-env 52 | %PIP% install -U pip 53 | %PIP% install -r requirements.txt 54 | %PIP% install -r requirements-dev.txt 55 | %CONDA% deactivate 56 | if %ERRORLEVEL%==0 ( 57 | echo.Conda environment created! Run conda activate pymc4-env to activate it. 58 | ) else ( 59 | echo.Failed to create conda environment. 60 | ) 61 | EXIT /B %ERRORLEVEL% 62 | 63 | 64 | :venv 65 | echo.Creating Python virtual environment... 66 | rmdir /s /q pymc4-venv 67 | %PYTHON% -m venv pymc4-venv 68 | pymc4-venv\Scripts\activate 69 | %PIP% install -U pip 70 | %PIP% install -r requirements.txt 71 | %PIP% install -r requirements-dev.txt 72 | deactivate 73 | if %ERRORLEVEL%==0 ( 74 | echo.Virtual environment created! Run pymc4-venv\Scripts\activate to activate it." 75 | ) else ( 76 | echo.Failed to create virtual environment. 77 | ) 78 | EXIT /B %ERRORLEVEL% 79 | 80 | 81 | :docker 82 | echo.Creating Docker image... 83 | scripts\container --build 84 | if %ERRORLEVEL%==0 ( 85 | echo.Successfully built docker image. 86 | ) else ( 87 | echo.Failed to build docker image. 88 | ) 89 | EXIT /B %ERRORLEVEL% 90 | 91 | 92 | :docstyle 93 | echo.Checking documentation with pydocstyle... 94 | %PYTHON% -m pydocstyle %PACKAGE_DIR% 95 | if %ERRORLEVEL%==0 ( 96 | echo.Pydocstyle passes! 97 | ) else ( 98 | echo.Pydocstyle failed! 99 | ) 100 | EXIT /B %ERRORLEVEL% 101 | 102 | 103 | :format 104 | echo.Checking code format with black... 105 | %PYTHON% -m black --check --diff %PACKAGE_DIR% %TESTS_DIR% 106 | if %ERRORLEVEL%==0 ( 107 | echo.Black passes! 108 | ) else ( 109 | echo.Black failed! 110 | ) 111 | EXIT /B %ERRORLEVEL% 112 | 113 | 114 | :style 115 | echo.Checking style with pylint... 116 | %PYTHON% -m pylint %PACKAGE_DIR% 117 | if %ERRORLEVEL%==0 ( 118 | echo.Pylint passes! 119 | ) else ( 120 | echo.Pylint failed! 121 | ) 122 | EXIT /B %ERRORLEVEL% 123 | 124 | 125 | :types 126 | echo.Checking type hints with mypy... 127 | %PYTHON% -m mypy --ignore-missing-imports %PACKAGE_DIR% 128 | if %ERRORLEVEL%==0 ( 129 | echo.Mypy passes! 130 | ) else ( 131 | echo.Mypy failed! 132 | ) 133 | EXIT /B %ERRORLEVEL% 134 | 135 | 136 | :black 137 | %PYTHON% -m black %PACKAGE_DIR% %TESTS_DIR% 138 | EXIT /B %ERRORLEVEL% 139 | 140 | 141 | :notebooks 142 | jupyter nbconvert --config nbconfig.py --execute --ExecutePreprocessor.kernel_name="pymc4-dev" --ExecutePreprocessor.timeout=1200 --to html 143 | del %ROOT_DIR%notebooks\*.html 144 | EXIT /B %ERRORLEVEL% 145 | 146 | 147 | :test 148 | %PYTHON% -m pytest -v %PACKAGE_DIR% %TESTS_DIR% --doctest-modules --html=testing-report.html --self-contained-html 149 | EXIT /B %ERRORLEVEL% 150 | 151 | 152 | :lint 153 | call :docstyle && call :format && call :style && call :types 154 | EXIT /B %ERRORLEVEL% 155 | 156 | 157 | :check 158 | call :lint && call :test 159 | EXIT /B %ERRORLEVEL% 160 | -------------------------------------------------------------------------------- /nbconfig.py: -------------------------------------------------------------------------------- 1 | """We have notebooks placed under CI execution 2 | as an integration test of sorts for the notebook docs. 3 | 4 | Avoid placing long-running notebooks here, 5 | as they will clog up the CI. 6 | Instead, prioritize the smaller notebooks that newcomers might encounter at first. 7 | That way, we can ensure that any breaking API changes 8 | that may affect how newcomers interact with the library 9 | can be caught as soon as possible. 10 | """ 11 | 12 | c = get_config() 13 | 14 | c.NbConvertApp.notebooks = [ 15 | "notebooks/baseball.ipynb", 16 | "notebooks/basic-usage.ipynb", 17 | "notebooks/rugby_analytics.ipynb", 18 | # will reinstate in a later PR 19 | # "notebooks/radon_hierarchical.ipynb", 20 | ] 21 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc4/10b5854219786dac3337ad2127683c62d891e1ea/notebooks/__init__.py -------------------------------------------------------------------------------- /notebooks/context_design/README.md: -------------------------------------------------------------------------------- 1 | # Context design 2 | These notebooks served as documentation and demonstrations of an earlier 3 | design of PyMC4. While they won't run they are being left here for reference. 4 | 5 | -------------------------------------------------------------------------------- /notebooks/data/efron-morris-75-data.tsv: -------------------------------------------------------------------------------- 1 | FirstName LastName At-Bats Hits BattingAverage RemainingAt-Bats RemainingAverage SeasonAt-Bats SeasonHits SeasonAverage 2 | Roberto Clemente 45 18 0.4 367 0.346 412 145 0.352 3 | Frank Robinson 45 17 0.378 426 0.2981 471 144 0.306 4 | Frank Howard 45 16 0.356 521 0.2764 566 160 0.283 5 | Jay Johnstone 45 15 0.333 275 0.2218 320 76 0.238 6 | Ken Berry 45 14 0.311 418 0.2727 463 128 0.276 7 | Jim Spencer 45 14 0.311 466 0.2704 511 140 0.274 8 | Don Kessinger 45 13 0.289 586 0.2645 631 168 0.266 9 | Luis Alvarado 45 12 0.267 138 0.2101 183 41 0.224 10 | Ron Santo 45 11 0.244 510 0.2686 555 148 0.267 11 | Ron Swaboda 45 11 0.244 200 0.23 245 57 0.233 12 | Rico Petrocelli 45 10 0.222 538 0.2639 583 152 0.261 13 | Ellie Rodriguez 45 10 0.222 186 0.2258 231 52 0.225 14 | George Scott 45 10 0.222 435 0.3034 480 142 0.296 15 | Del Unser 45 10 0.222 277 0.2635 322 83 0.258 16 | Billy Williams 45 10 0.222 591 0.3299 636 205 0.251 17 | Bert Campaneris 45 9 0.2 558 0.2849 603 168 0.279 18 | Thurman Munson 45 8 0.178 408 0.3162 453 137 0.302 19 | Max Alvis 45 7 0.156 70 0.2 115 21 0.183 20 | -------------------------------------------------------------------------------- /notebooks/data/rugby.csv: -------------------------------------------------------------------------------- 1 | ,home_team,away_team,home_score,away_score,year 2 | 0,Wales,Italy,23,15,2014 3 | 1,France,England,26,24,2014 4 | 2,Ireland,Scotland,28,6,2014 5 | 3,Ireland,Wales,26,3,2014 6 | 4,Scotland,England,0,20,2014 7 | 5,France,Italy,30,10,2014 8 | 6,Wales,France,27,6,2014 9 | 7,Italy,Scotland,20,21,2014 10 | 8,England,Ireland,13,10,2014 11 | 9,Ireland,Italy,46,7,2014 12 | 10,Scotland,France,17,19,2014 13 | 11,England,Wales,29,18,2014 14 | 12,Italy,England,11,52,2014 15 | 13,Wales,Scotland,51,3,2014 16 | 14,France,Ireland,20,22,2014 17 | 15,Wales,England,16,21,2015 18 | 16,Italy,Ireland,3,26,2015 19 | 17,France,Scotland,15,8,2015 20 | 18,England,Italy,47,17,2015 21 | 19,Ireland,France,18,11,2015 22 | 20,Scotland,Wales,23,26,2015 23 | 21,Scotland,Italy,19,22,2015 24 | 22,France,Wales,13,20,2015 25 | 23,Ireland,England,19,9,2015 26 | 24,Wales,Ireland,23,16,2015 27 | 25,England,Scotland,25,13,2015 28 | 26,Italy,France,0,29,2015 29 | 27,Italy,Wales,20,61,2015 30 | 28,Scotland,Ireland,10,40,2015 31 | 29,England,France,55,35,2015 32 | 30,France,Italy,23,21,2016 33 | 31,Scotland,England,9,15,2016 34 | 32,Ireland,Wales,16,16,2016 35 | 33,France,Ireland,10,9,2016 36 | 34,Wales,Scotland,27,23,2016 37 | 35,Italy,England,9,40,2016 38 | 36,Wales,France,19,10,2016 39 | 37,Italy,Scotland,20,36,2016 40 | 38,England,Ireland,21,10,2016 41 | 39,Ireland,Italy,58,15,2016 42 | 40,England,Wales,25,21,2016 43 | 41,Scotland,France,29,18,2016 44 | 42,Wales,Italy,67,14,2016 45 | 43,Ireland,Scotland,35,25,2016 46 | 44,France,England,21,31,2016 47 | 45,Scotland,Ireland,27,22,2017 48 | 46,England,France,19,16,2017 49 | 47,Italy,Wales,7,33,2017 50 | 48,Italy,Ireland,10,63,2017 51 | 49,Wales,England,16,21,2017 52 | 50,France,Scotland,22,16,2017 53 | 51,Scotland,Wales,29,13,2017 54 | 52,Ireland,France,19,9,2017 55 | 53,England,Italy,36,15,2017 56 | 54,Wales,Ireland,22,9,2017 57 | 55,Italy,France,18,40,2017 58 | 56,England,Scotland,61,21,2017 59 | 57,Scotland,Italy,29,0,2017 60 | 58,France,Wales,20,18,2017 61 | 59,Ireland,England,13,9,2017 62 | -------------------------------------------------------------------------------- /notebooks/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def plot_samples(x, batched_samples, labels, names, ylim=None): 6 | if not isinstance(batched_samples, np.ndarray): 7 | batched_samples = np.asarray(batched_samples) 8 | n_samples = batched_samples.shape[0] 9 | if ylim is not None: 10 | ymin, ymax = ylim 11 | else: 12 | ymin, ymax = batched_samples.min() - 0.2, batched_samples.max() + 0.2 13 | fig, ax = plt.subplots(n_samples, 1, figsize=(14, n_samples * 3)) 14 | if isinstance(labels, (list, tuple)): 15 | labels = [np.asarray(label) for label in labels] 16 | else: 17 | labels = np.asarray(labels) 18 | for i in range(len(ax)): 19 | samples = batched_samples[i] 20 | axi = ax[i] 21 | if isinstance(labels, (list, tuple)): 22 | lab = names[0] + "=" + str(labels[0][i]) 23 | for l, name in zip(labels[1:], names[1:]): 24 | lab += ", " + name + "=" + str(l[i]) 25 | else: 26 | lab = names + "=" + str(labels[i]) 27 | for sample in samples: 28 | axi.plot(x, sample, label=lab) 29 | axi.set_ylim(ymin=ymin, ymax=ymax) 30 | axi.set_title(lab) 31 | plt.show() 32 | 33 | 34 | def plot_cov_matrix(k, X, labels, names, vlim=None, cmap="inferno", interpolation="none"): 35 | cov = k(X, X) 36 | cov = np.asarray(cov) 37 | if vlim is not None: 38 | vmin, vmax = vlim 39 | else: 40 | vmin, vmax = cov.min(), cov.max() 41 | if isinstance(labels, (list, tuple)): 42 | labels = [np.asarray(label) for label in labels] 43 | n_samples = len(labels[0]) 44 | else: 45 | labels = np.asarray(labels) 46 | n_samples = 1 47 | fig, ax = plt.subplots(1, n_samples, figsize=(5 * n_samples, 4)) 48 | if not isinstance(ax, np.ndarray): 49 | ax = np.asarray([ax]) 50 | for i in range(ax.shape[0]): 51 | axi = ax[i] 52 | if isinstance(labels, (list, tuple)): 53 | lab = names[0] + "=" + str(labels[0][i]) 54 | for l, name in zip(labels[1:], names[1:]): 55 | lab += ", " + name + "=" + str(l[i]) 56 | else: 57 | lab = names + "=" + str(labels[i]) 58 | m = axi.imshow(cov[i], cmap=cmap, interpolation=interpolation) 59 | m.set_clim(vmin=vmin, vmax=vmax) 60 | plt.colorbar(m, ax=axi) 61 | axi.grid(False) 62 | axi.set_title(lab) 63 | plt.show() 64 | -------------------------------------------------------------------------------- /pymc4/__init__.py: -------------------------------------------------------------------------------- 1 | """PyMC4.""" 2 | from . import utils 3 | from .coroutine_model import Model, model 4 | from .scopes import name_scope, variable_name 5 | from . import coroutine_model 6 | from . import distributions 7 | from . import flow 8 | from .flow import ( 9 | evaluate_model_transformed, 10 | evaluate_model, 11 | evaluate_model_posterior_predictive, 12 | evaluate_meta_model, 13 | evaluate_meta_posterior_predictive_model, 14 | ) 15 | from . import inference 16 | from .distributions import * 17 | from .forward_sampling import sample_prior_predictive, sample_posterior_predictive 18 | from .inference.sampling import sample 19 | from .mcmc.samplers import * 20 | from . import gp 21 | from . import mcmc 22 | from .variational import * 23 | 24 | 25 | __version__ = "4.0a2" 26 | -------------------------------------------------------------------------------- /pymc4/coroutine_model.py: -------------------------------------------------------------------------------- 1 | """Main model functionality.""" 2 | import functools 3 | import types 4 | from typing import Optional, Union 5 | 6 | from pymc4.scopes import name_scope 7 | from pymc4.utils import biwrap, NameParts 8 | 9 | 10 | # we need that indicator to distinguish between explicit None and no value provided case 11 | _no_name_provided = object() 12 | 13 | 14 | @biwrap 15 | def model(genfn, *, name=_no_name_provided, keep_auxiliary=True, keep_return=True, method=False): 16 | """Flexibly wrap a generator function into a Model template.""" 17 | if method: 18 | # What is this block for? 19 | template = ModelTemplate( 20 | genfn, name=name, keep_auxiliary=keep_auxiliary, keep_return=keep_return 21 | ) 22 | 23 | @functools.wraps(genfn) 24 | def wrapped(*args, **kwargs): 25 | return template(*args, **kwargs) 26 | 27 | return wrapped 28 | else: 29 | template = ModelTemplate( 30 | genfn, name=name, keep_auxiliary=keep_auxiliary, keep_return=keep_return 31 | ) 32 | return template 33 | 34 | 35 | def get_name(default, base_fn, name) -> Optional[str]: 36 | """Parse the name of an rv from arguments. 37 | 38 | Parameters 39 | ---------- 40 | default : _no_name_provided, str, or None 41 | Default to fall back to if it is not _no_name_provided 42 | base_fn : callable 43 | In case the random variable has a name attribute 44 | and defualt is _no_name_provided, use that 45 | name : _no_name_provided, str, or None 46 | Provided argument 47 | 48 | Returns 49 | ------- 50 | str or None 51 | """ 52 | if name is _no_name_provided: 53 | if default is not _no_name_provided: 54 | name = default 55 | elif hasattr(base_fn, "name"): 56 | name = getattr(base_fn, "name") 57 | elif hasattr(base_fn, "__name__"): 58 | name = base_fn.__name__ 59 | return name 60 | 61 | 62 | class ModelTemplate: 63 | """Model Template -- generative model with metadata. 64 | 65 | ModelTemplate is a callable object that represents a generative process. A generative process samples 66 | from prior distributions and allows them to interact in arbitrarily-complex, user-defined ways. 67 | 68 | Parameters 69 | ---------- 70 | template : callable 71 | Generative process, that accepts any arguments as conditioners and returns realizations if any. 72 | keep_auxiliary : bool 73 | Generative process may require some auxiliary variables to be created, but they are probably will not be used 74 | anywhere else. In that case it is useful to tell PyMC4 engine that we can get rid of auxiliary variables 75 | as long as they are not needed any more. 76 | keep_return : bool 77 | The return value of the model will be recorded 78 | """ 79 | 80 | def __init__(self, template, *, name=None, keep_auxiliary=True, keep_return=True): 81 | self.template = template 82 | self.name = name 83 | self.keep_auxiliary = keep_auxiliary 84 | self.keep_return = keep_return 85 | 86 | def __call__( 87 | self, *args, name=_no_name_provided, keep_auxiliary=None, keep_return=None, **kwargs 88 | ): 89 | """ 90 | Evaluate the model. 91 | 92 | Model evaluation usually comes with :code:`yield` keyword, see Examples below 93 | 94 | Parameters 95 | ---------- 96 | name : str 97 | The desired name for the model, by default, it is inferred from the model declaration context, 98 | but can be used just once 99 | keep_auxiliary : bool 100 | Whether to override the default variable for `keep_auxiliary` 101 | keep_return: bool 102 | Whether to override the default variable for `keep_return` 103 | args : tuple 104 | positional conditioners for generative process 105 | kwargs : dict 106 | keyword conditioners for the generative process 107 | 108 | Returns 109 | ------- 110 | Model 111 | The conditioned generative process, for which we can obtain generator (generative process) with :code:`iter` 112 | 113 | Examples 114 | -------- 115 | >>> import pymc4 as pm 116 | >>> from pymc4 import distributions as dist 117 | 118 | >>> @pm.model # keep_return is True by default 119 | ... def nested_model(cond): 120 | ... norm = yield dist.Normal("n", cond, 1) 121 | ... return norm 122 | 123 | >>> @pm.model 124 | ... def main_model(): 125 | ... norm = yield dist.Normal("n", 0, 1) 126 | ... result = yield nested_model(norm, name="a") 127 | ... return result 128 | >>> ret, state = pm.evaluate_model(main_model()) 129 | >>> print(sorted(state.untransformed_values)) 130 | ['main_model/a/n', 'main_model/n'] 131 | 132 | The model return values are stored in ``state.deterministics_values`` 133 | 134 | >>> print(sorted(list(state.deterministics_values))) 135 | ['main_model', 'main_model/a'] 136 | 137 | Setting :code`keep_return=False` for the nested model we can remove 138 | ``'main_model/a'`` from output state's deterministics 139 | 140 | >>> @pm.model 141 | ... def main_model(): 142 | ... norm = yield dist.Normal("n", 0, 1) 143 | ... result = yield nested_model(norm, name="a", keep_return=False) 144 | ... return result 145 | >>> ret, state = pm.evaluate_model(main_model()) 146 | >>> print(sorted(state.deterministics_values)) 147 | ['main_model'] 148 | 149 | We can also observe some variables setting :code:`observed=True` in a distribution 150 | 151 | >>> @pm.model # keep_return is True by default 152 | ... def main_model(): 153 | ... norm = yield dist.Normal("n", 0, 1, observed=0.) 154 | ... result = yield nested_model(norm, name="a") 155 | ... return result 156 | >>> ret, state = pm.evaluate_model(main_model()) 157 | >>> print(sorted(state.untransformed_values)) 158 | ['main_model/a/n'] 159 | >>> print(sorted(state.observed_values)) 160 | ['main_model/n'] 161 | """ 162 | genfn = functools.partial(self.template, *args, **kwargs) 163 | name = get_name(self.name, self.template, name) 164 | if name is not None and not NameParts.is_valid_untransformed_name(name): 165 | # throw an informative message to fix a name 166 | raise ValueError(NameParts.UNTRANSFORMED_NAME_ERROR_MESSAGE) 167 | if keep_auxiliary is None: 168 | keep_auxiliary = self.keep_auxiliary 169 | if keep_return is None: 170 | keep_return = self.keep_return 171 | 172 | return Model(genfn, name=name, keep_auxiliary=keep_auxiliary, keep_return=keep_return) 173 | 174 | 175 | def unpack(arg): 176 | """Convert an argument into a generator or a value.""" 177 | if isinstance(arg, (Model, types.GeneratorType)): 178 | return (yield arg) 179 | else: 180 | return arg 181 | 182 | 183 | class Model: 184 | """Base coroutine object. 185 | 186 | Supports iteration over random variables via `.control_flow`. 187 | """ 188 | 189 | # this is gonna be used for generator-like objects, 190 | # prohibit modification of this dict wrapping it into a MappingProxy 191 | default_model_info = types.MappingProxyType( 192 | dict(keep_auxiliary=True, keep_return=False, scope=name_scope(None), name=None) 193 | ) 194 | 195 | @staticmethod 196 | def validate_name(name: Optional[Union[int, str]]) -> Optional[str]: 197 | """Validate the type of the name argument.""" 198 | if name is not None and not isinstance(name, (int, str)): 199 | raise ValueError("name should be either `str` or `int`, got type {}".format(type(name))) 200 | elif name is None: 201 | return None 202 | else: 203 | return str(name) 204 | 205 | def __init__(self, genfn, *, name=None, keep_auxiliary=True, keep_return=True): 206 | self.genfn = genfn 207 | self.name = self.validate_name(name) 208 | self.model_info = dict( 209 | keep_auxiliary=keep_auxiliary, 210 | keep_return=keep_return, 211 | scope=name_scope(self.name), 212 | name=self.name, 213 | ) 214 | 215 | def control_flow(self): 216 | """Iterate over the random variables in the model.""" 217 | return (yield from self.genfn()) 218 | -------------------------------------------------------------------------------- /pymc4/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous import * 2 | from .discrete import * 3 | from .multivariate import * 4 | from .timeseries import * 5 | from .distribution import Potential, Deterministic 6 | from .mixture import Mixture 7 | from . import transforms 8 | from .mixture import * 9 | from .state_functions import * 10 | -------------------------------------------------------------------------------- /pymc4/distributions/batchstack.py: -------------------------------------------------------------------------------- 1 | """The BatchStacker distribution class.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from typing import Optional, Union, Tuple 8 | 9 | # Dependency imports 10 | import numpy as np 11 | 12 | # import tensorflow.compat.v2 as tf 13 | import tensorflow as tf 14 | from tensorflow_probability import distributions as tfd 15 | 16 | from tensorflow_probability.python.distributions import distribution as distribution_lib 17 | from tensorflow_probability.python.distributions import kullback_leibler 18 | from tensorflow_probability.python.internal import assert_util 19 | from tensorflow_probability.python.internal import distribution_util 20 | from tensorflow_probability.python.internal import prefer_static 21 | from tensorflow_probability.python.internal import tensorshape_util 22 | 23 | 24 | def _make_summary_statistic(attr): # noqa 25 | """Build functions that compute summary statistics, eg, mean, stddev, mode.""" 26 | 27 | def _fn(self, **kwargs): 28 | """Implement summary statistic, eg, mean, stddev, mode.""" 29 | x = getattr(self.distribution, attr)(**kwargs) 30 | shape = prefer_static.concat( 31 | [ 32 | prefer_static.ones( 33 | prefer_static.rank_from_shape(self.batch_stack), 34 | dtype=self.batch_stack.dtype, 35 | ), 36 | self.distribution.batch_shape_tensor(), 37 | self.distribution.event_shape_tensor(), 38 | ], 39 | axis=0, 40 | ) 41 | x = tf.reshape(x, shape=shape) 42 | shape = prefer_static.concat( 43 | [ 44 | self.batch_stack, 45 | self.distribution.batch_shape_tensor(), 46 | self.distribution.event_shape_tensor(), 47 | ], 48 | axis=0, 49 | ) 50 | return tf.broadcast_to(x, shape) 51 | 52 | return _fn 53 | 54 | 55 | class BatchStacker(distribution_lib.Distribution): 56 | r""" 57 | BatchStacker distribution via independent draws. 58 | 59 | This distribution is useful for stacking collections of independent, 60 | identical draws. It is otherwise identical to the input distribution. 61 | 62 | The probability function is, 63 | 64 | .. math:: 65 | p(x) = prod{ p(x[i]) : i = 0, ..., (n - 1) } 66 | 67 | Examples 68 | -------- 69 | Example 1: Five scalar draws. 70 | 71 | >>> from tensorflow_probability import distributions as tfd 72 | >>> s = BatchStacker( 73 | ... tfd.Normal(loc=0, scale=1), 74 | ... batch_stack=5) 75 | >>> x = s.sample() 76 | >>> x.shape.as_list() 77 | [5] 78 | >>> lp = s.log_prob(x) 79 | >>> lp.shape.as_list() 80 | [5] 81 | 82 | Example 2: `[5, 4]`-draws of a bivariate Normal. 83 | 84 | >>> s = BatchStacker( 85 | ... tfd.Independent(tfd.Normal(loc=tf.zeros([3, 2]), scale=1), 86 | ... reinterpreted_batch_ndims=1), 87 | ... batch_stack=[5, 4]) 88 | >>> x = s.sample([6, 1]) 89 | >>> x.shape.as_list() 90 | [6, 1, 5, 4, 3, 2] 91 | >>> lp = s.log_prob(x) 92 | >>> lp.shape.as_list() 93 | [6, 1, 5, 4, 3] 94 | """ 95 | 96 | def __init__( 97 | self, 98 | distribution: tfd.Distribution, 99 | batch_stack=Union[int, Tuple[int, ...], tf.Tensor], 100 | validate_args: bool = False, 101 | name: Optional[str] = None, 102 | ): 103 | r""" 104 | Construct the `BatchStacker` distribution. 105 | 106 | Parameters 107 | ---------- 108 | distribution : tfd.Distribution 109 | The base distribution instance to transform. Typically an instance of 110 | ``tensorflow_probability.distributions.Distribution``. 111 | batch_stack : Union[int, Tuple[int, ...], tf.Tensor] 112 | Shape of the stack of distributions. To be more precise, ``distribution`` has 113 | its underlying ``batch_shape``, what ``batch_stack`` does effectively is to 114 | add independent replications of ``distribution`` as extra ``batch_shape`` axes, 115 | to the left of the original ``batch_shape``. The resulting batch shape is 116 | ``tf.TensorShape(batch_stack) + distribution.batch_shape``. 117 | validate_args : bool 118 | Whether to validate input with asserts. If ``validate_args`` is ``False``, 119 | and the inputs are invalid, correct behavior is not guaranteed. 120 | name : Optional[str] 121 | The name for ops managed by the distribution. If ``None``, defaults to 122 | ``"BatchStacker" + distribution.name``. 123 | """ 124 | parameters = dict(locals()) 125 | name = name or "BatchStacker" + distribution.name 126 | self._distribution = distribution 127 | with tf.name_scope(name) as name: 128 | batch_stack = distribution_util.expand_to_vector( 129 | tf.convert_to_tensor(batch_stack, dtype_hint=tf.int32, name="batch_stack") 130 | ) 131 | self._batch_stack = batch_stack 132 | super(BatchStacker, self).__init__( 133 | dtype=self._distribution.dtype, 134 | reparameterization_type=self._distribution.reparameterization_type, 135 | validate_args=validate_args, 136 | allow_nan_stats=self._distribution.allow_nan_stats, 137 | parameters=parameters, 138 | name=name, 139 | ) 140 | 141 | @property 142 | def distribution(self): 143 | """Get the underlying tensorflow_probability.Distribution.""" 144 | return self._distribution 145 | 146 | @property 147 | def batch_stack(self): 148 | """Get this instance's batch_stack value.""" 149 | return self._batch_stack 150 | 151 | def _batch_shape_tensor(self): 152 | return prefer_static.concat( 153 | [ 154 | self.batch_stack, 155 | self.distribution.batch_shape_tensor(), 156 | ], 157 | axis=0, 158 | ) 159 | 160 | def _batch_shape(self): 161 | batch_stack = tf.TensorShape(tf.get_static_value(self.batch_stack)) 162 | if ( 163 | tensorshape_util.rank(batch_stack) is None 164 | or tensorshape_util.rank(self.distribution.event_shape) is None 165 | ): 166 | return tf.TensorShape(None) 167 | return tensorshape_util.concatenate(batch_stack, self.distribution.batch_shape) 168 | 169 | def _event_shape_tensor(self): 170 | return self.distribution.event_shape_tensor() 171 | 172 | def _event_shape(self): 173 | return self.distribution.event_shape 174 | 175 | def _sample_n(self, n, seed, **kwargs): 176 | return self.distribution.sample( 177 | prefer_static.concat([[n], self.batch_stack], axis=0), seed=seed, **kwargs 178 | ) 179 | 180 | def _log_prob(self, x, **kwargs): 181 | batch_ndims = prefer_static.rank_from_shape( 182 | self.distribution.batch_shape_tensor, self.distribution.batch_shape 183 | ) 184 | extra_batch_ndims = prefer_static.rank_from_shape(self.batch_stack) 185 | event_ndims = prefer_static.rank_from_shape( 186 | self.distribution.event_shape_tensor, self.distribution.event_shape 187 | ) 188 | ndims = prefer_static.rank(x) 189 | # (1) Expand x's dims. 190 | d = ndims - extra_batch_ndims - batch_ndims - event_ndims 191 | x = tf.reshape( 192 | x, 193 | shape=tf.pad( 194 | tf.shape(x), 195 | paddings=[[prefer_static.maximum(0, -d), 0]], 196 | constant_values=1, 197 | ), 198 | ) 199 | # (2) Compute x's log_prob. 200 | return self.distribution.log_prob(x, **kwargs) 201 | 202 | def _entropy(self, **kwargs): 203 | return self.distribution.entropy(**kwargs) 204 | 205 | _mean = _make_summary_statistic("mean") 206 | _stddev = _make_summary_statistic("stddev") 207 | _variance = _make_summary_statistic("variance") 208 | _mode = _make_summary_statistic("mode") 209 | 210 | 211 | @kullback_leibler.RegisterKL(BatchStacker, BatchStacker) 212 | def _kl_sample(a: BatchStacker, b: BatchStacker, name: str = "kl_sample") -> tf.Tensor: 213 | r""" 214 | Batched KL divergence :math:`KL(a || b)` for :class:`~.BatchStacker` distributions. 215 | 216 | We can leverage the fact that: 217 | 218 | .. math:: 219 | KL(BatchStacker(a) || BatchStacker(b)) = \sum(KL(a || b)) 220 | 221 | where the :math:`\sum` is over the ``batch_stack`` dims. 222 | 223 | Parameters 224 | ---------- 225 | a : BatchStacker 226 | Instance of ``BatchStacker`` distribution. 227 | b : BatchStacker 228 | Instance of ``BatchStacker`` distribution. 229 | name : str 230 | Name to use for created ops. 231 | 232 | Returns 233 | ------- 234 | kldiv : tf.Tensor 235 | Batchwise :math:`KL(a || b)`. 236 | 237 | Raises 238 | ------ 239 | ValueError 240 | If the ``batch_stack`` of ``a`` and ``b`` don't match. 241 | """ 242 | assertions = [] 243 | a_ss = tf.get_static_value(a.batch_stack) 244 | b_ss = tf.get_static_value(b.batch_stack) 245 | msg = "`a.batch_stack` must be identical to `b.batch_stack`." 246 | if a_ss is not None and b_ss is not None: 247 | if not np.array_equal(a_ss, b_ss): 248 | raise ValueError(msg) 249 | elif a.validate_args or b.validate_args: 250 | assertions.append(assert_util.assert_equal(a.batch_stack, b.batch_stack, message=msg)) 251 | with tf.control_dependencies(assertions): 252 | return kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name) 253 | -------------------------------------------------------------------------------- /pymc4/distributions/mixture.py: -------------------------------------------------------------------------------- 1 | """PyMC4 Distribution of a random variable consisting of a mixture of other 2 | distributions. 3 | 4 | Wraps tfd.Mixture as pm.Mixture 5 | """ 6 | 7 | import collections 8 | from typing import Union, Tuple, List 9 | 10 | import tensorflow as tf 11 | from tensorflow_probability import distributions as tfd 12 | from pymc4.distributions.distribution import Distribution 13 | 14 | 15 | class Mixture(Distribution): 16 | r""" 17 | Mixture random variable. 18 | Often used to model subpopulation heterogeneity 19 | .. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i) 20 | ======== ============================================ 21 | Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)` 22 | Mean :math:`\sum_{i = 1}^n w_i \mu_i` 23 | ======== ============================================ 24 | Parameters 25 | ---------- 26 | p : tf.Tensor 27 | p >= 0 and p <= 1 28 | The mixture weights, in the form of probabilities, 29 | must sum to one on the last (i.e., right-most) axis. 30 | distributions : pm.Distribution|sequence of pm.Distribution 31 | Multi-dimensional PyMC4 distribution (e.g. `pm.Poisson(...)`) 32 | or iterable of one-dimensional PyMC4 distributions 33 | :math:`f_1, \ldots, f_n` 34 | 35 | Examples 36 | -------- 37 | Let's define a simple two-component Gaussian mixture: 38 | 39 | >>> import tensorflow as tf 40 | >>> import pymc4 as pm 41 | >>> @pm.model 42 | ... def mixture(dat): 43 | ... p = tf.constant([0.5, 0.5]) 44 | ... m = yield pm.Normal("means", loc=tf.constant([0.0, 0.0]), scale=1.0) 45 | ... comps = pm.Normal("comps", m, scale=1.0) 46 | ... obs = yield pm.Mixture("mix", p=p, distributions=comps, observed=dat) 47 | ... return obs 48 | 49 | The above implementation only allows components of the same family of distribitions. 50 | In order to allow for different families, we need a more verbose implementation: 51 | 52 | >>> @pm.model 53 | ... def mixture(dat): 54 | ... p = tf.constant([0.5, 0.5]) 55 | ... m = yield pm.Normal("means", loc=tf.constant([0.0, 0.0]), scale=1.0) 56 | ... comp1 = pm.Normal("comp1", m[..., 0], scale=1.0) 57 | ... comp2 = pm.StudentT("comp2", m[..., 1], scale=1.0, df=3) 58 | ... obs = yield pm.Mixture("mix", p=p, distributions=[comp1, comp2], observed=dat) 59 | ... return obs 60 | 61 | We can also, as usual with Tensorflow, use higher dimensional parameters: 62 | 63 | >>> @pm.model 64 | ... def mixture(dat): 65 | ... p = tf.constant([[0.8, 0.2], [0.4, 0.6], [0.5, 0.5]]) 66 | ... m = yield pm.Normal("means", loc=tf.constant([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]), scale=1.0) 67 | ... comp1 = pm.Normal("d1", m[..., 0], scale=1.0) 68 | ... comp2 = pm.StudentT("d2", m[..., 1], scale=1.0, df=3) 69 | ... obs = yield pm.Mixture("mix", p=p, distributions=[comp1, comp2], observed=dat) 70 | ... return obs 71 | 72 | Note that in the last implementation the mixing weights need to sum to one 73 | on the right-most axis (to ensure correct parameterization use `validate_args=True`) 74 | """ 75 | 76 | def __init__( 77 | self, 78 | name: str, 79 | p: tf.Tensor, 80 | distributions: Union[Distribution, List[Distribution], Tuple[Distribution]], 81 | **kwargs, 82 | ): 83 | super().__init__(name, p=p, distributions=distributions, **kwargs) 84 | 85 | @staticmethod 86 | def _init_distribution(conditions, **kwargs): 87 | p, d = conditions["p"], conditions["distributions"] 88 | # if 'd' is a sequence of pymc distributions, then use the underlying 89 | # tfp distributions for the mixture 90 | if isinstance(d, collections.abc.Sequence): 91 | if any(not isinstance(el, Distribution) for el in d): 92 | raise TypeError( 93 | "every element in 'distribution' needs to be a pymc4.Distribution object" 94 | ) 95 | distr = [el._distribution for el in d] 96 | return tfd.Mixture( 97 | tfd.Categorical(probs=p, **kwargs), 98 | distr, 99 | **kwargs, 100 | use_static_graph=True, 101 | ) 102 | # else if 'd' is a pymc distribution with batch_size > 1 103 | elif isinstance(d, Distribution): 104 | return tfd.MixtureSameFamily( 105 | tfd.Categorical(probs=p, **kwargs), d._distribution, **kwargs 106 | ) 107 | else: 108 | raise TypeError( 109 | "'distribution' needs to be a pymc4.Distribution object or a sequence of distributions" 110 | ) 111 | -------------------------------------------------------------------------------- /pymc4/distributions/state_functions.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Optional, List, Union, Any 3 | import tensorflow as tf 4 | import tensorflow_probability as tfp 5 | from tensorflow_probability.python.mcmc.internal import util as mcmc_util 6 | from tensorflow_probability.python.internal import samplers 7 | 8 | tfd = tfp.distributions 9 | 10 | __all__ = ["categorical_uniform_fn", "bernoulli_fn", "gaussian_round_fn"] 11 | 12 | # TODO: We can furthere optimize proposal function 13 | # For now if a user provides two sampler functions with 14 | # the same proposal function but different scale parameter 15 | # the sampler code will treat them as separete sampling kernels 16 | # which will increase the graph size. 17 | 18 | 19 | class Proposal(metaclass=abc.ABCMeta): 20 | def __init__(self, name: Optional[str] = None): 21 | if name: 22 | self._name = name 23 | 24 | @abc.abstractmethod 25 | def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]: 26 | """ 27 | Proposal function that is passed as the argument 28 | to RWM kernel 29 | 30 | Parameters 31 | ---------- 32 | state_parts : List[tf.Tensor] 33 | A list of `Tensor`s of any shape and real dtype representing 34 | the state of the `current_state` of the Markov chain 35 | seed: Optional[int] 36 | The random seed for this `Op`. If `None`, no seed is 37 | applied 38 | Default value: `None` 39 | 40 | Returns 41 | ------- 42 | List[tf.Tensor] 43 | A Python `list` of The `Tensor`s. Has the same 44 | shape and type as the `state_parts`. 45 | 46 | Raises 47 | ------ 48 | ValueError: if `scale` does not broadcast with `state_parts`. 49 | """ 50 | pass 51 | 52 | @abc.abstractmethod 53 | def __eq__(self, other) -> bool: 54 | """ 55 | Comparison operator overload of each proposal sub-class. 56 | The operator is required to disnguish same proposal functions to separate 57 | samplers in `Compound step` 58 | 59 | Parameters 60 | ---------- 61 | other: pm.distributions.Proposal 62 | Another instance of `Proposal` sub-class. 63 | 64 | Returns 65 | ------- 66 | bool 67 | True/False for equality of instances 68 | """ 69 | pass 70 | 71 | def __call__(self): 72 | return self._fn 73 | 74 | 75 | class CategoricalUniformFn(Proposal): 76 | """ 77 | Categorical proposal sub-class with the `_fn` that is sampling new proposal 78 | from catecorical distribution with uniform probabilities. 79 | 80 | Parameters 81 | ---------- 82 | classes: int 83 | Number of classes for catecorical distribution 84 | name: Optional[str] 85 | Python `str` name prefixed to Ops created by this function. 86 | Default value: 'categorical_uniform_fn'. 87 | """ 88 | 89 | _name = "categorical_uniform_fn" 90 | 91 | def __init__(self, classes: int, name: Optional[str] = None): 92 | super().__init__(name) 93 | self.classes = classes 94 | 95 | def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]: 96 | with tf.name_scope(self._name or "categorical_uniform_fn"): 97 | part_seeds = samplers.split_seed(seed, n=len(state_parts), salt="CategoricalUniformFn") 98 | deltas = tf.nest.map_structure( 99 | lambda x, s: tfd.Categorical(logits=tf.ones(self.classes)).sample( 100 | seed=s, sample_shape=tf.shape(x) 101 | ), 102 | state_parts, 103 | part_seeds, 104 | ) 105 | return deltas 106 | 107 | def __eq__(self, other) -> bool: 108 | return self._name == other._name and self.classes == other.classes 109 | 110 | 111 | class BernoulliFn(Proposal): 112 | """ 113 | Bernoulli proposal sub-class with the `_fn` that is sampling new proposal 114 | from bernoulli distribution with p=0.5. 115 | 116 | Parameters 117 | ---------- 118 | name: Optional[str] 119 | Python `str` name prefixed to Ops created by this function. 120 | Default value: 'categorical_uniform_fn'. 121 | """ 122 | 123 | _name = "bernoulli_fn" 124 | 125 | def __init__(self, name: Optional[str] = None): 126 | super().__init__(name) 127 | 128 | def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]: 129 | with tf.name_scope(self._name or "bernoulli_fn"): 130 | part_seeds = samplers.split_seed(seed, n=len(state_parts), salt="BernoulliFn") 131 | 132 | def generate_bernoulli(state_part, part_seed): 133 | delta = tfd.Bernoulli( 134 | probs=tf.ones_like(state_part, dtype=tf.float32) * 0.5, dtype=state_part.dtype 135 | ).sample(seed=part_seed) 136 | state_part = (state_part + delta) % tf.constant(2, dtype=state_part.dtype) 137 | return state_part 138 | 139 | new_state = tf.nest.map_structure(generate_bernoulli, state_parts, part_seeds) 140 | return new_state 141 | 142 | def __eq__(self, other) -> bool: 143 | return self._name == other._name 144 | 145 | 146 | class GaussianRoundFn(Proposal): 147 | """ 148 | Gaussian-Round proposal sub-class with the `_fn` that is sampling new proposal 149 | from normal distribution N(0, 1) and rounding the values. 150 | 151 | Parameters 152 | ---------- 153 | scale: Union[List[Any], Any] 154 | a `Tensor` or Python `list` of `Tensor`s of any shapes and `dtypes` 155 | controlling the scale of the proposal distribution. 156 | name: Optional[str] 157 | Python `str` name prefixed to Ops created by this function. 158 | Default value: 'categorical_uniform_fn'. 159 | """ 160 | 161 | _name = "gaussian_round_fn" 162 | 163 | def __init__(self, scale: Union[List[Any], Any] = 1.0, name: Optional[str] = None): 164 | super().__init__(name) 165 | self.scale = scale 166 | 167 | def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]: 168 | scale = self.scale 169 | with tf.name_scope(self._name or "gaussian_round_fn"): 170 | scales = scale if mcmc_util.is_list_like(scale) else [scale] 171 | if len(scales) == 1: 172 | scales *= len(state_parts) 173 | if len(state_parts) != len(scales): 174 | raise ValueError("`scale` must broadcast with `state_parts`") 175 | 176 | part_seeds = samplers.split_seed(seed, n=len(state_parts), salt="BernoulliFn") 177 | 178 | def generate_rounded_normal(state_part, scale_part, part_seed): 179 | delta = tfd.Normal(0.0, tf.ones_like(state_part)).sample(seed=part_seed) 180 | state_part += delta 181 | return tf.round(state_part) 182 | 183 | new_state = tf.nest.map_structure( 184 | generate_rounded_normal, state_parts, scales, part_seeds 185 | ) 186 | return new_state 187 | 188 | def __eq__(self, other) -> bool: 189 | return self._name == other._name and self.scale == other.scale 190 | 191 | 192 | categorical_uniform_fn = CategoricalUniformFn 193 | bernoulli_fn = BernoulliFn 194 | gaussian_round_fn = GaussianRoundFn 195 | -------------------------------------------------------------------------------- /pymc4/distributions/timeseries.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import tensorflow as tf 3 | from tensorflow_probability import sts 4 | from tensorflow_probability import distributions as tfd 5 | from pymc4.distributions.distribution import ContinuousDistribution 6 | 7 | 8 | class AR(ContinuousDistribution): 9 | r"""Autoregressive process with `order` lags. 10 | 11 | Parameters 12 | ---------- 13 | num_timesteps : int Tensor 14 | Total number of timesteps to model. 15 | coefficients : float Tensor 16 | Autoregressive coefficients of shape `concat(batch_shape, [order])`. 17 | order = coefficients.shape[-1] (order>0) 18 | level_scale : Scalar float Tensor 19 | Standard deviation of the transition noise at each step 20 | (any additional dimensions are treated as batch 21 | dimensions). 22 | initial_state : (Optional) float Tensor 23 | Corresponding values of size `order` for 24 | imagined timesteps before the initial step. 25 | initial_step : (Optional) int Tensor 26 | Starting timestep (Default value: 0). 27 | 28 | Examples 29 | -------- 30 | >>> import pymc4 as pm 31 | >>> @pm.model 32 | ... def model(): 33 | ... x = yield pm.AR('x', num_timesteps=50, coefficients=[0.2, -0.8], level_scale=-0.2) 34 | """ 35 | 36 | def __init__( 37 | self, 38 | name, 39 | num_timesteps, 40 | coefficients, 41 | level_scale, 42 | initial_state=None, 43 | initial_step=0, 44 | **kwargs, 45 | ): 46 | super().__init__( 47 | name, 48 | num_timesteps=num_timesteps, 49 | coefficients=coefficients, 50 | level_scale=level_scale, 51 | initial_state=initial_state, 52 | initial_step=initial_step, 53 | **kwargs, 54 | ) 55 | 56 | @classmethod 57 | def unpack_conditions(cls, **kwargs): 58 | conditions, base_parameters = super().unpack_conditions(**kwargs) 59 | warnings.warn( 60 | "At the moment, the Autoregressive distribution does not accept the initialization " 61 | "arguments: dtype, allow_nan_stats or validate_args. Any of those keyword arguments " 62 | "passed during initialization will be ignored." 63 | ) 64 | return conditions, {} 65 | 66 | @staticmethod 67 | def _init_distribution(conditions: dict, **kwargs): 68 | num_timesteps = conditions["num_timesteps"] 69 | coefficients = conditions["coefficients"] 70 | level_scale = conditions["level_scale"] 71 | initial_state = conditions["initial_state"] 72 | initial_step = conditions["initial_step"] 73 | 74 | coefficients = tf.convert_to_tensor(value=coefficients, name="coefficients") 75 | order = tf.compat.dimension_value(coefficients.shape[-1]) 76 | 77 | time_series_object = sts.Autoregressive(order=order) 78 | distribution = time_series_object.make_state_space_model( 79 | num_timesteps=num_timesteps, 80 | param_vals={"coefficients": coefficients, "level_scale": level_scale}, 81 | initial_state_prior=tfd.MultivariateNormalDiag( 82 | loc=initial_state, scale_diag=[1e-6] * order 83 | ), 84 | initial_step=initial_step, 85 | ) 86 | return distribution 87 | -------------------------------------------------------------------------------- /pymc4/distributions/transforms.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Optional 3 | 4 | from tensorflow_probability import bijectors as tfb 5 | 6 | __all__ = ["Log", "Sigmoid", "LowerBound", "UpperBound", "Interval"] 7 | 8 | 9 | class JacobianPreference(enum.Enum): 10 | Forward = "Forward" 11 | Backward = "Backward" 12 | 13 | 14 | class Transform: 15 | name: Optional[str] = None 16 | jacobian_preference = JacobianPreference.Forward 17 | 18 | def forward(self, x): 19 | """ 20 | Forward of a bijector. 21 | 22 | Applies transformation forward to input variable `x`. 23 | When transform is used on some distribution `p`, it will transform the random variable `x` after sampling 24 | from `p`. 25 | 26 | Parameters 27 | ---------- 28 | x : tensor 29 | Input tensor to be transformed. 30 | 31 | Returns 32 | ------- 33 | tensor 34 | Transformed tensor. 35 | """ 36 | raise NotImplementedError 37 | 38 | def inverse(self, z): 39 | """ 40 | Backward of a bijector. 41 | 42 | Applies inverse of transformation to input variable `z`. 43 | When transform is used on some distribution `p`, which has observed values `z`, it is used to 44 | transform the values of `z` correctly to the support of `p`. 45 | 46 | Parameters 47 | ---------- 48 | z : tensor 49 | Input tensor to be inverse transformed. 50 | 51 | Returns 52 | ------- 53 | tensor 54 | Inverse transformed tensor. 55 | """ 56 | raise NotImplementedError 57 | 58 | def forward_log_det_jacobian(self, x): 59 | """ 60 | Calculate logarithm of the absolute value of the Jacobian determinant for input `x`. 61 | 62 | Parameters 63 | ---------- 64 | x : tensor 65 | Input to calculate Jacobian determinant of. 66 | 67 | Returns 68 | ------- 69 | tensor 70 | The log abs Jacobian determinant of `x` w.r.t. this transform. 71 | """ 72 | raise NotImplementedError 73 | 74 | def inverse_log_det_jacobian(self, z): 75 | """ 76 | Calculate logarithm of the absolute value of the Jacobian determinant for output `z`. 77 | 78 | Parameters 79 | ---------- 80 | z : tensor 81 | Output to calculate Jacobian determinant of. 82 | 83 | Returns 84 | ------- 85 | tensor 86 | The log abs Jacobian determinant of `z` w.r.t. this transform. 87 | 88 | Notes 89 | ----- 90 | May be desired to be implemented efficiently 91 | """ 92 | raise -self.forward_log_det_jacobian(self.inverse(z)) 93 | 94 | 95 | class Invert(Transform): 96 | def __init__(self, transform): 97 | if transform.jacobian_preference == JacobianPreference.Forward: 98 | self.jacobian_preference = JacobianPreference.Backward 99 | else: 100 | self.jacobian_preference = JacobianPreference.Forward 101 | self._transform = transform 102 | 103 | def forward(self, x): 104 | return self._transform.inverse(x) 105 | 106 | def inverse(self, z): 107 | return self._transform.forward(z) 108 | 109 | def forward_log_det_jacobian(self, x): 110 | return self._transform.inverse_log_det_jacobian(x) 111 | 112 | def inverse_log_det_jacobian(self, z): 113 | return self._transform.forward_log_det_jacobian(z) 114 | 115 | 116 | class BackwardTransform(Transform): 117 | """Base class for Transforms with Jacobian Preference as Backward""" 118 | 119 | JacobianPreference = JacobianPreference.Backward 120 | 121 | def __init__(self, transform): 122 | self._transform = transform 123 | 124 | def forward(self, x): 125 | return self._transform.inverse(x) 126 | 127 | def inverse(self, z): 128 | return self._transform.forward(z) 129 | 130 | def forward_log_det_jacobian(self, x): 131 | return self._transform.inverse_log_det_jacobian(x, self._transform.inverse_min_event_ndims) 132 | 133 | def inverse_log_det_jacobian(self, z): 134 | return self._transform.forward_log_det_jacobian(z, self._transform.forward_min_event_ndims) 135 | 136 | 137 | class Log(BackwardTransform): 138 | name = "log" 139 | 140 | def __init__(self): 141 | # NOTE: We actually need the inverse to match PyMC3, do we? 142 | transform = tfb.Exp() 143 | super().__init__(transform) 144 | 145 | 146 | class Sigmoid(BackwardTransform): 147 | name = "sigmoid" 148 | 149 | def __init__(self): 150 | transform = tfb.Sigmoid() 151 | super().__init__(transform) 152 | 153 | 154 | class LowerBound(BackwardTransform): 155 | """"Transformation to interval [lower_limit, inf]""" 156 | 157 | name = "lowerbound" 158 | 159 | def __init__(self, lower_limit): 160 | transform = tfb.Chain([tfb.Shift(lower_limit), tfb.Exp()]) 161 | super().__init__(transform) 162 | 163 | 164 | class UpperBound(BackwardTransform): 165 | """"Transformation to interval [-inf, upper_limit]""" 166 | 167 | name = "upperbound" 168 | 169 | def __init__(self, upper_limit): 170 | transform = tfb.Chain([tfb.Shift(upper_limit), tfb.Scale(-1), tfb.Exp()]) 171 | super().__init__(transform) 172 | 173 | 174 | class Interval(BackwardTransform): 175 | """"Transformation to interval [lower_limit, upper_limit]""" 176 | 177 | name = "interval" 178 | 179 | def __init__(self, lower_limit, upper_limit): 180 | transform = tfb.Sigmoid(low=lower_limit, high=upper_limit) 181 | super().__init__(transform) 182 | -------------------------------------------------------------------------------- /pymc4/flow/__init__.py: -------------------------------------------------------------------------------- 1 | """Functions for evaluating log probabilities.""" 2 | from .executor import SamplingExecutor, SamplingState 3 | from .transformed_executor import TransformedSamplingExecutor 4 | from .posterior_predictive_executor import PosteriorPredictiveSamplingExecutor 5 | from .meta_executor import MetaSamplingExecutor, MetaPosteriorPredictiveSamplingExecutor 6 | 7 | __all__ = [ 8 | "SamplingExecutor", 9 | "TransformedSamplingExecutor", 10 | "PosteriorPredictiveSamplingExecutor", 11 | "MetaSamplingExecutor", 12 | "MetaPosteriorPredictiveSamplingExecutor", 13 | "evaluate_model", 14 | "evaluate_model_transformed", 15 | "evaluate_model_posterior_predictive", 16 | "evaluate_meta_model", 17 | "evaluate_meta_posterior_predictive_model", 18 | ] 19 | 20 | evaluate_model = SamplingExecutor() 21 | evaluate_model_transformed = TransformedSamplingExecutor() 22 | evaluate_model_posterior_predictive = PosteriorPredictiveSamplingExecutor() 23 | evaluate_meta_model = MetaSamplingExecutor() 24 | evaluate_meta_posterior_predictive_model = MetaPosteriorPredictiveSamplingExecutor() 25 | -------------------------------------------------------------------------------- /pymc4/flow/meta_executor.py: -------------------------------------------------------------------------------- 1 | """Execute graph with test values to extract a model's meta-information. 2 | Specifically, we wish to extract: 3 | - All variable's core shapes 4 | - All observed, deterministic, and unobserved variables (both transformed and 5 | untransformed. 6 | """ 7 | from typing import Tuple, Any, Union 8 | import tensorflow as tf 9 | from pymc4 import scopes 10 | from pymc4.distributions import distribution 11 | from pymc4.flow.executor import ( 12 | SamplingState, 13 | EvaluationError, 14 | observed_value_in_evaluation, 15 | assert_values_compatible_with_distribution, 16 | ) 17 | from pymc4.flow.transformed_executor import TransformedSamplingExecutor 18 | from pymc4.flow.posterior_predictive_executor import PosteriorPredictiveSamplingExecutor 19 | 20 | 21 | __all__ = ["MetaSamplingExecutor", "MetaPosteriorPredictiveSamplingExecutor"] 22 | 23 | 24 | class MetaSamplingExecutor(TransformedSamplingExecutor): 25 | """Do a forward pass through the model only using distribution test values.""" 26 | 27 | def proceed_distribution( 28 | self, 29 | dist: distribution.Distribution, 30 | state: SamplingState, 31 | sample_shape: Union[int, Tuple[int], tf.TensorShape] = None, 32 | ) -> Tuple[Any, SamplingState]: 33 | if dist.is_anonymous: 34 | raise EvaluationError("Attempting to create an anonymous Distribution") 35 | scoped_name = scopes.variable_name(dist.name) 36 | if scoped_name is None: 37 | raise EvaluationError("Attempting to create an anonymous Distribution") 38 | 39 | if ( 40 | scoped_name in state.discrete_distributions 41 | or scoped_name in state.continuous_distributions 42 | or scoped_name in state.deterministics_values 43 | ): 44 | raise EvaluationError( 45 | "Attempting to create a duplicate variable {!r}, " 46 | "this may happen if you forget to use `pm.name_scope()` when calling same " 47 | "model/function twice without providing explicit names. If you see this " 48 | "error message and the function being called is not wrapped with " 49 | "`pm.model`, you should better wrap it to provide explicit name for this model".format( 50 | scoped_name 51 | ) 52 | ) 53 | if scoped_name in state.observed_values or dist.is_observed: 54 | observed_variable = observed_value_in_evaluation(scoped_name, dist, state) 55 | if observed_variable is None: 56 | # None indicates we pass None to the state.observed_values dict, 57 | # might be posterior predictive or programmatically override to exchange observed variable to latent 58 | if scoped_name not in state.untransformed_values: 59 | # posterior predictive 60 | if dist.is_root: 61 | return_value = state.untransformed_values[ 62 | scoped_name 63 | ] = dist.get_test_sample(sample_shape=sample_shape) 64 | else: 65 | return_value = state.untransformed_values[ 66 | scoped_name 67 | ] = dist.get_test_sample() 68 | else: 69 | # replace observed variable with a custom one 70 | return_value = state.untransformed_values[scoped_name] 71 | # We also store the name in posterior_predictives just to keep 72 | # track of the variables used in posterior predictive sampling 73 | state.posterior_predictives.add(scoped_name) 74 | state.observed_values.pop(scoped_name) 75 | else: 76 | if scoped_name in state.untransformed_values: 77 | raise EvaluationError( 78 | EvaluationError.OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_VALUE_PASSED.format( 79 | scoped_name 80 | ) 81 | ) 82 | assert_values_compatible_with_distribution(scoped_name, observed_variable, dist) 83 | return_value = state.observed_values[scoped_name] = observed_variable 84 | elif scoped_name in state.untransformed_values: 85 | return_value = state.untransformed_values[scoped_name] 86 | else: 87 | if dist.is_root: 88 | return_value = state.untransformed_values[scoped_name] = dist.get_test_sample( 89 | sample_shape=sample_shape 90 | ) 91 | else: 92 | return_value = state.untransformed_values[scoped_name] = dist.get_test_sample() 93 | if dist._grad_support: 94 | state.continuous_distributions[scoped_name] = dist 95 | else: 96 | state.discrete_distributions[scoped_name] = dist 97 | return return_value, state 98 | 99 | 100 | class MetaPosteriorPredictiveSamplingExecutor( 101 | MetaSamplingExecutor, PosteriorPredictiveSamplingExecutor 102 | ): 103 | """Do a forward pass through the model only using distribution test values. 104 | Also modify the distributions to make them suitable for posterior predictive sampling. 105 | """ 106 | 107 | # Everything is done in the parent classes 108 | pass 109 | -------------------------------------------------------------------------------- /pymc4/flow/posterior_predictive_executor.py: -------------------------------------------------------------------------------- 1 | """Execute graph in a transformed space and change the observed distribution's shape. 2 | 3 | Specifically, we wish to transform the observed distributions' shape to make 4 | it aware of the observed values, in order to later draw posterior predictive 5 | samples. 6 | """ 7 | from typing import Mapping, Any 8 | import tensorflow as tf 9 | from pymc4 import scopes 10 | from pymc4.distributions.distribution import Distribution 11 | from pymc4.flow.executor import ( 12 | EvaluationError, 13 | ModelType, 14 | SamplingExecutor, 15 | SamplingState, 16 | observed_value_in_evaluation, 17 | get_observed_tensor_shape, 18 | assert_values_compatible_with_distribution, 19 | ) 20 | from pymc4.flow.transformed_executor import transform_dist_if_necessary 21 | 22 | 23 | class PosteriorPredictiveSamplingExecutor(SamplingExecutor): 24 | """Execute the probabilistic model for posterior predictive sampling. 25 | 26 | This means that the model will be evaluated in the same way as the 27 | TransformedSamplingExecutor evaluates it. All unobserved distributions 28 | will be left as they are. All observed distributions will modified in 29 | the following way: 30 | 1) The distribution's shape (batch_shape + event_shape) will be checked 31 | for consitency with the supplied observed value's shape. 32 | 2) If they are inconsistent, an EvaluationError will be raised. 33 | 3) If they are consistent the distribution's observed values shape 34 | will be broadcasted with the distribution's shape to construct a new 35 | Distribution instance with no observations. This distribution will be 36 | used for posterior predictive sampling 37 | """ 38 | 39 | def validate_state(self, state: SamplingState): 40 | """Validate that the model is not in a bad state.""" 41 | return 42 | 43 | def modify_distribution( 44 | self, dist: ModelType, model_info: Mapping[str, Any], state: SamplingState 45 | ) -> ModelType: 46 | """Remove the observed distribution values but keep their shapes. 47 | 48 | Modify observed Distribution instances in the following way: 49 | 1) The distribution's shape (batch_shape + event_shape) will be checked 50 | for consitency with the supplied observed value's shape. 51 | 2) If they are inconsistent, an EvaluationError will be raised. 52 | 3) If they are consistent the distribution's observed values' shape 53 | will be broadcasted with the distribution's shape to construct a new 54 | Distribution instance with no observations. 55 | 4) This distribution will be yielded instead of the original incoming 56 | dist, and it will be used for posterior predictive sampling 57 | 58 | Parameters 59 | ---------- 60 | dist: Union[types.GeneratorType, pymc4.coroutine_model.Model] 61 | The 62 | model_info: Mapping[str, Any] 63 | Either ``dist.model_info`` or 64 | ``pymc4.coroutine_model.Model.default_model_info`` if ``dist`` is not a 65 | ``pymc4.courutine_model.Model`` instance. 66 | state: SamplingState 67 | The model's evaluation state. 68 | 69 | Returns 70 | ------- 71 | model: Union[types.GeneratorType, pymc4.coroutine_model.Model] 72 | The original ``dist`` if it was not an observed ``Distribution`` or 73 | the ``Distribution`` with the changed ``batch_shape`` and observations 74 | set to ``None``. 75 | 76 | Raises 77 | ------ 78 | EvaluationError 79 | When ``dist`` and its passed observed value don't have a consistent 80 | shape 81 | """ 82 | dist = super().modify_distribution(dist, model_info, state) 83 | # We only modify the shape of Distribution instances that have observed 84 | # values 85 | dist = transform_dist_if_necessary(dist, state, allow_transformed_and_untransformed=False) 86 | if not isinstance(dist, Distribution): 87 | return dist 88 | scoped_name = scopes.variable_name(dist.name) 89 | if scoped_name is None: 90 | raise EvaluationError("Attempting to create an anonymous Distribution") 91 | 92 | observed_value = observed_value_in_evaluation(scoped_name, dist, state) 93 | if observed_value is None: 94 | return dist 95 | 96 | # We set the state's observed value to None to explicitly override 97 | # any previously given observed and at the same time, have the 98 | # scope_name added to the posterior_predictives set in 99 | # self.proceed_distribution 100 | state.observed_values[scoped_name] = None 101 | 102 | # We first check the TFP distribution's shape and compare it with the 103 | # observed_value's shape 104 | assert_values_compatible_with_distribution(scoped_name, observed_value, dist) 105 | 106 | # Now we get the broadcasted shape between the observed value and the distribution 107 | observed_shape = get_observed_tensor_shape(observed_value) 108 | dist_shape = dist.batch_shape + dist.event_shape 109 | new_dist_shape = tf.broadcast_static_shape(observed_shape, dist_shape) 110 | extra_batch_stack = new_dist_shape[: len(new_dist_shape) - len(dist_shape)] 111 | 112 | # Now we construct and return the same distribution but setting 113 | # observed to None and setting a batch_size that matches the result of 114 | # broadcasting the observed and distribution shape 115 | batch_stack = extra_batch_stack + (dist.batch_stack if dist.batch_stack is not None else ()) 116 | if len(batch_stack) > 0: 117 | reinterpreted_batch_ndims = dist.reinterpreted_batch_ndims 118 | if dist.event_stack: 119 | reinterpreted_batch_ndims += len(extra_batch_stack) 120 | new_dist = type(dist)( 121 | name=dist.name, 122 | transform=dist.transform, 123 | observed=None, 124 | batch_stack=batch_stack, 125 | conditionally_independent=dist.conditionally_independent, 126 | event_stack=dist.event_stack, 127 | reinterpreted_batch_ndims=reinterpreted_batch_ndims, 128 | **dist.conditions, 129 | ) 130 | else: 131 | new_dist = type(dist)( 132 | name=dist.name, 133 | transform=dist.transform, 134 | observed=None, 135 | batch_stack=None, 136 | conditionally_independent=dist.conditionally_independent, 137 | event_stack=dist.event_stack, 138 | reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims, 139 | **dist.conditions, 140 | ) 141 | return new_dist 142 | -------------------------------------------------------------------------------- /pymc4/flow/transformed_executor.py: -------------------------------------------------------------------------------- 1 | """Execute graph in a transformed state. 2 | 3 | Specifically, we wish to transform distributions whose support is bounded on 4 | one or both sides to distributions that are supported for all real numbers. 5 | """ 6 | import functools 7 | 8 | 9 | from typing import Mapping, Any 10 | from pymc4 import scopes, distributions 11 | from pymc4.distributions import distribution 12 | from pymc4.distributions.transforms import JacobianPreference 13 | from pymc4.flow.executor import ( 14 | SamplingExecutor, 15 | EvaluationError, 16 | observed_value_in_evaluation, 17 | ModelType, 18 | SamplingState, 19 | ) 20 | 21 | 22 | class TransformedSamplingExecutor(SamplingExecutor): 23 | """Perform inference in an unconstrained space.""" 24 | 25 | def validate_state(self, state): 26 | """Validate that the model is not in a bad state.""" 27 | return 28 | 29 | def modify_distribution( 30 | self, dist: ModelType, model_info: Mapping[str, Any], state: SamplingState 31 | ) -> ModelType: 32 | """Apply transformations to a distribution.""" 33 | dist = super().modify_distribution(dist, model_info, state) 34 | if not isinstance(dist, distribution.Distribution): 35 | return dist 36 | 37 | return transform_dist_if_necessary(dist, state, allow_transformed_and_untransformed=True) 38 | 39 | 40 | def make_untransformed_model(dist, transform, state): 41 | # we gonna sample here, but logp should be computed for the transformed space 42 | # 0. as explained above we indicate we already performed autotransform 43 | dist.model_info["autotransformed"] = True 44 | # 1. sample a value, as we've checked there is no state provided 45 | # we need `dist.model_info["autotransformed"] = True` here not to get in a trouble 46 | # the return value is not yet user facing 47 | sampled_untransformed_value = yield dist 48 | sampled_transformed_value = transform.forward(sampled_untransformed_value) 49 | # already stored untransformed value via yield 50 | # state.values[scoped_name] = sampled_untransformed_value 51 | transformed_scoped_name = scopes.transformed_variable_name(transform.name, dist.name) 52 | state.transformed_values[transformed_scoped_name] = sampled_transformed_value 53 | # 2. increment the potential 54 | if transform.jacobian_preference == JacobianPreference.Forward: 55 | potential_fn = functools.partial( 56 | transform.forward_log_det_jacobian, sampled_untransformed_value 57 | ) 58 | coef = -1.0 59 | else: 60 | potential_fn = functools.partial( 61 | transform.inverse_log_det_jacobian, sampled_transformed_value 62 | ) 63 | coef = 1.0 64 | yield distributions.Potential(potential_fn, coef=coef) 65 | # 3. return value to the user 66 | return sampled_untransformed_value 67 | 68 | 69 | def make_transformed_model(dist, transform, state): 70 | # 1. now compute all the variables: in the transformed and untransformed space 71 | scoped_name = scopes.variable_name(dist.name) 72 | transformed_scoped_name = scopes.transformed_variable_name(transform.name, dist.name) 73 | state.untransformed_values[scoped_name] = transform.inverse( 74 | state.transformed_values[transformed_scoped_name] 75 | ) 76 | # disable sampling and save cached results to store for yield dist 77 | 78 | # once we are done with variables we can yield the value in untransformed space 79 | # to the user and also increment the potential 80 | 81 | # Important: 82 | # I have no idea yet, how to make that beautiful. 83 | # Here we indicate the distribution is already autotransformed not to get in the infinite loop 84 | dist.model_info["autotransformed"] = True 85 | 86 | # 2. here decide on logdet computation, this might be effective 87 | # with transformed value, but not with an untransformed one 88 | # this information is stored in transform.jacobian_preference class attribute 89 | # we postpone the computation of logdet as it might have some overhead 90 | if transform.jacobian_preference == JacobianPreference.Forward: 91 | potential_fn = functools.partial( 92 | transform.forward_log_det_jacobian, state.untransformed_values[scoped_name] 93 | ) 94 | coef = -1.0 95 | else: 96 | potential_fn = functools.partial( 97 | transform.inverse_log_det_jacobian, 98 | state.transformed_values[transformed_scoped_name], 99 | ) 100 | coef = 1.0 101 | yield distributions.Potential(potential_fn, coef=coef) 102 | # 3. final return+yield will return untransformed_value 103 | # as it is stored in state.values 104 | # Note: we need yield here to make another checks on name duplicates, etc 105 | return (yield dist) 106 | 107 | 108 | def transform_dist_if_necessary(dist, state, *, allow_transformed_and_untransformed): 109 | if dist.transform is None or dist.model_info.get("autotransformed", False): 110 | return dist 111 | scoped_name = scopes.variable_name(dist.name) 112 | transform = dist.transform 113 | transformed_scoped_name = scopes.transformed_variable_name(transform.name, dist.name) 114 | if observed_value_in_evaluation(scoped_name, dist, state) is not None: 115 | # do not modify a distribution if it is observed 116 | # same for programmatically observed 117 | # but not for programmatically set to unobserved (when value is None) 118 | # but raise if we have transformed value passed in dict 119 | if transformed_scoped_name in state.transformed_values: 120 | raise EvaluationError( 121 | EvaluationError.OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_TRANSFORMED_VALUE_PASSED.format( 122 | scoped_name, transformed_scoped_name 123 | ) 124 | ) 125 | if scoped_name in state.untransformed_values: 126 | raise EvaluationError( 127 | EvaluationError.OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_VALUE_PASSED.format( 128 | scoped_name, scoped_name 129 | ) 130 | ) 131 | return dist 132 | 133 | if transformed_scoped_name in state.transformed_values: 134 | if (not allow_transformed_and_untransformed) and scoped_name in state.untransformed_values: 135 | state.untransformed_values.pop(scoped_name) 136 | return make_transformed_model(dist, transform, state) 137 | else: 138 | return make_untransformed_model(dist, transform, state) 139 | -------------------------------------------------------------------------------- /pymc4/gp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cov 2 | from . import mean 3 | from . import util 4 | from . import gp 5 | from .gp import * 6 | -------------------------------------------------------------------------------- /pymc4/gp/mean.py: -------------------------------------------------------------------------------- 1 | """Mean functions for PyMC4's Gaussian Process Module.""" 2 | 3 | from typing import Union 4 | 5 | import tensorflow as tf 6 | from tensorflow_probability.python.internal import dtype_util 7 | 8 | from .util import ArrayLike, TfTensor, _inherit_docs 9 | 10 | __all__ = ["Mean", "Zero", "Constant"] 11 | 12 | 13 | class Mean: 14 | r""" 15 | Base Class for all the mean functions in GP. 16 | 17 | Parameters 18 | ---------- 19 | feature_ndims : int, optional 20 | The number of feature dimensions to be absorbed during 21 | the computation. (default=1) 22 | """ 23 | 24 | def __init__(self, feature_ndims=1): 25 | self.feature_ndims = feature_ndims 26 | 27 | def __call__(self, X: ArrayLike) -> TfTensor: 28 | r""" 29 | Evaluate the mean function at a point. 30 | 31 | Parameters 32 | ---------- 33 | X : array_like 34 | Tensor or array of points at which to evaluate 35 | the mean function. 36 | 37 | Returns 38 | ------- 39 | mu : tensorflow.Tensor 40 | Mean evaluated at points ``X``. 41 | """ 42 | raise NotImplementedError("Your mean function should override this method.") 43 | 44 | def __add__(self, mean2): 45 | return MeanAdd(self, mean2) 46 | 47 | def __mul__(self, mean2): 48 | return MeanProd(self, mean2) 49 | 50 | 51 | class MeanAdd(Mean): 52 | r""" 53 | Addition of two or more mean functions. 54 | 55 | Parameters 56 | ---------- 57 | mean1 : Mean 58 | First mean function 59 | mean2 : Mean 60 | Second mean function 61 | """ 62 | 63 | def __init__(self, mean1: Mean, mean2: Mean): 64 | if mean1.feature_ndims != mean2.feature_ndims: 65 | raise ValueError("Cannot combine means with different feature_ndims.") 66 | self.mean1 = mean1 67 | self.mean2 = mean2 68 | 69 | @_inherit_docs(Mean.__call__) 70 | def __call__(self, X: ArrayLike) -> TfTensor: 71 | return self.mean1(X) + self.mean2(X) 72 | 73 | 74 | class MeanProd(Mean): 75 | r""" 76 | Product of two or more mean functions. 77 | 78 | Parameters 79 | ---------- 80 | mean1 : Mean 81 | First mean function 82 | mean2 : Mean 83 | Second mean function 84 | """ 85 | 86 | def __init__(self, mean1: Mean, mean2: Mean): 87 | if mean1.feature_ndims != mean2.feature_ndims: 88 | raise ValueError("Cannot combine means with different feature_ndims.") 89 | self.mean1 = mean1 90 | self.mean2 = mean2 91 | 92 | @_inherit_docs(Mean.__call__) 93 | def __call__(self, X: ArrayLike) -> TfTensor: 94 | return self.mean1(X) * self.mean2(X) 95 | 96 | 97 | class Zero(Mean): 98 | r""" 99 | Zero mean function. 100 | 101 | Parameters 102 | ---------- 103 | feature_ndims : int, optional 104 | number of rightmost dims to include in mean computation. (default=1) 105 | """ 106 | 107 | @_inherit_docs(Mean.__call__) 108 | def __call__(self, X: ArrayLike) -> TfTensor: 109 | dtype = dtype_util.common_dtype([X]) 110 | X = tf.convert_to_tensor(X, dtype=dtype) 111 | return tf.zeros(X.shape[: -self.feature_ndims], dtype=dtype) 112 | 113 | 114 | class Constant(Mean): 115 | r""" 116 | Constant mean function. 117 | 118 | Parameters 119 | ---------- 120 | coef : array_like, optional 121 | co-efficient to scale the mean. (default=1) 122 | feature_ndims : int, optional 123 | number of rightmost dims to include in mean computation. (default=1) 124 | """ 125 | 126 | def __init__(self, coef: Union[ArrayLike, float] = 1, feature_ndims: int = 1): 127 | self.coef = coef 128 | super().__init__(feature_ndims=feature_ndims) 129 | 130 | @_inherit_docs(Mean.__call__) 131 | def __call__(self, X: ArrayLike) -> TfTensor: 132 | dtype = dtype_util.common_dtype([X, self.coef]) 133 | X = tf.convert_to_tensor(X, dtype=dtype) 134 | return tf.ones(X.shape[: -self.feature_ndims], dtype=dtype) * self.coef 135 | -------------------------------------------------------------------------------- /pymc4/gp/util.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | import re 4 | from typing import Union 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | 10 | ArrayLike = Union[np.ndarray, tf.Tensor] 11 | TfTensor = tf.Tensor 12 | FreeRV = ArrayLike 13 | 14 | 15 | def stabilize(K, shift=None): 16 | r"""Add a diagonal shift to a covariance matrix.""" 17 | K = tf.convert_to_tensor(K) 18 | diag = tf.linalg.diag_part(K) 19 | if shift is None: 20 | shift = 1e-6 if K.dtype == tf.float64 else 1e-4 21 | return tf.linalg.set_diag(K, diag + shift) 22 | 23 | 24 | def _inherit_docs(frommeth): 25 | r"""Decorate a method or class to inherit docs from `frommeth`.""" 26 | 27 | def inherit(tometh): 28 | methdocs = frommeth.__doc__ 29 | if methdocs is None: 30 | raise ValueError("No docs to inherit!") 31 | tometh.__doc__ = methdocs 32 | return tometh 33 | 34 | return inherit 35 | 36 | 37 | def _build_docs(meth_or_cls): 38 | r"""Decorate a method or class to build its doc strings.""" 39 | pattern = re.compile("\%\(.*\)") 40 | modname = inspect.getmodule(meth_or_cls) 41 | docs = meth_or_cls.__doc__ 42 | while pattern.search(docs) is not None: 43 | docname = pattern.search(docs).group(0)[2:-1] 44 | try: 45 | docstr = getattr(modname, docname) 46 | except AttributeError: 47 | warnings.warn( 48 | f"While documenting {meth_or_cls.__name__}, arrtibute {docname} not found.", 49 | SyntaxWarning, 50 | ) 51 | # FIXME: This should continue execution by skipping 52 | # the docs not found. Instead, currently, it just stops 53 | # execution! 54 | break 55 | docs = pattern.sub(docstr, docs, count=1) 56 | meth_or_cls.__doc__ = docs 57 | return meth_or_cls 58 | -------------------------------------------------------------------------------- /pymc4/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from . import sampling 2 | -------------------------------------------------------------------------------- /pymc4/inference/sampling.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any, List 2 | from pymc4.coroutine_model import Model 3 | from pymc4 import flow 4 | from pymc4.mcmc.samplers import reg_samplers, _log 5 | from pymc4.mcmc.utils import initialize_state, scope_remove_transformed_part_if_required 6 | import logging 7 | 8 | MYPY = False 9 | 10 | if not MYPY: 11 | logging._warn_preinit_stderr = 0 12 | 13 | 14 | def check_proposal_functions( 15 | model: Model, 16 | state: Optional[flow.SamplingState] = None, 17 | observed: Optional[dict] = None, 18 | ) -> bool: 19 | """ 20 | Check for the non-default proposal generation functions 21 | 22 | Parameters 23 | ---------- 24 | model : pymc4.Model 25 | Model to sample posterior for 26 | state : Optional[flow.SamplingState] 27 | Current state 28 | observed : Optional[Dict[str, Any]] 29 | Observed values (optional) 30 | """ 31 | (_, state, _, _, continuous_distrs, discrete_distrs) = initialize_state( 32 | model, observed=observed, state=state 33 | ) 34 | init = state.all_unobserved_values 35 | init_state = list(init.values()) 36 | init_keys = list(init.keys()) 37 | 38 | for i, state_part in enumerate(init_state): 39 | untrs_var, unscoped_tr_var = scope_remove_transformed_part_if_required( 40 | init_keys[i], state.transformed_values 41 | ) 42 | # get the distribution for the random variable name 43 | distr = continuous_distrs.get(untrs_var, None) 44 | if distr is None: 45 | distr = discrete_distrs[untrs_var] 46 | func = distr._default_new_state_part 47 | if callable(func): 48 | return True 49 | return False 50 | 51 | 52 | def sample( 53 | model: Model, 54 | sampler_type: Optional[str] = None, 55 | num_samples: int = 1000, 56 | num_chains: int = 10, 57 | burn_in: int = 100, 58 | observed: Optional[Dict[str, Any]] = None, 59 | state: Optional[flow.SamplingState] = None, 60 | xla: bool = False, 61 | use_auto_batching: bool = True, 62 | sampler_methods: Optional[List] = None, 63 | trace_discrete: Optional[List[str]] = None, 64 | seed: Optional[int] = None, 65 | include_log_likelihood: bool = False, 66 | **kwargs, 67 | ): 68 | """ 69 | Perform MCMC sampling using NUTS (for now). 70 | 71 | Parameters 72 | ---------- 73 | model : pymc4.Model 74 | Model to sample posterior for 75 | sampler_type : Optional[str] 76 | The step method type for the model 77 | num_samples : int 78 | Num samples in a chain 79 | num_chains : int 80 | Num chains to run 81 | burn_in : int 82 | Length of burn-in period 83 | observed : Optional[Dict[str, Any]] 84 | New observed values (optional) 85 | state : Optional[pymc4.flow.SamplingState] 86 | Alternative way to pass specify initial values and observed values 87 | xla : bool 88 | Enable experimental XLA 89 | **kwargs: Dict[str, Any] 90 | All kwargs for kernel, adaptive_step_kernel, chain_sample method 91 | use_auto_batching : bool 92 | WARNING: This is an advanced user feature. If you are not sure how to use this, please use 93 | the default ``True`` value. 94 | If ``True``, the model's total ``log_prob`` will be automatically vectorized to work across 95 | multiple indepedent chains using ``tf.vectorized_map``. If ``False``, the model is assumed 96 | be defined in vectorized way. This means that every distribution has the proper 97 | ``batch_shape`` and ``event_shape``s so that all the outputs from each distribution's 98 | ``log_prob`` will broadcast with each other, and that the forward passes through the model 99 | (prior and posterior predictive sampling) all work on values with any value of 100 | ``batch_shape``. Achieving this is a hard task, but it enables the model to be safely 101 | evaluated in parallel across all chains in MCMC, so sampling will be faster than in the 102 | automatically batched scenario. 103 | trace_discrete : Optional[List[str]] 104 | INFO: This is an advanced user feature. 105 | The pyhton list of variables that should be casted to tf.int32 after sampling is completed 106 | seed : Optional[int] 107 | A seed for reproducible sampling 108 | include_log_likelihood : bool, default=False 109 | Include log likelihood in trace 110 | Returns 111 | ------- 112 | Trace : InferenceDataType 113 | An ArviZ's InferenceData object with the groups: posterior, sample_stats and observed_data 114 | Examples 115 | -------- 116 | Let's start with a simple model. We'll need some imports to experiment with it. 117 | >>> import pymc4 as pm 118 | >>> import numpy as np 119 | This particular model has a latent variable `sd` 120 | >>> @pm.model 121 | ... def nested_model(cond): 122 | ... sd = yield pm.HalfNormal("sd", 1.) 123 | ... norm = yield pm.Normal("n", cond, sd, observed=np.random.randn(10)) 124 | ... return norm 125 | Now, we may want to perform sampling from this model. We already observed some variables and we 126 | now need to fix the condition. 127 | >>> conditioned = nested_model(cond=2.) 128 | Passing ``cond=2.`` we condition our model for future evaluation. Now we go to sampling. 129 | Nothing special is required but passing the model to ``pm.sample``, the rest configuration is 130 | held by PyMC4. 131 | >>> trace = sample(conditioned) 132 | Notes 133 | ----- 134 | Things that are considered to be under discussion are overriding observed variables. The API 135 | for that may look like 136 | >>> new_observed = {"nested_model/n": np.random.randn(10) + 1} 137 | >>> trace = sample(conditioned, observed=new_observed) 138 | This will give a trace with new observed variables. This way is considered to be explicit. 139 | """ 140 | # assign sampler is no sampler_type is passed`` 141 | sampler_assigned: str = auto_assign_sampler(model, sampler_type) 142 | 143 | try: 144 | sampler = reg_samplers[sampler_assigned] 145 | except KeyError: 146 | _log.warning( 147 | "The given sampler doesn't exist. Please choose samplers from: {}".format( 148 | list(reg_samplers.keys()) 149 | ) 150 | ) 151 | raise 152 | 153 | # TODO: keep num_adaptation_steps for nuts/hmc with 154 | # adaptive step but later should be removed because of ambiguity 155 | if any(x in sampler_assigned for x in ["nuts", "hmc"]): 156 | kwargs["num_adaptation_steps"] = burn_in 157 | 158 | sampler = sampler(model, **kwargs) 159 | 160 | # If some distributions in the model have non default proposal 161 | # generation functions then we lanuch compound step instead of rwm 162 | if sampler_assigned == "rwm": 163 | compound_required = check_proposal_functions(model, state=state, observed=observed) 164 | if compound_required: 165 | sampler_assigned = "compound" 166 | sampler = reg_samplers[sampler_assigned](model, **kwargs) 167 | 168 | if sampler_assigned == "compound": 169 | sampler._assign_default_methods( 170 | sampler_methods=sampler_methods, state=state, observed=observed 171 | ) 172 | 173 | return sampler( 174 | num_samples=num_samples, 175 | num_chains=num_chains, 176 | burn_in=burn_in, 177 | observed=observed, 178 | state=state, 179 | use_auto_batching=use_auto_batching, 180 | xla=xla, 181 | seed=seed, 182 | trace_discrete=trace_discrete, 183 | include_log_likelihood=include_log_likelihood, 184 | ) 185 | 186 | 187 | def auto_assign_sampler( 188 | model: Model, 189 | sampler_type: Optional[str] = None, 190 | ): 191 | """ 192 | The toy implementation of sampler assigner 193 | Parameters 194 | ---------- 195 | model : pymc4.Model 196 | Model to sample posterior for 197 | sampler_type : Optional[str] 198 | The step method type for the model 199 | Returns 200 | ------- 201 | sampler_type : str 202 | Sampler type name 203 | """ 204 | if sampler_type: 205 | _log.info("Working with {} sampler".format(reg_samplers[sampler_type].__name__)) 206 | return sampler_type 207 | 208 | _, _, free_disc_names, free_cont_names, _, _ = initialize_state(model) 209 | if not free_disc_names: 210 | _log.info("Auto-assigning NUTS sampler") 211 | return "nuts" 212 | else: 213 | _log.info("The model contains discrete distributions. " "\nCompound step is chosen.") 214 | return "compound" 215 | -------------------------------------------------------------------------------- /pymc4/mcmc/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import samplers 3 | -------------------------------------------------------------------------------- /pymc4/mcmc/tf_support.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import collections 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | from tensorflow_probability.python.mcmc import kernel as kernel_base 7 | from tensorflow_probability.python.mcmc.internal import util as mcmc_util 8 | 9 | CompoundGibbsStepResults = collections.namedtuple("CompoundGibbsStepResults", ["compound_results"]) 10 | 11 | 12 | def _target_log_prob_fn_part_compound(*state_part, idx, len_, state, target_log_prob_fn): 13 | sl = slice(idx, idx + len_) 14 | temp_value = state[sl] 15 | state[sl] = state_part 16 | log_prob = target_log_prob_fn(*state) 17 | state[sl] = temp_value 18 | return log_prob 19 | 20 | 21 | def _target_log_prob_fn_part_gibbs(*state_part, idx, len_, state, target_log_prob_fn): 22 | sl = slice(idx, idx + len_) 23 | state[sl] = state_part 24 | log_prob = target_log_prob_fn(*state) 25 | return log_prob 26 | 27 | 28 | def kernel_create_object( 29 | sampleri, 30 | curr_indx, 31 | setli, 32 | current_state, 33 | target_log_prob_fn, 34 | target_log_prob_fn_part, 35 | ): 36 | mkf = sampleri[0] 37 | kernel = mkf.kernel( 38 | target_log_prob_fn=functools.partial( 39 | target_log_prob_fn_part, 40 | idx=curr_indx, 41 | len_=setli, 42 | state=current_state, 43 | target_log_prob_fn=target_log_prob_fn, 44 | ), 45 | **{**sampleri[1], **mkf.kernel_kwargs}, 46 | ) 47 | if mkf.adaptive_kernel: 48 | kernel = mkf.adaptive_kernel(inner_kernel=kernel, **mkf.adaptive_kwargs) 49 | return kernel 50 | 51 | 52 | class _CompoundGibbsStepTF(kernel_base.TransitionKernel): 53 | def __init__( 54 | self, 55 | target_log_prob_fn, 56 | compound_samplers, 57 | compound_set_lengths, 58 | name=None, 59 | ): 60 | """ 61 | Initializes the compound step transition kernel 62 | 63 | Args: 64 | target_log_prob_fn: Python callable which takes an argument like 65 | `current_state` (or `*current_state` if it's a list) and returns its 66 | (possibly unnormalized) log-density under the target distribution. 67 | compound_samplers: List of the pymc4.mcmc.samplers._BaseSampler sub-classes 68 | that are used for each subset of the free variables 69 | compound_set_lengths: List of the sizes of each subset of variables 70 | name: Python `str` name prefixed to Ops created by this function. 71 | """ 72 | 73 | with tf.name_scope(name or "CompoundSampler") as name: 74 | self._target_log_prob_fn = target_log_prob_fn 75 | self._compound_samplers = [ 76 | (sampler[0]._default_kernel_maker(), sampler[1]) for sampler in compound_samplers 77 | ] 78 | self._compound_set_lengths = compound_set_lengths 79 | self._cumulative_lengths = np.cumsum(compound_set_lengths) - compound_set_lengths 80 | self._name = name 81 | self._parameters = dict(target_log_prob_fn=target_log_prob_fn, name=name) 82 | 83 | @property 84 | def target_log_prob_fn(self): 85 | return self._target_log_prob_fn 86 | 87 | @property 88 | def parameters(self): 89 | return self._parameters 90 | 91 | @property 92 | def name(self): 93 | return self._name 94 | 95 | @property 96 | def is_calibrated(self): 97 | return True 98 | 99 | def one_step(self, current_state, previous_kernel_results, seed=None): 100 | with tf.name_scope(mcmc_util.make_name(self.name, "compound", "one_step")): 101 | unwrap_state_list = not tf.nest.is_nested(current_state) 102 | if unwrap_state_list: 103 | current_state = [current_state] 104 | # TODO: can't use any efficient type of structure here 105 | # like tf.TensorArray..., `next_state` can have multiple tf.dtype and 106 | # `next_results` stores namedtuple with the set of tf.Tensor's with 107 | # multiple tf.dtype too. 108 | next_state = [] 109 | next_results = [] 110 | previous_kernel_results = previous_kernel_results.compound_results 111 | 112 | for sampleri, setli, resulti, curri in zip( 113 | self._compound_samplers, 114 | self._compound_set_lengths, 115 | previous_kernel_results, 116 | self._cumulative_lengths, 117 | ): 118 | kernel = self.kernel_create_object( 119 | sampleri, curri, setli, current_state, self._target_log_prob_fn 120 | ) 121 | next_state_, next_result_ = kernel.one_step( 122 | current_state[slice(curri, curri + setli)], resulti, seed=seed 123 | ) 124 | # concat state results for flattened list 125 | next_state += next_state_ 126 | # save current results 127 | next_results.append(next_result_) 128 | return [next_state, CompoundGibbsStepResults(compound_results=next_results)] 129 | 130 | def bootstrap_results(self, init_state): 131 | """ 132 | Returns an object with the same type as returned by `one_step(...)[1]` 133 | Compound bootrstrap step 134 | """ 135 | with tf.name_scope(mcmc_util.make_name(self.name, "compound", "bootstrap_results")): 136 | if not mcmc_util.is_list_like(init_state): 137 | init_state = [init_state] 138 | init_state = [tf.convert_to_tensor(x) for x in init_state] 139 | 140 | init_results = [] 141 | for sampleri, setli, curri in zip( 142 | self._compound_samplers, 143 | self._compound_set_lengths, 144 | self._cumulative_lengths, 145 | ): 146 | kernel = self.kernel_create_object( 147 | sampleri, curri, setli, init_state, self._target_log_prob_fn 148 | ) 149 | # bootstrap results in listj 150 | init_results.append( 151 | kernel.bootstrap_results(init_state[slice(curri, curri + setli)]) 152 | ) 153 | 154 | return CompoundGibbsStepResults(compound_results=init_results) 155 | 156 | 157 | class _CompoundStepTF(_CompoundGibbsStepTF): 158 | def __init__(self, *args, **kwargs): 159 | super().__init__(*args, **kwargs) 160 | self.kernel_create_object = functools.partial( 161 | kernel_create_object, 162 | target_log_prob_fn_part=_target_log_prob_fn_part_compound, 163 | ) 164 | 165 | 166 | class _GibbsStepTF(_CompoundGibbsStepTF): 167 | def __init__(self, *args, **kwargs): 168 | raise NotImplementedError 169 | -------------------------------------------------------------------------------- /pymc4/mcmc/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import arviz as az 4 | from typing import Optional, Tuple, List, Dict, Any 5 | 6 | KERNEL_KWARGS_SET = collections.namedtuple( 7 | "KERNEL_ARGS_SET", ["kernel", "adaptive_kernel", "kernel_kwargs", "adaptive_kwargs"] 8 | ) 9 | 10 | from pymc4 import Model, flow 11 | from pymc4.distributions import distribution 12 | 13 | 14 | def initialize_sampling_state( 15 | model: Model, 16 | observed: Optional[dict] = None, 17 | state: Optional[flow.SamplingState] = None, 18 | ) -> Tuple[flow.SamplingState, List[str]]: 19 | """ 20 | Initialize the model provided state and/or observed variables. 21 | Parameters 22 | ---------- 23 | model : pymc4.Model 24 | observed : Optional[dict] 25 | state : Optional[flow.SamplingState] 26 | Returns 27 | ------- 28 | state: pymc4.flow.SamplingState 29 | The model's sampling state 30 | deterministic_names: List[str] 31 | The list of names of the model's deterministics_values 32 | """ 33 | _, state = flow.evaluate_meta_model(model, observed=observed, state=state) 34 | deterministic_names = list(state.deterministics_values) 35 | state, transformed_names = state.as_sampling_state() 36 | return state, deterministic_names + transformed_names 37 | 38 | 39 | def initialize_state( 40 | model: Model, 41 | observed: Optional[dict] = None, 42 | state: Optional[flow.SamplingState] = None, 43 | ) -> Tuple[ 44 | flow.SamplingState, 45 | flow.SamplingState, 46 | List[str], 47 | List[str], 48 | Dict[str, distribution.Distribution], 49 | Dict[str, distribution.Distribution], 50 | ]: 51 | """ 52 | Get list of discrete/continuous distributions 53 | Parameters 54 | ---------- 55 | model : pymc4.Model 56 | observed : Optional[dict] 57 | state : Optional[flow.SamplingState] 58 | Returns 59 | ------- 60 | state: Model 61 | Unsampled version of sample object 62 | sampling_state: 63 | The model's sampling state 64 | free_discrete_names: List[str] 65 | The list of free discrete variables 66 | free_continuous_names: List[str] 67 | The list of free continuous variables 68 | cont_distr: List[distribution.Distribution] 69 | The list of all continous distributions 70 | disc_distr: List[distribution.Distribution] 71 | The list of all discrete distributions 72 | """ 73 | _, state = flow.evaluate_model_transformed(model) 74 | free_discrete_names, free_continuous_names = ( 75 | list(state.discrete_distributions), 76 | list(state.continuous_distributions), 77 | ) 78 | observed_rvs = list(state.observed_values.keys()) 79 | free_discrete_names = list(filter(lambda x: x not in observed_rvs, free_discrete_names)) 80 | free_continuous_names = list(filter(lambda x: x not in observed_rvs, free_continuous_names)) 81 | sampling_state = None 82 | cont_distrs = state.continuous_distributions 83 | disc_distrs = state.discrete_distributions 84 | sampling_state, _ = state.as_sampling_state() 85 | return ( 86 | state, 87 | sampling_state, 88 | free_discrete_names, 89 | free_continuous_names, 90 | cont_distrs, 91 | disc_distrs, 92 | ) 93 | 94 | 95 | def trace_to_arviz( 96 | trace=None, 97 | sample_stats=None, 98 | observed_data=None, 99 | prior_predictive=None, 100 | posterior_predictive=None, 101 | log_likelihood=None, 102 | inplace=True, 103 | ): 104 | """ 105 | Tensorflow to Arviz trace convertor. 106 | Creates an ArviZ's InferenceData object with inference, prediction and/or sampling data 107 | generated by PyMC4 108 | Parameters 109 | ---------- 110 | trace : dict or InferenceData 111 | sample_stats : dict 112 | observed_data : dict 113 | prior_predictive : dict 114 | posterior_predictive : dict 115 | log_likelihood : dict 116 | inplace : bool 117 | Returns 118 | ------- 119 | ArviZ's InferenceData object 120 | """ 121 | if trace is not None and isinstance(trace, dict): 122 | trace = {k: np.swapaxes(v.numpy(), 1, 0) for k, v in trace.items() if "/" in k} 123 | if sample_stats is not None and isinstance(sample_stats, dict): 124 | sample_stats = {k: v.numpy().T for k, v in sample_stats.items()} 125 | if prior_predictive is not None and isinstance(prior_predictive, dict): 126 | prior_predictive = {k: v[np.newaxis] for k, v in prior_predictive.items()} 127 | if posterior_predictive is not None and isinstance(posterior_predictive, dict): 128 | if isinstance(trace, az.InferenceData) and inplace == True: 129 | return trace + az.from_dict(posterior_predictive=posterior_predictive) 130 | else: 131 | trace = None 132 | if log_likelihood is not None and isinstance(log_likelihood, dict): 133 | log_likelihood = { 134 | k: np.swapaxes(v.numpy(), 1, 0) for k, v in log_likelihood.items() if "/" in k 135 | } 136 | 137 | return az.from_dict( 138 | posterior=trace, 139 | sample_stats=sample_stats, 140 | prior_predictive=prior_predictive, 141 | posterior_predictive=posterior_predictive, 142 | log_likelihood=log_likelihood, 143 | observed_data=observed_data, 144 | ) 145 | 146 | 147 | def scope_remove_transformed_part_if_required(name: str, transformed_values: Dict[str, Any]): 148 | name_split = name.split("/") 149 | if transformed_values and name in transformed_values: 150 | name_split[-1] = name_split[-1][2:][name_split[-1][2:].find("_") + 1 :] 151 | return "/".join(name_split), name_split[-1] 152 | -------------------------------------------------------------------------------- /pymc4/plots/__init__.py: -------------------------------------------------------------------------------- 1 | from .gp_plots import plot_gp_dist 2 | -------------------------------------------------------------------------------- /pymc4/plots/gp_plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_gp_dist( 6 | ax, 7 | samples, 8 | x, 9 | plot_samples=True, 10 | palette="Reds", 11 | fill_alpha=0.8, 12 | samples_alpha=0.1, 13 | fill_kwargs=None, 14 | samples_kwargs=None, 15 | ): 16 | """ 17 | Plot 1D GP posteriors from trace. 18 | 19 | Parameters 20 | ---------- 21 | ax : axes 22 | Matplotlib axes. 23 | samples : trace or list of traces 24 | Trace(s) or posterior predictive sample from a GP. 25 | x : array 26 | Grid of X values corresponding to the samples. 27 | plot_samples: bool 28 | Plot the GP samples along with posterior (defaults True). 29 | palette: str 30 | Palette for coloring output (defaults to "Reds"). 31 | fill_alpha : float 32 | Alpha value for the posterior interval fill (defaults to 0.8). 33 | samples_alpha : float 34 | Alpha value for the sample lines (defaults to 0.1). 35 | fill_kwargs : dict 36 | Additional arguments for posterior interval fill (fill_between). 37 | samples_kwargs : dict 38 | Additional keyword arguments for samples plot. 39 | 40 | Returns 41 | ------- 42 | ax : Matplotlib axes 43 | """ 44 | if fill_kwargs is None: 45 | fill_kwargs = {} 46 | if samples_kwargs is None: 47 | samples_kwargs = {} 48 | 49 | cmap = plt.get_cmap(palette) 50 | percs = np.linspace(51, 99, 40) 51 | colors = (percs - np.min(percs)) / (np.max(percs) - np.min(percs)) 52 | samples = samples.T 53 | x = x.flatten() 54 | for i, p in enumerate(percs[::-1]): 55 | upper = np.percentile(samples, p, axis=1) 56 | lower = np.percentile(samples, 100 - p, axis=1) 57 | color_val = colors[i] 58 | ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=fill_alpha, **fill_kwargs) 59 | if plot_samples: 60 | # plot a few samples 61 | idx = np.random.randint(0, samples.shape[1], 30) 62 | ax.plot(x, samples[:, idx], color=cmap(0.9), lw=1, alpha=samples_alpha, **samples_kwargs) 63 | 64 | return ax 65 | -------------------------------------------------------------------------------- /pymc4/scopes.py: -------------------------------------------------------------------------------- 1 | """Module that defines scope classes to easily create name scopes or any other kind of scope.""" 2 | from typing import Optional, List, Any, Callable, Generator, Union 3 | import threading 4 | 5 | 6 | class Scope(object): 7 | """ 8 | General purpose variable scoping. 9 | 10 | PyMC4 scopes intended to store useful information during forward pass of the model. 11 | This is intended to have more functionality rather than just name scoping. 12 | So this class should be a starting point for further development. 13 | 14 | The class absorbs any keyword arguments passed there. Accessing any attribute should return 15 | either None or the passed value by keyword. :func:`Scope.chain` will return all 16 | attributes for context, starting from the first one (the deepest one is the last one). 17 | 18 | Examples 19 | -------- 20 | >>> with Scope(var=1): 21 | ... with Scope(var=3): 22 | ... print(list(Scope.chain("var"))) 23 | [1, 3] 24 | """ 25 | 26 | _leaf = object() 27 | context = threading.local() 28 | 29 | def __init__(self, **kwargs): 30 | self.__dict__.update(kwargs) 31 | 32 | def __enter__(self): 33 | """Enter a new scope context appending it to the ``Scopes.context.stack``.""" 34 | type(self).get_contexts().append(self) 35 | return self 36 | 37 | def __exit__(self, typ, value, traceback): 38 | """Remove the last scope context from the ``Scopes.context.stack``.""" 39 | type(self).get_contexts().pop() 40 | 41 | def __getattr__(self, item): 42 | """Get a ``Scope`` instance's attribute ``item``.""" 43 | return self.__dict__.get(item) 44 | 45 | @classmethod 46 | def get_contexts(cls) -> List: 47 | """Get the ``Scope`` class's context stack list. 48 | 49 | Returns 50 | ------- 51 | contexts : List 52 | If this is made from outside of a ``with Scopes(...)`` statement, 53 | an empty list is returned. In any other case, it returns the 54 | nestedA ``list`` of ``Scope`` instances that are in the context of the 55 | 56 | """ 57 | # no race-condition here, cls.context is a thread-local object 58 | # be sure not to override contexts in a subclass however! 59 | if not hasattr(cls.context, "stack"): 60 | cls.context.stack = [] 61 | return cls.context.stack 62 | 63 | @classmethod 64 | def chain( 65 | cls, 66 | attr: str, 67 | *, 68 | predicate: Callable[[Any], bool] = lambda _: True, 69 | drop_none: bool = False, 70 | leaf: Any = _leaf, 71 | ) -> Generator[Any, None, None]: 72 | """Yield all the values of a scoped attribute starting from the first to the deepest. 73 | 74 | Each ``Scope`` context manager can be used to add any given attribute's value to the context. 75 | This method explores the entire context stack an iterates through the values defined for a 76 | given attribute. It goes through the context stack from the first defined open context 77 | (the outer most scope) to the last ``Scope`` that was entered (inner most scope). 78 | 79 | Parameters 80 | ---------- 81 | cls : Scope 82 | The ``Scope`` subclass. 83 | attr : str 84 | The name of the ``Scope`` attribute to get. 85 | predicate : Callable[[Any], bool] 86 | A function used to filter scope instances encountered in the context stack. Its signature 87 | must take a single input argument and return ``True`` or ``False``. If it returns 88 | ``True``, the ``Scope`` will be processed further. This means that the ``attr`` 89 | attribute's value will be read from the ``Scope`` instance, and said value will be yielded 90 | (depending on ``drop_none``). If ``False``, the encountered scope will be skiped. 91 | By default, all encountered ``Scope`` instances are accepted for further processing. 92 | drop_none : bool 93 | If ``True`` and the ``attr`` value that is retrieved is ``None``, it is skipped. 94 | If ``False``, ``None`` will be yielded. 95 | leaf : Any 96 | A value to yield after having iterated through the entire context stack. By default, 97 | no value is yielded. 98 | 99 | Yields 100 | ------ 101 | Any 102 | The values of the attribute ``attr`` that are defined in the context stack and 103 | optionally, the ``leaf`` value. 104 | 105 | Example 106 | ------- 107 | If we nest several ``Scope`` instances, ``Scope.chain`` can iterate through the 108 | context stack looking for an attribute's value. 109 | 110 | >>> with Scope(var=1): 111 | ... with Scope(var=3): 112 | ... print(list(Scope.chain("var"))) 113 | [1, 3] 114 | 115 | If one of the nested ``Scope`` instance doesn't define an attribute's value or defines 116 | it as ``None``, it is not included in the yielded values by default. 117 | 118 | >>> with Scope(var=1, name="A"): 119 | ... with Scope(var=3): 120 | ... print(list(Scope.chain("name", drop_none=True))) 121 | ['A'] 122 | 123 | If we provide a ``leaf`` value, it will be returned as long as it isn't ``None`` and 124 | at the same time pass ``drop_none=True``. 125 | 126 | >>> with Scope(var=1, name="A"): 127 | ... with Scope(var=3): 128 | ... print(list(Scope.chain("name", leaf="leaf", drop_none=True))) 129 | ['A', 'leaf'] 130 | 131 | """ 132 | for c in cls.get_contexts(): 133 | if predicate(c): 134 | val = getattr(c, attr) 135 | if drop_none and val is None: 136 | continue 137 | else: 138 | yield val 139 | if leaf is not cls._leaf: 140 | if not (drop_none and leaf is None): 141 | yield leaf 142 | 143 | @classmethod 144 | def variable_name(cls, name: Optional[str]) -> Optional[str]: 145 | """ 146 | Generate PyMC4 variable name based on name scope we are currently in. 147 | 148 | Parameters 149 | ---------- 150 | name : Union[str, None] 151 | The desired target name for a variable. If ``None``, it will simply 152 | return the chained scope's ``name`` attribute. 153 | 154 | Returns 155 | ------- 156 | scoped_name : Union[str, None] 157 | If ``name`` is ``None`` and no scope defines the ``name`` attribute, this 158 | function returns ``None``. 159 | 160 | Examples 161 | -------- 162 | >>> with Scope(name="inner"): 163 | ... print(Scope.variable_name("leaf")) 164 | inner/leaf 165 | >>> with Scope(name="inner"): 166 | ... with Scope(): 167 | ... print(Scope.variable_name("leaf1")) 168 | inner/leaf1 169 | 170 | empty name results in None name 171 | >>> assert Scope.variable_name(None) is None 172 | >>> assert Scope.variable_name("") is None 173 | """ 174 | value = "/".join(map(str, cls.chain("name", leaf=name, drop_none=True))) 175 | if not value: 176 | return None 177 | else: 178 | return value 179 | 180 | @classmethod 181 | def transformed_variable_name(cls, transform_name: str, name: str) -> Optional[str]: 182 | """ 183 | Generate PyMC4 transformed variable name based on name scope we are currently in. 184 | 185 | Parameters 186 | ---------- 187 | transform_name : str 188 | The name of the transformation. 189 | name : str 190 | The plain name of the variable. 191 | 192 | Returns 193 | ------- 194 | str : scoped name 195 | This is equivalent to calling :meth:`~.variable_name` with the input 196 | ``"__{transform_name}_{name}"``. 197 | 198 | Examples 199 | -------- 200 | >>> with Scope(name="inner"): 201 | ... print(Scope.variable_name("leaf")) 202 | inner/leaf 203 | >>> with Scope(name="inner"): 204 | ... with Scope(): 205 | ... print(Scope.variable_name("leaf1")) 206 | inner/leaf1 207 | 208 | empty name results in None name 209 | >>> assert Scope.variable_name(None) is None 210 | >>> assert Scope.variable_name("") is None 211 | """ 212 | return cls.variable_name("__{}_{}".format(transform_name, name)) 213 | 214 | def __repr__(self) -> str: 215 | """Return the string representation of a ``Scope`` instance. 216 | 217 | Returns 218 | ------- 219 | str: 220 | Returns ``"Scope({self.__dict__})"``. 221 | """ 222 | return "Scope({})".format(self.__dict__) 223 | 224 | 225 | def name_scope(name: Union[str, None]) -> Scope: 226 | """Create a :class:`~.Scope` instance with a "name" attribute and sets its value to the provided ``name``. 227 | 228 | Parameters 229 | ---------- 230 | name : Union[str, None] 231 | The value that will be set to the ``Scope.name`` attribute. 232 | 233 | Returns 234 | ------- 235 | scope : Scope 236 | A scope instance that only defines the ``name`` attribute. 237 | """ 238 | return Scope(name=name) 239 | 240 | 241 | variable_name = Scope.variable_name 242 | transformed_variable_name = Scope.transformed_variable_name 243 | -------------------------------------------------------------------------------- /pymc4/utils.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous utility functions.""" 2 | import functools 3 | import re 4 | from typing import Callable, Sequence, Optional 5 | import io 6 | import pkgutil 7 | import os 8 | 9 | 10 | def biwrap(wrapper) -> Callable: # noqa 11 | """Allow for optional keyword arguments in lower level decoratrors. 12 | 13 | Notes 14 | ----- 15 | Currently this is only used to wrap pm.Model to capture model runtime flags such as 16 | keep_auxiliary and keep_return. See pm.Model for all possible keyword parameters 17 | 18 | """ 19 | 20 | @functools.wraps(wrapper) 21 | def enhanced(*args, **kwargs) -> Callable: 22 | 23 | # Check if decorated method is bound to a class 24 | is_bound_method = hasattr(args[0], wrapper.__name__) if args else False 25 | if is_bound_method: 26 | # If bound to a class, `self` will be an argument 27 | count = 1 28 | else: 29 | count = 0 30 | if len(args) > count: 31 | # If lower level decorator is not called user model will be an argument 32 | # fill in parameters and call pm.Model 33 | newfn = wrapper(*args, **kwargs) 34 | return newfn 35 | else: 36 | # If lower level decorator is called user model will not be passed in as an argument 37 | # prefill args and kwargs but do not call pm.Model 38 | newwrapper = functools.partial(wrapper, *args, **kwargs) 39 | return newwrapper 40 | 41 | return enhanced 42 | 43 | 44 | class NameParts: 45 | """Class that store names segmented into its three parts. 46 | 47 | A given name is made up by three parts: 48 | 1. ``path``: usually represents the context or scope under which a name 49 | was defined. For example, a distribution inside a model will have its 50 | name's path equal to the model's full name. To conveniently store nested 51 | contexts, the path is usually stored as a tuple of strings. 52 | 2. ``transform``: the name of the transformation that is applied to a 53 | distribution (can be an empty string, meaning no transformation). 54 | 3. ``untransformed_name``: the name, stripped from its path and transform. 55 | This represents, for example, a distribution's plain name, without the 56 | its model's scope or any transformation name. 57 | """ 58 | 59 | NAME_RE = re.compile(r"^(?:__(?P[^_]+)_)?(?P[^_].*)$") 60 | NAME_ERROR_MESSAGE = ( 61 | "Invalid name: `{}`, the correct one should look like: `__transform_name` or `name`, " 62 | "note only one underscore between the transform and actual name" 63 | ) 64 | UNTRANSFORMED_NAME_ERROR_MESSAGE = ( 65 | "Invalid name: `{}`, the correct one should look like: " "`name` without leading underscore" 66 | ) 67 | __slots__ = ("path", "transform_name", "untransformed_name") 68 | 69 | @classmethod 70 | def is_valid_untransformed_name(cls, name: str) -> bool: 71 | """Test if a name can be used as an untransformed random variable. 72 | 73 | This function attempts to test if the supplied name, by accident, 74 | matches the naming pattern used for auto transformed random variables. 75 | If it does not, it is assumed to be a potentially valid name. 76 | 77 | Parameters 78 | ---------- 79 | name : str 80 | The name to test. 81 | 82 | Returns 83 | ------- 84 | bool 85 | ``False`` if the ``name`` matches the pattern used to give names 86 | to auto transformed variables. ``True`` otherwise. 87 | 88 | """ 89 | match = cls.NAME_RE.match(name) 90 | return match is not None and match["transform"] is None 91 | 92 | @classmethod 93 | def is_valid_name(cls, name: str) -> bool: 94 | """Test if a name doesn't contain forbidden symbols. 95 | 96 | Parameters 97 | ---------- 98 | name : str 99 | The name to test. 100 | 101 | Returns 102 | ------- 103 | bool 104 | ``True`` if the ``name`` doesn't have forbidden symbols, ``False`` 105 | otherwise. 106 | 107 | """ 108 | match = cls.NAME_RE.match(name) 109 | return match is not None 110 | 111 | def __init__( 112 | self, 113 | path: Sequence[str], 114 | transform_name: Optional[str], 115 | untransformed_name: str, 116 | ): 117 | """Initialize a ``NameParts`` instance from its parts. 118 | 119 | Parameters 120 | ---------- 121 | path : Sequence[str] 122 | The path part of the name. This is a sequence of 123 | strings, each indicating a deeper layer in the path hierarchy. 124 | transform_name : Optional[str] 125 | The name of the applied transformation. ``None`` means no 126 | transformation was applied. 127 | untransformed_name : str 128 | The plain part of the name. 129 | 130 | """ 131 | self.path = tuple(path) 132 | self.untransformed_name = untransformed_name 133 | self.transform_name = transform_name 134 | 135 | @classmethod 136 | def from_name(cls, name: str) -> "NameParts": 137 | """Split a provided name into its parts and return them as ``NameParts``. 138 | 139 | Parameters 140 | ---------- 141 | name : str 142 | The name that must be segmented into parts. 143 | 144 | Raises 145 | ------ 146 | ValueError 147 | If the provided name is not valid. 148 | 149 | Returns 150 | ------- 151 | NameParts 152 | The parts of the provided ``name`` are used to construct a 153 | ``NameParts`` instance which is returned. 154 | 155 | """ 156 | split = name.split("/") 157 | path, original_name = split[:-1], split[-1] 158 | match = cls.NAME_RE.match(original_name) 159 | if not cls.is_valid_name(name): 160 | raise ValueError(cls.NAME_ERROR_MESSAGE.format(name)) 161 | return cls(path, match["transform"], match["name"]) # type: ignore 162 | 163 | @property 164 | def original_name(self) -> str: 165 | """Return the name of the distribution without its preceeding path. 166 | 167 | Returns 168 | ------- 169 | str 170 | The original name. This will include the transform and the 171 | untransformed parts of the name. 172 | """ 173 | if self.is_transformed: 174 | return "__{}_{}".format(self.transform_name, self.untransformed_name) 175 | else: 176 | return self.untransformed_name 177 | 178 | @property 179 | def full_original_name(self) -> str: 180 | """Return the full name of the distribution with all three parts. 181 | 182 | Returns 183 | ------- 184 | str 185 | The full name. This will include the path, transform and the 186 | untransformed parts of the name. 187 | """ 188 | return "/".join(self.path + (self.original_name,)) 189 | 190 | @property 191 | def full_untransformed_name(self) -> str: 192 | """Return the name of the distribution without its transform part. 193 | 194 | Returns 195 | ------- 196 | str 197 | The path and the untransformed_name joined by a slash. 198 | """ 199 | return "/".join(self.path + (self.untransformed_name,)) 200 | 201 | @property 202 | def is_transformed(self) -> bool: 203 | """Return ``True`` if the ``transform`` part of the name is not ``None``.""" 204 | return self.transform_name is not None 205 | 206 | def __repr__(self) -> str: 207 | """Return the ``NameParts`` ``full_original_name`` string representation.""" 208 | return "".format(self.full_original_name) 209 | 210 | def replace_transform(self, transform_name): 211 | """Replace the transform part of the name and return a new NameParts instance.""" 212 | return self.__class__(self.path, transform_name, self.untransformed_name) 213 | 214 | 215 | def get_data(filename): 216 | """Return a BytesIO object for a package data file. 217 | 218 | Parameters 219 | ---------- 220 | filename : str 221 | file to load 222 | Returns 223 | ------- 224 | BytesIO of the data 225 | """ 226 | data_pkg = "notebooks" 227 | return io.BytesIO(pkgutil.get_data(data_pkg, os.path.join("data", filename))) 228 | -------------------------------------------------------------------------------- /pymc4/variational/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for Variational Inference.""" 2 | from .approximations import * 3 | from .updates import * 4 | -------------------------------------------------------------------------------- /pymc4/variational/approximations.py: -------------------------------------------------------------------------------- 1 | """Implements ADVI approximations.""" 2 | from typing import Optional, Union 3 | from collections import namedtuple 4 | 5 | import arviz as az 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_probability as tfp 9 | from tensorflow_probability.python.internal import dtype_util 10 | 11 | from pymc4 import flow 12 | from pymc4.coroutine_model import Model 13 | from pymc4.mcmc.utils import initialize_sampling_state 14 | from pymc4.mcmc.samplers import calculate_log_likelihood 15 | from pymc4.utils import NameParts 16 | from pymc4.variational import updates 17 | from pymc4.variational.util import ArrayOrdering 18 | 19 | tfd = tfp.distributions 20 | tfb = tfp.bijectors 21 | ADVIFit = namedtuple("ADVIFit", "approximation, losses") 22 | 23 | 24 | class Approximation(tf.Module): 25 | """Base Approximation class.""" 26 | 27 | def __init__(self, model: Optional[Model] = None, random_seed: Optional[int] = None): 28 | if not isinstance(model, Model): 29 | raise TypeError( 30 | "`fit` function only supports `pymc4.Model` objects, but you've passed `{}`".format( 31 | type(model) 32 | ) 33 | ) 34 | 35 | self.model = model 36 | self._seed = random_seed 37 | self.state, self.deterministic_names = initialize_sampling_state(model) 38 | if not self.state.all_unobserved_values: 39 | raise ValueError( 40 | f"Can not calculate a log probability: the model {model.name or ''} has no unobserved values." 41 | ) 42 | 43 | self.order = ArrayOrdering(self.state.all_unobserved_values) 44 | self.unobserved_keys = self.state.all_unobserved_values.keys() 45 | ( 46 | self.target_log_prob, 47 | self.deterministics_callback, 48 | ) = self._build_logp_and_deterministic_fn() 49 | self.approx = self._build_posterior() 50 | 51 | def _build_logp_and_deterministic_fn(self): 52 | """Build vectorized logp and deterministic functions.""" 53 | 54 | @tf.function(autograph=False) 55 | def logpfn(*values): 56 | split_view = self.order.split(values[0]) 57 | _, st = flow.evaluate_meta_model(self.model, values=split_view) 58 | return st.collect_log_prob() 59 | 60 | @tf.function(autograph=False) 61 | def deterministics_callback(q_samples): 62 | st = flow.SamplingState.from_values( 63 | q_samples, observed_values=self.state.observed_values 64 | ) 65 | _, st = flow.evaluate_model_transformed(self.model, state=st) 66 | for transformed_name in st.transformed_values: 67 | untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name 68 | st.deterministics_values[untransformed_name] = st.untransformed_values.pop( 69 | untransformed_name 70 | ) 71 | return st.deterministics_values 72 | 73 | def vectorize_function(function): 74 | def vectorizedfn(*q_samples): 75 | return tf.vectorized_map(lambda samples: function(*samples), q_samples) 76 | 77 | return vectorizedfn 78 | 79 | return vectorize_function(logpfn), vectorize_function(deterministics_callback) 80 | 81 | def _build_posterior(self): 82 | raise NotImplementedError 83 | 84 | def sample(self, n: int = 500, include_log_likelihood: bool = False) -> az.InferenceData: 85 | """Generate samples from posterior distribution.""" 86 | samples = self.approx.sample(n) 87 | q_samples = self.order.split_samples(samples, n) 88 | q_samples = dict(**q_samples, **self.deterministics_callback(q_samples)) 89 | 90 | # Add a new axis so as n_chains=1 for InferenceData: handles shape issues 91 | trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()} 92 | log_likelihood_dict = dict() 93 | if include_log_likelihood: 94 | log_likelihood_dict = calculate_log_likelihood(self.model, trace, self.state) 95 | 96 | trace = az.from_dict( 97 | trace, 98 | observed_data=self.state.observed_values, 99 | log_likelihood=log_likelihood_dict if include_log_likelihood else None, 100 | ) 101 | return trace 102 | 103 | 104 | class MeanField(Approximation): 105 | """ 106 | Mean Field ADVI. 107 | 108 | This class implements Mean Field Automatic Differentiation Variational Inference. It posits spherical 109 | Gaussian family to fit posterior. And assumes the parameters to be uncorrelated. 110 | 111 | References 112 | ---------- 113 | - Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., 114 | and Blei, D. M. (2016). Automatic Differentiation Variational 115 | Inference. arXiv preprint arXiv:1603.00788. 116 | """ 117 | 118 | def _build_posterior(self): 119 | flattened_shape = self.order.size 120 | dtype = dtype_util.common_dtype( 121 | self.state.all_unobserved_values.values(), dtype_hint=tf.float64 122 | ) 123 | loc = tf.Variable(tf.random.normal([flattened_shape], dtype=dtype), name="mu") 124 | cov_param = tfp.util.TransformedVariable( 125 | tf.ones(flattened_shape, dtype=dtype), tfb.Softplus(), name="sigma" 126 | ) 127 | advi_approx = tfd.MultivariateNormalDiag(loc=loc, scale_diag=cov_param) 128 | return advi_approx 129 | 130 | 131 | class FullRank(Approximation): 132 | """ 133 | Full Rank ADVI. 134 | 135 | This class implements Full Rank Automatic Differentiation Variational Inference. It posits Multivariate 136 | Gaussian family to fit posterior. And estimates a full covariance matrix. As a result, it comes with 137 | higher computation costs. 138 | 139 | References 140 | ---------- 141 | - Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., 142 | and Blei, D. M. (2016). Automatic Differentiation Variational 143 | Inference. arXiv preprint arXiv:1603.00788. 144 | """ 145 | 146 | def _build_posterior(self): 147 | flattened_shape = self.order.size 148 | dtype = dtype_util.common_dtype( 149 | self.state.all_unobserved_values.values(), dtype_hint=tf.float64 150 | ) 151 | loc = tf.Variable(tf.random.normal([flattened_shape], dtype=dtype), name="mu") 152 | scale_tril = tfb.FillScaleTriL( 153 | diag_bijector=tfb.Chain( 154 | [ 155 | tfb.Shift(tf.cast(1e-3, dtype)), # diagonal offset 156 | tfb.Softplus(), 157 | tfb.Shift(tf.cast(np.log(np.expm1(1.0)), dtype)), # initial scale 158 | ] 159 | ), 160 | diag_shift=None, 161 | ) 162 | 163 | cov_matrix = tfp.util.TransformedVariable( 164 | tf.eye(flattened_shape, dtype=dtype), scale_tril, name="sigma" 165 | ) 166 | return tfd.MultivariateNormalTriL(loc=loc, scale_tril=cov_matrix) 167 | 168 | 169 | class LowRank(Approximation): 170 | """Low Rank Automatic Differential Variational Inference(Low Rank ADVI).""" 171 | 172 | def _build_posterior(self): 173 | raise NotImplementedError 174 | 175 | 176 | def fit( 177 | model: Optional[Model] = None, 178 | method: Union[str, MeanField, FullRank] = "advi", 179 | num_steps: int = 10000, 180 | sample_size: int = 1, 181 | random_seed: Optional[int] = None, 182 | optimizer=None, 183 | **kwargs, 184 | ): 185 | """ 186 | Fit an approximating distribution to log_prob of the model. 187 | 188 | Parameters 189 | ---------- 190 | model : Optional[:class:`Model`] 191 | Model to fit posterior against 192 | method : Union[str, :class:`Approximation`] 193 | Method to fit model using VI 194 | 195 | - 'advi' for :class:`MeanField` 196 | - 'fullrank_advi' for :class:`FullRank` 197 | - 'lowrank_advi' for :class:`LowRank` 198 | - or directly pass in :class:`Approximation` instance 199 | num_steps : int 200 | Number of iterations to run the optimizer 201 | sample_size : int 202 | Number of Monte Carlo samples used for approximation 203 | random_seed : Optional[int] 204 | Seed for tensorflow random number generator 205 | optimizer : TF1-style | TF2-style | from pymc4/variational/updates 206 | Tensorflow optimizer to use 207 | kwargs : Optional[Dict[str, Any]] 208 | Pass extra non-default arguments to 209 | ``tensorflow_probability.vi.fit_surrogate_posterior`` 210 | 211 | Returns 212 | ------- 213 | ADVIFit : collections.namedtuple 214 | Named tuple, including approximation, ELBO losses depending on the `trace_fn` 215 | """ 216 | _select = dict(advi=MeanField, fullrank_advi=FullRank) 217 | 218 | if isinstance(method, str): 219 | # Here we assume that `model` parameter is provided by the user. 220 | try: 221 | inference = _select[method.lower()](model, random_seed) 222 | except KeyError: 223 | raise KeyError( 224 | "method should be one of %s or Approximation instance" % set(_select.keys()) 225 | ) 226 | 227 | elif isinstance(method, Approximation): 228 | # Here we assume that `model` parameter is not provided by the user 229 | # as the :class:`Approximation` itself contains :class:`Model` instance. 230 | inference = method 231 | 232 | else: 233 | raise TypeError( 234 | "method should be one of %s or Approximation instance" % set(_select.keys()) 235 | ) 236 | 237 | # Defining `opt = optimizer or updates.adam()` 238 | # leads to optimizer initialization issues from tf. 239 | if optimizer: 240 | opt = optimizer 241 | else: 242 | opt = updates.adam() 243 | 244 | @tf.function(autograph=False) 245 | def run_approximation(): 246 | losses = tfp.vi.fit_surrogate_posterior( 247 | target_log_prob_fn=inference.target_log_prob, 248 | surrogate_posterior=inference.approx, 249 | num_steps=num_steps, 250 | sample_size=sample_size, 251 | seed=random_seed, 252 | optimizer=opt, 253 | **kwargs, 254 | ) 255 | return losses 256 | 257 | return ADVIFit(inference, run_approximation()) 258 | -------------------------------------------------------------------------------- /pymc4/variational/updates.py: -------------------------------------------------------------------------------- 1 | """Optimizers for ELBO convergence. 2 | 3 | These optimizers wrap tf.optimizers with defaults from PyMC3. 4 | """ 5 | import tensorflow as tf 6 | 7 | 8 | def adadelta( 9 | learning_rate: float = 1.0, rho: float = 0.95, epsilon: float = 1e-6, **kwargs 10 | ) -> tf.optimizers.Adadelta: 11 | r"""Adadelta optimizer. 12 | 13 | Parameters 14 | ---------- 15 | learning_rate : float 16 | Learning rate 17 | rho : float 18 | Squared gradient moving average decay factor 19 | epsilon : float 20 | Small value added for numerical stability 21 | 22 | Returns 23 | ------- 24 | tf.optimizers.Adadelta 25 | 26 | Notes 27 | ----- 28 | rho should be between 0 and 1. A value of rho close to 1 will decay the 29 | moving average slowly and a value close to 0 will decay the moving average 30 | fast. 31 | 32 | rho = 0.95 and epsilon=1e-6 are suggested in the paper and reported to 33 | work for multiple datasets (MNIST, speech). 34 | 35 | In the paper, no learning rate is considered (so learning_rate=1.0). 36 | Probably best to keep it at this value. 37 | epsilon is important for the very first update (so the numerator does 38 | not become 0). 39 | 40 | Using the step size eta and a decay factor rho the learning rate is 41 | calculated as: 42 | 43 | .. math:: 44 | r_t &= \rho r_{t-1} + (1-\rho)*g^2\\ 45 | \eta_t &= \eta \frac{\sqrt{s_{t-1} + \epsilon}} 46 | {\sqrt{r_t + \epsilon}}\\ 47 | s_t &= \rho s_{t-1} + (1-\rho)*(\eta_t*g)^2 48 | 49 | References 50 | ---------- 51 | .. [1] Zeiler, M. D. (2012): 52 | ADADELTA: An Adaptive Learning Rate Method. 53 | arXiv Preprint arXiv:1212.5701. 54 | """ 55 | return tf.optimizers.Adadelta(learning_rate=learning_rate, rho=rho, epsilon=epsilon, **kwargs) 56 | 57 | 58 | def adagrad(learning_rate: float = 1.0, epsilon: float = 1e-6, **kwargs) -> tf.optimizers.Adagrad: 59 | r"""Adagrad optimizer. 60 | 61 | Parameters 62 | ---------- 63 | learning_rate : float or symbolic scalar 64 | Learning rate 65 | epsilon : float or symbolic scalar 66 | Small value added for numerical stability 67 | 68 | Returns 69 | ------- 70 | tf.optimizers.Adagrad 71 | 72 | Notes 73 | ----- 74 | Using step size eta Adagrad calculates the learning rate for feature i at 75 | time step t as: 76 | 77 | .. math:: \eta_{t,i} = \frac{\eta} 78 | {\sqrt{\sum^t_{t^\prime} g^2_{t^\prime,i}+\epsilon}} g_{t,i} 79 | 80 | as such the learning rate is monotonically decreasing. 81 | 82 | Epsilon is not included in the typical formula, see [2]_. 83 | 84 | References 85 | ---------- 86 | .. [1] Duchi, J., Hazan, E., & Singer, Y. (2011): 87 | Adaptive subgradient methods for online learning and stochastic 88 | optimization. JMLR, 12:2121-2159. 89 | 90 | .. [2] Chris Dyer: 91 | Notes on AdaGrad. http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf 92 | """ 93 | return tf.optimizers.Adagrad(learning_rate=learning_rate, epsilon=epsilon, **kwargs) 94 | 95 | 96 | def adam( 97 | learning_rate: float = 0.001, 98 | beta_1: float = 0.9, 99 | beta_2: float = 0.999, 100 | epsilon: float = 1e-8, 101 | **kwargs, 102 | ) -> tf.optimizers.Adam: 103 | """Adam optimizer. 104 | 105 | Parameters 106 | ---------- 107 | learning_rate : float 108 | Learning rate 109 | beta_1 : float 110 | Exponential decay rate for the first moment estimates 111 | beta_2 : float 112 | Exponential decay rate for the second moment estimates 113 | epsilon : float 114 | Constant for numerical stability 115 | 116 | Returns 117 | ------- 118 | tf.optimizers.Adam 119 | 120 | Notes 121 | ----- 122 | The paper [1]_ includes an additional hyperparameter lambda. This is only 123 | needed to prove convergence of the algorithm and has no practical use 124 | (personal communication with the authors), it is therefore omitted here. 125 | 126 | References 127 | ---------- 128 | .. [1] Kingma, Diederik, and Jimmy Ba (2014): 129 | Adam: A Method for Stochastic Optimization. 130 | arXiv preprint arXiv:1412.6980. 131 | """ 132 | return tf.optimizers.Adam( 133 | learning_rate=learning_rate, 134 | beta_1=beta_1, 135 | beta_2=beta_2, 136 | epsilon=epsilon, 137 | **kwargs, 138 | ) 139 | 140 | 141 | def adamax( 142 | learning_rate: float = 0.002, 143 | beta_1: float = 0.9, 144 | beta_2: float = 0.999, 145 | epsilon: float = 1e-8, 146 | **kwargs, 147 | ) -> tf.optimizers.Adamax: 148 | """Adamax optimizer. 149 | 150 | Parameters 151 | ---------- 152 | learning_rate : float 153 | Learning rate 154 | beta_1 : float 155 | Exponential decay rate for the first moment estimates 156 | beta_2 : float 157 | Exponential decay rate for the second moment estimates 158 | epsilon : float 159 | Constant for numerical stability 160 | 161 | Returns 162 | ------- 163 | tf.optimizers.Adamax 164 | 165 | References 166 | ---------- 167 | .. [1] Kingma, Diederik, and Jimmy Ba (2014): 168 | Adam: A Method for Stochastic Optimization. 169 | arXiv preprint arXiv:1412.6980. 170 | """ 171 | return tf.optimizers.Adamax( 172 | learning_rate=learning_rate, 173 | beta_1=beta_1, 174 | beta_2=beta_2, 175 | epsilon=epsilon, 176 | **kwargs, 177 | ) 178 | 179 | 180 | def sgd(learning_rate: float = 1e-3, **kwargs) -> tf.optimizers.SGD: 181 | """SGD optimizer. 182 | 183 | Generates update expressions of the form: 184 | * ``param := param - learning_rate * gradient`` 185 | 186 | Parameters 187 | ---------- 188 | learning_rate : float 189 | Learning rate 190 | 191 | Returns 192 | ------- 193 | tf.optimizers.SGD 194 | """ 195 | return tf.optimizers.SGD(learning_rate=learning_rate, **kwargs) 196 | -------------------------------------------------------------------------------- /pymc4/variational/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions/classes for Variational Inference.""" 2 | import collections 3 | import numpy as np 4 | import tensorflow as tf 5 | from typing import Dict 6 | 7 | VarMap = collections.namedtuple("VarMap", "var, slc, shp, dtyp") 8 | 9 | 10 | class ArrayOrdering: 11 | """ 12 | An ordering for an array space. 13 | 14 | Parameters 15 | ---------- 16 | free_rvs : dict 17 | Free random variables of the model 18 | """ 19 | 20 | def __init__(self, free_rvs): 21 | self.free_rvs = free_rvs 22 | self.by_name = {} 23 | self.size = 0 24 | 25 | for name, tensor in free_rvs.items(): 26 | flat_shape = int(np.prod(tensor.shape.as_list())) 27 | slc = slice(self.size, self.size + flat_shape) 28 | self.by_name[name] = VarMap(name, slc, tensor.shape, tensor.dtype) 29 | self.size += flat_shape 30 | 31 | def flatten(self) -> tf.Tensor: 32 | """Flattened view of parameters.""" 33 | flattened_tensor = [tf.reshape(var, shape=[-1]) for var in self.free_rvs.values()] 34 | return tf.concat(flattened_tensor, axis=0) 35 | 36 | def split(self, flatten_tensor: tf.Tensor) -> Dict[str, tf.Tensor]: 37 | """Split view of parameters used to calculate log probability.""" 38 | split_view = dict() 39 | for param in self.free_rvs: 40 | _, slc, shape, dtype = self.by_name[param] 41 | split_view[param] = tf.cast(tf.reshape(flatten_tensor[slc], shape), dtype) 42 | return split_view 43 | 44 | def split_samples(self, samples: tf.Tensor, n: int): 45 | """Split view of samples after drawing samples from posterior.""" 46 | q_samples = dict() 47 | for param in self.free_rvs.keys(): 48 | _, slc, shp, dtype = self.by_name[param] 49 | q_samples[param] = tf.cast( 50 | tf.reshape(samples[..., slc], tf.TensorShape([n] + shp.as_list())), 51 | dtype, 52 | ) 53 | return q_samples 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | target-version = ['py36', 'py37'] -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | mypy 2 | pydocstyle 3 | pytest 4 | pytest-cov 5 | pytest-html 6 | pylint 7 | black 8 | mock 9 | jupyter 10 | jupyterlab 11 | 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | arviz>=0.5.1 2 | gast>=0.3.2 3 | tf-nightly 4 | tfp-nightly 5 | pymc3 6 | scipy>=0.18.1 7 | -------------------------------------------------------------------------------- /scripts/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM conda/miniconda3 2 | 3 | LABEL maintainer="PyMC Devs https://github.com/pymc-devs/pymc4" 4 | 5 | ARG SRC_DIR 6 | ARG PYTHON_VERSION 7 | 8 | ENV PYTHON_VERSION=${PYTHON_VERSION} 9 | 10 | 11 | # Change behavior of create_test.sh script 12 | ENV DOCKER_BUILD=true 13 | 14 | # For Sphinx documentation builds 15 | ENV LC_ALL=C.UTF-8 16 | ENV LANG=C.UTF-8 17 | 18 | # Update container 19 | RUN apt-get update && apt-get install -y git build-essential pandoc vim \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | 23 | # Copy requirements and environment installation scripts 24 | COPY $SRC_DIR/requirements.txt opt/pymc4/ 25 | COPY $SRC_DIR/requirements-dev.txt opt/pymc4/ 26 | COPY $SRC_DIR/scripts/ opt/pymc4/scripts 27 | WORKDIR /opt/pymc4 28 | 29 | 30 | # Create conda environment. Defaults to Python 3.6 31 | RUN ./scripts/create_testenv.sh 32 | 33 | 34 | # Set automatic conda activation in non interactive and shells 35 | ENV BASH_ENV="/root/activate_conda.sh" 36 | RUN echo ". /root/activate_conda.sh" > /root/.bashrc 37 | 38 | 39 | # Remove conda cache 40 | RUN conda clean --all 41 | 42 | COPY $SRC_DIR /opt/pymc4 43 | 44 | # Clear any cached files from copied from host filesystem 45 | RUN find -type d -name __pycache__ -exec rm -rf {} + 46 | 47 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Docker FAQ 2 | 3 | * Image can be built locally using the command `make docker` or the command 4 | `./scripts/container.sh --build` from the root `pymc4` directory 5 | 6 | * After image is built an interactive bash session can be run 7 | `docker run -it pymc4 bash` 8 | 9 | * Command can be issued to the container such as linting and testing 10 | without interactive session 11 | * `docker run pymc4 bash -c "pytest pymc4/tests"` 12 | * `docker run pymc4 bash -c "./scripts/lint.sh"` 13 | -------------------------------------------------------------------------------- /scripts/container.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | for %%a in ("%%~dp0:~0,-1") do set "SRC_DIR=%%~dpa" 4 | 5 | for %%a in (%*) do ( 6 | if "%%a" == "--build" ( 7 | docker build -t pymc4 ^ 8 | -f %SRC_DIR%\scripts\Dockerfile ^ 9 | --build-arg SRC_DIR=. %SRC_DIR% ^ 10 | --rm 11 | ) 12 | ) 13 | 14 | for %%a in (%*) do ( 15 | if "%%a" == "--clear_cache" ( 16 | for /R %cd% %%G IN (__pycache__) do rmdir /s /q %%G 17 | ) 18 | ) 19 | 20 | for %%a in (%*) do ( 21 | if "%%a" == "--test" ( 22 | docker run --mount type=bind,source=%cd%,target=/opt/pymc4/ pymc4:latest bash -c ^ 23 | "pytest -v pymc4 tests --doctest-modules --cov=pymc4/" 24 | ) 25 | ) 26 | 27 | EXIT /B 0 -------------------------------------------------------------------------------- /scripts/container.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | SRC_DIR=${SRC_DIR:-`pwd`} 3 | 4 | # Build container for use of testing 5 | if [[ $* == *--build* ]]; then 6 | echo "Building Docker Image" 7 | docker build \ 8 | -t pymc4 \ 9 | -f $SRC_DIR/scripts/Dockerfile \ 10 | --build-arg SRC_DIR=. $SRC_DIR \ 11 | --rm 12 | fi 13 | 14 | if [[ $* == *--clear_cache* ]]; then 15 | echo "Removing cached files" 16 | find -type d -name __pycache__ -exec rm -rf {} + 17 | fi 18 | 19 | if [[ $* == *--test* ]]; then 20 | echo "Testing PyMC4" 21 | docker run --mount type=bind,source="$(pwd)",target=/opt/pymc4/ pymc4:latest bash -c \ 22 | "pytest -v pymc4 tests --doctest-modules --cov=pymc4/" 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/create_testenv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ex # fail on first error, print commands 4 | 5 | command -v conda >/dev/null 2>&1 || { 6 | echo "Requires conda but it is not installed. Run install_miniconda.sh." >&2; 7 | exit 1; 8 | } 9 | 10 | # if no python specified, use Travis version, or else 3.6 11 | PYTHON_VERSION=${PYTHON_VERSION:-${TRAVIS_PYTHON_VERSION:-3.6}} 12 | 13 | 14 | if [[ $* != *--global* ]]; then 15 | ENVNAME="testenv_${PYTHON_VERSION}" 16 | 17 | if conda env list | grep -q ${ENVNAME} 18 | then 19 | echo "Environment ${ENVNAME} already exists, keeping up to date" 20 | else 21 | echo "Creating environment ${ENVNAME}" 22 | conda create -n ${ENVNAME} --yes pip python=${PYTHON_VERSION} 23 | fi 24 | 25 | # Activate environment immediately 26 | source activate ${ENVNAME} 27 | 28 | if [ "$DOCKER_BUILD" = true ] ; then 29 | # Also add it to root bash settings to set default if used later 30 | 31 | echo "Creating .bashrc profile for docker image" 32 | echo "set conda_env=${ENVNAME}" > /root/activate_conda.sh 33 | echo "source activate ${ENVNAME}" >> /root/activate_conda.sh 34 | 35 | 36 | fi 37 | fi 38 | 39 | 40 | # Install PyMC4 dependencies 41 | pip install --upgrade pip 42 | 43 | 44 | # Install editable using the setup.py 45 | pip install --no-cache-dir -r requirements.txt 46 | pip install --no-cache-dir -r requirements-dev.txt 47 | -------------------------------------------------------------------------------- /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -ex # fail on first error, print commands 4 | 5 | SRC_DIR=${SRC_DIR:-$(pwd)} 6 | 7 | echo "Skipping documentation check. Re-enabling this would be a helpful contribution!" 8 | # echo "Checking documentation..." 9 | # python -m pydocstyle --convention=numpy "${SRC_DIR}"/pymc4/ 10 | echo "Success!" 11 | 12 | echo "Checking code style with black..." 13 | python -m black -l 100 --check "${SRC_DIR}"/pymc4/ "${SRC_DIR}"/tests/ 14 | echo "Success!" 15 | 16 | echo "Type checking with mypy..." 17 | python -m mypy --ignore-missing-imports "${SRC_DIR}"/pymc4/ 18 | 19 | echo "Checking code style with pylint..." 20 | python -m pylint "${SRC_DIR}"/pymc4/ "${SRC_DIR}"/tests/ 21 | echo "Success!" 22 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pydocstyle] 2 | # Ignore errors for missing docstrings. 3 | # Ignore D202 (No blank lines allowed after function docstring) 4 | # due to bug in black: https://github.com/ambv/black/issues/355 5 | add-ignore = D100,D101,D102,D103,D104,D105,D106,D107,D202 6 | convention = numpy 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import re 3 | from pathlib import Path 4 | 5 | from setuptools import setup, find_packages 6 | 7 | 8 | PROJECT_ROOT = Path(__file__).resolve().parent 9 | REQUIREMENTS_FILE = PROJECT_ROOT / "requirements.txt" 10 | REQUIREMENTS_DEV_FILE = PROJECT_ROOT / "requirements-dev.txt" 11 | README_FILE = PROJECT_ROOT / "README.md" 12 | VERSION_FILE = PROJECT_ROOT / "pymc4" / "__init__.py" 13 | 14 | NAME = "pymc4" 15 | DESCRIPTION = "A Python probabilistic programming interface to TensorFlow, for Bayesian modelling and machine learning." 16 | AUTHOR = "PyMC Developers" 17 | AUTHOR_EMAIL = "pymc.devs@gmail.com" 18 | URL = "https://github.com/pymc-devs/pymc4" 19 | LICENSE = "Apache License, Version 2.0" 20 | 21 | CLASSIFIERS = [ 22 | "Programming Language :: Python", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.6", 25 | "Programming Language :: Python :: 3.7", 26 | "License :: OSI Approved :: Apache Software License", 27 | "Intended Audience :: Science/Research", 28 | "Topic :: Scientific/Engineering", 29 | "Topic :: Scientific/Engineering :: Mathematics", 30 | "Operating System :: OS Independent", 31 | ] 32 | 33 | 34 | def get_requirements(path): 35 | with codecs.open(path) as buff: 36 | return buff.read().splitlines() 37 | 38 | 39 | def get_long_description(): 40 | with codecs.open(README_FILE, "rt") as buff: 41 | return buff.read() 42 | 43 | 44 | def get_version(): 45 | lines = open(VERSION_FILE, "rt").readlines() 46 | version_regex = r"^__version__ = ['\"]([^'\"]*)['\"]" 47 | for line in lines: 48 | mo = re.search(version_regex, line, re.M) 49 | if mo: 50 | return mo.group(1) 51 | raise RuntimeError("Unable to find version in %s." % (VERSION_FILE,)) 52 | 53 | 54 | if __name__ == "__main__": 55 | setup( 56 | name=NAME, 57 | version=get_version(), 58 | description=DESCRIPTION, 59 | license=LICENSE, 60 | author=AUTHOR, 61 | author_email=AUTHOR_EMAIL, 62 | url=URL, 63 | classifiers=CLASSIFIERS, 64 | packages=find_packages(), 65 | install_requires=get_requirements(REQUIREMENTS_FILE), 66 | tests_require=get_requirements(REQUIREMENTS_DEV_FILE), 67 | long_description=get_long_description(), 68 | long_description_content_type="text/markdown", 69 | include_package_data=True, 70 | ) 71 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc4/10b5854219786dac3337ad2127683c62d891e1ea/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """PyMC4 test configuration.""" 2 | import pytest 3 | import pymc4 as pm 4 | import numpy as np 5 | import tensorflow as tf 6 | import itertools 7 | 8 | # Tensor shapes on which the GP model will be tested 9 | BATCH_AND_FEATURE_SHAPES = [ 10 | (1,), 11 | (2,), 12 | ( 13 | 2, 14 | 2, 15 | ), 16 | ] 17 | SAMPLE_SHAPE = [(1,), (3,)] 18 | 19 | 20 | @pytest.fixture(scope="function", autouse=True) 21 | def tf_seed(): 22 | tf.random.set_seed(37208) # random.org 23 | yield 24 | 25 | 26 | @pytest.fixture(scope="function") 27 | def simple_model(): 28 | @pm.model() 29 | def simple_model(): 30 | norm = yield pm.Normal("norm", 0, 1) 31 | return norm 32 | 33 | return simple_model 34 | 35 | 36 | @pytest.fixture(scope="function") 37 | def simple_model_with_deterministic(simple_model): 38 | @pm.model() 39 | def simple_model_with_deterministic(): 40 | norm = yield simple_model() 41 | determ = yield pm.Deterministic("determ", norm * 2) 42 | return determ 43 | 44 | return simple_model_with_deterministic 45 | 46 | 47 | @pytest.fixture(scope="function") 48 | def simple_model_no_free_rvs(): 49 | @pm.model() 50 | def simple_model_no_free_rvs(): 51 | norm = yield pm.Normal("norm", 0, 1, observed=1) 52 | return norm 53 | 54 | return simple_model_no_free_rvs 55 | 56 | 57 | @pytest.fixture( 58 | scope="function", 59 | params=itertools.product( 60 | [(), (3,), (3, 2)], [(), (2,), (4,), (5, 4)], [(), (1,), (10,), (10, 10)] 61 | ), 62 | ids=str, 63 | ) 64 | def unvectorized_model(request): 65 | norm_shape, observed_shape, batch_size = request.param 66 | observed = np.ones(observed_shape) 67 | 68 | @pm.model() 69 | def unvectorized_model(): 70 | norm = yield pm.Normal("norm", 0, 1, batch_stack=norm_shape) 71 | determ = yield pm.Deterministic("determ", tf.reduce_max(norm)) 72 | output = yield pm.Normal("output", determ, 1, observed=observed) 73 | 74 | return unvectorized_model, norm_shape, observed, batch_size 75 | 76 | 77 | @pytest.fixture(scope="module", params=["XLA", "noXLA"], ids=str) 78 | def xla_fixture(request): 79 | return request.param == "XLA" 80 | 81 | 82 | @pytest.fixture(scope="function") 83 | def deterministics_in_nested_models(): 84 | @pm.model 85 | def nested_model(cond): 86 | x = yield pm.Normal("x", cond, 1) 87 | dx = yield pm.Deterministic("dx", x + 1) 88 | return dx 89 | 90 | @pm.model 91 | def outer_model(): 92 | cond = yield pm.HalfNormal("cond", 1) 93 | dcond = yield pm.Deterministic("dcond", cond * 2) 94 | dx = yield nested_model(dcond) 95 | ddx = yield pm.Deterministic("ddx", dx) 96 | return ddx 97 | 98 | expected_untransformed = { 99 | "outer_model", 100 | "outer_model/cond", 101 | "outer_model/nested_model", 102 | "outer_model/nested_model/x", 103 | } 104 | expected_transformed = {"outer_model/__log_cond"} 105 | expected_deterministics = { 106 | "outer_model/dcond", 107 | "outer_model/ddx", 108 | "outer_model/nested_model/dx", 109 | } 110 | deterministic_mapping = { 111 | "outer_model/dcond": (["outer_model/__log_cond"], lambda x: np.exp(x) * 2), 112 | "outer_model/ddx": (["outer_model/nested_model/dx"], lambda x: x), 113 | "outer_model/nested_model/dx": ( 114 | ["outer_model/nested_model/x"], 115 | lambda x: x + 1, 116 | ), 117 | } 118 | 119 | return ( 120 | outer_model, 121 | expected_untransformed, 122 | expected_transformed, 123 | expected_deterministics, 124 | deterministic_mapping, 125 | ) 126 | 127 | 128 | @pytest.fixture(scope="module", params=["auto_batch", "trust_manual_batching"], ids=str) 129 | def use_auto_batching_fixture(request): 130 | return request.param == "auto_batch" 131 | 132 | 133 | @pytest.fixture(scope="function", params=["unvectorized_model", "vectorized_model"], ids=str) 134 | def vectorized_model_fixture(request): 135 | is_vectorized_model = request.param == "vectorized_model" 136 | observed = np.zeros((5, 4), dtype="float32") 137 | core_shapes = { 138 | "model/mu": (4,), 139 | "model/__log_scale": (), 140 | } 141 | if is_vectorized_model: 142 | # A model where we pay great attention to making each distribution 143 | # have exactly the right event_shape, and assure that when we sample 144 | # from its prior, the requested `sample_shape` gets sent to the 145 | # conditionally independent variables, and expect that shape to go 146 | # through the conditionally dependent variables as batch_shapes 147 | @pm.model 148 | def model(): 149 | mu = yield pm.Normal( 150 | "mu", 151 | tf.zeros(4), 152 | 1, 153 | conditionally_independent=True, 154 | reinterpreted_batch_ndims=1, 155 | ) 156 | scale = yield pm.HalfNormal("scale", 1, conditionally_independent=True) 157 | x = yield pm.Normal( 158 | "x", 159 | mu, 160 | scale[..., None], 161 | observed=observed, 162 | reinterpreted_batch_ndims=1, 163 | event_stack=5, 164 | ) 165 | 166 | else: 167 | 168 | @pm.model 169 | def model(): 170 | mu = yield pm.Normal("mu", tf.zeros(4), 1) 171 | scale = yield pm.HalfNormal("scale", 1) 172 | x = yield pm.Normal("x", mu, scale, batch_stack=5, observed=observed) 173 | 174 | return model, is_vectorized_model, core_shapes 175 | 176 | 177 | @pytest.fixture(scope="module", params=BATCH_AND_FEATURE_SHAPES, ids=str) 178 | def get_batch_shape(request): 179 | return request.param 180 | 181 | 182 | @pytest.fixture(scope="module", params=SAMPLE_SHAPE, ids=str) 183 | def get_sample_shape(request): 184 | return request.param 185 | 186 | 187 | @pytest.fixture(scope="module", params=BATCH_AND_FEATURE_SHAPES, ids=str) 188 | def get_feature_shape(request): 189 | return request.param 190 | 191 | 192 | @pytest.fixture(scope="module") 193 | def get_data(get_batch_shape, get_sample_shape, get_feature_shape): 194 | X = tf.random.normal(get_batch_shape + get_sample_shape + get_feature_shape) 195 | return get_batch_shape, get_sample_shape, get_feature_shape, X 196 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings= 3 | ignore:tostring.*is deprecated 4 | -------------------------------------------------------------------------------- /tests/test_8schools.py: -------------------------------------------------------------------------------- 1 | import pymc4 as pm4 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from numpy.testing import assert_almost_equal 6 | 7 | J = 8 8 | y = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32) 9 | sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32) 10 | 11 | 12 | @pm4.model 13 | def schools_pm4(): 14 | eta = yield pm4.Normal("eta", 0, 1, batch_stack=J) 15 | mu = yield pm4.Normal("mu", 1, 10) 16 | tau = yield pm4.HalfNormal("tau", 1 * 2.0) 17 | 18 | theta = mu + tau * eta 19 | 20 | obs = yield pm4.Normal("obs", theta, sigma, observed=y) 21 | return obs 22 | 23 | 24 | def test_model_logp(): 25 | """Make sure log probability matches standard. 26 | 27 | Recreate this with 28 | 29 | ```python 30 | import pymc3 as pm 31 | import numpy as np 32 | 33 | J = 8 34 | y = np.array([28, 8, -3, 7, -1, 1, 18, 12]) 35 | sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) 36 | 37 | 38 | with pm.Model() as eight_schools: 39 | eta = pm.Normal("eta", 0, 1, shape=J) 40 | mu = pm.Normal("mu", 1, 10) 41 | tau = pm.HalfNormal("tau", 2.0) 42 | 43 | theta = mu + tau * eta 44 | 45 | pm.Normal("obs", theta, sigma, observed=y) 46 | 47 | print(eight_schools.logp({'eta': np.zeros(8), 'mu': 0, 'tau_log__': 1}).astype(np.float32)) 48 | ``` 49 | """ 50 | logp, *_ = pm4.mcmc.samplers.build_logp_and_deterministic_functions( 51 | schools_pm4(), observed={"obs": y}, state=None 52 | ) 53 | init_value = logp( 54 | **{ 55 | "schools_pm4/eta": tf.zeros(8), 56 | "schools_pm4/mu": tf.zeros(()), 57 | "schools_pm4/__log_tau": tf.ones(()), 58 | } 59 | ).numpy() 60 | 61 | assert_almost_equal(init_value, -42.876_114) 62 | 63 | 64 | def test_sample_no_xla(): 65 | # TODO: better test, compare to a golden standard chain from pymc3, 66 | # for now it is only to verify it is runnable 67 | chains, samples = 4, 100 68 | trace = pm4.inference.sampling.sample( 69 | schools_pm4(), 70 | step_size=0.28, 71 | num_chains=chains, 72 | num_samples=samples, 73 | burn_in=50, 74 | xla=False, 75 | ).posterior 76 | for var_name in ("eta", "mu", "tau", "__log_tau"): 77 | assert f"schools_pm4/{var_name}" in trace.keys() 78 | -------------------------------------------------------------------------------- /tests/test_compound.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pymc4 as pm 3 | import numpy as np 4 | import tensorflow as tf 5 | from pymc4.mcmc.samplers import RandomWalkM 6 | from pymc4.mcmc.samplers import reg_samplers 7 | 8 | 9 | @pytest.fixture(scope="function") 10 | def simple_model(): 11 | @pm.model() 12 | def simple_model(): 13 | var1 = yield pm.Normal("var1", 0, 1) 14 | return var1 15 | 16 | return simple_model 17 | 18 | 19 | @pytest.fixture(scope="function") 20 | def compound_model(): 21 | @pm.model() 22 | def compound_model(): 23 | var1 = yield pm.Normal("var1", 0, 1) 24 | var2 = yield pm.Bernoulli("var2", 0.5) 25 | return var2 26 | 27 | return compound_model 28 | 29 | 30 | @pytest.fixture(scope="function") 31 | def categorical_same_shape(): 32 | @pm.model 33 | def categorical_same_shape(): 34 | var1 = yield pm.Categorical("var1", probs=[0.2, 0.4, 0.4]) 35 | var1 = yield pm.Categorical("var2", probs=[0.1, 0.3, 0.6]) 36 | var1 = yield pm.Categorical("var3", probs=[0.1, 0.1, 0.8]) 37 | 38 | return categorical_same_shape 39 | 40 | 41 | @pytest.fixture(scope="function") 42 | def model_symmetric(): 43 | @pm.model 44 | def model_symmetric(): 45 | var1 = yield pm.Bernoulli("var1", 0.1) 46 | var2 = yield pm.Bernoulli("var2", 1 - 0.1) 47 | 48 | return model_symmetric 49 | 50 | 51 | @pytest.fixture(scope="function") 52 | def categorical_different_shape(): 53 | @pm.model 54 | def categorical_different_shape(): 55 | var1 = yield pm.Categorical("var1", probs=[0.2, 0.4, 0.4]) 56 | var1 = yield pm.Categorical("var2", probs=[0.1, 0.3, 0.1, 0.5]) 57 | var1 = yield pm.Categorical("var3", probs=[0.1, 0.1, 0.1, 0.2, 0.5]) 58 | 59 | return categorical_different_shape 60 | 61 | 62 | @pytest.fixture(scope="module", params=["XLA", "noXLA"], ids=str) 63 | def xla_fixture(request): 64 | return request.param == "XLA" 65 | 66 | 67 | @pytest.fixture(scope="module", params=[3, 5]) 68 | def seed(request): 69 | return request.param 70 | 71 | 72 | @pytest.fixture(scope="module", params=["hmc", "nuts", "rwm", "compound"]) 73 | def sampler_type(request): 74 | return request.param 75 | 76 | 77 | @pytest.fixture(scope="module", params=["rwm", "compound"]) 78 | def discrete_support_sampler_type(request): 79 | return request.param 80 | 81 | 82 | @pytest.fixture(scope="module", params=["nuts_simple", "hmc_simple"]) 83 | def expanded_sampler_type(request): 84 | return request.param 85 | 86 | 87 | def test_samplers_on_compound_model(compound_model, seed, xla_fixture, sampler_type): 88 | def _execute(): 89 | model = compound_model() 90 | trace = pm.sample(model, sampler_type=sampler_type, xla_fixture=xla_fixture, seed=seed) 91 | var1 = round(trace.posterior["compound_model/var1"].mean().item(), 1) 92 | # int32 dtype variable 93 | var2 = tf.reduce_sum(trace.posterior["compound_model/var2"]) / (1000 * 10) 94 | np.testing.assert_allclose(var1, 0.0, atol=0.1) 95 | np.testing.assert_allclose(var2, 0.5, atol=0.1) 96 | 97 | if sampler_type in ["compound", "rwm"]: 98 | # execute normally if sampler supports discrete distributions 99 | _execute() 100 | else: 101 | # else check for the exception thrown 102 | with pytest.raises(ValueError): 103 | _execute() 104 | 105 | 106 | def test_compound_model_sampler_method( 107 | compound_model, seed, xla_fixture, discrete_support_sampler_type 108 | ): 109 | model = compound_model() 110 | trace = pm.sample( 111 | model, 112 | sampler_type=discrete_support_sampler_type, 113 | sampler_methods=[("var2", RandomWalkM)], 114 | xla_fixture=xla_fixture, 115 | seed=seed, 116 | ) 117 | var1 = round(trace.posterior["compound_model/var1"].mean().item(), 1) 118 | # int32 dtype variable 119 | var2 = tf.reduce_sum(trace.posterior["compound_model/var2"]) / (1000 * 10) 120 | np.testing.assert_allclose(var1, 0.0, atol=0.1) 121 | np.testing.assert_allclose(var2, 0.5, atol=0.1) 122 | 123 | 124 | def test_samplers_on_simple_model(simple_model, xla_fixture, sampler_type): 125 | model = simple_model() 126 | trace = pm.sample(model, sampler_type=sampler_type, xla_fixture=xla_fixture) 127 | var1 = round(trace.posterior["simple_model/var1"].mean().item(), 1) 128 | np.testing.assert_allclose(var1, 0.0, atol=0.1) 129 | 130 | 131 | def test_extended_samplers_on_simple_model(simple_model, xla_fixture, expanded_sampler_type): 132 | model = simple_model() 133 | trace = pm.sample(model, sampler_type=expanded_sampler_type, xla_fixture=xla_fixture) 134 | var1 = round(trace.posterior["simple_model/var1"].mean().item(), 1) 135 | np.testing.assert_allclose(var1, 0.0, atol=0.1) 136 | 137 | 138 | def test_simple_seed(simple_model, seed): 139 | model = simple_model() 140 | trace1 = pm.sample(model, xla_fixture=xla_fixture, seed=seed) 141 | trace2 = pm.sample(model, xla_fixture=xla_fixture, seed=seed) 142 | np.testing.assert_allclose( 143 | tf.norm(trace1.posterior["simple_model/var1"] - trace2.posterior["simple_model/var1"]), 144 | 0.0, 145 | atol=1e-6, 146 | ) 147 | 148 | 149 | def test_compound_seed(compound_model, seed): 150 | model = compound_model() 151 | trace1 = pm.sample(model, xla_fixture=xla_fixture, seed=seed) 152 | trace2 = pm.sample(model, xla_fixture=xla_fixture, seed=seed) 153 | np.testing.assert_allclose( 154 | tf.norm( 155 | tf.cast( 156 | trace1.posterior["compound_model/var1"] - trace2.posterior["compound_model/var1"], 157 | dtype=tf.float32, 158 | ) 159 | ), 160 | 0.0, 161 | atol=1e-6, 162 | ) 163 | np.testing.assert_allclose( 164 | tf.norm( 165 | tf.cast( 166 | trace1.posterior["compound_model/var2"] - trace2.posterior["compound_model/var2"], 167 | dtype=tf.float32, 168 | ) 169 | ), 170 | 0.0, 171 | atol=1e-6, 172 | ) 173 | 174 | 175 | def test_sampler_merging(categorical_same_shape, categorical_different_shape): 176 | model_same = categorical_same_shape() 177 | model_diff = categorical_different_shape() 178 | sampler = reg_samplers["compound"] 179 | sampler1 = sampler(model_same) 180 | sampler1._assign_default_methods() 181 | sampler2 = sampler(model_diff) 182 | sampler2._assign_default_methods() 183 | assert len(sampler1.kernel_kwargs["compound_samplers"]) == 1 184 | assert len(sampler2.kernel_kwargs["compound_samplers"]) == 3 185 | sampler_methods1 = [("var1", RandomWalkM)] 186 | sampler_methods2 = [ 187 | ("var1", RandomWalkM, {"new_state_fn": pm.categorical_uniform_fn(classes=3)}) 188 | ] 189 | sampler_methods3 = [ 190 | ( 191 | "var1", 192 | RandomWalkM, 193 | {"new_state_fn": pm.categorical_uniform_fn(classes=3, name="smth_different")}, 194 | ) 195 | ] 196 | 197 | sampler_methods4 = [ 198 | ( 199 | "var1", 200 | RandomWalkM, 201 | {"new_state_fn": pm.categorical_uniform_fn(classes=3, name="smth_different")}, 202 | ), 203 | ( 204 | "var3", 205 | RandomWalkM, 206 | {"new_state_fn": pm.categorical_uniform_fn(classes=3, name="smth_different")}, 207 | ), 208 | ] 209 | 210 | sampler_methods5 = [ 211 | ( 212 | "var1", 213 | RandomWalkM, 214 | ), 215 | ( 216 | "var2", 217 | RandomWalkM, 218 | ), 219 | ( 220 | "var3", 221 | RandomWalkM, 222 | ), 223 | ] 224 | 225 | sampler_ = sampler(model_same) 226 | sampler_._assign_default_methods(sampler_methods=sampler_methods1) 227 | assert len(sampler_.kernel_kwargs["compound_samplers"]) == 1 228 | sampler_._assign_default_methods(sampler_methods=sampler_methods2) 229 | assert len(sampler_.kernel_kwargs["compound_samplers"]) == 1 230 | sampler_._assign_default_methods(sampler_methods=sampler_methods3) 231 | assert len(sampler_.kernel_kwargs["compound_samplers"]) == 2 232 | sampler_._assign_default_methods(sampler_methods=sampler_methods4) 233 | assert len(sampler_.kernel_kwargs["compound_samplers"]) == 2 234 | sampler_ = sampler(model_diff) 235 | sampler_._assign_default_methods(sampler_methods=sampler_methods5) 236 | assert len(sampler_.kernel_kwargs["compound_samplers"]) == 3 237 | 238 | 239 | def test_other_samplers(simple_model, seed): 240 | model = simple_model() 241 | trace1 = pm.sample(model, sampler_type="nuts_simple", xla_fixture=xla_fixture, seed=seed) 242 | trace2 = pm.sample(model, sampler_type="hmc_simple", xla_fixture=xla_fixture, seed=seed) 243 | np.testing.assert_allclose(tf.reduce_mean(trace1.posterior["simple_model/var1"]), 0.0, atol=0.1) 244 | np.testing.assert_allclose(tf.reduce_mean(trace2.posterior["simple_model/var1"]), 0.0, atol=0.1) 245 | 246 | 247 | def test_compound_symmetric(model_symmetric, seed): 248 | model = model_symmetric() 249 | trace = pm.sample(model) 250 | np.testing.assert_allclose( 251 | tf.reduce_mean(tf.cast(trace.posterior["model_symmetric/var1"], dtype=tf.float32)), 252 | 0.1, 253 | atol=0.1, 254 | ) 255 | np.testing.assert_allclose( 256 | tf.reduce_mean(tf.cast(trace.posterior["model_symmetric/var2"], dtype=tf.float32)), 257 | 1.0 - 0.1, 258 | atol=0.1, 259 | ) 260 | -------------------------------------------------------------------------------- /tests/test_discrete.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pymc4 as pm 3 | import numpy as np 4 | 5 | 6 | @pytest.fixture(scope="function") 7 | def model_with_discrete_categorical(): 8 | @pm.model() 9 | def model_with_discrete_categorical(): 10 | disc = yield pm.Categorical("disc", probs=[0.1, 0.9]) 11 | return disc 12 | 13 | return model_with_discrete_categorical 14 | 15 | 16 | @pytest.fixture(scope="function") 17 | def model_with_discrete_bernoulli(): 18 | @pm.model() 19 | def model_with_discrete_bernoulli(): 20 | disc = yield pm.Bernoulli("disc", 0.9) 21 | return disc 22 | 23 | return model_with_discrete_bernoulli 24 | 25 | 26 | @pytest.fixture(scope="function") 27 | def model_with_discrete_and_continuous(): 28 | @pm.model() 29 | def model_with_discrete_and_continuous(): 30 | disc = yield pm.Categorical("disc", probs=[0.1, 0.9]) 31 | norm = yield pm.Normal("mu", 0, 1) 32 | return norm 33 | 34 | return model_with_discrete_and_continuous 35 | 36 | 37 | @pytest.fixture(scope="module", params=["XLA", "noXLA"], ids=str) 38 | def xla_fixture(request): 39 | return request.param == "XLA" 40 | 41 | 42 | @pytest.fixture(scope="module", params=[3, 5, 7]) 43 | def seed(request): 44 | return request.param 45 | 46 | 47 | def test_discrete_sampling_categorical(model_with_discrete_categorical, xla_fixture, seed): 48 | model = model_with_discrete_categorical() 49 | trace = pm.sample(model=model, sampler_type="compound", xla_fixture=xla_fixture, seed=seed) 50 | round_value = round(trace.posterior["model_with_discrete_categorical/disc"].mean().item(), 1) 51 | # check to match the categorical prob parameter 52 | np.testing.assert_allclose(round_value, 0.9, atol=0.1) 53 | 54 | 55 | def test_discrete_sampling_bernoulli(model_with_discrete_bernoulli, xla_fixture, seed): 56 | model = model_with_discrete_bernoulli() 57 | trace = pm.sample(model=model, sampler_type="compound", xla_fixture=xla_fixture, seed=seed) 58 | round_value = round(trace.posterior["model_with_discrete_bernoulli/disc"].mean().item(), 1) 59 | # check to match the bernoulli prob parameter 60 | np.testing.assert_allclose(round_value, 0.9, atol=0.1) 61 | 62 | 63 | def test_compound_sampling(model_with_discrete_and_continuous, xla_fixture, seed): 64 | model = model_with_discrete_and_continuous() 65 | trace = pm.sample(model=model, sampler_type="compound", xla_fixture=xla_fixture, seed=seed) 66 | round_value = round(trace.posterior["model_with_discrete_and_continuous/disc"].mean().item(), 1) 67 | np.testing.assert_allclose(round_value, 0.9, atol=0.1) 68 | -------------------------------------------------------------------------------- /tests/test_gp.py: -------------------------------------------------------------------------------- 1 | """Test suite for GP Module""" 2 | 3 | import tensorflow as tf 4 | 5 | import pymc4 as pm 6 | 7 | import pytest 8 | 9 | # Test all the GP models only using a particular 10 | # mean and covariance functions but varying tensor shapes 11 | # NOTE: the mean and covariance functions used here 12 | # must be present in `MEAN_FUNCS` and `COV_FUNCS` resp. 13 | GP_MODELS = [ 14 | ( 15 | "LatentGP", 16 | { 17 | "mean_fn": ("Zero", {}), 18 | "cov_fn": ("ExpQuad", {"amplitude": 1.0, "length_scale": 1.0}), 19 | }, 20 | ), 21 | ] 22 | 23 | 24 | @pytest.fixture(scope="module", params=GP_MODELS, ids=str) 25 | def get_gp_model(request): 26 | return request.param 27 | 28 | 29 | def build_model(model_name, model_kwargs, feature_ndims): 30 | """Create a gp model from an element in the `GP_MODELS` list""" 31 | # First, create a mean function 32 | name = model_kwargs["mean_fn"][0] 33 | kwargs = model_kwargs["mean_fn"][1] 34 | MeanClass = getattr(pm.gp.mean, name) 35 | mean_fn = MeanClass(**kwargs, feature_ndims=feature_ndims) 36 | # Then, create the kernel function 37 | name = model_kwargs["cov_fn"][0] 38 | kwargs = model_kwargs["cov_fn"][1] 39 | KernelClass = getattr(pm.gp.cov, name) 40 | cov_fn = KernelClass(**kwargs, feature_ndims=feature_ndims) 41 | # Now, create the model and return 42 | GPModel = getattr(pm.gp, model_name) 43 | model = GPModel(mean_fn=mean_fn, cov_fn=cov_fn) 44 | return model 45 | 46 | 47 | def test_gp_models_prior(tf_seed, get_data, get_gp_model): 48 | """Test the prior method of a GP mode, if present""" 49 | batch_shape, sample_shape, feature_shape, X = get_data 50 | gp_model = build_model(get_gp_model[0], get_gp_model[1], len(feature_shape)) 51 | # @pm.model 52 | # def model(gp, X): 53 | # yield gp.prior('f', X) 54 | try: 55 | # sampling_model = model(gp_model, X) 56 | # trace = pm.sample(sampling_model, num_samples=3, num_chains=1, burn_in=10) 57 | # trace = np.asarray(trace.posterior["model/f"]) 58 | prior_dist = gp_model.prior("prior", X) 59 | except NotImplementedError: 60 | pytest.skip("Skipping: prior not implemented") 61 | # if sample_shape == (1,): 62 | # assert trace.shape == (1, 3, ) + batch_shape 63 | # else: 64 | # assert trace.shape == (1, 3, ) + batch_shape + sample_shape 65 | if sample_shape == (1,): 66 | assert prior_dist.sample(1).shape == (1,) + batch_shape 67 | else: 68 | assert prior_dist.sample(1).shape == (1,) + batch_shape + sample_shape 69 | 70 | 71 | def test_gp_models_conditional(tf_seed, get_data, get_gp_model): 72 | """Test the conditional method of a GP mode, if present""" 73 | batch_shape, sample_shape, feature_shape, X = get_data 74 | gp_model = build_model(get_gp_model[0], get_gp_model[1], len(feature_shape)) 75 | Xnew = tf.random.normal(batch_shape + sample_shape + feature_shape) 76 | 77 | @pm.model 78 | def model(gp, X, Xnew): 79 | f = yield gp.prior("f", X) 80 | yield gp.conditional("fcond", Xnew, given={"X": X, "f": f}) 81 | 82 | try: 83 | # sampling_model = model(gp_model, X, Xnew) 84 | # trace = pm.sample(sampling_model, num_samples=3, num_chains=1, burn_in=10) 85 | # trace = np.asarray(trace.posterior["model/fcond"]) 86 | f = gp_model.prior("f", X).sample(1)[0] 87 | cond_dist = gp_model.conditional("fcond", Xnew, given={"X": X, "f": f}) 88 | cond_samples = cond_dist.sample(3) 89 | except NotImplementedError: 90 | pytest.skip("Skipping: conditional not implemented") 91 | # if sample_shape == (1,): 92 | # assert trace.shape == (1, 3,) + batch_shape 93 | # else: 94 | # assert trace.shape == (1, 3,) + batch_shape + sample_shape 95 | if sample_shape == (1,): 96 | assert cond_samples.shape == (3,) + batch_shape 97 | else: 98 | assert cond_samples.shape == (3,) + batch_shape + sample_shape 99 | -------------------------------------------------------------------------------- /tests/test_gp_mean.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import pymc4 as pm 5 | 6 | import pytest 7 | 8 | # Test all the mean functions in pm.gp module 9 | MEAN_FUNCS = [ 10 | ( 11 | "Zero", 12 | { 13 | "test_point": np.array([[1.0], [2.0]], dtype=np.float64), 14 | "expected": np.array([0.0, 0.0], dtype=np.float64), 15 | "feature_ndims": 1, 16 | }, 17 | ), 18 | ( 19 | "Constant", 20 | { 21 | "coef": 5.0, 22 | "test_point": np.array([[1.0], [2.0]], dtype=np.float64), 23 | "expected": np.array([5.0, 5.0], dtype=np.float64), 24 | "feature_ndims": 1, 25 | }, 26 | ), 27 | ] 28 | 29 | 30 | @pytest.fixture(scope="module", params=MEAN_FUNCS, ids=str) 31 | def get_mean_func(request): 32 | return request.param 33 | 34 | 35 | def test_mean_funcs(tf_seed, get_data, get_mean_func): 36 | """Test the mean functions present in MEAN_FUNCS dictionary""" 37 | # Build the mean function 38 | attr_name = get_mean_func[0] 39 | kwargs = get_mean_func[1] 40 | test_point = kwargs.pop("test_point", None) 41 | expected = kwargs.pop("expected", None) 42 | feature_ndims = kwargs.pop("feature_ndims", 1) 43 | MeanClass = getattr(pm.gp.mean, attr_name) 44 | 45 | # Get data to compute on. 46 | batch_shape, sample_shape, feature_shape, X = get_data 47 | 48 | # Build and evaluate the mean function. 49 | mean_func = MeanClass(**kwargs, feature_ndims=len(feature_shape)) 50 | val = mean_func(X) 51 | 52 | # Test 1 : Tensor Shape evaluations 53 | assert val.shape == batch_shape + sample_shape 54 | 55 | # Test 2 : Point evaluations 56 | if test_point is not None: 57 | mean_func = MeanClass(**kwargs, feature_ndims=feature_ndims) 58 | val = mean_func(test_point).numpy() 59 | 60 | # We need to be careful about the dtypes. Even though tensorflow uses float32 61 | # default dtype. The function should not break for other dtypes also. 62 | assert val.dtype == expected.dtype 63 | assert val.shape == expected.shape 64 | assert np.allclose(val, expected) 65 | 66 | 67 | def test_mean_combination(tf_seed, get_mean_func): 68 | """Test if the combination of various mean functions yield consistent results""" 69 | # Data to compute on. 70 | batch_shape, sample_shape, feature_shape, X = ( 71 | (2,), 72 | (2,), 73 | (2,), 74 | tf.random.normal((2, 2, 2)), 75 | ) 76 | attr_name = get_mean_func[0] 77 | kwargs = get_mean_func[1] 78 | test_point = kwargs.pop("test_point", None) 79 | expected = kwargs.pop("expected", None) 80 | feature_ndims = kwargs.pop("feature_ndims", 1) 81 | MeanClass = getattr(pm.gp.mean, attr_name) 82 | 83 | # Build and evaluate the mean function. 84 | mean_func = MeanClass(**kwargs, feature_ndims=len(feature_shape)) 85 | 86 | # Get the combinations of the mean functions 87 | mean_add = mean_func + mean_func 88 | mean_prod = mean_func * mean_func 89 | 90 | # Evaluate the combinations 91 | mean_add_val = mean_add(X) 92 | mean_prod_val = mean_prod(X) 93 | 94 | # Test 1 : Shape evaluations 95 | assert mean_add_val.shape == batch_shape + sample_shape 96 | assert mean_prod_val.shape == batch_shape + sample_shape 97 | 98 | # Test 2 : Point evaluations 99 | if test_point is not None: 100 | mean_func = MeanClass(**kwargs, feature_ndims=feature_ndims) 101 | 102 | # Get the combinations of the mean functions 103 | mean_add = mean_func + mean_func 104 | mean_prod = mean_func * mean_func 105 | 106 | # Evaluate the combinations 107 | mean_add_val = mean_add(test_point).numpy() 108 | mean_prod_val = mean_prod(test_point).numpy() 109 | 110 | # We need to be careful about the dtypes. Even though tensorflow uses float32 111 | # default dtype. The function should not break for other dtypes also. 112 | assert mean_add_val.dtype == expected.dtype 113 | assert mean_add_val.shape == expected.shape 114 | assert mean_prod_val.dtype == expected.dtype 115 | assert mean_prod_val.shape == expected.shape 116 | assert np.allclose(mean_add_val, 2 * expected) 117 | assert np.allclose(mean_prod_val, expected ** 2) 118 | -------------------------------------------------------------------------------- /tests/test_gp_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pymc4 as pm 4 | 5 | import pytest 6 | 7 | doc_string = "Func doc" 8 | 9 | 10 | def test_stabilize_default_shift(): 11 | data = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 12 | shifted = pm.gp.util.stabilize(data) 13 | expected = tf.constant([[1.0001, 2.0], [3.0, 4.0001]]) 14 | assert np.allclose(shifted, expected, rtol=1e-18) 15 | data = tf.constant([[1.0, 2.0], [3.0, 4.0]], dtype=tf.float64) 16 | shifted = pm.gp.util.stabilize(data) 17 | expected = tf.constant([[1.000001, 2.0], [3.0, 4.000001]]) 18 | 19 | 20 | def test_stabilize(): 21 | data = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 22 | shifted = pm.gp.util.stabilize(data, shift=1.0) 23 | expected = tf.constant([[2.0, 2.0], [3.0, 5.0]]) 24 | assert np.allclose(shifted, expected, rtol=1e-18) 25 | 26 | 27 | def test_inherit_docs(): 28 | def func(): 29 | """ 30 | Func docs. 31 | """ 32 | pass 33 | 34 | @pm.gp.util._inherit_docs(func) 35 | def other_func(): 36 | pass 37 | 38 | assert other_func.__doc__ == func.__doc__ 39 | 40 | 41 | def test_inherit_docs_exception(): 42 | def func(): 43 | pass 44 | 45 | with pytest.raises(ValueError, match=r"No docs to inherit"): 46 | 47 | @pm.gp.util._inherit_docs(func) 48 | def other_func(): 49 | pass 50 | 51 | 52 | def test_build_docs(): 53 | @pm.gp.util._build_docs 54 | def func(): 55 | """%(doc_string)""" 56 | pass 57 | 58 | assert func.__doc__ == doc_string 59 | 60 | 61 | def test_build_doc_warning(): 62 | with pytest.warns(SyntaxWarning, match=r"arrtibute nodoc not found"): 63 | 64 | @pm.gp.util._build_docs 65 | def func(): 66 | """%(nodoc)""" 67 | pass 68 | -------------------------------------------------------------------------------- /tests/test_mixture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for PyMC4 mixture distribution 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | import tensorflow as tf 8 | import tensorflow_probability as tfp 9 | 10 | import pymc4 as pm 11 | from pymc4.coroutine_model import ModelTemplate 12 | 13 | tfd = tfp.distributions 14 | 15 | distribution_conditions = { 16 | "two_components": { 17 | "n": 1, 18 | "k": 2, 19 | "p": np.array([0.5, 0.5], dtype="float32"), 20 | "loc": np.array([0.0, 0.0], dtype="float32"), 21 | "scale": 1.0, 22 | }, 23 | "three_components": { 24 | "n": 1, 25 | "k": 3, 26 | "p": np.array([0.5, 0.25, 0.25], dtype="float32"), 27 | "loc": np.array([0.0, 0.0, 0.0], dtype="float32"), 28 | "scale": 1.0, 29 | }, 30 | "two_components_three_distributions": { 31 | "n": 3, 32 | "k": 2, 33 | "p": np.array([[0.5, 0.5], [0.8, 0.2], [0.7, 0.3]], dtype="float32"), 34 | "loc": np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], dtype="float32"), 35 | "scale": 1.0, 36 | }, 37 | "three_components_three_distributions": { 38 | "n": 3, 39 | "k": 3, 40 | "p": np.array([[0.5, 0.25, 0.25], [0.8, 0.1, 0.1], [0.2, 0.5, 0.3]], dtype="float32"), 41 | "loc": np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype="float32"), 42 | "scale": 1.0, 43 | }, 44 | } 45 | 46 | 47 | @pytest.fixture(scope="function", params=list(distribution_conditions), ids=str) 48 | def mixture_components(request): 49 | par = distribution_conditions[request.param] 50 | return par["n"], par["k"], par["p"], par["loc"], par["scale"] 51 | 52 | 53 | def _mixture(k, p, loc, scale, dat): 54 | m = yield pm.Normal("means", loc=loc, scale=scale) 55 | distributions = [pm.Normal("d" + str(i), loc=m[..., i], scale=scale) for i in range(k)] 56 | obs = yield pm.Mixture( 57 | "mixture", p=p, distributions=distributions, validate_args=True, observed=dat 58 | ) 59 | return obs 60 | 61 | 62 | def _mixture_same_family(k, p, loc, scale, dat): 63 | m = yield pm.Normal("means", loc=loc, scale=scale) 64 | distribution = pm.Normal("d", loc=m, scale=scale) 65 | obs = yield pm.Mixture( 66 | "mixture", p=p, distributions=distribution, validate_args=True, observed=dat 67 | ) 68 | return obs 69 | 70 | 71 | @pytest.fixture(scope="function", params=[_mixture_same_family, _mixture], ids=str) 72 | def mixture(mixture_components, request): 73 | n, k, p, loc, scale = mixture_components 74 | dat = tfd.Normal(loc=np.zeros(n), scale=1).sample(100).numpy() 75 | if n == 1: 76 | dat = dat.reshape(-1) 77 | model = ModelTemplate(request.param, name="mixture", keep_auxiliary=True, keep_return=True) 78 | model = model(k, p, loc, scale, dat) 79 | return model, n, k 80 | 81 | 82 | def test_wrong_distribution_argument_batched_fails(): 83 | with pytest.raises(TypeError, match=r"sequence of distributions"): 84 | pm.Mixture("mix", p=[0.5, 0.5], distributions=tfd.Normal(0, 1)) 85 | 86 | 87 | def test_wrong_distribution_argument_in_list_fails(): 88 | with pytest.raises(TypeError, match=r"every element in 'distribution' "): 89 | pm.Mixture( 90 | "mix", 91 | p=[0.5, 0.5], 92 | distributions=[ 93 | pm.Normal("comp1", loc=0.0, scale=1.0), 94 | "not a distribution", 95 | ], 96 | ) 97 | 98 | 99 | def test_sampling(mixture, xla_fixture): 100 | model, n, k = mixture 101 | if xla_fixture: 102 | with pytest.raises(tf.errors.InvalidArgumentError): 103 | pm.sample(model, num_samples=100, num_chains=2, xla=xla_fixture) 104 | else: 105 | trace = pm.sample(model, num_samples=100, num_chains=2, xla=xla_fixture) 106 | if n == 1: 107 | assert trace.posterior["mixture/means"].shape == (2, 100, k) 108 | else: 109 | assert trace.posterior["mixture/means"].shape == (2, 100, n, k) 110 | 111 | 112 | def test_prior_predictive(mixture): 113 | model, n, _ = mixture 114 | ppc = pm.sample_prior_predictive(model, sample_shape=100).prior_predictive 115 | if n == 1: 116 | assert ppc["mixture/mixture"].shape == (1, 100) 117 | else: 118 | assert ppc["mixture/mixture"].shape == (1, 100, n) 119 | 120 | 121 | def test_posterior_predictive(mixture): 122 | model, n, _ = mixture 123 | trace = pm.sample(model, num_samples=100, num_chains=2) 124 | ppc = pm.sample_posterior_predictive(model, trace).posterior_predictive 125 | if n == 1: 126 | assert ppc["mixture/mixture"].shape == (2, 100, 100) 127 | else: 128 | assert ppc["mixture/mixture"].shape == (2, 100, 100, n) 129 | -------------------------------------------------------------------------------- /tests/test_plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pymc4.plots import plot_gp_dist 3 | 4 | 5 | def test_gp_plot(tf_seed): 6 | """Test if the plot_gp_dist returns consistent results""" 7 | import matplotlib.pyplot as plt 8 | 9 | fig, ax = plt.subplots() 10 | ax = plot_gp_dist(ax, np.random.randn(2, 2), x=np.random.randn(2, 1), plot_samples=True) 11 | assert ax is not None 12 | -------------------------------------------------------------------------------- /tests/test_sampling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pymc4 as pm 3 | import numpy as np 4 | from scipy import stats 5 | import tensorflow as tf 6 | 7 | 8 | def test_sample_deterministics(simple_model_with_deterministic, xla_fixture): 9 | model = simple_model_with_deterministic() 10 | trace = pm.sample( 11 | model=model, 12 | num_samples=10, 13 | num_chains=4, 14 | burn_in=100, 15 | step_size=0.1, 16 | xla=xla_fixture, 17 | ) 18 | norm = "simple_model_with_deterministic/simple_model/norm" 19 | determ = "simple_model_with_deterministic/determ" 20 | np.testing.assert_allclose(trace.posterior[determ], trace.posterior[norm] * 2) 21 | 22 | 23 | def test_vectorize_log_prob_det_function(unvectorized_model): 24 | model, norm_shape, observed, batch_size = unvectorized_model 25 | model = model() 26 | ( 27 | logpfn, 28 | all_unobserved_values, 29 | deterministics_callback, 30 | deterministic_names, 31 | state, 32 | ) = pm.mcmc.samplers.build_logp_and_deterministic_functions(model) 33 | for _ in range(len(batch_size)): 34 | logpfn = pm.mcmc.samplers.vectorize_logp_function(logpfn) 35 | deterministics_callback = pm.mcmc.samplers.vectorize_logp_function(deterministics_callback) 36 | 37 | # Test function inputs and initial values are as expected 38 | assert set(all_unobserved_values) <= {"unvectorized_model/norm"} 39 | assert all_unobserved_values["unvectorized_model/norm"].numpy().shape == norm_shape 40 | assert set(deterministic_names) <= {"unvectorized_model/determ"} 41 | 42 | # Setup inputs to vectorized functions 43 | inputs = np.random.normal(size=batch_size + norm_shape).astype("float32") 44 | input_tensor = tf.convert_to_tensor(inputs) 45 | 46 | # Test deterministic part 47 | expected_deterministic = np.max(np.reshape(inputs, batch_size + (-1,)), axis=-1) 48 | deterministics_callback_output = deterministics_callback(input_tensor)[0].numpy() 49 | assert deterministics_callback_output.shape == batch_size 50 | np.testing.assert_allclose(deterministics_callback_output, expected_deterministic, rtol=1e-5) 51 | 52 | # Test log_prob part 53 | expected_log_prob = np.sum( 54 | np.reshape(stats.norm.logpdf(inputs), batch_size + (-1,)), axis=-1 55 | ) + np.sum( # norm.log_prob 56 | stats.norm.logpdf(observed.flatten(), loc=expected_deterministic[..., None], scale=1), 57 | axis=-1, 58 | ) # output.log_prob 59 | logpfn_output = logpfn(input_tensor).numpy() 60 | assert logpfn_output.shape == batch_size 61 | np.testing.assert_allclose(logpfn_output, expected_log_prob, rtol=1e-5) 62 | 63 | 64 | def test_sampling_with_deterministics_in_nested_models( 65 | deterministics_in_nested_models, xla_fixture 66 | ): 67 | ( 68 | model, 69 | expected_untransformed, 70 | expected_transformed, 71 | expected_deterministics, 72 | deterministic_mapping, 73 | ) = deterministics_in_nested_models 74 | trace = pm.sample( 75 | model=model(), 76 | num_samples=10, 77 | num_chains=4, 78 | burn_in=100, 79 | step_size=0.1, 80 | xla=xla_fixture, 81 | ) 82 | for deterministic, (inputs, op) in deterministic_mapping.items(): 83 | np.testing.assert_allclose( 84 | trace.posterior[deterministic], 85 | op(*[trace.posterior[i] for i in inputs]), 86 | rtol=1e-6, 87 | ) 88 | 89 | 90 | def test_sampling_with_no_free_rvs(simple_model_no_free_rvs): 91 | model = simple_model_no_free_rvs() 92 | with pytest.raises(ValueError): 93 | trace = pm.sample(model=model, num_samples=1, num_chains=1, burn_in=1) 94 | 95 | 96 | def test_sample_auto_batching(vectorized_model_fixture, xla_fixture, use_auto_batching_fixture): 97 | model, is_vectorized_model, core_shapes = vectorized_model_fixture 98 | num_samples = 10 99 | num_chains = 4 100 | if not is_vectorized_model and not use_auto_batching_fixture: 101 | with pytest.raises(Exception): 102 | pm.sample( 103 | model=model(), 104 | num_samples=num_samples, 105 | num_chains=num_chains, 106 | burn_in=1, 107 | step_size=0.1, 108 | xla=xla_fixture, 109 | use_auto_batching=use_auto_batching_fixture, 110 | ) 111 | else: 112 | trace = pm.sample( 113 | model=model(), 114 | num_samples=num_samples, 115 | num_chains=num_chains, 116 | burn_in=1, 117 | step_size=0.1, 118 | xla=xla_fixture, 119 | use_auto_batching=use_auto_batching_fixture, 120 | ) 121 | posterior = trace.posterior 122 | for rv_name, core_shape in core_shapes.items(): 123 | assert posterior[rv_name].shape == (num_chains, num_samples) + core_shape 124 | 125 | 126 | def test_beta_sample(): 127 | @pm.model 128 | def model(): 129 | dist = yield pm.Beta("beta", 0, 1) 130 | return dist 131 | 132 | trace = pm.sample(model(), num_samples=1, burn_in=1) 133 | 134 | assert trace.posterior["model/beta"] is not None 135 | assert trace.posterior["model/__sigmoid_beta"] is not None 136 | 137 | 138 | def test_sampling_unknown_sampler(simple_model): 139 | model = simple_model() 140 | with pytest.raises(KeyError): 141 | trace = pm.sample(model=model, sampler_type="unknown") 142 | 143 | 144 | def test_sampling_log_likelihood(vectorized_model_fixture): 145 | model, is_vectorized_model, core_shapes = vectorized_model_fixture 146 | num_samples = 10 147 | num_chains = 4 148 | trace = pm.sample( 149 | model=model(), 150 | num_samples=num_samples, 151 | num_chains=num_chains, 152 | burn_in=1, 153 | step_size=0.1, 154 | include_log_likelihood=True, 155 | ) 156 | 157 | if is_vectorized_model: 158 | # only one log likeliood matrix 159 | assert trace.log_likelihood["model/x"].shape == (num_chains, num_samples) 160 | 161 | else: 162 | state, _ = pm.initialize_sampling_state(model()) 163 | assert trace.log_likelihood["model/x"].shape == ( 164 | num_chains, 165 | num_samples, 166 | *state.observed_values["model/x"].shape, 167 | ) 168 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test PyMC4 Utils 3 | """ 4 | 5 | import pymc4 as pm 6 | import pytest 7 | from mock import Mock 8 | 9 | 10 | @pytest.fixture(scope="function") 11 | def mock_biwrap_functools_call(monkeypatch): 12 | """Mock functools partial to test execution path of pm.model decorator when used 13 | in both the called and uncalled configuration""" 14 | _functools = Mock() 15 | 16 | def _partial(*args, **kwargs): 17 | raise Exception("Mocked functools partial") 18 | 19 | _functools.partial = _partial 20 | 21 | monkeypatch.setattr(pm.utils, "functools", _functools) 22 | 23 | 24 | def test_biwrap_and_mocked_functools_raises_exception_with_called_decorator( 25 | mock_biwrap_functools_call, 26 | ): 27 | """Test code path for called decorator by adding exception to to pm4.utils.functools.partial""" 28 | 29 | with pytest.raises(Exception) as e: 30 | 31 | @pm.model() 32 | def fake_model(): 33 | yield None 34 | 35 | assert "Mocked functools partial" in str(e) 36 | 37 | 38 | def test_biwrap_with_uncalled_decorator(mock_biwrap_functools_call): 39 | """Test code path not taken by verifying exception is not raised by pm.utils.functools.partial""" 40 | with pytest.raises(Exception) as e: 41 | # Verify that functools.partial has been mocked correctly. 42 | # If this section is failing then test suite is configured in correctly 43 | pm.utils.functools.partial() 44 | assert "Mocked functools partial" in str(e) 45 | 46 | @pm.model 47 | def fake_model(): 48 | yield None 49 | -------------------------------------------------------------------------------- /tests/test_variational.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pymc4 as pm 3 | import numpy as np 4 | from scipy.stats import norm 5 | 6 | 7 | @pytest.fixture(scope="function") 8 | def conjugate_normal_model(): 9 | unknown_mean = -5 10 | known_sigma = 3 11 | data_points = 1000 12 | data = np.random.normal(unknown_mean, known_sigma, size=data_points) 13 | prior_mean = 4 14 | prior_sigma = 2 15 | 16 | # References - http://patricklam.org/teaching/conjugacy_print.pdf 17 | precision = 1 / prior_sigma ** 2 + data_points / known_sigma ** 2 18 | estimated_mean = ( 19 | prior_mean / prior_sigma ** 2 + (data_points * np.mean(data) / known_sigma ** 2) 20 | ) / precision 21 | 22 | @pm.model 23 | def model(): 24 | mu = yield pm.Normal("mu", prior_mean, prior_sigma) 25 | ll = yield pm.Normal("ll", mu, known_sigma, observed=data) 26 | return ll 27 | 28 | return dict(estimated_mean=estimated_mean, known_sigma=known_sigma, data=data, model=model) 29 | 30 | 31 | # fmt: off 32 | _test_kwargs = { 33 | "ADVI": { 34 | "method": pm.MeanField, 35 | "fit_kwargs": {}, 36 | "sample_kwargs": {"n": 500, "include_log_likelihood": True}, 37 | }, 38 | "FullRank ADVI": { 39 | "method": pm.FullRank, 40 | "fit_kwargs": {} 41 | }, 42 | "FullRank ADVI: sample_size=2": { 43 | "method": pm.FullRank, 44 | "fit_kwargs": {"sample_size": 2} 45 | } 46 | } 47 | 48 | # fmt: on 49 | @pytest.fixture(scope="function", params=list(_test_kwargs), ids=str) 50 | def approximation(request): 51 | return request.param 52 | 53 | 54 | def test_fit(approximation, conjugate_normal_model): 55 | model = conjugate_normal_model["model"]() 56 | approx = _test_kwargs[approximation] 57 | advi = pm.fit(method=approx["method"](model), **approx["fit_kwargs"]) 58 | assert advi is not None 59 | assert advi.losses.numpy().shape == (approx["fit_kwargs"].get("num_steps", 10000),) 60 | 61 | q_samples = advi.approximation.sample(**approx.get("sample_kwargs", {"n": 1000})) 62 | 63 | # Calculating mean from all draws and comparing to the actual one 64 | calculated_mean = q_samples.posterior["model/mu"].mean(dim=("chain", "draw")) 65 | np.testing.assert_allclose(calculated_mean, conjugate_normal_model["estimated_mean"], rtol=0.05) 66 | 67 | if "sample_kwargs" in approx and approx["sample_kwargs"].get("include_log_likelihood"): 68 | sample_mean = q_samples.posterior["model/mu"].sel(chain=0, draw=0) # Single draw 69 | ll_from_scipy = norm.logpdf( 70 | conjugate_normal_model["data"], sample_mean, conjugate_normal_model["known_sigma"] 71 | ) 72 | ll_from_pymc4 = q_samples.log_likelihood["model/ll"].sel(chain=0, draw=0) 73 | assert ll_from_scipy.shape == ll_from_pymc4.shape 74 | np.testing.assert_allclose(ll_from_scipy, ll_from_pymc4, rtol=1e-4) 75 | 76 | 77 | @pytest.fixture(scope="function") 78 | def bivariate_gaussian(): 79 | mu = np.zeros(2, dtype=np.float32) 80 | cov = np.array([[1, 0.8], [0.8, 1]], dtype=np.float32) 81 | 82 | @pm.model 83 | def bivariate_gaussian(): 84 | density = yield pm.MvNormal("density", loc=mu, covariance_matrix=cov) 85 | return density 86 | 87 | return bivariate_gaussian 88 | 89 | 90 | def test_bivariate_shapes(bivariate_gaussian): 91 | advi = pm.fit(bivariate_gaussian(), num_steps=5000) 92 | assert advi.losses.numpy().shape == (5000,) 93 | 94 | samples = advi.approximation.sample(5000) 95 | assert samples.posterior["bivariate_gaussian/density"].values.shape == (1, 5000, 2) 96 | 97 | 98 | def test_advi_with_deterministics(simple_model_with_deterministic): 99 | advi = pm.fit(simple_model_with_deterministic(), num_steps=1000) 100 | samples = advi.approximation.sample(100) 101 | norm = "simple_model_with_deterministic/simple_model/norm" 102 | determ = "simple_model_with_deterministic/determ" 103 | np.testing.assert_allclose(samples.posterior[determ], samples.posterior[norm] * 2) 104 | 105 | 106 | def test_advi_with_deterministics_in_nested_models(deterministics_in_nested_models): 107 | ( 108 | model, 109 | *_, 110 | deterministic_mapping, 111 | ) = deterministics_in_nested_models 112 | advi = pm.fit(model(), num_steps=1000) 113 | samples = advi.approximation.sample(100) 114 | for deterministic, (inputs, op) in deterministic_mapping.items(): 115 | np.testing.assert_allclose( 116 | samples.posterior[deterministic], op(*[samples.posterior[i] for i in inputs]), rtol=1e-6 117 | ) 118 | --------------------------------------------------------------------------------