├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── usage-question.md
├── pull_request_template.md
└── workflows
│ ├── post-release.yml
│ ├── publish-to-test-pypi.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _templates
│ └── autosummary
│ │ ├── class.rst
│ │ └── function.rst
├── api
│ └── index.rst
├── changelog.rst
├── conf.py
├── contributing.rst
├── examples.rst
├── examples
│ ├── gallery
│ │ └── sbc.md
│ └── img
│ │ └── sbc.png
├── index.rst
├── installation.rst
└── make.bat
├── ecdf.png
├── environment-dev.yml
├── environment.yml
├── hist.png
├── pyproject.toml
├── requirements-dev.txt
├── requirements-docs.txt
└── simuk
├── __init__.py
├── sbc.py
└── tests
└── test_sbc.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behaviour. Ideally a self-contained snippet of code, or link to a notebook or external code. Please include screenshots/images produced with simuk here, or the stack trace including `simuk` code to help.
15 |
16 | **Expected behaviors**
17 | A clear and concise description of what you expected to happen.
18 |
19 | **Additional context**
20 | Versions of `simuk` and other libraries used, operating system used, and anything else that may be useful.
21 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Tell us about it
11 |
12 | The more specific the better.
13 |
14 | ## Thoughts on implementation
15 |
16 | Not required, but if you have thoughts on how to implement the feature, that can be helpful! In case there are academic references, we welcome those here too.
17 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/usage-question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Usage Question
3 | about: General questions about simuk usage
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Short Description
11 |
12 | Let us know what you're trying to do, and we can point you in the right direction. Screenshots/plots or stack traces that include `simuk` are particularly helpful.
13 |
14 | ## Code Example or link
15 |
16 | Please provide a minimal, self-contained, and reproducible example demonstrating what you're trying to do. Ideally, it will be a code snippet, a link to a notebook, or a link to code that can be run on another user's computer.
17 |
18 | Also include the simuk version and the version of any other relevant packages.
19 |
20 | ## Relevant documentation or public examples
21 |
22 | Please provide documentation, public examples, or any additional information which may be relevant to your question
23 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Description
2 |
14 |
15 | ## Checklist
16 |
17 |
18 | - [ ] Code style is correct (follows ruff and black guidelines)
19 | - [ ] Includes new or updated tests to cover the new feature
20 | - [ ] New features are properly documented (with an example if appropriate)
21 | - [ ] Includes a sample plot to visually illustrate the changes (only for plot-related functions)
22 |
23 |
34 |
--------------------------------------------------------------------------------
/.github/workflows/post-release.yml:
--------------------------------------------------------------------------------
1 | name: Post-release
2 | on:
3 | release:
4 | types: [published, released]
5 | workflow_dispatch:
6 |
7 | jobs:
8 | changelog:
9 | name: Update changelog
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | with:
14 | ref: main
15 | - uses: rhysd/changelog-from-release/action@v3
16 | with:
17 | file: CHANGELOG.md
18 | github_token: ${{ secrets.GITHUB_TOKEN }}
19 | commit_summary_template: 'update changelog for %s changes'
20 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-test-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish tagged releases to PyPI and TestPyPI
2 |
3 | on:
4 | release:
5 | types:
6 | - created
7 |
8 | jobs:
9 | build-n-publish:
10 | name: Build and publish Python distributions to PyPI and TestPyPI
11 | runs-on: ubuntu-latest
12 | environment:
13 | name: pypi
14 | url: https://pypi.org/p/simuk
15 | permissions:
16 | id-token: write
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up Python 3.11
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.11"
23 |
24 | - name: Install pypa/build
25 | run: >-
26 | python -m
27 | pip install
28 | build
29 | --user
30 | - name: Build a binary wheel and a source tarball
31 | run: >-
32 | python -m
33 | build
34 | --sdist
35 | --wheel
36 | --outdir dist/
37 |
38 | - name: Publish package distributions to PyPI
39 | uses: pypa/gh-action-pypi-publish@release/v1
40 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Run tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths-ignore:
8 | - "docs/**"
9 | pull_request:
10 | paths-ignore:
11 | - "docs/**"
12 |
13 | jobs:
14 | test:
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: ["3.11", "3.12", "3.13"]
19 |
20 | name: Set up Python ${{ matrix.python-version }}
21 | steps:
22 | - uses: actions/checkout@v4
23 | with:
24 | fetch-depth: 0
25 |
26 | - name: Set up Python ${{ matrix.python-version }}
27 | uses: conda-incubator/setup-miniconda@v3
28 | with:
29 | channels: conda-forge, defaults
30 | channel-priority: true
31 | python-version: ${{ matrix.python-version }}
32 | auto-update-conda: true
33 |
34 | - name: Install simuk
35 | shell: bash -l {0}
36 | run: |
37 | conda install pip
38 | pip install -r requirements-dev.txt
39 | pip install .
40 | python --version
41 | conda list
42 | pip freeze
43 | - name: Run linters
44 | shell: bash -l {0}
45 | run: |
46 | python -m black simuk --check
47 | echo "Success!"
48 | echo "Checking code style with ruff..."
49 | ruff check simuk/
50 | - name: Run tests
51 | shell: bash -l {0}
52 | run: |
53 | python -m pytest -vv --cov=simuk --cov-report=term --cov-report=xml simuk/tests
54 | env:
55 | PYTHON_VERSION: ${{ matrix.python-version }}
56 |
57 | - name: Upload coverage to Codecov
58 | uses: codecov/codecov-action@v5
59 | with:
60 | env_vars: OS,PYTHON
61 | name: codecov-umbrella
62 | fail_ci_if_error: false
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # Distribution / packaging
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | .eggs/
11 | lib/
12 | lib64/
13 | parts/
14 | sdist/
15 | var/
16 | *.egg-info/
17 | .installed.cfg
18 | *.egg
19 |
20 | # Virtual environment
21 | venv/
22 | ENV/
23 | env/
24 | env.bak/
25 | venv.bak/
26 |
27 | # Jupyter Notebook checkpoints
28 | .ipynb_checkpoints
29 |
30 | # Pytest cache
31 | .pytest_cache/
32 |
33 | # Ruff Cache
34 | .ruff_cache/
35 |
36 | # Coverage reports
37 | htmlcov/
38 | .coverage
39 | .coverage.*
40 | .cache
41 | nosetests.xml
42 | coverage.xml
43 | *.cover
44 |
45 | # IDE specific files
46 | .idea/
47 | .vscode/
48 | *.swp
49 | *.swo
50 | *.sublime-project
51 | *.sublime-workspace
52 |
53 | # Operating system files
54 | .DS_Store
55 | Thumbs.db
56 |
57 | # Sphinx documentation
58 | docs/_build/
59 | _build
60 | jupyter_execute
61 | docs/api/generated
62 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: true
2 |
3 | repos:
4 | - repo: https://github.com/ambv/black
5 | rev: 22.3.0
6 | hooks:
7 | - id: black
8 | language_version: python3
9 |
10 | - repo: https://github.com/astral-sh/ruff-pre-commit
11 | rev: v0.6.8
12 | hooks:
13 | - id: ruff
14 | args: [ --fix, --exit-non-zero-on-fix ]
15 | - id: ruff-format
16 | types_or: [ python, pyi ]
17 | - repo: https://github.com/MarcoGorelli/madforhooks
18 | rev: 0.3.0
19 | hooks:
20 | - id: no-print-statements
21 |
22 | - repo: https://github.com/pre-commit/pre-commit-hooks
23 | rev: v4.5.0
24 | hooks:
25 | - id: check-added-large-files
26 | args: [--maxkb=1500]
27 | - id: end-of-file-fixer
28 | - id: trailing-whitespace
29 | - id: mixed-line-ending
30 | args: [--fix=lf]
31 |
32 |
33 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 | version: 2
4 |
5 | build:
6 | os: ubuntu-22.04
7 | tools:
8 | python: "3.11"
9 |
10 | sphinx:
11 | configuration: docs/conf.py
12 |
13 | python:
14 | install:
15 | - requirements: requirements-docs.txt
16 | - method: pip
17 | path: .
18 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 |
2 | # [0.2.0](https://github.com/arviz-devs/simuk/releases/tag/0.2.0) - 2025-03-20
3 |
4 | ## What's Changed
5 | * Use do operator by [@aloctavodia](https://github.com/aloctavodia) in [#25](https://github.com/arviz-devs/simuk/pull/25)
6 | * remove bambi as a dependency and update tests by [@aloctavodia](https://github.com/aloctavodia) in [#26](https://github.com/arviz-devs/simuk/pull/26)
7 | * Use observe, allow multiple observed variables by [@aloctavodia](https://github.com/aloctavodia) in [#27](https://github.com/arviz-devs/simuk/pull/27)
8 | * Add Documentation by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#28](https://github.com/arviz-devs/simuk/pull/28)
9 | * Add python3.13 to tests by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#31](https://github.com/arviz-devs/simuk/pull/31)
10 | * Remove custom plots, use arviz instead by [@aloctavodia](https://github.com/aloctavodia) in [#32](https://github.com/arviz-devs/simuk/pull/32)
11 | * Add support for numpyro models in SBC by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#30](https://github.com/arviz-devs/simuk/pull/30)
12 | * Use numpyro converter from arviz_base by [@aloctavodia](https://github.com/aloctavodia) in [#33](https://github.com/arviz-devs/simuk/pull/33)
13 |
14 |
15 | **Full Changelog**: https://github.com/arviz-devs/simuk/compare/0.1.1...0.2.0
16 |
17 | [Changes][0.2.0]
18 |
19 |
20 |
21 | # [0.1.1](https://github.com/arviz-devs/simuk/releases/tag/0.1.1) - 2025-02-13
22 |
23 | ## What's Changed
24 | * enhance pre-commit configuration with additional hooks and language Version by [@Advaitgaur004](https://github.com/Advaitgaur004) in [#22](https://github.com/arviz-devs/simuk/pull/22)
25 | * Add tests by [@rohanbabbar04](https://github.com/rohanbabbar04) in [#24](https://github.com/arviz-devs/simuk/pull/24)
26 |
27 | ## New Contributors
28 | * [@Advaitgaur004](https://github.com/Advaitgaur004) made their first contribution in [#22](https://github.com/arviz-devs/simuk/pull/22)
29 | * [@rohanbabbar04](https://github.com/rohanbabbar04) made their first contribution in [#24](https://github.com/arviz-devs/simuk/pull/24)
30 |
31 | **Full Changelog**: https://github.com/arviz-devs/simuk/compare/0.1.0...0.1.1
32 |
33 | [Changes][0.1.1]
34 |
35 |
36 |
37 | # [0.1.0](https://github.com/arviz-devs/simuk/releases/tag/0.1.0) - 2025-02-12
38 |
39 | ## What's Changed
40 | * Fixing up the readme by [@springcoil](https://github.com/springcoil) in [#1](https://github.com/arviz-devs/simuk/pull/1)
41 | * fix reference to my_model in pseudocode by [@psteinb](https://github.com/psteinb) in [#2](https://github.com/arviz-devs/simuk/pull/2)
42 | * Make it work with pymc 5.x by [@aloctavodia](https://github.com/aloctavodia) in [#3](https://github.com/arviz-devs/simuk/pull/3)
43 | * update readme by [@aloctavodia](https://github.com/aloctavodia) in [#5](https://github.com/arviz-devs/simuk/pull/5)
44 | * Add figsize argument to plots by [@aloctavodia](https://github.com/aloctavodia) in [#7](https://github.com/arviz-devs/simuk/pull/7)
45 | * fix bug with bambi models by [@aloctavodia](https://github.com/aloctavodia) in [#8](https://github.com/arviz-devs/simuk/pull/8)
46 | * simultaneous confidence band by [@aloctavodia](https://github.com/aloctavodia) in [#9](https://github.com/arviz-devs/simuk/pull/9)
47 | * Remove observed_vars argument by [@aloctavodia](https://github.com/aloctavodia) in [#10](https://github.com/arviz-devs/simuk/pull/10)
48 | * Update README.md by [@ColCarroll](https://github.com/ColCarroll) in [#18](https://github.com/arviz-devs/simuk/pull/18)
49 | * setup package conf by [@aloctavodia](https://github.com/aloctavodia) in [#20](https://github.com/arviz-devs/simuk/pull/20)
50 | * add templates and workflows by [@aloctavodia](https://github.com/aloctavodia) in [#21](https://github.com/arviz-devs/simuk/pull/21)
51 | * Update workflow by [@aloctavodia](https://github.com/aloctavodia) in [#23](https://github.com/arviz-devs/simuk/pull/23)
52 |
53 | **Full Changelog**: https://github.com/arviz-devs/simuk/commits/0.1.0
54 |
55 | [Changes][0.1.0]
56 |
57 |
58 | [0.2.0]: https://github.com/arviz-devs/simuk/compare/0.1.1...0.2.0
59 | [0.1.1]: https://github.com/arviz-devs/simuk/compare/0.1.0...0.1.1
60 | [0.1.0]: https://github.com/arviz-devs/simuk/tree/0.1.0
61 |
62 |
63 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Simuk Community Code of Conduct
2 |
3 | Simuk adopts the NumFOCUS Code of Conduct directly. In other words, we
4 | expect our community to treat others with kindness and understanding.
5 |
6 |
7 | # THE SHORT VERSION
8 | Be kind to others. Do not insult or put down others.
9 | Behave professionally. Remember that harassment and sexist, racist,
10 | or exclusionary jokes are not appropriate.
11 |
12 | All communication should be appropriate for a professional audience
13 | including people of many different backgrounds. Sexual language and
14 | imagery are not appropriate.
15 |
16 | Simuk is dedicated to providing a harassment-free community for everyone,
17 | regardless of gender, sexual orientation, gender identity, and
18 | expression, disability, physical appearance, body size, race,
19 | or religion. We do not tolerate harassment of community members
20 | in any form.
21 |
22 | Thank you for helping make this a welcoming, friendly community for all.
23 |
24 |
25 | # How to Submit a Report
26 | If you feel that there has been a Code of Conduct violation an anonymous
27 | reporting form is available.
28 | **If you feel your safety is in jeopardy or the situation is an
29 | emergency, we urge you to contact local law enforcement before making
30 | a report. (In the U.S., dial 911.)**
31 |
32 | We are committed to promptly addressing any reported issues.
33 | If you have experienced or witnessed behavior that violates this
34 | Code of Conduct, please complete the form below to
35 | make a report.
36 |
37 | **REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT
38 |
39 | Reports are sent to the NumFOCUS Code of Conduct Enforcement Team
40 | (see below).
41 |
42 | You can view the Privacy Policy and Terms of Service for TypeForm here.
43 | The NumFOCUS Privacy Policy is here:
44 | https://www.numfocus.org/privacy-policy
45 |
46 |
47 | # Full Code of Conduct
48 | The full text of the NumFOCUS/Simuk Code of Conduct can be found on
49 | NumFOCUS's website
50 | https://numfocus.org/code-of-conduct
51 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Simuk
2 | This document outlines only the most common contributions.
3 | Please see the [Contributing guide](https://arviz-devs.github.io/arviz/contributing/index.html)
4 | on our documentation for a better view of how can you contribute to Simuk.
5 | We welcome a wide range of contributions, not only code!
6 |
7 | ## Reporting issues
8 | If you encounter any bug or incorrect behaviour while using Simuk,
9 | please report an issue to our [issue tracker](https://github.com/arviz-devs/simuk/issues).
10 | Please include any supporting information, in particular the version of
11 | Simuk that you are using.
12 | The issue tracker has several templates available to help in writing the issue
13 | and including useful supporting information.
14 |
15 | ## Contributing code
16 | Thanks for your interest in contributing code to Simuk!
17 |
18 | * If this is your first time contributing to a project on GitHub, please read through our step by step guide to contributing to Simuk
19 | * If you have contributed to other projects on GitHub you can go straight to our [development workflow]()
20 |
21 | ### Adding new features
22 | If you are interested in adding a new feature to Simuk,
23 | first submit an issue using the "Feature Request" label for the community
24 | to discuss its place and implementation within Simuk.
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Colin Carroll
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Simulation Based Calibration
2 |
3 | A [PyMC](http://docs.pymc.io) and [Bambi](https://bambinos.github.io/bambi/) implementation of the algorithms from:
4 |
5 | Sean Talts, Michael Betancourt, Daniel Simpson, Aki Vehtari, Andrew Gelman: “Validating Bayesian Inference Algorithms with Simulation-Based Calibration”, 2018; [arXiv:1804.06788](http://arxiv.org/abs/1804.06788)
6 |
7 | Many thanks to the authors for providing open, reproducible code and implementations in `rstan` and `PyStan` ([link](https://github.com/seantalts/simulation-based-calibration)).
8 |
9 |
10 | ## Installation
11 |
12 | May be pip installed from github:
13 |
14 | ```bash
15 | pip install simuk
16 | ```
17 |
18 | ## Quickstart
19 |
20 | 1. Define a PyMC or Bambi model. For example, the centered eight schools model:
21 |
22 | ```python
23 | import numpy as np
24 | import pymc as pm
25 |
26 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
27 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
28 |
29 | with pm.Model() as centered_eight:
30 | mu = pm.Normal('mu', mu=0, sigma=5)
31 | tau = pm.HalfCauchy('tau', beta=5)
32 | theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8)
33 | y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data)
34 | ```
35 | 2. Pass the model to the `SBC` class, and run the simulations. This will take a while, as it is running the model many times.
36 | ```python
37 | sbc = SBC(centered_eight,
38 | num_simulations=100, # ideally this should be higher, like 1000
39 | sample_kwargs={'draws': 25, 'tune': 50})
40 |
41 | sbc.run_simulations()
42 | ```
43 | ```python
44 | 79%|███████▉ | 79/100 [05:36<01:29, 4.27s/it]
45 | ```
46 |
47 | 3. Plot the empirical CDF for the difference between prior and posterior. The lines
48 | should be close to uniform and within the oval envelope.
49 |
50 | ```python
51 | sbc.plot_results()
52 | ```
53 |
54 | 
55 |
56 |
57 | ## What is going on here?
58 |
59 | The [paper on the arXiv](http://arxiv.org/abs/1804.06788) is very well written, and explains the algorithm quite well.
60 |
61 | Morally, the example below is exactly what this library does, but it generalizes to more complicated models:
62 |
63 | ```python
64 | with pm.Model() as model:
65 | x = pm.Normal('x')
66 | pm.Normal('y', mu=x, observed=y)
67 | ```
68 |
69 | Then what this library does is compute
70 |
71 | ```python
72 | with my_model():
73 | prior_samples = pm.sample_prior_predictive(num_trials)
74 |
75 | simulations = {'x': []}
76 | for idx in range(num_trials):
77 | y_tilde = prior_samples['y'][idx]
78 | x_tilde = prior_samples['x'][idx]
79 | with model(y=y_tilde):
80 | idata = pm.sample()
81 | simulations['x'].append((idata.posterior['x'] < x_tilde).sum())
82 | ```
83 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = # "-W" treats warnings as errors
6 | SPHINXBUILD ?= sphinx-multiversion
7 | SOURCEDIR = .
8 | BUILDDIR ?= _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | clean:
17 | rm -rf $(BUILDDIR)/*
18 |
19 | # For local build
20 | local:
21 | sphinx-build "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
22 |
23 | # Catch-all target: route all unknown targets to Sphinx using the new
24 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
25 | %: Makefile
26 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
27 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline }}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 |
7 | .. rubric:: Methods
8 |
9 | {% for item in methods if item != "__init__" %}
10 | .. automethod:: {{ objname }}.{{ item }}
11 | {%- endfor %}
12 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/function.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline }}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autofunction:: {{ objname }}
6 |
--------------------------------------------------------------------------------
/docs/api/index.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | This reference provides detailed documentation for user functions in the current release of Simuk.
5 |
6 | Simulation based calibration
7 | ----------------------------
8 |
9 | .. currentmodule:: simuk
10 |
11 | .. autosummary::
12 | :toctree: generated/
13 |
14 | SBC
15 |
--------------------------------------------------------------------------------
/docs/changelog.rst:
--------------------------------------------------------------------------------
1 | Changelog
2 | =========
3 |
4 | .. include:: ../CHANGELOG.md
5 | :parser: myst_parser.sphinx_
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 |
6 | from simuk import __version__
7 |
8 | sys.path.insert(0, os.path.abspath("../simuk"))
9 |
10 | # -- Project information -----------------------------------------------------
11 |
12 | project = "Simuk"
13 | author = "ArviZ contributors"
14 | copyright = f"2025, {author}"
15 |
16 | # The short X.Y version
17 | version = __version__
18 | # The full version, including alpha/beta/rc tags
19 | release = __version__
20 |
21 |
22 | # -- General configuration ---------------------------------------------------
23 | extensions = [
24 | "sphinx.ext.autodoc",
25 | "sphinx.ext.autosummary",
26 | "sphinx.ext.viewcode",
27 | "sphinx.ext.napoleon",
28 | "sphinx.ext.mathjax",
29 | "sphinx_copybutton",
30 | "myst_nb",
31 | "matplotlib.sphinxext.plot_directive",
32 | "sphinx_tabs.tabs",
33 | "sphinx_design",
34 | "numpydoc",
35 | "jupyter_sphinx",
36 | ]
37 |
38 | # -- Extension configuration -------------------------------------------------
39 | nb_execution_mode = "auto"
40 | nb_execution_excludepatterns = ["*.ipynb"]
41 | nb_kernel_rgx_aliases = {".*": "python3"}
42 | myst_enable_extensions = ["colon_fence", "deflist", "dollarmath"]
43 | autosummary_generate = True
44 | autodoc_member_order = "bysource"
45 | numpydoc_show_class_members = False
46 | numpydoc_show_inherited_class_members = False
47 | numpydoc_class_members_toctree = False
48 |
49 | source_suffix = ".rst"
50 |
51 | master_doc = "index"
52 | language = "en"
53 | templates_path = ["_templates"]
54 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
55 | pygments_style = "sphinx"
56 | html_css_files = ["custom.css"]
57 | html_title = "Simuk"
58 | html_short_title = "Simuk"
59 | html_theme = "pydata_sphinx_theme"
60 |
61 | html_theme_options = {
62 | "collapse_navigation": True,
63 | "show_toc_level": 2,
64 | "navigation_depth": 4,
65 | "search_bar_text": "Search the docs...",
66 | "icon_links": [
67 | {
68 | "name": "GitHub",
69 | "url": "https://github.com/arviz-devs/simuk",
70 | "icon": "fa-brands fa-github",
71 | },
72 | ],
73 | # "logo": {
74 | # "image_light": "Simuk_flat.png",
75 | # "image_dark": "Simuk_flat_white.png",
76 | # },
77 | }
78 | html_context = {
79 | "github_user": "arviz-devs",
80 | "github_repo": "simuk",
81 | "github_version": "main",
82 | "doc_path": "docs/",
83 | "default_mode": "light",
84 | }
85 |
86 |
87 | # -- Options for HTMLHelp output ---------------------------------------------
88 |
89 | htmlhelp_basename = "simukdoc"
90 |
91 |
92 | # -- Options for LaTeX output ------------------------------------------------
93 |
94 | latex_documents = [
95 | (master_doc, "simuk.tex", "simuk Documentation", "The developers of simuk", "manual"),
96 | ]
97 |
98 |
99 | # -- Options for manual page output ------------------------------------------
100 |
101 | man_pages = [(master_doc, "simuk", "simuk Documentation", [author], 1)]
102 |
103 |
104 | # -- Options for Texinfo output ----------------------------------------------
105 |
106 | texinfo_documents = [
107 | (
108 | master_doc,
109 | "simuk",
110 | "simuk Documentation",
111 | author,
112 | "simuk",
113 | "One line description of project.",
114 | "Miscellaneous",
115 | ),
116 | ]
117 |
118 |
119 | # -- Options for Epub output -------------------------------------------------
120 |
121 | # Bibliographic Dublin Core info.
122 | epub_title = project
123 |
124 | epub_exclude_files = ["search.html"]
125 |
--------------------------------------------------------------------------------
/docs/contributing.rst:
--------------------------------------------------------------------------------
1 | Contributing
2 | ============
3 |
4 | We welcome contributions from interested individuals or groups! For information about contributing to Simuk check out our instructions, policies, and guidelines `here `_.
5 |
6 | See the `GitHub contributor page `_.
7 |
--------------------------------------------------------------------------------
/docs/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | The gallery below presents examples that demonstrate the use of Simuk.
5 |
6 | .. grid:: 1 2 3 3
7 | :gutter: 2 2 3 3
8 |
9 | .. grid-item-card::
10 | :link: ./examples/gallery/sbc.html
11 | :text-align: center
12 | :shadow: none
13 | :class-card: example-gallery
14 |
15 | .. image:: examples/img/sbc.png
16 | :alt: SBC
17 |
18 | +++
19 | SBC
20 |
--------------------------------------------------------------------------------
/docs/examples/gallery/sbc.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | text_representation:
4 | extension: .md
5 | format_name: myst
6 | kernelspec:
7 | display_name: Python 3
8 | language: python
9 | name: python3
10 | ---
11 |
12 | # Simulation based calibration
13 |
14 | This example demonstrates how to use the `SBC` class for simulation-based calibration, supporting both PyMC and Bambi models.
15 |
16 | ```{jupyter-execute}
17 |
18 | from arviz_plots import plot_ecdf_pit, style
19 | import numpy as np
20 | import simuk
21 | style.use("arviz-variat")
22 | ```
23 |
24 | ::::::{tab-set}
25 | :class: full-width
26 |
27 | :::::{tab-item} PyMC
28 | :sync: pymc
29 |
30 | First, define a PyMC model. In this example, we will use the centered eight schools model.
31 |
32 | ```{jupyter-execute}
33 |
34 | import pymc as pm
35 |
36 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
37 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
38 |
39 | with pm.Model() as centered_eight:
40 | mu = pm.Normal('mu', mu=0, sigma=5)
41 | tau = pm.HalfCauchy('tau', beta=5)
42 | theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8)
43 | y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data)
44 | ```
45 |
46 | Pass the model to the SBC class, set the number of simulations to 100, and run the simulations. This process may take
47 | some time since the model runs multiple times (100 in this example).
48 |
49 | ```{jupyter-execute}
50 |
51 | sbc = simuk.SBC(centered_eight,
52 | num_simulations=100,
53 | sample_kwargs={'draws': 25, 'tune': 50})
54 |
55 | sbc.run_simulations();
56 | ```
57 |
58 | To compare the prior and posterior distributions, we will plot the results from the simulations,
59 | using the ArviZ function `plot_ecdf_pit`.
60 | We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval.
61 |
62 | ```{jupyter-execute}
63 |
64 | plot_ecdf_pit(sbc.simulations,
65 | pc_kwargs={'col_wrap':4},
66 | plot_kwargs={"xlabel":False},
67 | )
68 | ```
69 |
70 | :::::
71 |
72 | :::::{tab-item} Bambi
73 | :sync: bambi
74 |
75 | Now, we define a Bambi Model.
76 |
77 | ```{jupyter-execute}
78 |
79 | import bambi as bmb
80 | import pandas as pd
81 |
82 | x = np.random.normal(0, 1, 200)
83 | y = 2 + np.random.normal(x, 1)
84 | df = pd.DataFrame({"x": x, "y": y})
85 | bmb_model = bmb.Model("y ~ x", df)
86 | ```
87 |
88 | Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations.
89 | This process may take some time, as the model runs multiple times
90 |
91 | ```{jupyter-execute}
92 |
93 | sbc = simuk.SBC(bmb_model,
94 | num_simulations=100,
95 | sample_kwargs={'draws': 25, 'tune': 50})
96 |
97 | sbc.run_simulations();
98 | ```
99 |
100 | To compare the prior and posterior distributions, we will plot the results from the simulations.
101 | We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval.
102 |
103 | ```{jupyter-execute}
104 | plot_ecdf_pit(sbc.simulations)
105 | ```
106 |
107 | :::::
108 |
109 | :::::{tab-item} Numpyro
110 | :sync: numpyro
111 |
112 | We define a Numpyro Model, we use the centered eight schools model.
113 |
114 | ```{jupyter-execute}
115 | import numpyro
116 | import numpyro.distributions as dist
117 | from jax import random
118 | from numpyro.infer import NUTS
119 |
120 | y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
121 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
122 |
123 | def eight_schools_cauchy_prior(J, sigma, y=None):
124 | mu = numpyro.sample("mu", dist.Normal(0, 5))
125 | tau = numpyro.sample("tau", dist.HalfCauchy(5))
126 | with numpyro.plate("J", J):
127 | theta = numpyro.sample("theta", dist.Normal(mu, tau))
128 | numpyro.sample("y", dist.Normal(theta, sigma), obs=y)
129 |
130 | # We use the NUTS sampler
131 | nuts_kernel = NUTS(eight_schools_cauchy_prior)
132 | ```
133 |
134 | Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations. For numpyro model,
135 | we pass in the ``data_dir`` parameter.
136 |
137 | ```{jupyter-execute}
138 | sbc = simuk.SBC(nuts_kernel,
139 | sample_kwargs={"num_warmup": 50, "num_samples": 75},
140 | num_simulations=100,
141 | data_dir={"J": 8, "sigma": sigma, "y": y},
142 | )
143 | sbc.run_simulations()
144 | ```
145 |
146 | To compare the prior and posterior distributions, we will plot the results.
147 | We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval.
148 |
149 | ```{jupyter-execute}
150 | plot_ecdf_pit(sbc.simulations,
151 | pc_kwargs={'col_wrap':4},
152 | plot_kwargs={"xlabel":False}
153 | )
154 | ```
155 |
--------------------------------------------------------------------------------
/docs/examples/img/sbc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/simuk/26d769f578a54546f23d8a48df24567e30fdd19a/docs/examples/img/sbc.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Overview
2 | ========
3 |
4 | Simuk is a Python library for simulation-based calibration (SBC) and the generation of synthetic data.
5 | Simulation-Based Calibration (SBC) is a method for validating Bayesian inference by checking whether the
6 | posterior distributions align with the expected theoretical results derived from the prior.
7 |
8 | Quickstart
9 | ----------
10 |
11 | This quickstart guide provides a simple example to help you get started. If you're looking for more examples
12 | and use cases, be sure to check out the :doc:`examples` section.
13 |
14 | To use SBC, you need to define a model function that generates simulated data and corresponding prior predictive
15 | samples, then compare them to posterior samples obtained through inference.
16 |
17 | In our case, we will take a PyMC model and pass it into our ``SBC`` class.
18 |
19 | .. code-block:: python
20 |
21 | import numpy as np
22 | import pymc as pm
23 |
24 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
25 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
26 |
27 | with pm.Model() as centered_eight:
28 | mu = pm.Normal('mu', mu=0, sigma=5)
29 | tau = pm.HalfCauchy('tau', beta=5)
30 | theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8)
31 | y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data)
32 |
33 | # Pass it into the SBC class
34 | sbc = simuk.SBC(centered_eight, num_simulations=100, sample_kwargs={'draws': 25, 'tune': 50})
35 |
36 | Now, we use the ``run_simulations`` method to generate and analyze simulated data, running the model multiple times to
37 | compare prior and posterior distributions.
38 |
39 | .. code-block:: python
40 |
41 | sbc.run_simulations()
42 |
43 | Plot the empirical CDF to compare the differences between the prior and posterior.
44 |
45 | .. code-block:: python
46 |
47 | sbc.plot_results()
48 |
49 | The lines should be nearly uniform and fall within the oval envelope. It suggests that the prior and posterior distributions
50 | are properly aligned and that there are no significant biases or issues with the model.
51 |
52 | .. toctree::
53 | :maxdepth: 1
54 | :hidden:
55 | :caption: Getting Started
56 |
57 | Overview
58 | installation
59 |
60 | .. toctree::
61 | :maxdepth: 2
62 | :hidden:
63 | :caption: API documentation
64 |
65 | api/index.rst
66 |
67 | .. toctree::
68 | :maxdepth: 2
69 | :hidden:
70 | :caption: Examples
71 |
72 | examples
73 |
74 | .. toctree::
75 | :maxdepth: 1
76 | :hidden:
77 | :caption: References
78 |
79 | contributing
80 | changelog
81 |
82 | References
83 | ----------
84 |
85 | - Talts, Sean, Michael Betancourt, Daniel Simpson, Aki Vehtari, and Andrew Gelman. 2018. “Validating Bayesian Inference Algorithms with Simulation-Based Calibration.” `arXiv:1804.06788 `_.
86 | - Modrák, M., Moon, A. H., Kim, S., Bürkner, P., Huurre, N., Faltejsková, K., … & Vehtari, A. (2023). Simulation-based calibration checking for Bayesian computation: The choice of test quantities shapes sensitivity. Bayesian Analysis, advance publication, DOI: `10.1214/23-BA1404 `_
87 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | Simuk is tested and compatible with **Python 3.11 and later**. It depends on Arviz, PyMC, and tqdm. For specific version details,
5 | check the `pyproject.toml `_ file.
6 |
7 | The latest release of simuk can be installed using ``pip`` :
8 |
9 | Using pip
10 | ---------
11 |
12 | .. code-block:: bash
13 |
14 | pip install simuk
15 |
16 |
17 | Development Version
18 | -------------------
19 |
20 | The latest development version can be installed from the main branch using ``pip``:
21 |
22 | .. code-block:: bash
23 |
24 | pip install git+https://github.com/arviz-devs/simuk.git
25 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/ecdf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/simuk/26d769f578a54546f23d8a48df24567e30fdd19a/ecdf.png
--------------------------------------------------------------------------------
/environment-dev.yml:
--------------------------------------------------------------------------------
1 | name: simuk-dev
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python >= 3.11
7 | - pip
8 | - python >= 3.11
9 | - pymc>=5.20.1
10 | - bambi>=0.13.0
11 | - arviz>=0.20.0
12 | - black=22.3.0
13 | - click=8.0.4
14 | - pytest-cov>=2.6.1
15 | - pytest>=4.4.0
16 | - pre-commit>=2.19
17 | - ruff==0.9.1
18 |
19 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: simuk
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python >= 3.11
7 | - pymc>=5.20.1
8 | - arviz>=0.20.0
9 |
--------------------------------------------------------------------------------
/hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/simuk/26d769f578a54546f23d8a48df24567e30fdd19a/hist.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.4,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "simuk"
7 | readme = "README.md"
8 | requires-python = ">=3.11"
9 | license = {file = "LICENSE"}
10 | authors = [
11 | {name = "ArviZ team", email = "arviz.devs@gmail.com"}
12 | ]
13 | classifiers = [
14 | "Development Status :: 3 - Alpha",
15 | "Intended Audience :: Science/Research",
16 | "Intended Audience :: Education",
17 | "License :: OSI Approved :: Apache Software License",
18 | "Operating System :: OS Independent",
19 | "Programming Language :: Python",
20 | "Programming Language :: Python :: 3",
21 | "Programming Language :: Python :: 3.11",
22 | "Programming Language :: Python :: 3.12",
23 | ]
24 | dynamic = ["version"]
25 | description = "Simulation based calibration and generation of synthetic data."
26 | dependencies = [
27 | "arviz_base>=0.5.0",
28 | "tqdm"
29 | ]
30 |
31 | [tool.flit.module]
32 | name = "simuk"
33 |
34 | [project.urls]
35 | source = "https://github.com/arviz-devs/simuk"
36 | tracker = "https://github.com/arviz-devs/simuk/issues"
37 | funding = "https://opencollective.com/arviz"
38 |
39 |
40 | [tool.black]
41 | line-length = 100
42 |
43 | [tool.isort]
44 | profile = "black"
45 | include_trailing_comma = true
46 | use_parentheses = true
47 | multi_line_output = 3
48 | line_length = 100
49 |
50 | [tool.pydocstyle]
51 | convention = "numpy"
52 |
53 | [tool.pytest.ini_options]
54 | testpaths = [
55 | "tests",
56 | ]
57 |
58 | [tool.ruff]
59 | line-length = 100
60 |
61 | [tool.ruff.lint]
62 | select = [
63 | "F", # Pyflakes
64 | "E", # Pycodestyle
65 | "W", # Pycodestyle
66 | "D", # pydocstyle
67 | "NPY", # numpy specific rules
68 | "UP", # pyupgrade
69 | "I", # isort
70 | "PL", # Pylint
71 | "TID", # Absolute imports
72 | ]
73 | ignore = [
74 | "PLR0912", # too many branches
75 | "PLR0913", # too many arguments
76 | "PLR2004", # magic value comparison
77 | "PLR0915", # too many statements
78 | "NPY002", # Replace legacy `np.random.randn` call with `np.random.Generator`
79 | "D1" # Missing docstring
80 | ]
81 |
82 | [tool.ruff.lint.per-file-ignores]
83 | "docs/source/**/*.ipynb" = ["D", "E", "F", "I", "NPY", "PL", "TID", "UP", "W"]
84 | "simuk/__init__.py" = ["I", "F401", "E402", "F403"]
85 | "simuk/tests/**/*" = ["D", "PLR2004", "TID252"]
86 | "simuk/tests/**/*.ipynb" = ["E", "F"]
87 |
88 | [tool.ruff.lint.pydocstyle]
89 | convention = "numpy"
90 |
91 | [tool.ruff.lint.flake8-tidy-imports]
92 | ban-relative-imports = "all" # Disallow all relative imports.
93 |
94 | [tool.ruff.format]
95 | docstring-code-format = false
96 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | black==22.3.0
2 | click==8.0.4
3 | pytest-cov>=2.6.1
4 | pytest>=4.4.0
5 | pre-commit>=2.19
6 | ipytest==0.13.0
7 | pymc>=5.20.1
8 | bambi>=0.13.0
9 | arviz_base>=0.5.0
10 | ruff==0.9.1
11 | numpyro>=0.17.0
12 |
--------------------------------------------------------------------------------
/requirements-docs.txt:
--------------------------------------------------------------------------------
1 | pydata-sphinx-theme>=0.6.3
2 | myst-nb
3 | pymc>=5.20.1
4 | bambi>=0.15.0
5 | arviz_plots @ git+https://github.com/arviz-devs/arviz-plots@main
6 | sphinx>=4
7 | sphinx-copybutton
8 | sphinx_tabs
9 | sphinx-design
10 | numpydoc
11 | jupyter-sphinx
12 | numpyro>=0.17.0
13 |
--------------------------------------------------------------------------------
/simuk/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Simuk.
3 |
4 | Simulation based calibration and other tools for evaluation using synthetic data.
5 | """
6 |
7 | from simuk.sbc import SBC
8 |
9 |
10 | __version__ = "0.2.0"
11 |
--------------------------------------------------------------------------------
/simuk/sbc.py:
--------------------------------------------------------------------------------
1 | """Simulation based calibration (Talts et. al. 2018) in PyMC."""
2 |
3 | import logging
4 | from copy import copy
5 | from importlib.metadata import version
6 |
7 | try:
8 | import pymc as pm
9 | except ImportError:
10 | pass
11 | try:
12 | import jax
13 | from numpyro.handlers import seed, trace
14 | from numpyro.infer import MCMC, Predictive
15 | from numpyro.infer.mcmc import MCMCKernel
16 | except ImportError:
17 | pass
18 |
19 | import numpy as np
20 | from arviz_base import extract, from_dict, from_numpyro
21 | from tqdm import tqdm
22 |
23 |
24 | class quiet_logging:
25 | """Turn off logging for PyMC, Bambi and PyTensor."""
26 |
27 | def __init__(self, *libraries):
28 | self.loggers = [logging.getLogger(library) for library in libraries]
29 |
30 | def __call__(self, func):
31 | def wrapped(cls, *args, **kwargs):
32 | levels = []
33 | for logger in self.loggers:
34 | levels.append(logger.level)
35 | logger.setLevel(logging.CRITICAL)
36 | res = func(cls, *args, **kwargs)
37 | for logger, level in zip(self.loggers, levels):
38 | logger.setLevel(level)
39 | return res
40 |
41 | return wrapped
42 |
43 |
44 | class SBC:
45 | """Set up class for doing SBC.
46 |
47 | Parameters
48 | ----------
49 | model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel
50 | A PyMC, Bambi model or Numpyro MCMC kernel. If a PyMC model the data needs to be defined as
51 | mutable data.
52 | num_simulations : int
53 | How many simulations to run
54 | sample_kwargs : dict[str] -> Any
55 | Arguments passed to pymc.sample or bambi.Model.fit
56 | seed : int (optional)
57 | Random seed. This persists even if running the simulations is
58 | paused for whatever reason.
59 | data_dir : dict
60 | Keyword arguments passed to numpyro model, intended for use when providing
61 | an MCMC Kernel model.
62 |
63 | Example
64 | -------
65 |
66 | .. code-block :: python
67 |
68 | with pm.Model() as model:
69 | x = pm.Normal('x')
70 | y = pm.Normal('y', mu=2 * x, observed=obs)
71 |
72 | sbc = SBC(model)
73 | sbc.run_simulations()
74 | sbc.plot_results()
75 |
76 | """
77 |
78 | def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None):
79 | if hasattr(model, "basic_RVs") and isinstance(model, pm.Model):
80 | self.engine = "pymc"
81 | self.model = model
82 | elif hasattr(model, "formula"):
83 | self.engine = "bambi"
84 | model.build()
85 | self.bambi_model = model
86 | self.model = model.backend.model
87 | self.formula = model.formula
88 | self.new_data = copy(model.data)
89 | elif isinstance(model, MCMCKernel):
90 | self.engine = "numpyro"
91 | self.numpyro_model = model
92 | self.model = self.numpyro_model.model
93 | self.run_simulations = self._run_simulations_numpyro
94 | self.data_dir = data_dir
95 | else:
96 | raise ValueError(
97 | "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel"
98 | )
99 | self.num_simulations = num_simulations
100 | if sample_kwargs is None:
101 | sample_kwargs = {}
102 | if self.engine == "numpyro":
103 | sample_kwargs.setdefault("num_warmup", 1000)
104 | sample_kwargs.setdefault("num_samples", 1000)
105 | sample_kwargs.setdefault("progress_bar", False)
106 | else:
107 | sample_kwargs.setdefault("progressbar", False)
108 | sample_kwargs.setdefault("compute_convergence_checks", False)
109 | self.sample_kwargs = sample_kwargs
110 | self.seed = seed
111 | self._seeds = self._get_seeds()
112 | self._extract_variable_names()
113 | self.simulations = {name: [] for name in self.var_names}
114 | self._simulations_complete = 0
115 |
116 | def _extract_variable_names(self):
117 | """Extract observed and free variables from the model."""
118 | if self.engine == "numpyro":
119 | with trace() as tr:
120 | with seed(rng_seed=int(self._seeds[0])):
121 | self.numpyro_model.model(**self.data_dir)
122 | self.var_names = [
123 | name
124 | for name, site in tr.items()
125 | if site["type"] == "sample" and not site.get("is_observed", False)
126 | ]
127 | self.observed_vars = [
128 | name
129 | for name, site in tr.items()
130 | if site["type"] == "sample" and site.get("is_observed", False)
131 | ]
132 | else:
133 | self.observed_vars = [obs.name for obs in self.model.observed_RVs]
134 | self.var_names = [v.name for v in self.model.free_RVs]
135 |
136 | def _get_seeds(self):
137 | """Set the random seed, and generate seeds for all the simulations."""
138 | rng = np.random.default_rng(self.seed)
139 | return rng.integers(0, 2**30, size=self.num_simulations)
140 |
141 | def _get_prior_predictive_samples(self):
142 | """Generate samples to use for the simulations."""
143 | with self.model:
144 | idata = pm.sample_prior_predictive(
145 | samples=self.num_simulations, random_seed=self._seeds[0]
146 | )
147 | prior_pred = extract(idata, group="prior_predictive", keep_dataset=True)
148 | prior = extract(idata, group="prior", keep_dataset=True)
149 | return prior, prior_pred
150 |
151 | def _get_prior_predictive_samples_numpyro(self):
152 | """Generate samples to use for the simulations using numpyro."""
153 | predictive = Predictive(self.model, num_samples=self.num_simulations)
154 | free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
155 | samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data)
156 | prior = {k: v for k, v in samples.items() if k not in self.observed_vars}
157 | prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars}
158 | return prior, prior_pred
159 |
160 | def _get_posterior_samples(self, prior_predictive_draw):
161 | """Generate posterior samples conditioned to a prior predictive sample."""
162 | new_model = pm.observe(self.model, prior_predictive_draw)
163 | with new_model:
164 | check = pm.sample(
165 | **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete]
166 | )
167 |
168 | posterior = extract(check, group="posterior", keep_dataset=True)
169 | return posterior
170 |
171 | def _get_posterior_samples_numpyro(self, prior_predictive_draw):
172 | """Generate posterior samples using numpyro conditioned to a prior predictive sample."""
173 | mcmc = MCMC(self.numpyro_model, **self.sample_kwargs)
174 | rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete])
175 | free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
176 | mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw)
177 | return from_numpyro(mcmc)["posterior"]
178 |
179 | def _convert_to_datatree(self):
180 | self.simulations = from_dict(
181 | {"prior_sbc": self.simulations},
182 | attrs={
183 | "/": {
184 | "inferece_library": self.engine,
185 | "inferece_library_version": version(self.engine),
186 | "modeling_interface": "simuk",
187 | "modeling_interface_version": version("simuk"),
188 | }
189 | },
190 | )
191 |
192 | @quiet_logging("pymc", "pytensor.gof.compilelock", "bambi")
193 | def run_simulations(self):
194 | """Run all the simulations.
195 |
196 | This function can be stopped and restarted on the same instance, so you can
197 | keyboard interrupt part way through, look at the plot, and then resume. If a
198 | seed was passed initially, it will still be respected (that is, the resulting
199 | simulations will be identical to running without pausing in the middle).
200 | """
201 | prior, prior_pred = self._get_prior_predictive_samples()
202 |
203 | progress = tqdm(
204 | initial=self._simulations_complete,
205 | total=self.num_simulations,
206 | )
207 | try:
208 | while self._simulations_complete < self.num_simulations:
209 | idx = self._simulations_complete
210 | prior_predictive_draw = {
211 | var_name: prior_pred[var_name].sel(chain=0, draw=idx).values
212 | for var_name in self.observed_vars
213 | }
214 |
215 | posterior = self._get_posterior_samples(prior_predictive_draw)
216 | for name in self.var_names:
217 | self.simulations[name].append(
218 | (posterior[name] < prior[name].sel(chain=0, draw=idx)).sum("sample").values
219 | )
220 | self._simulations_complete += 1
221 | progress.update()
222 | finally:
223 | self.simulations = {
224 | k: np.stack(v[: self._simulations_complete])[None, :]
225 | for k, v in self.simulations.items()
226 | }
227 | self._convert_to_datatree()
228 | progress.close()
229 |
230 | @quiet_logging("numpyro")
231 | def _run_simulations_numpyro(self):
232 | """Run all the simulations for Numpyro Model."""
233 | prior, prior_pred = self._get_prior_predictive_samples_numpyro()
234 | progress = tqdm(
235 | initial=self._simulations_complete,
236 | total=self.num_simulations,
237 | )
238 | try:
239 | while self._simulations_complete < self.num_simulations:
240 | idx = self._simulations_complete
241 | prior_predictive_draw = {k: v[idx] for k, v in prior_pred.items()}
242 | posterior = self._get_posterior_samples_numpyro(prior_predictive_draw)
243 | for name in self.var_names:
244 | self.simulations[name].append(
245 | (posterior[name].sel(chain=0) < prior[name][idx]).sum(axis=0).values
246 | )
247 | self._simulations_complete += 1
248 | progress.update()
249 | finally:
250 | self.simulations = {
251 | k: np.stack(v[: self._simulations_complete])[None, :]
252 | for k, v in self.simulations.items()
253 | }
254 | self._convert_to_datatree()
255 | progress.close()
256 |
--------------------------------------------------------------------------------
/simuk/tests/test_sbc.py:
--------------------------------------------------------------------------------
1 | import bambi as bmb
2 | import numpy as np
3 | import numpyro
4 | import numpyro.distributions as dist
5 | import pandas as pd
6 | import pymc as pm
7 | import pytest
8 | from numpyro.infer import NUTS
9 |
10 | import simuk
11 |
12 | np.random.seed(1234)
13 |
14 | data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
15 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
16 |
17 | with pm.Model() as centered_eight:
18 | mu = pm.Normal("mu", mu=0, sigma=5)
19 | tau = pm.HalfCauchy("tau", beta=5)
20 | theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8)
21 | y_obs = pm.Normal("y", mu=theta, sigma=sigma, observed=data)
22 |
23 | x = np.random.normal(0, 1, 20)
24 | y = 2 + np.random.normal(x, 1)
25 | df = pd.DataFrame({"x": x, "y": y})
26 | bmb_model = bmb.Model("y ~ x", df)
27 |
28 |
29 | @pytest.mark.parametrize("model", [centered_eight, bmb_model])
30 | def test_sbc(model):
31 | sbc = simuk.SBC(
32 | model,
33 | num_simulations=10,
34 | sample_kwargs={"draws": 5, "tune": 5},
35 | )
36 | sbc.run_simulations()
37 | assert "prior_sbc" in sbc.simulations
38 |
39 |
40 | def test_sbc_numpyro():
41 | y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
42 | sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
43 |
44 | def eight_schools_cauchy_prior(J, sigma, y=None):
45 | mu = numpyro.sample("mu", dist.Normal(0, 5))
46 | tau = numpyro.sample("tau", dist.HalfCauchy(5))
47 | with numpyro.plate("J", J):
48 | theta = numpyro.sample("theta", dist.Normal(mu, tau))
49 | numpyro.sample("y", dist.Normal(theta, sigma), obs=y)
50 |
51 | sbc = simuk.SBC(
52 | NUTS(eight_schools_cauchy_prior),
53 | data_dir={"J": 8, "sigma": sigma, "y": y},
54 | num_simulations=10,
55 | sample_kwargs={"num_warmup": 50, "num_samples": 25},
56 | )
57 | sbc.run_simulations()
58 | assert "prior_sbc" in sbc.simulations
59 |
--------------------------------------------------------------------------------