├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── usage-question.md ├── pull_request_template.md └── workflows │ ├── post-release.yml │ ├── publish-to-test-pypi.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _templates │ └── autosummary │ │ ├── class.rst │ │ └── function.rst ├── api │ └── index.rst ├── changelog.rst ├── conf.py ├── contributing.rst ├── examples.rst ├── examples │ ├── gallery │ │ └── sbc.md │ └── img │ │ └── sbc.png ├── index.rst ├── installation.rst └── make.bat ├── ecdf.png ├── environment-dev.yml ├── environment.yml ├── hist.png ├── pyproject.toml ├── requirements-dev.txt ├── requirements-docs.txt └── simuk ├── __init__.py ├── sbc.py └── tests └── test_sbc.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behaviour. Ideally a self-contained snippet of code, or link to a notebook or external code. Please include screenshots/images produced with simuk here, or the stack trace including `simuk` code to help. 15 | 16 | **Expected behaviors** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Additional context** 20 | Versions of `simuk` and other libraries used, operating system used, and anything else that may be useful. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Tell us about it 11 | 12 | The more specific the better. 13 | 14 | ## Thoughts on implementation 15 | 16 | Not required, but if you have thoughts on how to implement the feature, that can be helpful! In case there are academic references, we welcome those here too. 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/usage-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Usage Question 3 | about: General questions about simuk usage 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Short Description 11 | 12 | Let us know what you're trying to do, and we can point you in the right direction. Screenshots/plots or stack traces that include `simuk` are particularly helpful. 13 | 14 | ## Code Example or link 15 | 16 | Please provide a minimal, self-contained, and reproducible example demonstrating what you're trying to do. Ideally, it will be a code snippet, a link to a notebook, or a link to code that can be run on another user's computer. 17 | 18 | Also include the simuk version and the version of any other relevant packages. 19 | 20 | ## Relevant documentation or public examples 21 | 22 | Please provide documentation, public examples, or any additional information which may be relevant to your question 23 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 14 | 15 | ## Checklist 16 | 17 | 18 | - [ ] Code style is correct (follows ruff and black guidelines) 19 | - [ ] Includes new or updated tests to cover the new feature 20 | - [ ] New features are properly documented (with an example if appropriate) 21 | - [ ] Includes a sample plot to visually illustrate the changes (only for plot-related functions) 22 | 23 | 34 | -------------------------------------------------------------------------------- /.github/workflows/post-release.yml: -------------------------------------------------------------------------------- 1 | name: Post-release 2 | on: 3 | release: 4 | types: [published, released] 5 | workflow_dispatch: 6 | 7 | jobs: 8 | changelog: 9 | name: Update changelog 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | ref: main 15 | - uses: rhysd/changelog-from-release/action@v3 16 | with: 17 | file: CHANGELOG.md 18 | github_token: ${{ secrets.GITHUB_TOKEN }} 19 | commit_summary_template: 'update changelog for %s changes' 20 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish tagged releases to PyPI and TestPyPI 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python distributions to PyPI and TestPyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/simuk 15 | permissions: 16 | id-token: write 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.11" 23 | 24 | - name: Install pypa/build 25 | run: >- 26 | python -m 27 | pip install 28 | build 29 | --user 30 | - name: Build a binary wheel and a source tarball 31 | run: >- 32 | python -m 33 | build 34 | --sdist 35 | --wheel 36 | --outdir dist/ 37 | 38 | - name: Publish package distributions to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths-ignore: 8 | - "docs/**" 9 | pull_request: 10 | paths-ignore: 11 | - "docs/**" 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.11", "3.12", "3.13"] 19 | 20 | name: Set up Python ${{ matrix.python-version }} 21 | steps: 22 | - uses: actions/checkout@v4 23 | with: 24 | fetch-depth: 0 25 | 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: conda-incubator/setup-miniconda@v3 28 | with: 29 | channels: conda-forge, defaults 30 | channel-priority: true 31 | python-version: ${{ matrix.python-version }} 32 | auto-update-conda: true 33 | 34 | - name: Install simuk 35 | shell: bash -l {0} 36 | run: | 37 | conda install pip 38 | pip install -r requirements-dev.txt 39 | pip install . 40 | python --version 41 | conda list 42 | pip freeze 43 | - name: Run linters 44 | shell: bash -l {0} 45 | run: | 46 | python -m black simuk --check 47 | echo "Success!" 48 | echo "Checking code style with ruff..." 49 | ruff check simuk/ 50 | - name: Run tests 51 | shell: bash -l {0} 52 | run: | 53 | python -m pytest -vv --cov=simuk --cov-report=term --cov-report=xml simuk/tests 54 | env: 55 | PYTHON_VERSION: ${{ matrix.python-version }} 56 | 57 | - name: Upload coverage to Codecov 58 | uses: codecov/codecov-action@v5 59 | with: 60 | env_vars: OS,PYTHON 61 | name: codecov-umbrella 62 | fail_ci_if_error: false 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Distribution / packaging 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | .eggs/ 11 | lib/ 12 | lib64/ 13 | parts/ 14 | sdist/ 15 | var/ 16 | *.egg-info/ 17 | .installed.cfg 18 | *.egg 19 | 20 | # Virtual environment 21 | venv/ 22 | ENV/ 23 | env/ 24 | env.bak/ 25 | venv.bak/ 26 | 27 | # Jupyter Notebook checkpoints 28 | .ipynb_checkpoints 29 | 30 | # Pytest cache 31 | .pytest_cache/ 32 | 33 | # Ruff Cache 34 | .ruff_cache/ 35 | 36 | # Coverage reports 37 | htmlcov/ 38 | .coverage 39 | .coverage.* 40 | .cache 41 | nosetests.xml 42 | coverage.xml 43 | *.cover 44 | 45 | # IDE specific files 46 | .idea/ 47 | .vscode/ 48 | *.swp 49 | *.swo 50 | *.sublime-project 51 | *.sublime-workspace 52 | 53 | # Operating system files 54 | .DS_Store 55 | Thumbs.db 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | _build 60 | jupyter_execute 61 | docs/api/generated 62 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | 3 | repos: 4 | - repo: https://github.com/ambv/black 5 | rev: 22.3.0 6 | hooks: 7 | - id: black 8 | language_version: python3 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.6.8 12 | hooks: 13 | - id: ruff 14 | args: [ --fix, --exit-non-zero-on-fix ] 15 | - id: ruff-format 16 | types_or: [ python, pyi ] 17 | - repo: https://github.com/MarcoGorelli/madforhooks 18 | rev: 0.3.0 19 | hooks: 20 | - id: no-print-statements 21 | 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: v4.5.0 24 | hooks: 25 | - id: check-added-large-files 26 | args: [--maxkb=1500] 27 | - id: end-of-file-fixer 28 | - id: trailing-whitespace 29 | - id: mixed-line-ending 30 | args: [--fix=lf] 31 | 32 | 33 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | version: 2 4 | 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "3.11" 9 | 10 | sphinx: 11 | configuration: docs/conf.py 12 | 13 | python: 14 | install: 15 | - requirements: requirements-docs.txt 16 | - method: pip 17 | path: . 18 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | # [0.2.0](https://github.com/arviz-devs/simuk/releases/tag/0.2.0) - 2025-03-20 3 | 4 | ## What's Changed 5 | * Use do operator by [@aloctavodia](https://github.com/aloctavodia) in [#25](https://github.com/arviz-devs/simuk/pull/25) 6 | * remove bambi as a dependency and update tests by [@aloctavodia](https://github.com/aloctavodia) in [#26](https://github.com/arviz-devs/simuk/pull/26) 7 | * Use observe, allow multiple observed variables by [@aloctavodia](https://github.com/aloctavodia) in [#27](https://github.com/arviz-devs/simuk/pull/27) 8 | * Add Documentation by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#28](https://github.com/arviz-devs/simuk/pull/28) 9 | * Add python3.13 to tests by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#31](https://github.com/arviz-devs/simuk/pull/31) 10 | * Remove custom plots, use arviz instead by [@aloctavodia](https://github.com/aloctavodia) in [#32](https://github.com/arviz-devs/simuk/pull/32) 11 | * Add support for numpyro models in SBC by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#30](https://github.com/arviz-devs/simuk/pull/30) 12 | * Use numpyro converter from arviz_base by [@aloctavodia](https://github.com/aloctavodia) in [#33](https://github.com/arviz-devs/simuk/pull/33) 13 | 14 | 15 | **Full Changelog**: https://github.com/arviz-devs/simuk/compare/0.1.1...0.2.0 16 | 17 | [Changes][0.2.0] 18 | 19 | 20 | 21 | # [0.1.1](https://github.com/arviz-devs/simuk/releases/tag/0.1.1) - 2025-02-13 22 | 23 | ## What's Changed 24 | * enhance pre-commit configuration with additional hooks and language Version by [@Advaitgaur004](https://github.com/Advaitgaur004) in [#22](https://github.com/arviz-devs/simuk/pull/22) 25 | * Add tests by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#24](https://github.com/arviz-devs/simuk/pull/24) 26 | 27 | ## New Contributors 28 | * [@Advaitgaur004](https://github.com/Advaitgaur004) made their first contribution in [#22](https://github.com/arviz-devs/simuk/pull/22) 29 | * [@rohanbabbar04](https://github.com/rohanbabbar04) made their first contribution in [#24](https://github.com/arviz-devs/simuk/pull/24) 30 | 31 | **Full Changelog**: https://github.com/arviz-devs/simuk/compare/0.1.0...0.1.1 32 | 33 | [Changes][0.1.1] 34 | 35 | 36 | 37 | # [0.1.0](https://github.com/arviz-devs/simuk/releases/tag/0.1.0) - 2025-02-12 38 | 39 | ## What's Changed 40 | * Fixing up the readme by [@springcoil](https://github.com/springcoil) in [#1](https://github.com/arviz-devs/simuk/pull/1) 41 | * fix reference to my_model in pseudocode by [@psteinb](https://github.com/psteinb) in [#2](https://github.com/arviz-devs/simuk/pull/2) 42 | * Make it work with pymc 5.x by [@aloctavodia](https://github.com/aloctavodia) in [#3](https://github.com/arviz-devs/simuk/pull/3) 43 | * update readme by [@aloctavodia](https://github.com/aloctavodia) in [#5](https://github.com/arviz-devs/simuk/pull/5) 44 | * Add figsize argument to plots by [@aloctavodia](https://github.com/aloctavodia) in [#7](https://github.com/arviz-devs/simuk/pull/7) 45 | * fix bug with bambi models by [@aloctavodia](https://github.com/aloctavodia) in [#8](https://github.com/arviz-devs/simuk/pull/8) 46 | * simultaneous confidence band by [@aloctavodia](https://github.com/aloctavodia) in [#9](https://github.com/arviz-devs/simuk/pull/9) 47 | * Remove observed_vars argument by [@aloctavodia](https://github.com/aloctavodia) in [#10](https://github.com/arviz-devs/simuk/pull/10) 48 | * Update README.md by [@ColCarroll](https://github.com/ColCarroll) in [#18](https://github.com/arviz-devs/simuk/pull/18) 49 | * setup package conf by [@aloctavodia](https://github.com/aloctavodia) in [#20](https://github.com/arviz-devs/simuk/pull/20) 50 | * add templates and workflows by [@aloctavodia](https://github.com/aloctavodia) in [#21](https://github.com/arviz-devs/simuk/pull/21) 51 | * Update workflow by [@aloctavodia](https://github.com/aloctavodia) in [#23](https://github.com/arviz-devs/simuk/pull/23) 52 | 53 | **Full Changelog**: https://github.com/arviz-devs/simuk/commits/0.1.0 54 | 55 | [Changes][0.1.0] 56 | 57 | 58 | [0.2.0]: https://github.com/arviz-devs/simuk/compare/0.1.1...0.2.0 59 | [0.1.1]: https://github.com/arviz-devs/simuk/compare/0.1.0...0.1.1 60 | [0.1.0]: https://github.com/arviz-devs/simuk/tree/0.1.0 61 | 62 | 63 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Simuk Community Code of Conduct 2 | 3 | Simuk adopts the NumFOCUS Code of Conduct directly. In other words, we 4 | expect our community to treat others with kindness and understanding. 5 | 6 | 7 | # THE SHORT VERSION 8 | Be kind to others. Do not insult or put down others. 9 | Behave professionally. Remember that harassment and sexist, racist, 10 | or exclusionary jokes are not appropriate. 11 | 12 | All communication should be appropriate for a professional audience 13 | including people of many different backgrounds. Sexual language and 14 | imagery are not appropriate. 15 | 16 | Simuk is dedicated to providing a harassment-free community for everyone, 17 | regardless of gender, sexual orientation, gender identity, and 18 | expression, disability, physical appearance, body size, race, 19 | or religion. We do not tolerate harassment of community members 20 | in any form. 21 | 22 | Thank you for helping make this a welcoming, friendly community for all. 23 | 24 | 25 | # How to Submit a Report 26 | If you feel that there has been a Code of Conduct violation an anonymous 27 | reporting form is available. 28 | **If you feel your safety is in jeopardy or the situation is an 29 | emergency, we urge you to contact local law enforcement before making 30 | a report. (In the U.S., dial 911.)** 31 | 32 | We are committed to promptly addressing any reported issues. 33 | If you have experienced or witnessed behavior that violates this 34 | Code of Conduct, please complete the form below to 35 | make a report. 36 | 37 | **REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT 38 | 39 | Reports are sent to the NumFOCUS Code of Conduct Enforcement Team 40 | (see below). 41 | 42 | You can view the Privacy Policy and Terms of Service for TypeForm here. 43 | The NumFOCUS Privacy Policy is here: 44 | https://www.numfocus.org/privacy-policy 45 | 46 | 47 | # Full Code of Conduct 48 | The full text of the NumFOCUS/Simuk Code of Conduct can be found on 49 | NumFOCUS's website 50 | https://numfocus.org/code-of-conduct 51 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Simuk 2 | This document outlines only the most common contributions. 3 | Please see the [Contributing guide](https://arviz-devs.github.io/arviz/contributing/index.html) 4 | on our documentation for a better view of how can you contribute to Simuk. 5 | We welcome a wide range of contributions, not only code! 6 | 7 | ## Reporting issues 8 | If you encounter any bug or incorrect behaviour while using Simuk, 9 | please report an issue to our [issue tracker](https://github.com/arviz-devs/simuk/issues). 10 | Please include any supporting information, in particular the version of 11 | Simuk that you are using. 12 | The issue tracker has several templates available to help in writing the issue 13 | and including useful supporting information. 14 | 15 | ## Contributing code 16 | Thanks for your interest in contributing code to Simuk! 17 | 18 | * If this is your first time contributing to a project on GitHub, please read through our step by step guide to contributing to Simuk 19 | * If you have contributed to other projects on GitHub you can go straight to our [development workflow]() 20 | 21 | ### Adding new features 22 | If you are interested in adding a new feature to Simuk, 23 | first submit an issue using the "Feature Request" label for the community 24 | to discuss its place and implementation within Simuk. 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Colin Carroll 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simulation Based Calibration 2 | 3 | A [PyMC](http://docs.pymc.io) and [Bambi](https://bambinos.github.io/bambi/) implementation of the algorithms from: 4 | 5 | Sean Talts, Michael Betancourt, Daniel Simpson, Aki Vehtari, Andrew Gelman: “Validating Bayesian Inference Algorithms with Simulation-Based Calibration”, 2018; [arXiv:1804.06788](http://arxiv.org/abs/1804.06788) 6 | 7 | Many thanks to the authors for providing open, reproducible code and implementations in `rstan` and `PyStan` ([link](https://github.com/seantalts/simulation-based-calibration)). 8 | 9 | 10 | ## Installation 11 | 12 | May be pip installed from github: 13 | 14 | ```bash 15 | pip install simuk 16 | ``` 17 | 18 | ## Quickstart 19 | 20 | 1. Define a PyMC or Bambi model. For example, the centered eight schools model: 21 | 22 | ```python 23 | import numpy as np 24 | import pymc as pm 25 | 26 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 27 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 28 | 29 | with pm.Model() as centered_eight: 30 | mu = pm.Normal('mu', mu=0, sigma=5) 31 | tau = pm.HalfCauchy('tau', beta=5) 32 | theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8) 33 | y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data) 34 | ``` 35 | 2. Pass the model to the `SBC` class, and run the simulations. This will take a while, as it is running the model many times. 36 | ```python 37 | sbc = SBC(centered_eight, 38 | num_simulations=100, # ideally this should be higher, like 1000 39 | sample_kwargs={'draws': 25, 'tune': 50}) 40 | 41 | sbc.run_simulations() 42 | ``` 43 | ```python 44 | 79%|███████▉ | 79/100 [05:36<01:29, 4.27s/it] 45 | ``` 46 | 47 | 3. Plot the empirical CDF for the difference between prior and posterior. The lines 48 | should be close to uniform and within the oval envelope. 49 | 50 | ```python 51 | sbc.plot_results() 52 | ``` 53 | 54 | ![Simulation based calibration plots, ecdf](ecdf.png) 55 | 56 | 57 | ## What is going on here? 58 | 59 | The [paper on the arXiv](http://arxiv.org/abs/1804.06788) is very well written, and explains the algorithm quite well. 60 | 61 | Morally, the example below is exactly what this library does, but it generalizes to more complicated models: 62 | 63 | ```python 64 | with pm.Model() as model: 65 | x = pm.Normal('x') 66 | pm.Normal('y', mu=x, observed=y) 67 | ``` 68 | 69 | Then what this library does is compute 70 | 71 | ```python 72 | with my_model(): 73 | prior_samples = pm.sample_prior_predictive(num_trials) 74 | 75 | simulations = {'x': []} 76 | for idx in range(num_trials): 77 | y_tilde = prior_samples['y'][idx] 78 | x_tilde = prior_samples['x'][idx] 79 | with model(y=y_tilde): 80 | idata = pm.sample() 81 | simulations['x'].append((idata.posterior['x'] < x_tilde).sum()) 82 | ``` 83 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = # "-W" treats warnings as errors 6 | SPHINXBUILD ?= sphinx-multiversion 7 | SOURCEDIR = . 8 | BUILDDIR ?= _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | clean: 17 | rm -rf $(BUILDDIR)/* 18 | 19 | # For local build 20 | local: 21 | sphinx-build "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | # Catch-all target: route all unknown targets to Sphinx using the new 24 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 25 | %: Makefile 26 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 27 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | .. rubric:: Methods 8 | 9 | {% for item in methods if item != "__init__" %} 10 | .. automethod:: {{ objname }}.{{ item }} 11 | {%- endfor %} 12 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/function.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autofunction:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | This reference provides detailed documentation for user functions in the current release of Simuk. 5 | 6 | Simulation based calibration 7 | ---------------------------- 8 | 9 | .. currentmodule:: simuk 10 | 11 | .. autosummary:: 12 | :toctree: generated/ 13 | 14 | SBC 15 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | .. include:: ../CHANGELOG.md 5 | :parser: myst_parser.sphinx_ -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | 6 | from simuk import __version__ 7 | 8 | sys.path.insert(0, os.path.abspath("../simuk")) 9 | 10 | # -- Project information ----------------------------------------------------- 11 | 12 | project = "Simuk" 13 | author = "ArviZ contributors" 14 | copyright = f"2025, {author}" 15 | 16 | # The short X.Y version 17 | version = __version__ 18 | # The full version, including alpha/beta/rc tags 19 | release = __version__ 20 | 21 | 22 | # -- General configuration --------------------------------------------------- 23 | extensions = [ 24 | "sphinx.ext.autodoc", 25 | "sphinx.ext.autosummary", 26 | "sphinx.ext.viewcode", 27 | "sphinx.ext.napoleon", 28 | "sphinx.ext.mathjax", 29 | "sphinx_copybutton", 30 | "myst_nb", 31 | "matplotlib.sphinxext.plot_directive", 32 | "sphinx_tabs.tabs", 33 | "sphinx_design", 34 | "numpydoc", 35 | "jupyter_sphinx", 36 | ] 37 | 38 | # -- Extension configuration ------------------------------------------------- 39 | nb_execution_mode = "auto" 40 | nb_execution_excludepatterns = ["*.ipynb"] 41 | nb_kernel_rgx_aliases = {".*": "python3"} 42 | myst_enable_extensions = ["colon_fence", "deflist", "dollarmath"] 43 | autosummary_generate = True 44 | autodoc_member_order = "bysource" 45 | numpydoc_show_class_members = False 46 | numpydoc_show_inherited_class_members = False 47 | numpydoc_class_members_toctree = False 48 | 49 | source_suffix = ".rst" 50 | 51 | master_doc = "index" 52 | language = "en" 53 | templates_path = ["_templates"] 54 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 55 | pygments_style = "sphinx" 56 | html_css_files = ["custom.css"] 57 | html_title = "Simuk" 58 | html_short_title = "Simuk" 59 | html_theme = "pydata_sphinx_theme" 60 | 61 | html_theme_options = { 62 | "collapse_navigation": True, 63 | "show_toc_level": 2, 64 | "navigation_depth": 4, 65 | "search_bar_text": "Search the docs...", 66 | "icon_links": [ 67 | { 68 | "name": "GitHub", 69 | "url": "https://github.com/arviz-devs/simuk", 70 | "icon": "fa-brands fa-github", 71 | }, 72 | ], 73 | # "logo": { 74 | # "image_light": "Simuk_flat.png", 75 | # "image_dark": "Simuk_flat_white.png", 76 | # }, 77 | } 78 | html_context = { 79 | "github_user": "arviz-devs", 80 | "github_repo": "simuk", 81 | "github_version": "main", 82 | "doc_path": "docs/", 83 | "default_mode": "light", 84 | } 85 | 86 | 87 | # -- Options for HTMLHelp output --------------------------------------------- 88 | 89 | htmlhelp_basename = "simukdoc" 90 | 91 | 92 | # -- Options for LaTeX output ------------------------------------------------ 93 | 94 | latex_documents = [ 95 | (master_doc, "simuk.tex", "simuk Documentation", "The developers of simuk", "manual"), 96 | ] 97 | 98 | 99 | # -- Options for manual page output ------------------------------------------ 100 | 101 | man_pages = [(master_doc, "simuk", "simuk Documentation", [author], 1)] 102 | 103 | 104 | # -- Options for Texinfo output ---------------------------------------------- 105 | 106 | texinfo_documents = [ 107 | ( 108 | master_doc, 109 | "simuk", 110 | "simuk Documentation", 111 | author, 112 | "simuk", 113 | "One line description of project.", 114 | "Miscellaneous", 115 | ), 116 | ] 117 | 118 | 119 | # -- Options for Epub output ------------------------------------------------- 120 | 121 | # Bibliographic Dublin Core info. 122 | epub_title = project 123 | 124 | epub_exclude_files = ["search.html"] 125 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | We welcome contributions from interested individuals or groups! For information about contributing to Simuk check out our instructions, policies, and guidelines `here `_. 5 | 6 | See the `GitHub contributor page `_. 7 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | The gallery below presents examples that demonstrate the use of Simuk. 5 | 6 | .. grid:: 1 2 3 3 7 | :gutter: 2 2 3 3 8 | 9 | .. grid-item-card:: 10 | :link: ./examples/gallery/sbc.html 11 | :text-align: center 12 | :shadow: none 13 | :class-card: example-gallery 14 | 15 | .. image:: examples/img/sbc.png 16 | :alt: SBC 17 | 18 | +++ 19 | SBC 20 | -------------------------------------------------------------------------------- /docs/examples/gallery/sbc.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | kernelspec: 7 | display_name: Python 3 8 | language: python 9 | name: python3 10 | --- 11 | 12 | # Simulation based calibration 13 | 14 | This example demonstrates how to use the `SBC` class for simulation-based calibration, supporting both PyMC and Bambi models. 15 | 16 | ```{jupyter-execute} 17 | 18 | from arviz_plots import plot_ecdf_pit, style 19 | import numpy as np 20 | import simuk 21 | style.use("arviz-variat") 22 | ``` 23 | 24 | ::::::{tab-set} 25 | :class: full-width 26 | 27 | :::::{tab-item} PyMC 28 | :sync: pymc 29 | 30 | First, define a PyMC model. In this example, we will use the centered eight schools model. 31 | 32 | ```{jupyter-execute} 33 | 34 | import pymc as pm 35 | 36 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 37 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 38 | 39 | with pm.Model() as centered_eight: 40 | mu = pm.Normal('mu', mu=0, sigma=5) 41 | tau = pm.HalfCauchy('tau', beta=5) 42 | theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8) 43 | y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data) 44 | ``` 45 | 46 | Pass the model to the SBC class, set the number of simulations to 100, and run the simulations. This process may take 47 | some time since the model runs multiple times (100 in this example). 48 | 49 | ```{jupyter-execute} 50 | 51 | sbc = simuk.SBC(centered_eight, 52 | num_simulations=100, 53 | sample_kwargs={'draws': 25, 'tune': 50}) 54 | 55 | sbc.run_simulations(); 56 | ``` 57 | 58 | To compare the prior and posterior distributions, we will plot the results from the simulations, 59 | using the ArviZ function `plot_ecdf_pit`. 60 | We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval. 61 | 62 | ```{jupyter-execute} 63 | 64 | plot_ecdf_pit(sbc.simulations, 65 | pc_kwargs={'col_wrap':4}, 66 | plot_kwargs={"xlabel":False}, 67 | ) 68 | ``` 69 | 70 | ::::: 71 | 72 | :::::{tab-item} Bambi 73 | :sync: bambi 74 | 75 | Now, we define a Bambi Model. 76 | 77 | ```{jupyter-execute} 78 | 79 | import bambi as bmb 80 | import pandas as pd 81 | 82 | x = np.random.normal(0, 1, 200) 83 | y = 2 + np.random.normal(x, 1) 84 | df = pd.DataFrame({"x": x, "y": y}) 85 | bmb_model = bmb.Model("y ~ x", df) 86 | ``` 87 | 88 | Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations. 89 | This process may take some time, as the model runs multiple times 90 | 91 | ```{jupyter-execute} 92 | 93 | sbc = simuk.SBC(bmb_model, 94 | num_simulations=100, 95 | sample_kwargs={'draws': 25, 'tune': 50}) 96 | 97 | sbc.run_simulations(); 98 | ``` 99 | 100 | To compare the prior and posterior distributions, we will plot the results from the simulations. 101 | We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval. 102 | 103 | ```{jupyter-execute} 104 | plot_ecdf_pit(sbc.simulations) 105 | ``` 106 | 107 | ::::: 108 | 109 | :::::{tab-item} Numpyro 110 | :sync: numpyro 111 | 112 | We define a Numpyro Model, we use the centered eight schools model. 113 | 114 | ```{jupyter-execute} 115 | import numpyro 116 | import numpyro.distributions as dist 117 | from jax import random 118 | from numpyro.infer import NUTS 119 | 120 | y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 121 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 122 | 123 | def eight_schools_cauchy_prior(J, sigma, y=None): 124 | mu = numpyro.sample("mu", dist.Normal(0, 5)) 125 | tau = numpyro.sample("tau", dist.HalfCauchy(5)) 126 | with numpyro.plate("J", J): 127 | theta = numpyro.sample("theta", dist.Normal(mu, tau)) 128 | numpyro.sample("y", dist.Normal(theta, sigma), obs=y) 129 | 130 | # We use the NUTS sampler 131 | nuts_kernel = NUTS(eight_schools_cauchy_prior) 132 | ``` 133 | 134 | Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations. For numpyro model, 135 | we pass in the ``data_dir`` parameter. 136 | 137 | ```{jupyter-execute} 138 | sbc = simuk.SBC(nuts_kernel, 139 | sample_kwargs={"num_warmup": 50, "num_samples": 75}, 140 | num_simulations=100, 141 | data_dir={"J": 8, "sigma": sigma, "y": y}, 142 | ) 143 | sbc.run_simulations() 144 | ``` 145 | 146 | To compare the prior and posterior distributions, we will plot the results. 147 | We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval. 148 | 149 | ```{jupyter-execute} 150 | plot_ecdf_pit(sbc.simulations, 151 | pc_kwargs={'col_wrap':4}, 152 | plot_kwargs={"xlabel":False} 153 | ) 154 | ``` 155 | -------------------------------------------------------------------------------- /docs/examples/img/sbc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/simuk/26d769f578a54546f23d8a48df24567e30fdd19a/docs/examples/img/sbc.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | Simuk is a Python library for simulation-based calibration (SBC) and the generation of synthetic data. 5 | Simulation-Based Calibration (SBC) is a method for validating Bayesian inference by checking whether the 6 | posterior distributions align with the expected theoretical results derived from the prior. 7 | 8 | Quickstart 9 | ---------- 10 | 11 | This quickstart guide provides a simple example to help you get started. If you're looking for more examples 12 | and use cases, be sure to check out the :doc:`examples` section. 13 | 14 | To use SBC, you need to define a model function that generates simulated data and corresponding prior predictive 15 | samples, then compare them to posterior samples obtained through inference. 16 | 17 | In our case, we will take a PyMC model and pass it into our ``SBC`` class. 18 | 19 | .. code-block:: python 20 | 21 | import numpy as np 22 | import pymc as pm 23 | 24 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 25 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 26 | 27 | with pm.Model() as centered_eight: 28 | mu = pm.Normal('mu', mu=0, sigma=5) 29 | tau = pm.HalfCauchy('tau', beta=5) 30 | theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8) 31 | y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data) 32 | 33 | # Pass it into the SBC class 34 | sbc = simuk.SBC(centered_eight, num_simulations=100, sample_kwargs={'draws': 25, 'tune': 50}) 35 | 36 | Now, we use the ``run_simulations`` method to generate and analyze simulated data, running the model multiple times to 37 | compare prior and posterior distributions. 38 | 39 | .. code-block:: python 40 | 41 | sbc.run_simulations() 42 | 43 | Plot the empirical CDF to compare the differences between the prior and posterior. 44 | 45 | .. code-block:: python 46 | 47 | sbc.plot_results() 48 | 49 | The lines should be nearly uniform and fall within the oval envelope. It suggests that the prior and posterior distributions 50 | are properly aligned and that there are no significant biases or issues with the model. 51 | 52 | .. toctree:: 53 | :maxdepth: 1 54 | :hidden: 55 | :caption: Getting Started 56 | 57 | Overview 58 | installation 59 | 60 | .. toctree:: 61 | :maxdepth: 2 62 | :hidden: 63 | :caption: API documentation 64 | 65 | api/index.rst 66 | 67 | .. toctree:: 68 | :maxdepth: 2 69 | :hidden: 70 | :caption: Examples 71 | 72 | examples 73 | 74 | .. toctree:: 75 | :maxdepth: 1 76 | :hidden: 77 | :caption: References 78 | 79 | contributing 80 | changelog 81 | 82 | References 83 | ---------- 84 | 85 | - Talts, Sean, Michael Betancourt, Daniel Simpson, Aki Vehtari, and Andrew Gelman. 2018. “Validating Bayesian Inference Algorithms with Simulation-Based Calibration.” `arXiv:1804.06788 `_. 86 | - Modrák, M., Moon, A. H., Kim, S., Bürkner, P., Huurre, N., Faltejsková, K., … & Vehtari, A. (2023). Simulation-based calibration checking for Bayesian computation: The choice of test quantities shapes sensitivity. Bayesian Analysis, advance publication, DOI: `10.1214/23-BA1404 `_ 87 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Simuk is tested and compatible with **Python 3.11 and later**. It depends on Arviz, PyMC, and tqdm. For specific version details, 5 | check the `pyproject.toml `_ file. 6 | 7 | The latest release of simuk can be installed using ``pip`` : 8 | 9 | Using pip 10 | --------- 11 | 12 | .. code-block:: bash 13 | 14 | pip install simuk 15 | 16 | 17 | Development Version 18 | ------------------- 19 | 20 | The latest development version can be installed from the main branch using ``pip``: 21 | 22 | .. code-block:: bash 23 | 24 | pip install git+https://github.com/arviz-devs/simuk.git 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /ecdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/simuk/26d769f578a54546f23d8a48df24567e30fdd19a/ecdf.png -------------------------------------------------------------------------------- /environment-dev.yml: -------------------------------------------------------------------------------- 1 | name: simuk-dev 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python >= 3.11 7 | - pip 8 | - python >= 3.11 9 | - pymc>=5.20.1 10 | - bambi>=0.13.0 11 | - arviz>=0.20.0 12 | - black=22.3.0 13 | - click=8.0.4 14 | - pytest-cov>=2.6.1 15 | - pytest>=4.4.0 16 | - pre-commit>=2.19 17 | - ruff==0.9.1 18 | 19 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: simuk 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python >= 3.11 7 | - pymc>=5.20.1 8 | - arviz>=0.20.0 9 | -------------------------------------------------------------------------------- /hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/simuk/26d769f578a54546f23d8a48df24567e30fdd19a/hist.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.4,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "simuk" 7 | readme = "README.md" 8 | requires-python = ">=3.11" 9 | license = {file = "LICENSE"} 10 | authors = [ 11 | {name = "ArviZ team", email = "arviz.devs@gmail.com"} 12 | ] 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "Intended Audience :: Science/Research", 16 | "Intended Audience :: Education", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Operating System :: OS Independent", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | ] 24 | dynamic = ["version"] 25 | description = "Simulation based calibration and generation of synthetic data." 26 | dependencies = [ 27 | "arviz_base>=0.5.0", 28 | "tqdm" 29 | ] 30 | 31 | [tool.flit.module] 32 | name = "simuk" 33 | 34 | [project.urls] 35 | source = "https://github.com/arviz-devs/simuk" 36 | tracker = "https://github.com/arviz-devs/simuk/issues" 37 | funding = "https://opencollective.com/arviz" 38 | 39 | 40 | [tool.black] 41 | line-length = 100 42 | 43 | [tool.isort] 44 | profile = "black" 45 | include_trailing_comma = true 46 | use_parentheses = true 47 | multi_line_output = 3 48 | line_length = 100 49 | 50 | [tool.pydocstyle] 51 | convention = "numpy" 52 | 53 | [tool.pytest.ini_options] 54 | testpaths = [ 55 | "tests", 56 | ] 57 | 58 | [tool.ruff] 59 | line-length = 100 60 | 61 | [tool.ruff.lint] 62 | select = [ 63 | "F", # Pyflakes 64 | "E", # Pycodestyle 65 | "W", # Pycodestyle 66 | "D", # pydocstyle 67 | "NPY", # numpy specific rules 68 | "UP", # pyupgrade 69 | "I", # isort 70 | "PL", # Pylint 71 | "TID", # Absolute imports 72 | ] 73 | ignore = [ 74 | "PLR0912", # too many branches 75 | "PLR0913", # too many arguments 76 | "PLR2004", # magic value comparison 77 | "PLR0915", # too many statements 78 | "NPY002", # Replace legacy `np.random.randn` call with `np.random.Generator` 79 | "D1" # Missing docstring 80 | ] 81 | 82 | [tool.ruff.lint.per-file-ignores] 83 | "docs/source/**/*.ipynb" = ["D", "E", "F", "I", "NPY", "PL", "TID", "UP", "W"] 84 | "simuk/__init__.py" = ["I", "F401", "E402", "F403"] 85 | "simuk/tests/**/*" = ["D", "PLR2004", "TID252"] 86 | "simuk/tests/**/*.ipynb" = ["E", "F"] 87 | 88 | [tool.ruff.lint.pydocstyle] 89 | convention = "numpy" 90 | 91 | [tool.ruff.lint.flake8-tidy-imports] 92 | ban-relative-imports = "all" # Disallow all relative imports. 93 | 94 | [tool.ruff.format] 95 | docstring-code-format = false 96 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==22.3.0 2 | click==8.0.4 3 | pytest-cov>=2.6.1 4 | pytest>=4.4.0 5 | pre-commit>=2.19 6 | ipytest==0.13.0 7 | pymc>=5.20.1 8 | bambi>=0.13.0 9 | arviz_base>=0.5.0 10 | ruff==0.9.1 11 | numpyro>=0.17.0 12 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | pydata-sphinx-theme>=0.6.3 2 | myst-nb 3 | pymc>=5.20.1 4 | bambi>=0.15.0 5 | arviz_plots @ git+https://github.com/arviz-devs/arviz-plots@main 6 | sphinx>=4 7 | sphinx-copybutton 8 | sphinx_tabs 9 | sphinx-design 10 | numpydoc 11 | jupyter-sphinx 12 | numpyro>=0.17.0 13 | -------------------------------------------------------------------------------- /simuk/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simuk. 3 | 4 | Simulation based calibration and other tools for evaluation using synthetic data. 5 | """ 6 | 7 | from simuk.sbc import SBC 8 | 9 | 10 | __version__ = "0.2.0" 11 | -------------------------------------------------------------------------------- /simuk/sbc.py: -------------------------------------------------------------------------------- 1 | """Simulation based calibration (Talts et. al. 2018) in PyMC.""" 2 | 3 | import logging 4 | from copy import copy 5 | from importlib.metadata import version 6 | 7 | try: 8 | import pymc as pm 9 | except ImportError: 10 | pass 11 | try: 12 | import jax 13 | from numpyro.handlers import seed, trace 14 | from numpyro.infer import MCMC, Predictive 15 | from numpyro.infer.mcmc import MCMCKernel 16 | except ImportError: 17 | pass 18 | 19 | import numpy as np 20 | from arviz_base import extract, from_dict, from_numpyro 21 | from tqdm import tqdm 22 | 23 | 24 | class quiet_logging: 25 | """Turn off logging for PyMC, Bambi and PyTensor.""" 26 | 27 | def __init__(self, *libraries): 28 | self.loggers = [logging.getLogger(library) for library in libraries] 29 | 30 | def __call__(self, func): 31 | def wrapped(cls, *args, **kwargs): 32 | levels = [] 33 | for logger in self.loggers: 34 | levels.append(logger.level) 35 | logger.setLevel(logging.CRITICAL) 36 | res = func(cls, *args, **kwargs) 37 | for logger, level in zip(self.loggers, levels): 38 | logger.setLevel(level) 39 | return res 40 | 41 | return wrapped 42 | 43 | 44 | class SBC: 45 | """Set up class for doing SBC. 46 | 47 | Parameters 48 | ---------- 49 | model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel 50 | A PyMC, Bambi model or Numpyro MCMC kernel. If a PyMC model the data needs to be defined as 51 | mutable data. 52 | num_simulations : int 53 | How many simulations to run 54 | sample_kwargs : dict[str] -> Any 55 | Arguments passed to pymc.sample or bambi.Model.fit 56 | seed : int (optional) 57 | Random seed. This persists even if running the simulations is 58 | paused for whatever reason. 59 | data_dir : dict 60 | Keyword arguments passed to numpyro model, intended for use when providing 61 | an MCMC Kernel model. 62 | 63 | Example 64 | ------- 65 | 66 | .. code-block :: python 67 | 68 | with pm.Model() as model: 69 | x = pm.Normal('x') 70 | y = pm.Normal('y', mu=2 * x, observed=obs) 71 | 72 | sbc = SBC(model) 73 | sbc.run_simulations() 74 | sbc.plot_results() 75 | 76 | """ 77 | 78 | def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None): 79 | if hasattr(model, "basic_RVs") and isinstance(model, pm.Model): 80 | self.engine = "pymc" 81 | self.model = model 82 | elif hasattr(model, "formula"): 83 | self.engine = "bambi" 84 | model.build() 85 | self.bambi_model = model 86 | self.model = model.backend.model 87 | self.formula = model.formula 88 | self.new_data = copy(model.data) 89 | elif isinstance(model, MCMCKernel): 90 | self.engine = "numpyro" 91 | self.numpyro_model = model 92 | self.model = self.numpyro_model.model 93 | self.run_simulations = self._run_simulations_numpyro 94 | self.data_dir = data_dir 95 | else: 96 | raise ValueError( 97 | "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel" 98 | ) 99 | self.num_simulations = num_simulations 100 | if sample_kwargs is None: 101 | sample_kwargs = {} 102 | if self.engine == "numpyro": 103 | sample_kwargs.setdefault("num_warmup", 1000) 104 | sample_kwargs.setdefault("num_samples", 1000) 105 | sample_kwargs.setdefault("progress_bar", False) 106 | else: 107 | sample_kwargs.setdefault("progressbar", False) 108 | sample_kwargs.setdefault("compute_convergence_checks", False) 109 | self.sample_kwargs = sample_kwargs 110 | self.seed = seed 111 | self._seeds = self._get_seeds() 112 | self._extract_variable_names() 113 | self.simulations = {name: [] for name in self.var_names} 114 | self._simulations_complete = 0 115 | 116 | def _extract_variable_names(self): 117 | """Extract observed and free variables from the model.""" 118 | if self.engine == "numpyro": 119 | with trace() as tr: 120 | with seed(rng_seed=int(self._seeds[0])): 121 | self.numpyro_model.model(**self.data_dir) 122 | self.var_names = [ 123 | name 124 | for name, site in tr.items() 125 | if site["type"] == "sample" and not site.get("is_observed", False) 126 | ] 127 | self.observed_vars = [ 128 | name 129 | for name, site in tr.items() 130 | if site["type"] == "sample" and site.get("is_observed", False) 131 | ] 132 | else: 133 | self.observed_vars = [obs.name for obs in self.model.observed_RVs] 134 | self.var_names = [v.name for v in self.model.free_RVs] 135 | 136 | def _get_seeds(self): 137 | """Set the random seed, and generate seeds for all the simulations.""" 138 | rng = np.random.default_rng(self.seed) 139 | return rng.integers(0, 2**30, size=self.num_simulations) 140 | 141 | def _get_prior_predictive_samples(self): 142 | """Generate samples to use for the simulations.""" 143 | with self.model: 144 | idata = pm.sample_prior_predictive( 145 | samples=self.num_simulations, random_seed=self._seeds[0] 146 | ) 147 | prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) 148 | prior = extract(idata, group="prior", keep_dataset=True) 149 | return prior, prior_pred 150 | 151 | def _get_prior_predictive_samples_numpyro(self): 152 | """Generate samples to use for the simulations using numpyro.""" 153 | predictive = Predictive(self.model, num_samples=self.num_simulations) 154 | free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} 155 | samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) 156 | prior = {k: v for k, v in samples.items() if k not in self.observed_vars} 157 | prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} 158 | return prior, prior_pred 159 | 160 | def _get_posterior_samples(self, prior_predictive_draw): 161 | """Generate posterior samples conditioned to a prior predictive sample.""" 162 | new_model = pm.observe(self.model, prior_predictive_draw) 163 | with new_model: 164 | check = pm.sample( 165 | **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete] 166 | ) 167 | 168 | posterior = extract(check, group="posterior", keep_dataset=True) 169 | return posterior 170 | 171 | def _get_posterior_samples_numpyro(self, prior_predictive_draw): 172 | """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" 173 | mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) 174 | rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) 175 | free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} 176 | mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw) 177 | return from_numpyro(mcmc)["posterior"] 178 | 179 | def _convert_to_datatree(self): 180 | self.simulations = from_dict( 181 | {"prior_sbc": self.simulations}, 182 | attrs={ 183 | "/": { 184 | "inferece_library": self.engine, 185 | "inferece_library_version": version(self.engine), 186 | "modeling_interface": "simuk", 187 | "modeling_interface_version": version("simuk"), 188 | } 189 | }, 190 | ) 191 | 192 | @quiet_logging("pymc", "pytensor.gof.compilelock", "bambi") 193 | def run_simulations(self): 194 | """Run all the simulations. 195 | 196 | This function can be stopped and restarted on the same instance, so you can 197 | keyboard interrupt part way through, look at the plot, and then resume. If a 198 | seed was passed initially, it will still be respected (that is, the resulting 199 | simulations will be identical to running without pausing in the middle). 200 | """ 201 | prior, prior_pred = self._get_prior_predictive_samples() 202 | 203 | progress = tqdm( 204 | initial=self._simulations_complete, 205 | total=self.num_simulations, 206 | ) 207 | try: 208 | while self._simulations_complete < self.num_simulations: 209 | idx = self._simulations_complete 210 | prior_predictive_draw = { 211 | var_name: prior_pred[var_name].sel(chain=0, draw=idx).values 212 | for var_name in self.observed_vars 213 | } 214 | 215 | posterior = self._get_posterior_samples(prior_predictive_draw) 216 | for name in self.var_names: 217 | self.simulations[name].append( 218 | (posterior[name] < prior[name].sel(chain=0, draw=idx)).sum("sample").values 219 | ) 220 | self._simulations_complete += 1 221 | progress.update() 222 | finally: 223 | self.simulations = { 224 | k: np.stack(v[: self._simulations_complete])[None, :] 225 | for k, v in self.simulations.items() 226 | } 227 | self._convert_to_datatree() 228 | progress.close() 229 | 230 | @quiet_logging("numpyro") 231 | def _run_simulations_numpyro(self): 232 | """Run all the simulations for Numpyro Model.""" 233 | prior, prior_pred = self._get_prior_predictive_samples_numpyro() 234 | progress = tqdm( 235 | initial=self._simulations_complete, 236 | total=self.num_simulations, 237 | ) 238 | try: 239 | while self._simulations_complete < self.num_simulations: 240 | idx = self._simulations_complete 241 | prior_predictive_draw = {k: v[idx] for k, v in prior_pred.items()} 242 | posterior = self._get_posterior_samples_numpyro(prior_predictive_draw) 243 | for name in self.var_names: 244 | self.simulations[name].append( 245 | (posterior[name].sel(chain=0) < prior[name][idx]).sum(axis=0).values 246 | ) 247 | self._simulations_complete += 1 248 | progress.update() 249 | finally: 250 | self.simulations = { 251 | k: np.stack(v[: self._simulations_complete])[None, :] 252 | for k, v in self.simulations.items() 253 | } 254 | self._convert_to_datatree() 255 | progress.close() 256 | -------------------------------------------------------------------------------- /simuk/tests/test_sbc.py: -------------------------------------------------------------------------------- 1 | import bambi as bmb 2 | import numpy as np 3 | import numpyro 4 | import numpyro.distributions as dist 5 | import pandas as pd 6 | import pymc as pm 7 | import pytest 8 | from numpyro.infer import NUTS 9 | 10 | import simuk 11 | 12 | np.random.seed(1234) 13 | 14 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 15 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 16 | 17 | with pm.Model() as centered_eight: 18 | mu = pm.Normal("mu", mu=0, sigma=5) 19 | tau = pm.HalfCauchy("tau", beta=5) 20 | theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8) 21 | y_obs = pm.Normal("y", mu=theta, sigma=sigma, observed=data) 22 | 23 | x = np.random.normal(0, 1, 20) 24 | y = 2 + np.random.normal(x, 1) 25 | df = pd.DataFrame({"x": x, "y": y}) 26 | bmb_model = bmb.Model("y ~ x", df) 27 | 28 | 29 | @pytest.mark.parametrize("model", [centered_eight, bmb_model]) 30 | def test_sbc(model): 31 | sbc = simuk.SBC( 32 | model, 33 | num_simulations=10, 34 | sample_kwargs={"draws": 5, "tune": 5}, 35 | ) 36 | sbc.run_simulations() 37 | assert "prior_sbc" in sbc.simulations 38 | 39 | 40 | def test_sbc_numpyro(): 41 | y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 42 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 43 | 44 | def eight_schools_cauchy_prior(J, sigma, y=None): 45 | mu = numpyro.sample("mu", dist.Normal(0, 5)) 46 | tau = numpyro.sample("tau", dist.HalfCauchy(5)) 47 | with numpyro.plate("J", J): 48 | theta = numpyro.sample("theta", dist.Normal(mu, tau)) 49 | numpyro.sample("y", dist.Normal(theta, sigma), obs=y) 50 | 51 | sbc = simuk.SBC( 52 | NUTS(eight_schools_cauchy_prior), 53 | data_dir={"J": 8, "sigma": sigma, "y": y}, 54 | num_simulations=10, 55 | sample_kwargs={"num_warmup": 50, "num_samples": 25}, 56 | ) 57 | sbc.run_simulations() 58 | assert "prior_sbc" in sbc.simulations 59 | --------------------------------------------------------------------------------