├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CITATION.cff ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── benches └── run_tvm_leapfrog.rs_old ├── cliff.toml ├── docs ├── .gitignore ├── _freeze │ ├── index │ │ └── execute-results │ │ │ └── html.json │ ├── nf-adapt │ │ └── execute-results │ │ │ └── html.json │ ├── pymc-usage │ │ ├── execute-results │ │ │ └── html.json │ │ └── figure-html │ │ │ └── cell-5-output-1.png │ ├── sample-stats │ │ ├── execute-results │ │ │ └── html.json │ │ └── figure-html │ │ │ ├── cell-3-output-1.png │ │ │ ├── cell-4-output-1.png │ │ │ ├── cell-5-output-1.png │ │ │ ├── cell-7-output-2.png │ │ │ ├── cell-8-output-1.png │ │ │ └── cell-9-output-1.png │ ├── site_libs │ │ └── clipboard │ │ │ └── clipboard.min.js │ └── stan-usage │ │ ├── execute-results │ │ └── html.json │ │ └── figure-html │ │ └── cell-11-output-1.png ├── _quarto.yml ├── about.qmd ├── index.qmd ├── nf-adapt.qmd ├── pymc-usage.qmd ├── sample-stats.qmd ├── sampling-options.qmd ├── stan-usage.qmd └── styles.css ├── notebooks └── pytensor_logp.md ├── pyproject.toml ├── python └── nutpie │ ├── __init__.py │ ├── compile_pymc.py │ ├── compile_stan.py │ ├── compiled_pyfunc.py │ ├── normalizing_flow.py │ ├── sample.py │ └── transform_adapter.py ├── src ├── lib.rs ├── progress.rs ├── pyfunc.rs ├── pymc.rs ├── stan.rs └── wrapper.rs └── tests ├── reference ├── test_deterministic_sampling_jax.txt ├── test_deterministic_sampling_numba.txt ├── test_deterministic_sampling_stan.txt └── test_normalizing_flow.txt ├── test_pymc.py └── test_stan.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # SCM syntax highlighting 2 | pixi.lock linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | labels: 9 | - "Github CI/CD" 10 | - "no releasenotes" 11 | - package-ecosystem: "cargo" 12 | directory: "/rust" 13 | schedule: 14 | interval: "weekly" 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | tags: 9 | - "*" 10 | pull_request: 11 | workflow_dispatch: 12 | 13 | permissions: 14 | contents: read 15 | 16 | jobs: 17 | linux: 18 | runs-on: ${{ matrix.platform.runner }} 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | platform: 23 | - runner: ubuntu-22.04 24 | target: x86_64 25 | - runner: ubuntu-22.04-arm 26 | target: aarch64 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Install uv 30 | uses: astral-sh/setup-uv@v6 31 | - uses: actions/setup-python@v5 32 | with: 33 | python-version: | 34 | 3.10 35 | 3.11 36 | 3.12 37 | 3.13 38 | - name: Build wheels 39 | uses: PyO3/maturin-action@v1 40 | with: 41 | target: ${{ matrix.platform.target }} 42 | args: --release --out dist --interpreter 3.10 3.11 3.12 3.13 --zig 43 | sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} 44 | manylinux: auto 45 | before-script-linux: | 46 | dnf install -y clang-libs clang || sudo apt install llvm-dev libclang-dev clang 47 | - name: Upload wheels 48 | uses: actions/upload-artifact@v4 49 | with: 50 | name: wheels-linux-${{ matrix.platform.target }} 51 | path: dist 52 | - name: pytest 53 | shell: bash 54 | run: | 55 | set -e 56 | python3 -m venv .venv 57 | source .venv/bin/activate 58 | uv pip install 'nutpie[stan]' --find-links dist --force-reinstall 59 | uv pip install pytest pytest-timeout pytest-arraydiff 60 | pytest -m "stan and not flow" --arraydiff 61 | uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall 62 | uv pip install jax 63 | pytest -m "pymc and not flow" --arraydiff 64 | uv pip install 'nutpie[all]' --find-links dist --force-reinstall 65 | pytest -m flow --arraydiff 66 | 67 | # pyarrow doesn't currently seem to work on musllinux 68 | #musllinux: 69 | # runs-on: ${{ matrix.platform.runner }} 70 | # strategy: 71 | # fail-fast: false 72 | # matrix: 73 | # platform: 74 | # - runner: ubuntu-22.04 75 | # target: x86_64 76 | # - runner: ubuntu-22.04 77 | # target: aarch64 78 | # steps: 79 | # - uses: actions/checkout@v4 80 | # - uses: actions/setup-python@v5 81 | # with: 82 | # python-version: "3.12" 83 | # - name: Install uv 84 | # uses: astral-sh/setup-uv@v6 85 | # - name: Build wheels 86 | # uses: PyO3/maturin-action@v1 87 | # with: 88 | # target: ${{ matrix.platform.target }} 89 | # args: --release --out dist --find-interpreter 90 | # sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} 91 | # manylinux: musllinux_1_2 92 | # before-script-linux: | 93 | # dnf install -y clang-libs clang || apt install llvm-dev libclang-dev clang 94 | # - name: Upload wheels 95 | # uses: actions/upload-artifact@v4 96 | # with: 97 | # name: wheels-musllinux-${{ matrix.platform.target }} 98 | # path: dist 99 | # - name: pytest 100 | # if: ${{ startsWith(matrix.platform.target, 'x86_64') }} 101 | # uses: addnab/docker-run-action@v3 102 | # with: 103 | # image: alpine:latest 104 | # options: -v ${{ github.workspace }}:/io -w /io 105 | # run: | 106 | # set -e 107 | # apk add py3-pip py3-virtualenv curl make clang 108 | # curl -LsSf https://astral.sh/uv/install.sh | sh 109 | # source $HOME/.local/bin/env 110 | # python3 -m virtualenv .venv 111 | # source .venv/bin/activate 112 | # # No numba packages for alpine 113 | # uv pip install 'nutpie[stan]' --find-links dist --force-reinstall 114 | # uv pip install pytest 115 | # pytest 116 | # - name: pytest 117 | # if: ${{ !startsWith(matrix.platform.target, 'x86') }} 118 | # uses: uraimo/run-on-arch-action@v2 119 | # with: 120 | # arch: ${{ matrix.platform.target }} 121 | # distro: alpine_latest 122 | # githubToken: ${{ github.token }} 123 | # install: | 124 | # apk add py3-virtualenv curl make clang 125 | # curl -LsSf https://astral.sh/uv/install.sh | sh 126 | # source $HOME/.local/bin/env 127 | # run: | 128 | # set -e 129 | # python3 -m virtualenv .venv 130 | # source $HOME/.local/bin/env 131 | # source .venv/bin/activate 132 | # uv pip install pytest 133 | # # No numba packages for alpine 134 | # uv pip install 'nutpie[stan]' --find-links dist --force-reinstall 135 | # pytest 136 | 137 | windows: 138 | runs-on: ${{ matrix.platform.runner }} 139 | strategy: 140 | matrix: 141 | platform: 142 | - runner: windows-latest 143 | target: x64 144 | steps: 145 | - uses: actions/checkout@v4 146 | - uses: actions/setup-python@v5 147 | with: 148 | python-version: | 149 | 3.10 150 | 3.11 151 | 3.12 152 | 3.13 153 | architecture: ${{ matrix.platform.target }} 154 | - name: Install uv 155 | uses: astral-sh/setup-uv@v6 156 | - name: Install LLVM and Clang 157 | uses: KyleMayes/install-llvm-action@v2 158 | with: 159 | version: "15.0" 160 | directory: ${{ runner.temp }}/llvm 161 | - name: Set up TBB 162 | if: matrix.os == 'windows-latest' 163 | run: | 164 | Add-Content $env:GITHUB_PATH "$(pwd)/stan/lib/stan_math/lib/tbb" 165 | - name: Build wheels 166 | uses: PyO3/maturin-action@v1 167 | env: 168 | LIBCLANG_PATH: ${{ runner.temp }}/llvm/lib 169 | with: 170 | target: ${{ matrix.platform.target }} 171 | args: --release --out dist --find-interpreter 172 | sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} 173 | - name: Upload wheels 174 | uses: actions/upload-artifact@v4 175 | with: 176 | name: wheels-windows-${{ matrix.platform.target }} 177 | path: dist 178 | - name: pytest 179 | if: ${{ !startsWith(matrix.platform.target, 'aarch64') }} 180 | shell: bash 181 | run: | 182 | set -e 183 | python3 -m venv .venv 184 | source .venv/Scripts/activate 185 | uv pip install "nutpie[stan]" --find-links dist --force-reinstall 186 | uv pip install pytest pytest-timeout pytest-arraydiff 187 | pytest -m "stan and not flow" --arraydiff 188 | uv pip install "nutpie[pymc]" --find-links dist --force-reinstall 189 | uv pip install jax 190 | pytest -m "pymc and not flow" --arraydiff 191 | uv pip install "nutpie[all]" --find-links dist --force-reinstall 192 | pytest -m flow --arraydiff 193 | 194 | macos: 195 | runs-on: ${{ matrix.platform.runner }} 196 | strategy: 197 | fail-fast: false 198 | matrix: 199 | platform: 200 | - runner: macos-13 201 | target: x86_64 202 | - runner: macos-14 203 | target: aarch64 204 | steps: 205 | - uses: actions/checkout@v4 206 | - uses: actions/setup-python@v5 207 | with: 208 | python-version: | 209 | 3.10 210 | 3.11 211 | 3.12 212 | 3.13 213 | - name: Install uv 214 | uses: astral-sh/setup-uv@v6 215 | - uses: maxim-lobanov/setup-xcode@v1 216 | with: 217 | xcode-version: latest-stable 218 | - name: Build wheels 219 | uses: PyO3/maturin-action@v1 220 | with: 221 | target: ${{ matrix.platform.target }} 222 | args: --release --out dist --find-interpreter 223 | sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} 224 | - name: Upload wheels 225 | uses: actions/upload-artifact@v4 226 | with: 227 | name: wheels-macos-${{ matrix.platform.target }} 228 | path: dist 229 | - name: pytest 230 | run: | 231 | set -e 232 | python3 -m venv .venv 233 | source .venv/bin/activate 234 | uv pip install 'nutpie[stan]' --find-links dist --force-reinstall 235 | uv pip install pytest pytest-timeout pytest-arraydiff 236 | pytest -m "stan and not flow" --arraydiff 237 | uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall 238 | uv pip install jax 239 | pytest -m "pymc and not flow" --arraydiff 240 | uv pip install 'nutpie[all]' --find-links dist --force-reinstall 241 | pytest -m flow --arraydiff 242 | sdist: 243 | runs-on: ubuntu-latest 244 | steps: 245 | - uses: actions/checkout@v4 246 | - name: Build sdist 247 | uses: PyO3/maturin-action@v1 248 | with: 249 | command: sdist 250 | args: --out dist 251 | - name: Upload sdist 252 | uses: actions/upload-artifact@v4 253 | with: 254 | name: wheels-sdist 255 | path: dist 256 | 257 | release: 258 | name: Release 259 | runs-on: ubuntu-latest 260 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} 261 | needs: [linux, windows, macos, sdist] 262 | environment: 263 | name: pypi 264 | permissions: 265 | # Use to sign the release artifacts 266 | id-token: write 267 | # Used to upload release artifacts 268 | contents: write 269 | # Used to generate artifact attestation 270 | attestations: write 271 | steps: 272 | - uses: actions/download-artifact@v4 273 | - name: Generate artifact attestation 274 | uses: actions/attest-build-provenance@v2 275 | with: 276 | subject-path: "wheels-*/*" 277 | - name: Publish to PyPI 278 | if: ${{ startsWith(github.ref, 'refs/tags/') }} 279 | uses: PyO3/maturin-action@v1 280 | with: 281 | command: upload 282 | args: --non-interactive --skip-existing wheels-*/* 283 | - name: Upload to GitHub Release 284 | uses: softprops/action-gh-release@v2 285 | with: 286 | files: | 287 | wasm-wheels/*.whl 288 | prerelease: ${{ contains(github.ref, 'alpha') || contains(github.ref, 'beta') }} 289 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | __pycache__ 3 | *.ipynb 4 | *.so 5 | tvm_libs/* 6 | .gdb_history 7 | .ipynb_checkpoints 8 | notebooks/*.stan 9 | notebooks/*.csv 10 | notebooks/*.hpp 11 | notebooks/radon* 12 | perf.data* 13 | wheels 14 | .vscode/ 15 | *~ 16 | .zed 17 | .cargo 18 | *traces* 19 | .pyrightconfig.json 20 | *.zarr 21 | book 22 | docs/_site 23 | .quarto 24 | example-iree 25 | posteriordb 26 | .quarto 27 | docs/.quarto 28 | Untitled* 29 | notebooks-local 30 | pixi.lock 31 | pixi.toml 32 | reports 33 | benchmarks* 34 | reports* 35 | results* 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | - id: debug-statements 9 | - id: check-merge-conflict 10 | - id: check-toml 11 | - id: check-yaml 12 | - id: debug-statements 13 | - id: end-of-file-fixer 14 | exclude: "docs/_freeze" 15 | - id: no-commit-to-branch 16 | args: [--branch, main] 17 | - id: trailing-whitespace 18 | 19 | - repo: https://github.com/astral-sh/ruff-pre-commit 20 | rev: v0.11.12 21 | hooks: 22 | - id: ruff 23 | args: ["--fix", "--output-format=full"] 24 | - id: ruff-format 25 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | ## [0.15.1] - 2025-05-28 6 | 7 | ### Bug Fixes 8 | 9 | - Incorrect results with non-contiguous shared variable (Adrian Seyboldt) 10 | 11 | - Allow upper case backend string (Adrian Seyboldt) 12 | 13 | - Allow data named x with unfrozen model (Adrian Seyboldt) 14 | 15 | 16 | ### Styling 17 | 18 | - Fix small typing issue (Adrian Seyboldt) 19 | 20 | 21 | ## [0.15.0] - 2025-05-27 22 | 23 | ### Bug Fixes 24 | 25 | - Use stanio for creating Stan's data JSON (#205) (Brian Ward) 26 | 27 | - Rng for generated quantities (Adrian Seyboldt) 28 | 29 | - Correctly handle tuples in stan traces (Adrian Seyboldt) 30 | 31 | - Allow variables with zero shapes (Adrian Seyboldt) 32 | 33 | - Let rust sampler decide on default num chains (Adrian Seyboldt) 34 | 35 | 36 | ### Documentation 37 | 38 | - Fix section links path (Guspan Tanadi) 39 | 40 | - Link to website (Adrian Seyboldt) 41 | 42 | 43 | ### Features 44 | 45 | - Improvements to normalizing flow (Adrian Seyboldt) 46 | 47 | - Experiment with planar flows (Adrian Seyboldt) 48 | 49 | 50 | ### Miscellaneous Tasks 51 | 52 | - Bump pyo3 in the cargo group across 1 directory (dependabot[bot]) 53 | 54 | - Bump astral-sh/setup-uv from 5 to 6 (#203) (dependabot[bot]) 55 | 56 | - Add entries to gitignore (Adrian Seyboldt) 57 | 58 | - Update dependencies (Adrian Seyboldt) 59 | 60 | - Update changelog (Adrian Seyboldt) 61 | 62 | 63 | ### Styling 64 | 65 | - Fix some clippy warnings (Adrian Seyboldt) 66 | 67 | 68 | ### Testing 69 | 70 | - Check that normalizing flows are reproducible (Adrian Seyboldt) 71 | 72 | - Add low rank tests (Adrian Seyboldt) 73 | 74 | 75 | ### Build 76 | 77 | - Increase optimization level (Adrian Seyboldt) 78 | 79 | 80 | ## [0.14.3] - 2025-03-18 81 | 82 | ### Bug Fixes 83 | 84 | - Fix normalizing flows for 1d posteriors (Adrian Seyboldt) 85 | 86 | - Better initialization of masked flows (Adrian Seyboldt) 87 | 88 | 89 | ### Documentation 90 | 91 | - Fix spelling and grammar (Daniel Saunders) 92 | 93 | 94 | ### Features 95 | 96 | - Add masked coupling flow (Adrian Seyboldt) 97 | 98 | - Expose static trajectory length in nuts (Adrian Seyboldt) 99 | 100 | - Make mvscale layer optional (Adrian Seyboldt) 101 | 102 | - Add layer norm in normalizing flow (Adrian Seyboldt) 103 | 104 | - Small improvements to the masked normalizing flow (Adrian Seyboldt) 105 | 106 | 107 | ### Miscellaneous Tasks 108 | 109 | - Update dependencies (Adrian Seyboldt) 110 | 111 | 112 | ### Ci 113 | 114 | - Split some test into sections with optional deps (Adrian Seyboldt) 115 | 116 | 117 | ## [0.14.2] - 2025-03-06 118 | 119 | ### Bug Fixes 120 | 121 | - Handle missing flowjax correctly (Adrian Seyboldt) 122 | 123 | 124 | ### Testing 125 | 126 | - Mark tests as pymc or stan and select in ci (Adrian Seyboldt) 127 | 128 | 129 | ### Ci 130 | 131 | - Use native arm github action runner (Adrian Seyboldt) 132 | 133 | 134 | ## [0.14.1] - 2025-03-05 135 | 136 | ### Ci 137 | 138 | - Update run-on-arch to avoid segfault (Adrian Seyboldt) 139 | 140 | - Repare 0.14.1 (Adrian Seyboldt) 141 | 142 | 143 | ## [0.14.0] - 2025-03-05 144 | 145 | ### Bug Fixes 146 | 147 | - Set 'make_initial_point_fn' in 'from_pyfunc' to None by default (#175) (Tomás Capretto) 148 | 149 | 150 | ### Documentation 151 | 152 | - Add nutpie website source (Adrian Seyboldt) 153 | 154 | - Include frozen cell output in docs (Adrian Seyboldt) 155 | 156 | 157 | ### Features 158 | 159 | - Add normalizing flow adaptation (Adrian Seyboldt) 160 | 161 | 162 | ### Miscellaneous Tasks 163 | 164 | - Bump actions/attest-build-provenance from 1 to 2 (dependabot[bot]) 165 | 166 | - Bump softprops/action-gh-release from 1 to 2 (dependabot[bot]) 167 | 168 | - Bump uraimo/run-on-arch-action from 2 to 3 (dependabot[bot]) 169 | 170 | - Update pre-commit config (Adrian Seyboldt) 171 | 172 | 173 | ### Ci 174 | 175 | - Run python 3.13 in ci (Adrian Seyboldt) 176 | 177 | - Skip slow test on ci if emulating architecture (Adrian Seyboldt) 178 | 179 | 180 | ## [0.13.4] - 2025-02-18 181 | 182 | ### Bug Fixes 183 | 184 | - Add lock for pymc init point func (Adrian Seyboldt) 185 | 186 | 187 | ### Ci 188 | 189 | - Make sure all python versions are available in the builds (Adrian Seyboldt) 190 | 191 | - Skip python 3.13 for now (Adrian Seyboldt) 192 | 193 | 194 | ## [0.13.3] - 2025-02-12 195 | 196 | ### Bug Fixes 197 | 198 | - Use arrow list with i64 offsets to store trace (Adrian Seyboldt) 199 | 200 | - Use i64 offsets in numba backend (Adrian Seyboldt) 201 | 202 | - Avoid numpy compatibility warning (Adrian Seyboldt) 203 | 204 | - Specify that we currently don't support py313 due to pyo3 (Adrian Seyboldt) 205 | 206 | 207 | ### Features 208 | 209 | - Add support for pymc sampler initialization (jessegrabowski) 210 | 211 | - Use support_point as default init for pymc (Adrian Seyboldt) 212 | 213 | - Add option not to store some deterministics (Adrian Seyboldt) 214 | 215 | - Add option to freeze pymc models (Adrian Seyboldt) 216 | 217 | 218 | ### Miscellaneous Tasks 219 | 220 | - Bump uraimo/run-on-arch-action from 2.7.2 to 2.8.1 (dependabot[bot]) 221 | 222 | - Specify version as dynamic in pyproject (Adrian Seyboldt) 223 | 224 | - Update bridgestan (Adrian Seyboldt) 225 | 226 | - Update pre-commit versions (Adrian Seyboldt) 227 | 228 | 229 | ### Styling 230 | 231 | - Reformat some code (Adrian Seyboldt) 232 | 233 | 234 | ### Build 235 | 236 | - Bump some dependency versions (Adrian Seyboldt) 237 | 238 | 239 | ### Ci 240 | 241 | - Use ubuntu_latest on aarch64 (Adrian Seyboldt) 242 | 243 | - Update CI script using maturin (Adrian Seyboldt) 244 | 245 | 246 | ## [0.13.2] - 2024-07-26 247 | 248 | ### Features 249 | 250 | - Support float32 settings in pytensor (Adrian Seyboldt) 251 | 252 | 253 | ### Miscellaneous Tasks 254 | 255 | - Update dependencies (Adrian Seyboldt) 256 | 257 | 258 | ## [0.13.1] - 2024-07-09 259 | 260 | ### Bug Fixes 261 | 262 | - Fix jax backend with non-identifier variable names (Adrian Seyboldt) 263 | 264 | 265 | ### Miscellaneous Tasks 266 | 267 | - Update dependencies (Adrian Seyboldt) 268 | 269 | 270 | ## [0.13.0] - 2024-07-05 271 | 272 | ### Documentation 273 | 274 | - Document low-rank mass matrix parameters (Adrian Seyboldt) 275 | 276 | 277 | ### Features 278 | 279 | - Add low rank modified mass matrix adaptation (Adrian Seyboldt) 280 | 281 | 282 | ### Miscellaneous Tasks 283 | 284 | - Remove releases from changelog (Adrian Seyboldt) 285 | 286 | 287 | ## [0.12.0] - 2024-06-29 288 | 289 | ### Features 290 | 291 | - Add pyfunc backend (Adrian Seyboldt) 292 | 293 | - Add python code for pyfunc backend (Adrian Seyboldt) 294 | 295 | - Add gradient_backend argument for pymc models (Adrian Seyboldt) 296 | 297 | 298 | ### Miscellaneous Tasks 299 | 300 | - Bump version number (Adrian Seyboldt) 301 | 302 | 303 | ### Styling 304 | 305 | - Fix pre-commit issues (Adrian Seyboldt) 306 | 307 | 308 | ### Testing 309 | 310 | - Add tests for jax backend (Adrian Seyboldt) 311 | 312 | 313 | ### Build 314 | 315 | - Add jax as optional dependency (Adrian Seyboldt) 316 | 317 | 318 | ## [0.11.1] - 2024-06-16 319 | 320 | ### Bug Fixes 321 | 322 | - Fix random variables with missing values in pymc deterministics (Adrian Seyboldt) 323 | 324 | 325 | ### Features 326 | 327 | - Add progress bar on terminal (Adrian Seyboldt) 328 | 329 | 330 | ## [0.11.0] - 2024-05-29 331 | 332 | ### Bug Fixes 333 | 334 | - Use clone_replace instead of graph_replace (Adrian Seyboldt) 335 | 336 | - Allow shared vars to differ in expand and logp (Adrian Seyboldt) 337 | 338 | 339 | ### Features 340 | 341 | - Add option to use draw base mass matrix estimate (Adrian Seyboldt) 342 | 343 | - Report detailed progress (Adrian Seyboldt) 344 | 345 | - Show the number of draws in progress overview (Adrian Seyboldt) 346 | 347 | 348 | ### Miscellaneous Tasks 349 | 350 | - Bump KyleMayes/install-llvm-action from 1 to 2 (dependabot[bot]) 351 | 352 | - Bump uraimo/run-on-arch-action from 2.7.1 to 2.7.2 (dependabot[bot]) 353 | 354 | - Update dependencies (Adrian Seyboldt) 355 | 356 | - Update python dependencies (Adrian Seyboldt) 357 | 358 | 359 | ### Refactor 360 | 361 | - Move threaded sampling to nuts-rs (Adrian Seyboldt) 362 | 363 | - Specify callback rate (Adrian Seyboldt) 364 | 365 | - Switch to arrow-rs (Adrian Seyboldt) 366 | 367 | 368 | ### Styling 369 | 370 | - Fix formatting and clippy (Adrian Seyboldt) 371 | 372 | 373 | ### Testing 374 | 375 | - Fix incorrect error type in test (Adrian Seyboldt) 376 | 377 | 378 | ### Ci 379 | 380 | - Fix uploads of releases (Adrian Seyboldt) 381 | 382 | - Fix architectures in CI (Adrian Seyboldt) 383 | 384 | 385 | ## [0.10.0] - 2024-03-20 386 | 387 | ### Documentation 388 | 389 | - Mention non-blocking sampling in readme (Adrian Seyboldt) 390 | 391 | 392 | ### Features 393 | 394 | - Allow sampling in the backgound (Adrian Seyboldt) 395 | 396 | - Implement check if background sampling is complete (Adrian Seyboldt) 397 | 398 | - Implement pausing and unpausing of samplers (Adrian Seyboldt) 399 | 400 | - Filter warnings and compile through pymc (Adrian Seyboldt) 401 | 402 | 403 | ### Miscellaneous Tasks 404 | 405 | - Bump actions/setup-python from 4 to 5 (dependabot[bot]) 406 | 407 | - Bump uraimo/run-on-arch-action from 2.5.0 to 2.7.1 (dependabot[bot]) 408 | 409 | - Bump actions/checkout from 3 to 4 (dependabot[bot]) 410 | 411 | - Bump actions/upload-artifact from 3 to 4 (dependabot[bot]) 412 | 413 | - Bump the cargo group across 1 directory with 2 updates (dependabot[bot]) 414 | 415 | - Some major version bumps in rust deps (Adrian Seyboldt) 416 | 417 | - Bump dependency versions (Adrian Seyboldt) 418 | 419 | - Bump version (Adrian Seyboldt) 420 | 421 | - Update changelog (Adrian Seyboldt) 422 | 423 | 424 | ### Performance 425 | 426 | - Set the number of parallel chains dynamically (Adrian Seyboldt) 427 | 428 | 429 | ## [0.9.2] - 2024-02-19 430 | 431 | ### Bug Fixes 432 | 433 | - Allow dims with only length specified (Adrian Seyboldt) 434 | 435 | 436 | ### Documentation 437 | 438 | - Update suggested mamba commands in README (#70) (Ben Mares) 439 | 440 | - Fix README typo bridgestan→nutpie (#69) (Ben Mares) 441 | 442 | 443 | ### Features 444 | 445 | - Handle missing libraries more robustly (#72) (Ben Mares) 446 | 447 | 448 | ### Miscellaneous Tasks 449 | 450 | - Bump actions/download-artifact from 3 to 4 (dependabot[bot]) 451 | 452 | 453 | ### Ci 454 | 455 | - Make sure the local nutpie is installed (Adrian Seyboldt) 456 | 457 | - Install local nutpie package in all jobs (Adrian Seyboldt) 458 | 459 | 460 | ## [0.9.0] - 2023-09-12 461 | 462 | ### Bug Fixes 463 | 464 | - Better error context for init point errors (Adrian Seyboldt) 465 | 466 | 467 | ### Features 468 | 469 | - Improve error message by providing context (Adrian Seyboldt) 470 | 471 | - Use standard normal to initialize chains (Adrian Seyboldt) 472 | 473 | 474 | ### Miscellaneous Tasks 475 | 476 | - Update deps (Adrian Seyboldt) 477 | 478 | - Rename stan model transpose function (Adrian Seyboldt) 479 | 480 | - Update nutpie (Adrian Seyboldt) 481 | 482 | 483 | ### Styling 484 | 485 | - Fix formatting (Adrian Seyboldt) 486 | 487 | 488 | ## [0.8.0] - 2023-08-18 489 | 490 | ### Bug Fixes 491 | 492 | - Initialize points in uniform(-2, 2) (Adrian Seyboldt) 493 | 494 | - Multidimensional stan variables were stored in incorrect order (Adrian Seyboldt) 495 | 496 | - Fix with_coords for stan models (Adrian Seyboldt) 497 | 498 | 499 | ### Miscellaneous Tasks 500 | 501 | - Update deps (Adrian Seyboldt) 502 | 503 | - Update deps (Adrian Seyboldt) 504 | 505 | - Bump version (Adrian Seyboldt) 506 | 507 | - Update deps (Adrian Seyboldt) 508 | 509 | 510 | ## [0.7.0] - 2023-07-21 511 | 512 | ### Bug Fixes 513 | 514 | - Check logp value in stan interface (Adrian Seyboldt) 515 | 516 | - Make max_energy_error writable (Adrian Seyboldt) 517 | 518 | - Export names of unconstrained parameters (Adrian Seyboldt) 519 | 520 | - Fix return values of logp benchmark function (Adrian Seyboldt) 521 | 522 | 523 | ### Features 524 | 525 | - Export more details of divergences (Adrian Seyboldt) 526 | 527 | - Add extra_stanc_args argument to compile_stan_model (Chris Fonnesbeck) 528 | 529 | 530 | ### Miscellaneous Tasks 531 | 532 | - Update dependencies (Adrian Seyboldt) 533 | 534 | - Add changelog (Adrian Seyboldt) 535 | 536 | - Bump version (Adrian Seyboldt) 537 | 538 | 539 | ### Refactor 540 | 541 | - Hide private rust module (Adrian Seyboldt) 542 | 543 | 544 | ## [0.6.0] - 2023-07-03 545 | 546 | ### Features 547 | 548 | - Allow to update dims and coords in stan model (Adrian Seyboldt) 549 | 550 | 551 | ### Miscellaneous Tasks 552 | 553 | - Bump version (Adrian Seyboldt) 554 | 555 | 556 | 557 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: nutpie 6 | message: >- 7 | If you use this software, please cite it using the 8 | metadata from this file. 9 | type: software 10 | authors: 11 | - given-names: Adrian 12 | family-names: Seyboldt 13 | email: adrian.seyboldt@gmail.com 14 | affiliation: PyMC Labs 15 | orcid: 'https://orcid.org/0000-0002-4239-4541' 16 | - name: PyMC Developers 17 | website: 'https://github.com/pymc-devs/' 18 | repository-code: 'https://github.com/pymc-devs/nutpie' 19 | abstract: 'A fast sampler for Bayesian posteriors, wrapping nuts-rs.' 20 | keywords: 21 | - NUTS 22 | - Bayesian inference 23 | - MCMC 24 | license: MIT 25 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "nutpie" 3 | version = "0.15.1" 4 | authors = [ 5 | "Adrian Seyboldt ", 6 | "PyMC Developers ", 7 | ] 8 | edition = "2021" 9 | license = "MIT" 10 | repository = "https://github.com/pymc-devs/nutpie" 11 | keywords = ["statistics", "bayes"] 12 | description = "Python wrapper for nuts-rs -- a NUTS sampler written in Rust." 13 | rust-version = "1.76" 14 | 15 | [features] 16 | extension-module = ["pyo3/extension-module"] 17 | default = ["extension-module"] 18 | 19 | [lib] 20 | name = "_lib" 21 | crate-type = ["cdylib"] 22 | 23 | [dependencies] 24 | nuts-rs = "0.16.1" 25 | numpy = "0.25.0" 26 | rand = "0.9.0" 27 | thiserror = "2.0.3" 28 | rand_chacha = "0.9.0" 29 | rayon = "1.10.0" 30 | # Keep arrow in sync with nuts-rs requirements 31 | arrow = { version = "55.1.0", default-features = false, features = ["ffi"] } 32 | anyhow = "1.0.72" 33 | itertools = "0.14.0" 34 | bridgestan = "2.6.1" 35 | rand_distr = "0.5.0" 36 | smallvec = "1.14.0" 37 | upon = { version = "0.9.0", default-features = false, features = [] } 38 | time-humanize = { version = "0.1.3", default-features = false } 39 | indicatif = "0.17.8" 40 | tch = { version = "0.20.0", optional = true } 41 | 42 | [dependencies.pyo3] 43 | version = "0.25.0" 44 | features = ["extension-module", "anyhow"] 45 | 46 | [dev-dependencies] 47 | criterion = "0.6.0" 48 | 49 | [profile.release] 50 | lto = "fat" 51 | codegen-units = 1 52 | opt-level = 3 53 | 54 | [profile.bench] 55 | debug = true 56 | lto = "fat" 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Adrian Seyboldt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nutpie: A fast sampler for Bayesian posteriors 2 | 3 | The `nutpie` package provides a fast NUTS sampler for PyMC and Stan models. 4 | 5 | See the [documentation](https://pymc-devs.github.io/nutpie/) for more details. 6 | 7 | ## Installation 8 | 9 | nutpie can be installed using Conda or Mamba from conda-forge with 10 | 11 | ```bash 12 | mamba install -c conda-forge nutpie 13 | ``` 14 | 15 | Or using pip: 16 | 17 | ```bash 18 | pip install nutpie 19 | ``` 20 | 21 | To install it from source, install a Rust compiler and maturin and then 22 | 23 | ```bash 24 | maturin develop --release 25 | ``` 26 | 27 | If you want to use the nightly SIMD implementation for some of the math functions, 28 | switch to Rust nightly and then install with the `simd_support` feature in then 29 | nutpie directory: 30 | 31 | ```bash 32 | rustup override set nightly 33 | maturin develop --release --features=simd_support 34 | ``` 35 | 36 | ## Usage with PyMC 37 | 38 | First, PyMC and Numba need to be installed, for example using 39 | 40 | ```bash 41 | mamba install -c conda-forge pymc numba 42 | ``` 43 | 44 | We need to create a model: 45 | 46 | ```python 47 | import pymc as pm 48 | import numpy as np 49 | import nutpie 50 | import pandas as pd 51 | import seaborn as sns 52 | 53 | # Load the radon dataset 54 | data = pd.read_csv(pm.get_data("radon.csv")) 55 | data["log_radon"] = data["log_radon"].astype(np.float64) 56 | county_idx, counties = pd.factorize(data.county) 57 | coords = {"county": counties, "obs_id": np.arange(len(county_idx))} 58 | 59 | # Create a simple hierarchical model for the radon dataset 60 | with pm.Model(coords=coords, check_bounds=False) as pymc_model: 61 | intercept = pm.Normal("intercept", sigma=10) 62 | 63 | # County effects 64 | raw = pm.ZeroSumNormal("county_raw", dims="county") 65 | sd = pm.HalfNormal("county_sd") 66 | county_effect = pm.Deterministic("county_effect", raw * sd, dims="county") 67 | 68 | # Global floor effect 69 | floor_effect = pm.Normal("floor_effect", sigma=2) 70 | 71 | # County:floor interaction 72 | raw = pm.ZeroSumNormal("county_floor_raw", dims="county") 73 | sd = pm.HalfNormal("county_floor_sd") 74 | county_floor_effect = pm.Deterministic( 75 | "county_floor_effect", raw * sd, dims="county" 76 | ) 77 | 78 | mu = ( 79 | intercept 80 | + county_effect[county_idx] 81 | + floor_effect * data.floor.values 82 | + county_floor_effect[county_idx] * data.floor.values 83 | ) 84 | 85 | sigma = pm.HalfNormal("sigma", sigma=1.5) 86 | pm.Normal( 87 | "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id" 88 | ) 89 | ``` 90 | 91 | We then compile this model and sample form the posterior: 92 | 93 | ```python 94 | compiled_model = nutpie.compile_pymc_model(pymc_model) 95 | trace_pymc = nutpie.sample(compiled_model) 96 | ``` 97 | 98 | `trace_pymc` now contains an ArviZ `InferenceData` object, including sampling 99 | statistics and the posterior of the variables defined above. 100 | 101 | We can also control the sampler in a non-blocking way: 102 | 103 | ```python 104 | # The sampler will now run the the background 105 | sampler = nutpie.sample(compiled_model, blocking=False) 106 | 107 | # Pause and resume the sampling 108 | sampler.pause() 109 | sampler.resume() 110 | 111 | # Wait for the sampler to finish (up to timeout seconds) 112 | sampler.wait(timeout=0.1) 113 | # Note that not passing any timeout to `wait` will 114 | # wait until the sampler finishes, then return the InferenceData object: 115 | idata = sampler.wait() 116 | 117 | # or we can also abort the sampler (and return the incomplete trace) 118 | incomplete_trace = sampler.abort() 119 | 120 | # or cancel and discard all progress: 121 | sampler.cancel() 122 | ``` 123 | 124 | ## Usage with Stan 125 | 126 | In order to sample from Stan model, `bridgestan` needs to be installed. 127 | A pip package is available, but right now this can not be installed using Conda. 128 | 129 | ```bash 130 | pip install bridgestan 131 | ``` 132 | 133 | When we install nutpie with pip, we can also specify that we want optional 134 | dependencies for Stan models using 135 | 136 | ``` 137 | pip install 'nutpie[stan]' 138 | ``` 139 | 140 | In addition, a C++ compiler needs to be available. For details see 141 | [the Stan docs](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). 142 | 143 | We can then compile a Stan model, and sample using nutpie: 144 | 145 | ```python 146 | import nutpie 147 | 148 | code = """ 149 | data { 150 | real mu; 151 | } 152 | parameters { 153 | real x; 154 | } 155 | model { 156 | x ~ normal(mu, 1); 157 | } 158 | """ 159 | 160 | compiled = nutpie.compile_stan_model(code=code) 161 | # Provide data 162 | compiled = compiled.with_data(mu=3.) 163 | trace = nutpie.sample(compiled) 164 | ``` 165 | 166 | ## Advantages 167 | 168 | nutpie uses [`nuts-rs`](https://github.com/pymc-devs/nuts-rs), a library written in Rust, that implements NUTS as in 169 | PyMC and Stan, but with a slightly different mass matrix tuning method as 170 | those. It often produces a higher effective sample size per gradient 171 | evaluation, and tends to converge faster and with fewer gradient evaluation. 172 | -------------------------------------------------------------------------------- /benches/run_tvm_leapfrog.rs_old: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | 3 | use tvm::runtime::graph_rt::GraphRt; 4 | use tvm::runtime::vm::{Executable, VirtualMachine, VirtualMachineBuilder}; 5 | use tvm::runtime::{Context, Module}; 6 | use tvm::{DataType, NDArray}; 7 | 8 | use ndarray::{Array, Array0, Array1, ArrayD}; 9 | 10 | use std::path::Path; 11 | 12 | fn make_vm>(exe: P, code: P) -> VirtualMachine { 13 | let code = std::fs::read(code).expect("Could not read code."); 14 | let lib = Module::load(&exe).expect("Could not read executable module."); 15 | let exe = Executable::new((&code).into(), lib).expect("Could not build executable"); 16 | let ctx = Context::cpu(0); 17 | VirtualMachineBuilder::new(exe) 18 | .context(ctx) 19 | .build() 20 | .expect("Error building vm") 21 | } 22 | 23 | fn make_graph_rt>(factory: P, ctx: Context) -> GraphRt { 24 | let lib = Module::load(&factory).expect("Could not load graph factory module."); 25 | let ctxs = vec![ctx]; 26 | GraphRt::create_from_factory(lib, "default", ctxs).expect("Could not create graph runtime.") 27 | } 28 | 29 | fn criterion_benchmark(c: &mut Criterion) { 30 | let code_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("tvm_libs/leapfrog_10.code"); 31 | let lib_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("tvm_libs/leapfrog_10.so"); 32 | 33 | let ctx = Context::cpu(0); 34 | const N: usize = 10; 35 | let dt = DataType::float(32, 1); 36 | 37 | /* 38 | let mut vm = make_vm(&lib_path, &code_path); 39 | 40 | 41 | let pos: ArrayD = Array::ones((N,)).into_dyn(); 42 | let diag_mass: ArrayD = Array::ones((N,)).into_dyn(); 43 | let momentum: ArrayD = Array::ones((N,)).into_dyn(); 44 | let epsilon: ArrayD = Array::ones(()).into_dyn(); 45 | let grad: ArrayD = Array::ones((N,)).into_dyn(); 46 | 47 | 48 | let pos = NDArray::from_rust_ndarray(&pos, ctx, dt).unwrap(); 49 | let diag_mass = NDArray::from_rust_ndarray(&diag_mass, ctx, dt).unwrap(); 50 | let momentum = NDArray::from_rust_ndarray(&momentum, ctx, dt).unwrap(); 51 | let epsilon = NDArray::from_rust_ndarray(&epsilon, ctx, dt).unwrap(); 52 | let grad = NDArray::from_rust_ndarray(&grad, ctx, dt).unwrap(); 53 | 54 | let func = "main"; 55 | 56 | vm.set_input(func, vec![pos.into(), momentum.into(), grad.into(), epsilon.into(), diag_mass.into()]).unwrap(); 57 | 58 | c.bench_function("run leafprog 10", |b| b.iter(|| { 59 | vm.invoke(&func).unwrap(); 60 | })); 61 | 62 | */ 63 | 64 | let lib_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("tvm_libs/leapfrog_10.so"); 65 | let mut graph_rt = make_graph_rt(&lib_path, ctx); 66 | 67 | let pos: ArrayD = Array::ones((N,)).into_dyn(); 68 | let diag_mass: ArrayD = Array::ones((N,)).into_dyn(); 69 | let momentum: ArrayD = Array::ones((N,)).into_dyn(); 70 | let epsilon: ArrayD = Array::ones(()).into_dyn(); 71 | let grad: ArrayD = Array::ones((N,)).into_dyn(); 72 | 73 | let pos = NDArray::from_rust_ndarray(&pos, ctx, dt).unwrap(); 74 | let diag_mass = NDArray::from_rust_ndarray(&diag_mass, ctx, dt).unwrap(); 75 | let momentum = NDArray::from_rust_ndarray(&momentum, ctx, dt).unwrap(); 76 | let epsilon = NDArray::from_rust_ndarray(&epsilon, ctx, dt).unwrap(); 77 | let grad = NDArray::from_rust_ndarray(&grad, ctx, dt).unwrap(); 78 | 79 | c.bench_function("set graph input 10", |b| { 80 | b.iter(|| { 81 | graph_rt.set_input("position_in", &pos).unwrap(); 82 | graph_rt.set_input("momentum_in", &momentum).unwrap(); 83 | graph_rt.set_input("grad_in", &grad).unwrap(); 84 | graph_rt.set_input("epsilon", &epsilon).unwrap(); 85 | graph_rt.set_input("mass_diag", &diag_mass).unwrap(); 86 | }) 87 | }); 88 | 89 | c.bench_function("set graph input 10 idx", |b| { 90 | b.iter(|| { 91 | graph_rt.set_input_idx(0, &pos).unwrap(); 92 | graph_rt.set_input_idx(1, &momentum).unwrap(); 93 | graph_rt.set_input_idx(2, &grad).unwrap(); 94 | graph_rt.set_input_idx(3, &epsilon).unwrap(); 95 | graph_rt.set_input_idx(4, &diag_mass).unwrap(); 96 | }) 97 | }); 98 | 99 | c.bench_function("set graph input 10 idx nocopy", |b| { 100 | b.iter(|| { 101 | graph_rt.set_input_zero_copy_idx(0, &pos).unwrap(); 102 | graph_rt.set_input_zero_copy_idx(1, &momentum).unwrap(); 103 | graph_rt.set_input_zero_copy_idx(2, &grad).unwrap(); 104 | graph_rt.set_input_zero_copy_idx(3, &epsilon).unwrap(); 105 | graph_rt.set_input_zero_copy_idx(4, &diag_mass).unwrap(); 106 | }) 107 | }); 108 | 109 | c.bench_function("run graph 10", |b| { 110 | b.iter(|| { 111 | graph_rt.run().unwrap(); 112 | }) 113 | }); 114 | 115 | let out0_: ArrayD = Array::ones((N,)).into_dyn(); 116 | let out0 = NDArray::from_rust_ndarray(&out0_, ctx, dt).unwrap(); 117 | c.bench_function("run all 10", |b| { 118 | b.iter(|| { 119 | graph_rt.set_input_zero_copy_idx(0, &pos).unwrap(); 120 | graph_rt.set_input_zero_copy_idx(1, &momentum).unwrap(); 121 | graph_rt.set_input_zero_copy_idx(2, &grad).unwrap(); 122 | graph_rt.set_input_zero_copy_idx(3, &epsilon).unwrap(); 123 | graph_rt.set_input_zero_copy_idx(4, &diag_mass).unwrap(); 124 | graph_rt.run().unwrap(); 125 | graph_rt.get_output_into(2, out0.clone()).unwrap(); 126 | }) 127 | }); 128 | } 129 | 130 | criterion_group!(benches, criterion_benchmark); 131 | criterion_main!(benches); 132 | -------------------------------------------------------------------------------- /cliff.toml: -------------------------------------------------------------------------------- 1 | # git-cliff ~ default configuration file 2 | # https://git-cliff.org/docs/configuration 3 | # 4 | # Lines starting with "#" are comments. 5 | # Configuration options are organized into tables and keys. 6 | # See documentation for more information on available options. 7 | 8 | [changelog] 9 | # changelog header 10 | header = """ 11 | # Changelog\n 12 | All notable changes to this project will be documented in this file.\n 13 | """ 14 | # template for the changelog body 15 | # https://tera.netlify.app/docs 16 | body = """ 17 | {% if version %}\ 18 | ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} 19 | {% else %}\ 20 | ## [unreleased] 21 | {% endif %}\ 22 | {% for group, commits in commits | group_by(attribute="group") %} 23 | ### {{ group | upper_first }} 24 | {% for commit in commits %} 25 | - {% if commit.breaking %}[**breaking**] {% endif %}{{ commit.message | upper_first }} \ 26 | ({{ commit.author.name }}) 27 | {% endfor %} 28 | {% endfor %}\n 29 | """ 30 | # remove the leading and trailing whitespace from the template 31 | trim = true 32 | # changelog footer 33 | footer = """ 34 | 35 | """ 36 | 37 | [git] 38 | # parse the commits based on https://www.conventionalcommits.org 39 | conventional_commits = true 40 | # filter out the commits that are not conventional 41 | filter_unconventional = true 42 | # process each line of a commit as an individual commit 43 | split_commits = false 44 | # regex for preprocessing the commit messages 45 | commit_preprocessors = [ 46 | # { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/orhun/git-cliff/issues/${2}))"}, # replace issue numbers 47 | ] 48 | commit_parsers = [ 49 | { message = "^feat", group = "Features" }, 50 | { message = "^fix", group = "Bug Fixes" }, 51 | { message = "^doc", group = "Documentation" }, 52 | { message = "^perf", group = "Performance" }, 53 | { message = "^refactor", group = "Refactor" }, 54 | { message = "^style", group = "Styling" }, 55 | { message = "^test", group = "Testing" }, 56 | { message = "^chore: Prepare", skip = true }, 57 | { message = "^chore\\(release\\)", skip = true }, 58 | { message = "^chore", group = "Miscellaneous Tasks" }, 59 | { body = ".*security", group = "Security" }, 60 | ] # regex for parsing and grouping commits 61 | # protect breaking changes from being skipped due to matching a skipping commit_parser 62 | protect_breaking_commits = false 63 | # filter out the commits that are not matched by commit parsers 64 | filter_commits = false 65 | # glob pattern for matching git tags 66 | tag_pattern = "v[0-9]*" 67 | # regex for skipping tags 68 | skip_tags = "v0.1.0-beta.1" 69 | # regex for ignoring tags 70 | ignore_tags = "" 71 | # sort the tags topologically 72 | topo_order = false 73 | # sort the commits inside sections by oldest/newest order 74 | sort_commits = "oldest" 75 | # limit the number of commits included in the changelog. 76 | # limit_commits = 42 77 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | /.quarto/ 2 | -------------------------------------------------------------------------------- /docs/_freeze/index/execute-results/html.json: -------------------------------------------------------------------------------- 1 | { 2 | "hash": "94e4388705073729b94725a15410d650", 3 | "result": { 4 | "engine": "jupyter", 5 | "markdown": "---\ntitle: Nutpie Documentation\n---\n\n\n\n`nutpie` is a high-performance library designed for Bayesian inference, that\nprovides efficient sampling algorithms for probabilistic models. It can sample\nmodels that are defined in PyMC or Stan (numpyro and custom hand-coded\nlikelihoods with gradient are coming soon).\n\n- Faster sampling than either the PyMC or Stan default samplers. (An average\n ~2x speedup on `posteriordb` compared to Stan)\n- All the diagnostic information of PyMC and Stan and some more.\n- GPU support for PyMC models through jax.\n- A more informative progress bar.\n- Access to the incomplete trace during sampling.\n- *Experimental* normalizing flow adaptation for more efficient sampling of\n difficult posteriors.\n\n## Quickstart: PyMC\n\nInstall `nutpie` with pip, uv, pixi, or conda:\n\nFor usage with pymc:\n\n```bash\n# One of\npip install \"nutpie[pymc]\"\nuv add \"nutpie[pymc]\"\npixi add nutpie pymc numba\nconda install -c conda-forge nutpie pymc numba\n```\n\nAnd then sample with\n\n\n::: {#1c2d97ba .cell execution_count=1}\n``` {.python .cell-code}\nimport nutpie\nimport pymc as pm\n\nwith pm.Model() as model:\n mu = pm.Normal(\"mu\", mu=0, sigma=1)\n obs = pm.Normal(\"obs\", mu=mu, sigma=1, observed=[1, 2, 3])\n\ncompiled = nutpie.compile_pymc_model(model)\ntrace = nutpie.sample(compiled)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140001.351
\n \n \n 140001.283
\n \n \n 140001.293
\n \n \n 140001.233
\n \n \n 140001.403
\n \n \n 140001.281
\n
\n```\n:::\n:::\n\n\nFor more information, see the detailed [PyMC usage guide](pymc-usage.qmd).\n\n## Quickstart: Stan\n\nStan needs access to a compiler toolchain, you can find instructions for those\n[here](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain).\nYou can then install nutpie through pip or uv:\n\n```bash\n# One of\npip install \"nutpie[stan]\"\nuv add \"nutpie[stan]\"\n```\n\n\n\n::: {#700ed270 .cell execution_count=3}\n``` {.python .cell-code}\nimport nutpie\n\nmodel = \"\"\"\ndata {\n int N;\n vector[N] y;\n}\nparameters {\n real mu;\n}\nmodel {\n mu ~ normal(0, 1);\n y ~ normal(mu, 1);\n}\n\"\"\"\n\ncompiled = (\n nutpie\n .compile_stan_model(code=model)\n .with_data(N=3, y=[1, 2, 3])\n)\ntrace = nutpie.sample(compiled)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140001.291
\n \n \n 140001.273
\n \n \n 140001.343
\n \n \n 140001.331
\n \n \n 140001.413
\n \n \n 140001.293
\n
\n```\n:::\n:::\n\n\nFor more information, see the detailed [Stan usage guide](stan-usage.qmd).\n\n", 6 | "supporting": [ 7 | "index_files" 8 | ], 9 | "filters": [], 10 | "includes": { 11 | "include-in-header": [ 12 | "\n\n\n" 13 | ] 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /docs/_freeze/nf-adapt/execute-results/html.json: -------------------------------------------------------------------------------- 1 | { 2 | "hash": "99a8749bdb41e64f77fd32de347bbc2a", 3 | "result": { 4 | "engine": "jupyter", 5 | "markdown": "---\ntitle: Adaptation with Normalizing Flows\n---\n\n\n\n**Experimental and subject to change**\n\nNormalizing flow adaptation through Fisher HMC is a new sampling algorithm that\nautomatically reparameterizes a model. It adds some computational cost outside\nmodel log-density evaluations, but allows sampling from much more difficult\nposterior distributions. For models with expensive log-density evaluations, the\nnormalizing flow adaptation can also be much faster, if it can reduce the number\nof log-density evaluations needed to reach a given effective sample size.\n\nThe normalizing flow adaptation works by learning a transformation of the parameter\nspace that makes the posterior distribution more amenable to sampling. This is done\nby fitting a sequence of invertible transformations (the \"flow\") that maps the\noriginal parameter space to a space where the posterior is closer to a standard\nnormal distribution. The flow is trained during warmup.\n\nFor more information about the algorithm, see the (still work in progress) paper\n[If only my posterior were normal: Introducing Fisher\nHMC](https://github.com/aseyboldt/covadapt-paper/releases/download/latest/main.pdf).\n\nCurrently, a lot of time is spent on compiling various parts of the normalizing\nflow, and for small models this can take a large amount of the total time.\nHopefully, we will be able to reduce this overhead in the future.\n\n## Requirements\n\nInstall the optional dependencies for normalizing flow adaptation:\n\n```\npip install 'nutpie[nnflow]'\n```\n\nIf you use with PyMC, this will only work if the model is compiled using the jax\nbackend, and if the `gradient_backend` is also set to `jax`.\n\nTraining of the normalizing flow can often be accelerated by using a GPU (even\nif the model itself is written in Stan, without any GPU support). To enable GPU\nyou need to make sure your `jax` installation comes with GPU support, for\ninstance by installing it with `pip install 'jax[cuda12]'`, or selecting the\n`jaxlib` version with GPU support, if you are using conda-forge. You can check if\nyour installation has GPU support by checking the output of:\n\n```python\nimport jax\njax.devices()\n```\n\n### Usage\n\nTo use normalizing flow adaptation in `nutpie`, you need to enable the\n`transform_adapt` option during sampling. Here is an example of how we can use\nit to sample from a difficult posterior:\n\n\n::: {#1e499251 .cell execution_count=1}\n``` {.python .cell-code}\nimport pymc as pm\nimport nutpie\nimport numpy as np\nimport arviz\n\n# Define a 100-dimensional funnel model\nwith pm.Model() as model:\n log_sigma = pm.Normal(\"log_sigma\")\n pm.Normal(\"x\", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100)\n\n# Compile the model with the jax backend\ncompiled = nutpie.compile_pymc_model(\n model, backend=\"jax\", gradient_backend=\"jax\"\n)\n```\n:::\n\n\nIf we sample this model without normalizing flow adaptation, we will encounter\nconvergence issues, often divergences and always low effective sample sizes:\n\n::: {#f7faabf0 .cell execution_count=2}\n``` {.python .cell-code}\n# Sample without normalizing flow adaptation\ntrace_no_nf = nutpie.sample(compiled, seed=1)\nassert (arviz.ess(trace_no_nf) < 100).any().to_array().any()\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for 16 seconds

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.457
\n \n \n 140000.3115
\n \n \n 140000.317
\n \n \n 140000.287
\n \n \n 140000.3915
\n \n \n 140000.347
\n
\n```\n:::\n:::\n\n\n::: {#6cfb99bd .cell execution_count=3}\n``` {.python .cell-code}\n# We can add further arguments for the normalizing flow:\ncompiled = compiled.with_transform_adapt(\n num_layers=5, # Number of layers in the normalizing flow\n nn_width=32, # Neural networks with 32 hidden units\n num_diag_windows=6, # Number of windows with a diagonal mass matrix intead of a flow\n verbose=False, # Whether to print details about the adaptation process\n show_progress=False, # Whether to show a progress bar for each optimization step\n)\n\n# Sample with normalizing flow adaptation\ntrace_nf = nutpie.sample(\n compiled,\n transform_adapt=True, # Enable the normalizing flow adaptation\n seed=1,\n chains=2,\n cores=1, # Running chains in parallel can be slow\n window_switch_freq=150, # Optimize the normalizing flow every 150 iterations\n)\nassert trace_nf.sample_stats.diverging.sum() == 0\nassert (arviz.ess(trace_nf) > 1000).all().to_array().all()\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 2

\n

Active Chains: 0

\n

\n Finished Chains:\n 2\n

\n

Sampling for 18 minutes

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 250000.527
\n \n \n 250000.537
\n
\n```\n:::\n:::\n\n\nThe sampler used fewer gradient evaluations with the normalizing flow adaptation,\nbut still converged, and produce a good effective sample size:\n\n::: {#78aaecea .cell execution_count=4}\n``` {.python .cell-code}\nn_steps = int(trace_nf.sample_stats.n_steps.sum() + trace_nf.warmup_sample_stats.n_steps.sum())\ness = float(arviz.ess(trace_nf).min().to_array().min())\nprint(f\"Number of gradient evaluations: {n_steps}\")\nprint(f\"Minimum effective sample size: {ess}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nNumber of gradient evaluations: 42527\nMinimum effective sample size: 1835.9674640023168\n```\n:::\n:::\n\n\nWithout normalizing flow, it used more gradient evaluations, and still wasn't able\nto get a good effective sample size:\n\n::: {#820fea9f .cell execution_count=5}\n``` {.python .cell-code}\nn_steps = int(trace_no_nf.sample_stats.n_steps.sum() + trace_no_nf.warmup_sample_stats.n_steps.sum())\ness = float(arviz.ess(trace_no_nf).min().to_array().min())\nprint(f\"Number of gradient evaluations: {n_steps}\")\nprint(f\"Minimum effective sample size: {ess}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nNumber of gradient evaluations: 124219\nMinimum effective sample size: 31.459420094540565\n```\n:::\n:::\n\n\nThe flow adaptation occurs during warmup, so the number of warmup draws should\nbe large enough to allow the flow to converge. For more complex posteriors, you\nmay need to increase the number of layers (using the `num_layers` argument), or\nyou might want to increase the number of warmup draws.\n\nTo monitor the progress of the flow adaptation, you can set `verbose=True`, or\n`show_progress=True`, but the second should only be used if you sample just one\nchain.\n\nAll losses are on a log-scale. Negative values smaller -2 are a good sign that\nthe adaptation was successful. If the loss stays positive, the flow is either\nnot expressive enough, or the training period is too short. The sampler might\nstill converge, but will probably need more gradient evaluations per effective\ndraw. Large losses bigger than 6 tend to indicate that the posterior is too\ndifficult to sample with the current flow, and the sampler will probably not\nconverge.\n\n", 6 | "supporting": [ 7 | "nf-adapt_files/figure-html" 8 | ], 9 | "filters": [], 10 | "includes": { 11 | "include-in-header": [ 12 | "\n\n\n" 13 | ] 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /docs/_freeze/pymc-usage/figure-html/cell-5-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/pymc-usage/figure-html/cell-5-output-1.png -------------------------------------------------------------------------------- /docs/_freeze/sample-stats/figure-html/cell-3-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/sample-stats/figure-html/cell-3-output-1.png -------------------------------------------------------------------------------- /docs/_freeze/sample-stats/figure-html/cell-4-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/sample-stats/figure-html/cell-4-output-1.png -------------------------------------------------------------------------------- /docs/_freeze/sample-stats/figure-html/cell-5-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/sample-stats/figure-html/cell-5-output-1.png -------------------------------------------------------------------------------- /docs/_freeze/sample-stats/figure-html/cell-7-output-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/sample-stats/figure-html/cell-7-output-2.png -------------------------------------------------------------------------------- /docs/_freeze/sample-stats/figure-html/cell-8-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/sample-stats/figure-html/cell-8-output-1.png -------------------------------------------------------------------------------- /docs/_freeze/sample-stats/figure-html/cell-9-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/sample-stats/figure-html/cell-9-output-1.png -------------------------------------------------------------------------------- /docs/_freeze/site_libs/clipboard/clipboard.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * clipboard.js v2.0.11 3 | * https://clipboardjs.com/ 4 | * 5 | * Licensed MIT © Zeno Rocha 6 | */ 7 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.ClipboardJS=e():t.ClipboardJS=e()}(this,function(){return n={686:function(t,e,n){"use strict";n.d(e,{default:function(){return b}});var e=n(279),i=n.n(e),e=n(370),u=n.n(e),e=n(817),r=n.n(e);function c(t){try{return document.execCommand(t)}catch(t){return}}var a=function(t){t=r()(t);return c("cut"),t};function o(t,e){var n,o,t=(n=t,o="rtl"===document.documentElement.getAttribute("dir"),(t=document.createElement("textarea")).style.fontSize="12pt",t.style.border="0",t.style.padding="0",t.style.margin="0",t.style.position="absolute",t.style[o?"right":"left"]="-9999px",o=window.pageYOffset||document.documentElement.scrollTop,t.style.top="".concat(o,"px"),t.setAttribute("readonly",""),t.value=n,t);return e.container.appendChild(t),e=r()(t),c("copy"),t.remove(),e}var f=function(t){var e=1 N;\n vector[N] y;\n}\nparameters {\n real mu;\n}\nmodel {\n mu ~ normal(0, 1);\n y ~ normal(mu, 1);\n}\n\"\"\"\n\ncompiled_model = nutpie.compile_stan_model(code=model_code)\n```\n:::\n\n\n### Sampling\n\nWe can now compile the model and sample from it:\n\n::: {#60b965bf .cell execution_count=3}\n``` {.python .cell-code}\ncompiled_model_with_data = compiled_model.with_data(N=3, y=[1, 2, 3])\ntrace = nutpie.sample(compiled_model_with_data)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140001.333
\n \n \n 140001.391
\n \n \n 140001.373
\n \n \n 140001.381
\n \n \n 140001.353
\n \n \n 140001.333
\n
\n```\n:::\n:::\n\n\n### Using Dimensions\n\nWe'll use the radon model from\n[this](https://mc-stan.org/learn-stan/case-studies/radon_cmdstanpy_plotnine.html)\ncase-study from the stan documentation, to show how we can use coordinates and\ndimension names to simplify working with trace objects.\n\nWe follow the same data preparation as in the case-study:\n\n::: {#92d854e3 .cell execution_count=4}\n``` {.python .cell-code}\nimport pandas as pd\nimport numpy as np\nimport arviz as az\nimport seaborn as sns\n\nhome_data = pd.read_csv(\n \"https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/srrs2.dat\",\n index_col=\"idnum\",\n)\ncounty_data = pd.read_csv(\n \"https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/cty.dat\",\n)\n\nradon_data = (\n home_data\n .rename(columns=dict(cntyfips=\"ctfips\"))\n .merge(\n (\n county_data\n .drop_duplicates(['stfips', 'ctfips', 'st', 'cty', 'Uppm'])\n .set_index([\"ctfips\", \"stfips\"])\n ),\n right_index=True,\n left_on=[\"ctfips\", \"stfips\"],\n )\n .assign(log_radon=lambda x: np.log(np.clip(x.activity, 0.1, np.inf)))\n .assign(log_uranium=lambda x: np.log(np.clip(x[\"Uppm\"], 0.1, np.inf)))\n .query(\"state == 'MN'\")\n)\n```\n:::\n\n\nAnd also use the partially pooled model from the case-study:\n\n::: {#ce581edd .cell execution_count=5}\n``` {.python .cell-code}\nmodel_code = \"\"\"\ndata {\n int N; // observations\n int J; // counties\n array[N] int county;\n vector[N] x;\n vector[N] y;\n}\nparameters {\n real mu_alpha;\n real sigma_alpha;\n vector[J] alpha; // non-centered parameterization\n real beta;\n real sigma;\n}\nmodel {\n y ~ normal(alpha[county] + beta * x, sigma);\n alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling\n beta ~ normal(0, 10);\n sigma ~ normal(0, 10);\n mu_alpha ~ normal(0, 10);\n sigma_alpha ~ normal(0, 10);\n}\ngenerated quantities {\n array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma);\n}\n\"\"\"\n```\n:::\n\n\nWe collect the dataset in the format that the stan model requires,\nand specify the dimensions of each of the non-scalar variables in the model:\n\n::: {#9a29bf02 .cell execution_count=6}\n``` {.python .cell-code}\ncounty_idx, counties = pd.factorize(radon_data[\"county\"], use_na_sentinel=False)\nobservations = radon_data.index\n\ncoords = {\n \"county\": counties,\n \"observation\": observations,\n}\n\ndims = {\n \"alpha\": [\"county\"],\n \"y_rep\": [\"observation\"],\n}\n\ndata = {\n \"N\": len(observations),\n \"J\": len(counties),\n # Stan uses 1-based indexing!\n \"county\": county_idx + 1,\n \"x\": radon_data.log_uranium.values,\n \"y\": radon_data.log_radon.values,\n}\n```\n:::\n\n\nThen, we compile the model and provide the dimensions, coordinates and the\ndataset we just defined:\n\n::: {#fe0286f3 .cell execution_count=7}\n``` {.python .cell-code}\ncompiled_model = (\n nutpie.compile_stan_model(code=model_code)\n .with_data(**data)\n .with_dims(**dims)\n .with_coords(**coords)\n)\n```\n:::\n\n\n::: {#7a704cbf .cell execution_count=8}\n``` {.python .cell-code}\n%%time\ntrace = nutpie.sample(compiled_model, seed=0)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.3931
\n \n \n 140000.477
\n \n \n 140000.457
\n \n \n 140000.467
\n \n \n 140000.457
\n \n \n 140000.457
\n
\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\nCPU times: user 2.27 s, sys: 39.2 ms, total: 2.31 s\nWall time: 547 ms\n```\n:::\n:::\n\n\nAs some basic convergance checking we verify that all Rhat values are smaller\nthan 1.02, all parameters have at least 500 effective draws and that we have no\ndivergences:\n\n::: {#013fe62f .cell execution_count=9}\n``` {.python .cell-code}\nassert trace.sample_stats.diverging.sum() == 0\nassert az.ess(trace).min().min() > 500\nassert az.rhat(trace).max().max() > 1.02\n```\n:::\n\n\nThanks to the coordinates and dimensions we specified, the resulting trace will\nnow contain labeled data, so that plots based on it have properly set-up labels:\n\n::: {#34452909 .cell execution_count=10}\n``` {.python .cell-code}\nimport arviz as az\nimport seaborn as sns\nimport xarray as xr\n\nsns.catplot(\n data=trace.posterior.alpha.to_dataframe().reset_index(),\n y=\"county\",\n x=\"alpha\",\n kind=\"boxen\",\n height=13,\n aspect=1/2.5,\n showfliers=False,\n)\n```\n\n::: {.cell-output .cell-output-display}\n![](stan-usage_files/figure-html/cell-11-output-1.png){}\n:::\n:::\n\n\n", 6 | "supporting": [ 7 | "stan-usage_files" 8 | ], 9 | "filters": [], 10 | "includes": { 11 | "include-in-header": [ 12 | "\n\n\n" 13 | ] 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /docs/_freeze/stan-usage/figure-html/cell-11-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/nutpie/6cbc79f0d984d460f224123f20e8cfcc531f8e6e/docs/_freeze/stan-usage/figure-html/cell-11-output-1.png -------------------------------------------------------------------------------- /docs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | website: 5 | title: "Nutpie" 6 | navbar: 7 | left: 8 | - href: index.qmd 9 | text: Home 10 | - href: pymc-usage.qmd 11 | text: Usage with PyMC 12 | - href: stan-usage.qmd 13 | text: Usage with Stan 14 | - href: sampling-options.qmd 15 | text: Sampling Options 16 | - href: nf-adapt.qmd 17 | text: Normalizing flow adaptation 18 | - href: sample-stats.qmd 19 | text: Diagnostic information 20 | - about.qmd 21 | tools: 22 | - icon: github 23 | href: https://github.com/pymc-devs/nutpie 24 | 25 | format: 26 | html: 27 | theme: 28 | - cosmo 29 | - brand 30 | css: styles.css 31 | toc: true 32 | 33 | execute: 34 | freeze: auto 35 | -------------------------------------------------------------------------------- /docs/about.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "About" 3 | --- 4 | 5 | Nutpie is part of the PyMC organization. The PyMC organization develops and 6 | maintains tools for Bayesian statistical modeling and probabilistic machine 7 | learning. 8 | 9 | Nutpie provides a high-performance implementation of the No-U-Turn Sampler 10 | (NUTS) that can be used with models defined in PyMC, Stan and other frameworks. 11 | It was created to enable faster and more efficient Bayesian inference while 12 | maintaining compatibility with existing probabilistic programming tools. 13 | 14 | For more information about the PyMC organization, visit the following links: 15 | 16 | - [PyMC Website](https://www.pymc.io) 17 | - [PyMC GitHub Organization](https://github.com/pymc-devs) 18 | -------------------------------------------------------------------------------- /docs/index.qmd: -------------------------------------------------------------------------------- 1 | # Nutpie Documentation 2 | 3 | `nutpie` is a high-performance library designed for Bayesian inference, that 4 | provides efficient sampling algorithms for probabilistic models. It can sample 5 | models that are defined in PyMC or Stan (numpyro and custom hand-coded 6 | likelihoods with gradient are coming soon). 7 | 8 | - Faster sampling than either the PyMC or Stan default samplers. (An average 9 | ~2x speedup on `posteriordb` compared to Stan) 10 | - All the diagnostic information of PyMC and Stan and some more. 11 | - GPU support for PyMC models through jax. 12 | - A more informative progress bar. 13 | - Access to the incomplete trace during sampling. 14 | - *Experimental* normalizing flow adaptation for more efficient sampling of 15 | difficult posteriors. 16 | 17 | ## Quickstart: PyMC 18 | 19 | Install `nutpie` with pip, uv, pixi, or conda: 20 | 21 | For usage with pymc: 22 | 23 | ```bash 24 | # One of 25 | pip install "nutpie[pymc]" 26 | uv add "nutpie[pymc]" 27 | pixi add nutpie pymc numba 28 | conda install -c conda-forge nutpie pymc numba 29 | ``` 30 | 31 | And then sample with 32 | 33 | ```{python} 34 | import nutpie 35 | import pymc as pm 36 | 37 | with pm.Model() as model: 38 | mu = pm.Normal("mu", mu=0, sigma=1) 39 | obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3]) 40 | 41 | compiled = nutpie.compile_pymc_model(model) 42 | trace = nutpie.sample(compiled) 43 | ``` 44 | 45 | For more information, see the detailed [PyMC usage guide](pymc-usage.qmd). 46 | 47 | ## Quickstart: Stan 48 | 49 | Stan needs access to a compiler toolchain, you can find instructions for those 50 | [here](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). 51 | You can then install nutpie through pip or uv: 52 | 53 | ```bash 54 | # One of 55 | pip install "nutpie[stan]" 56 | uv add "nutpie[stan]" 57 | ``` 58 | 59 | ```{python} 60 | #| echo: false 61 | import os 62 | os.environ["TBB_CXX_TYPE"] = "clang" 63 | ``` 64 | 65 | ```{python} 66 | import nutpie 67 | 68 | model = """ 69 | data { 70 | int N; 71 | vector[N] y; 72 | } 73 | parameters { 74 | real mu; 75 | } 76 | model { 77 | mu ~ normal(0, 1); 78 | y ~ normal(mu, 1); 79 | } 80 | """ 81 | 82 | compiled = ( 83 | nutpie 84 | .compile_stan_model(code=model) 85 | .with_data(N=3, y=[1, 2, 3]) 86 | ) 87 | trace = nutpie.sample(compiled) 88 | ``` 89 | 90 | For more information, see the detailed [Stan usage guide](stan-usage.qmd). 91 | -------------------------------------------------------------------------------- /docs/nf-adapt.qmd: -------------------------------------------------------------------------------- 1 | # Adaptation with Normalizing Flows 2 | 3 | **Experimental and subject to change** 4 | 5 | Normalizing flow adaptation through Fisher HMC is a new sampling algorithm that 6 | automatically reparameterizes a model. It adds some computational cost outside 7 | model log-density evaluations, but allows sampling from much more difficult 8 | posterior distributions. For models with expensive log-density evaluations, the 9 | normalizing flow adaptation can also be much faster, if it can reduce the number 10 | of log-density evaluations needed to reach a given effective sample size. 11 | 12 | The normalizing flow adaptation works by learning a transformation of the parameter 13 | space that makes the posterior distribution more amenable to sampling. This is done 14 | by fitting a sequence of invertible transformations (the "flow") that maps the 15 | original parameter space to a space where the posterior is closer to a standard 16 | normal distribution. The flow is trained during warmup. 17 | 18 | For more information about the algorithm, see the (still work in progress) paper 19 | [If only my posterior were normal: Introducing Fisher 20 | HMC](https://github.com/aseyboldt/covadapt-paper/releases/download/latest/main.pdf). 21 | 22 | Currently, a lot of time is spent on compiling various parts of the normalizing 23 | flow, and for small models this can take a large amount of the total time. 24 | Hopefully, we will be able to reduce this overhead in the future. 25 | 26 | ## Requirements 27 | 28 | Install the optional dependencies for normalizing flow adaptation: 29 | 30 | ``` 31 | pip install 'nutpie[nnflow]' 32 | ``` 33 | 34 | If you use with PyMC, this will only work if the model is compiled using the jax 35 | backend, and if the `gradient_backend` is also set to `jax`. 36 | 37 | Training of the normalizing flow can often be accelerated by using a GPU (even 38 | if the model itself is written in Stan, without any GPU support). To enable GPU 39 | you need to make sure your `jax` installation comes with GPU support, for 40 | instance by installing it with `pip install 'jax[cuda12]'`, or selecting the 41 | `jaxlib` version with GPU support, if you are using conda-forge. You can check if 42 | your installation has GPU support by checking the output of: 43 | 44 | ```python 45 | import jax 46 | jax.devices() 47 | ``` 48 | 49 | ### Usage 50 | 51 | To use normalizing flow adaptation in `nutpie`, you need to enable the 52 | `transform_adapt` option during sampling. Here is an example of how we can use 53 | it to sample from a difficult posterior: 54 | 55 | ```{python} 56 | import pymc as pm 57 | import nutpie 58 | import numpy as np 59 | import arviz 60 | 61 | # Define a 100-dimensional funnel model 62 | with pm.Model() as model: 63 | log_sigma = pm.Normal("log_sigma") 64 | pm.Normal("x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100) 65 | 66 | # Compile the model with the jax backend 67 | compiled = nutpie.compile_pymc_model( 68 | model, backend="jax", gradient_backend="jax" 69 | ) 70 | ``` 71 | 72 | If we sample this model without normalizing flow adaptation, we will encounter 73 | convergence issues, often divergences and always low effective sample sizes: 74 | 75 | ```{python} 76 | # Sample without normalizing flow adaptation 77 | trace_no_nf = nutpie.sample(compiled, seed=1) 78 | assert (arviz.ess(trace_no_nf) < 100).any().to_array().any() 79 | ``` 80 | 81 | ```{python} 82 | # We can add further arguments for the normalizing flow: 83 | compiled = compiled.with_transform_adapt( 84 | num_layers=5, # Number of layers in the normalizing flow 85 | nn_width=32, # Neural networks with 32 hidden units 86 | num_diag_windows=6, # Number of windows with a diagonal mass matrix intead of a flow 87 | verbose=False, # Whether to print details about the adaptation process 88 | show_progress=False, # Whether to show a progress bar for each optimization step 89 | ) 90 | 91 | # Sample with normalizing flow adaptation 92 | trace_nf = nutpie.sample( 93 | compiled, 94 | transform_adapt=True, # Enable the normalizing flow adaptation 95 | seed=1, 96 | chains=2, 97 | cores=1, # Running chains in parallel can be slow 98 | window_switch_freq=150, # Optimize the normalizing flow every 150 iterations 99 | ) 100 | assert trace_nf.sample_stats.diverging.sum() == 0 101 | assert (arviz.ess(trace_nf) > 1000).all().to_array().all() 102 | ``` 103 | 104 | The sampler used fewer gradient evaluations with the normalizing flow adaptation, 105 | but still converged, and produce a good effective sample size: 106 | 107 | ```{python} 108 | n_steps = int(trace_nf.sample_stats.n_steps.sum() + trace_nf.warmup_sample_stats.n_steps.sum()) 109 | ess = float(arviz.ess(trace_nf).min().to_array().min()) 110 | print(f"Number of gradient evaluations: {n_steps}") 111 | print(f"Minimum effective sample size: {ess}") 112 | ``` 113 | 114 | Without normalizing flow, it used more gradient evaluations, and still wasn't able 115 | to get a good effective sample size: 116 | 117 | ```{python} 118 | n_steps = int(trace_no_nf.sample_stats.n_steps.sum() + trace_no_nf.warmup_sample_stats.n_steps.sum()) 119 | ess = float(arviz.ess(trace_no_nf).min().to_array().min()) 120 | print(f"Number of gradient evaluations: {n_steps}") 121 | print(f"Minimum effective sample size: {ess}") 122 | ``` 123 | 124 | The flow adaptation occurs during warmup, so the number of warmup draws should 125 | be large enough to allow the flow to converge. For more complex posteriors, you 126 | may need to increase the number of layers (using the `num_layers` argument), or 127 | you might want to increase the number of warmup draws. 128 | 129 | To monitor the progress of the flow adaptation, you can set `verbose=True`, or 130 | `show_progress=True`, but the second should only be used if you sample just one 131 | chain. 132 | 133 | All losses are on a log-scale. Negative values smaller -2 are a good sign that 134 | the adaptation was successful. If the loss stays positive, the flow is either 135 | not expressive enough, or the training period is too short. The sampler might 136 | still converge, but will probably need more gradient evaluations per effective 137 | draw. Large losses bigger than 6 tend to indicate that the posterior is too 138 | difficult to sample with the current flow, and the sampler will probably not 139 | converge. 140 | -------------------------------------------------------------------------------- /docs/pymc-usage.qmd: -------------------------------------------------------------------------------- 1 | # Usage with PyMC models 2 | 3 | This document shows how to use `nutpie` with PyMC models. We will use the 4 | `pymc` package to define a simple model and sample from it using `nutpie`. 5 | 6 | ## Installation 7 | 8 | The recommended way to install `pymc` is through the `conda` ecosystem. A good 9 | package manager for conda packages is `pixi`. See for the [pixi 10 | documentation](https://pixi.sh) for instructions on how to install it. 11 | 12 | We create a new project for this example: 13 | 14 | ```bash 15 | pixi new pymc-example 16 | ``` 17 | 18 | This will create a new directory `pymc-example` with a `pixi.toml` file, that 19 | you can edit to add meta information. 20 | 21 | We then add the `pymc` and `nutpie` packages to the project: 22 | 23 | ```bash 24 | cd pymc-example 25 | pixi add pymc nutpie arviz 26 | ``` 27 | 28 | You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. 29 | Both are excellent tools for working with Python and data science projects. 30 | 31 | ### Using VSCode 32 | 33 | 1. Open VSCode. 34 | 2. Open the `pymc-example` directory created earlier. 35 | 3. Create a new file named `model.ipynb`. 36 | 4. Select the pixi kernel to run the code. 37 | 38 | ### Using JupyterLab 39 | 40 | 1. Add jupyter labs to the project by running `pixi add jupyterlab`. 41 | 1. Open JupyterLab by running `pixi run jupyter lab` in your terminal. 42 | 3. Create a new Python notebook. 43 | 44 | ## Defining and Sampling a Simple Model 45 | 46 | We will define a simple Bayesian model using `pymc` and sample from it using 47 | `nutpie`. 48 | 49 | ### Model Definition 50 | 51 | In your `model.ipypy` file or Jupyter notebook, add the following code: 52 | 53 | ```{python} 54 | import pymc as pm 55 | import nutpie 56 | import pandas as pd 57 | 58 | coords = {"observation": range(3)} 59 | 60 | with pm.Model(coords=coords) as model: 61 | # Prior distributions for the intercept and slope 62 | intercept = pm.Normal("intercept", mu=0, sigma=1) 63 | slope = pm.Normal("slope", mu=0, sigma=1) 64 | 65 | # Likelihood (sampling distribution) of observations 66 | x = [1, 2, 3] 67 | 68 | mu = intercept + slope * x 69 | y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3], dims="observation") 70 | ``` 71 | 72 | ### Sampling 73 | 74 | We can now compile the model using the numba backend: 75 | 76 | ```{python} 77 | compiled = nutpie.compile_pymc_model(model) 78 | trace = nutpie.sample(compiled) 79 | ``` 80 | 81 | Alternatively, we can also sample through the `pymc` API: 82 | 83 | ```python 84 | with model: 85 | trace = pm.sample(nuts_sampler="nutpie") 86 | ``` 87 | 88 | While sampling, nutpie shows a progress bar for each chain. It also includes 89 | information about how each chain is doing: 90 | 91 | - It shows the current number of draws 92 | - The step size of the integrator (very small stepsizes are typically a bad 93 | sign) 94 | - The number of divergences (if there are divergences, that means that nutpie is 95 | probably not sampling the posterior correctly) 96 | - The number of gradient evaluation nutpie uses for each draw. Large numbers 97 | (100 to 1000) are a sign that the parameterization of the model is not ideal, 98 | and the sampler is very inefficient. 99 | 100 | After sampling, this returns an `arviz` InferenceData object that you can use to 101 | analyze the trace. 102 | 103 | For example, we should check the effective sample size: 104 | 105 | ```{python} 106 | import arviz as az 107 | az.ess(trace) 108 | ``` 109 | 110 | and take a look at a trace plot: 111 | 112 | ```{python} 113 | az.plot_trace(trace); 114 | ``` 115 | 116 | ### Choosing the backend 117 | 118 | Right now, we have been using the numba backend. This is the default backend for 119 | `nutpie`, when sampling from pymc models. It tends to have relatively long 120 | compilation times, but samples small models very efficiently. For larger models 121 | the `jax` backend sometimes outperforms `numba`. 122 | 123 | First, we need to install the `jax` package: 124 | 125 | ```bash 126 | pixi add jax 127 | ``` 128 | 129 | We can select the backend by passing the `backend` argument to the `compile_pymc_model`: 130 | 131 | ```python 132 | compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") 133 | trace = nutpie.sample(compiled_jax) 134 | ``` 135 | 136 | Or through the pymc API: 137 | 138 | ```python 139 | with model: 140 | trace = pm.sample( 141 | nuts_sampler="nutpie", 142 | nuts_sampler_kwargs={"backend": "jax"}, 143 | ) 144 | ``` 145 | 146 | If you have an nvidia GPU, you can also use the `jax` backend with the `gpu`. We 147 | will have to install the `jaxlib` package with the `cuda` option 148 | 149 | ```bash 150 | pixi add jaxlib --build 'cuda12' 151 | ``` 152 | 153 | Restart the kernel and check that the GPU is available: 154 | 155 | ```python 156 | import jax 157 | 158 | # Should list the cuda device 159 | jax.devices() 160 | ``` 161 | 162 | Sampling again, should now use the GPU, which you can observe by checking the 163 | GPU usage with `nvidia-smi` or `nvtop`. 164 | 165 | ### Changing the dataset without recompilation 166 | 167 | If you want to use the same model with different datasets, you can modify 168 | datasets after compilation. Since jax does not like changes in shapes, this is 169 | only recommended with the numba backend. 170 | 171 | First, we define the model, but put our dataset in a `pm.Data` structure: 172 | 173 | ```{python} 174 | with pm.Model() as model: 175 | x = pm.Data("x", [1, 2, 3]) 176 | intercept = pm.Normal("intercept", mu=0, sigma=1) 177 | slope = pm.Normal("slope", mu=0, sigma=1) 178 | mu = intercept + slope * x 179 | y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3]) 180 | ``` 181 | 182 | We can now compile the model: 183 | 184 | ```{python} 185 | compiled = nutpie.compile_pymc_model(model) 186 | trace = nutpie.sample(compiled) 187 | ``` 188 | 189 | After compilation, we can change the dataset: 190 | 191 | ```{python} 192 | compiled2 = compiled.with_data(x=[4, 5, 6]) 193 | trace2 = nutpie.sample(compiled2) 194 | ``` 195 | -------------------------------------------------------------------------------- /docs/sample-stats.qmd: -------------------------------------------------------------------------------- 1 | # Understanding Sampler Statistics in Nutpie 2 | 3 | This guide explains the various statistics that nutpie collects during sampling. We'll use Neal's funnel distribution as an example, as it's a challenging model that demonstrates many important sampling concepts. 4 | 5 | ## Example Model: Neal's Funnel 6 | 7 | Let's start by implementing Neal's funnel in PyMC: 8 | 9 | ```{python} 10 | import pymc as pm 11 | import nutpie 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | import pandas as pd 16 | import arviz as az 17 | 18 | # Create the funnel model 19 | with pm.Model() as model: 20 | log_sigma = pm.Normal('log_sigma') 21 | pm.Normal('x', sigma=pm.math.exp(log_sigma), shape=5) 22 | 23 | # Sample with detailed statistics 24 | compiled = nutpie.compile_pymc_model(model) 25 | trace = nutpie.sample( 26 | compiled, 27 | tune=1000, 28 | store_mass_matrix=True, 29 | store_gradient=True, 30 | store_unconstrained=True, 31 | store_divergences=True, 32 | seed=42, 33 | ) 34 | ``` 35 | 36 | ## Sampler Statistics Overview 37 | 38 | The sampler statistics can be grouped into several categories: 39 | 40 | ### Basic HMC Statistics 41 | 42 | These statistics are always collected and are essential for basic diagnostics: 43 | 44 | ```{python} 45 | # Access through trace.sample_stats 46 | basic_stats = [ 47 | 'depth', # Tree depth for current draw 48 | 'maxdepth_reached', # Whether max tree depth was hit 49 | 'logp', # Log probability of current position 50 | 'energy', # Hamiltonian energy 51 | 'diverging', # Whether the transition diverged 52 | 'step_size', # Current step size 53 | 'step_size_bar', # Current estimate of an ideal step size 54 | 'n_steps' # Number of leapfrog steps 55 | 56 | ] 57 | 58 | # Plot step size evolution during warmup 59 | trace.warmup_sample_stats.step_size_bar.plot.line(x="draw", yscale="log") 60 | ``` 61 | 62 | ### Mass Matrix Adaptation 63 | 64 | These statistics track how the mass matrix evolves: 65 | 66 | ```{python} 67 | ( 68 | trace 69 | .warmup_sample_stats 70 | .mass_matrix_inv 71 | .plot 72 | .line( 73 | x="draw", 74 | yscale="log", 75 | col="chain", 76 | col_wrap=2, 77 | ) 78 | ) 79 | ``` 80 | 81 | Variables that are a source of convergence issues, will often show high variance 82 | in the final mass matrix estimate across chains. 83 | 84 | The mass matrix will always be fixed for 10% of draws at the end, because we 85 | only run final step size adaptation during that time, but high variance in the 86 | mass matrix before this final window and indicate that more tuning steps might 87 | be needed. 88 | 89 | ### Detailed Diagnostics 90 | 91 | These are only available when explicitly requested: 92 | 93 | ```python 94 | detailed_stats = [ 95 | 'gradient', # Gradient at current position 96 | 'unconstrained_draw', # Parameters in unconstrained space 97 | 'divergence_start', # Position where divergence started 98 | 'divergence_end', # Position where divergence ended 99 | 'divergence_momentum', # Momentum at divergence 100 | 'divergence_message' # Description of divergence 101 | ] 102 | ``` 103 | 104 | #### Identify Divergences 105 | 106 | We can, for instance, use this to identify the sources of divergences: 107 | 108 | ```{python} 109 | import xarray as xr 110 | 111 | draws = ( 112 | trace 113 | .sample_stats 114 | .unconstrained_draw 115 | .assign_coords(kind="draw") 116 | ) 117 | divergence_locations = ( 118 | trace 119 | .sample_stats 120 | .divergence_start 121 | .assign_coords(kind="divergence") 122 | ) 123 | 124 | points = xr.concat([draws, divergence_locations], dim="kind") 125 | points.to_dataset("unconstrained_parameter").plot.scatter(x="log_sigma", y="x_0", hue="kind") 126 | ``` 127 | 128 | #### Covariance of gradients and draws 129 | 130 | TODO this section should really use the transformed gradients and draws, not the 131 | unconstrained ones, as that avoids the manual mass matrix correction. This 132 | is only available for the normalizing flow adaptation at the moment though. 133 | 134 | In models with problematic posterior correlations, the singular value 135 | decomposition of gradients and draws can often point us to the source of the 136 | issue. 137 | 138 | Let's build a little model with correlations between parameters: 139 | 140 | ```{python} 141 | with pm.Model() as model: 142 | x = pm.Normal('x') 143 | y = pm.Normal("y", mu=x, sigma=0.01) 144 | z = pm.Normal("z", mu=y, shape=100) 145 | 146 | compiled = nutpie.compile_pymc_model(model) 147 | trace = nutpie.sample( 148 | compiled, 149 | tune=1000, 150 | store_gradient=True, 151 | store_unconstrained=True, 152 | store_mass_matrix=True, 153 | seed=42, 154 | ) 155 | ``` 156 | 157 | Now we can compute eigenvalues of the covariance matrix of the gradient and 158 | draws (using the singular value decomposition to avoid quadratic cost): 159 | 160 | ```{python} 161 | def covariance_eigenvalues(x, mass_matrix): 162 | assert x.dims == ("chain", "draw", "unconstrained_parameter") 163 | x = x.stack(sample=["draw", "chain"]) 164 | x = (x - x.mean("sample")) / np.sqrt(mass_matrix) 165 | u, s, v = np.linalg.svd(x.T / np.sqrt(x.shape[1]), full_matrices=False) 166 | print(u.shape, s.shape, v.shape) 167 | s = xr.DataArray( 168 | s, 169 | dims=["eigenvalue"], 170 | coords={"eigenvalue": range(s.size)}, 171 | ) 172 | v = xr.DataArray( 173 | v, 174 | dims=["eigenvalue", "unconstrained_parameter"], 175 | coords={ 176 | "eigenvalue": s.eigenvalue, 177 | "unconstrained_parameter": x.unconstrained_parameter, 178 | }, 179 | ) 180 | return s ** 2, v 181 | 182 | mass_matrix = trace.sample_stats.mass_matrix_inv.isel(draw=-1, chain=0) 183 | draws_eigs, draws_eigv = covariance_eigenvalues(trace.sample_stats.unconstrained_draw, mass_matrix) 184 | grads_eigs, grads_eigv = covariance_eigenvalues(trace.sample_stats.gradient, 1 / mass_matrix) 185 | 186 | draws_eigs.plot.line(x="eigenvalue", yscale="log") 187 | grads_eigs.plot.line(x="eigenvalue", yscale="log") 188 | ``` 189 | 190 | We can see one very large and one very small eigenvalue in both covariances. 191 | Large eigenvalues for the draws, and small eigenvalues for the gradients prevent 192 | the sampler from taking larger steps. Small eigenvalues in the draws, and large 193 | eigenvalues in the grads, mean that the sampler has to move far in parameter 194 | space to get independent draws. So both lead to problems during sampling. For 195 | models with many parameters, typically only the large eigenvalues of each are 196 | meaningful, because of estimation issues with the small eigenvalues. 197 | 198 | We can also look at the eigenvectors to see which parameters are responsible for 199 | the correlations: 200 | 201 | ```{python} 202 | ( 203 | draws_eigv 204 | .sel(eigenvalue=0) 205 | .to_pandas() 206 | .sort_values(key=abs) 207 | .tail(10) 208 | .plot.bar(x="unconstrained_parameter") 209 | ) 210 | ``` 211 | 212 | ```{python} 213 | ( 214 | grads_eigv 215 | .sel(eigenvalue=0) 216 | .to_pandas() 217 | .sort_values(key=abs) 218 | .tail(10) 219 | .plot.bar(x="unconstrained_parameter") 220 | ) 221 | ``` 222 | -------------------------------------------------------------------------------- /docs/sampling-options.qmd: -------------------------------------------------------------------------------- 1 | # Sampling Configuration Guide 2 | 3 | This guide covers the configuration options for `nutpie.sample` and provides 4 | practical advice for tuning your sampler. We'll start with basic usage and move 5 | to advanced topics like mass matrix adaptation. 6 | 7 | ## Quick Start 8 | 9 | For most models, don't think too much about the options of the sampler, and just 10 | use the defaults. Most sampling problems can't easily be solved by changing the 11 | sampler, most of the time they require model changes. So in most cases, simply use 12 | 13 | ```python 14 | trace = nutpie.sample(compiled_model) 15 | ``` 16 | 17 | ## Core Sampling Parameters 18 | 19 | ### Drawing Samples 20 | 21 | ```python 22 | trace = nutpie.sample( 23 | model, 24 | draws=1000, # Number of post-warmup draws per chain 25 | tune=500, # Number of warmup draws for adaptation 26 | chains=6, # Number of independent chains 27 | cores=None, # Number chains that are allowed to run simultainiously 28 | seed=12345 # Random seed for reproducibility 29 | ) 30 | ``` 31 | 32 | The number of draws affects both accuracy and computational cost: 33 | - Too few draws (< 500) may not capture the posterior well 34 | - Too many draws (> 10000) may waste computation time 35 | 36 | If a model is sampling without divergences, but with effective sample sizes that 37 | are not as large as necessary to achieve the markov-error for your estimates, 38 | you can increase the number of chains and/or draws. 39 | 40 | If the effective sample size is much smaller than the number of draws, you might 41 | want to consider reparameterizing the model instead, to, for instance, remove 42 | posterior correlations. 43 | 44 | ## Sampler Diagnostics 45 | 46 | You can enable more detailed diagnostics when troubleshooting: 47 | 48 | ```python 49 | trace = nutpie.sample( 50 | model, 51 | save_warmup=True, # Keep warmup draws, default is True 52 | store_divergences=True, # Track divergent transitions 53 | store_unconstrained=True, # Store transformed parameters 54 | store_gradient=True, # Store gradient information 55 | store_mass_matrix=True # Track mass matrix adaptation 56 | ) 57 | ``` 58 | 59 | For each of the `store_*` arguments, additional arrays will be available in the 60 | `trace.sample_stats`. 61 | 62 | ## Non-blocking sampling 63 | 64 | 65 | 66 | ### Settings for HMC and NUTS 67 | 68 | ```python 69 | trace = nutpie.sample( 70 | model, 71 | target_accept=0.8, # Target acceptance rate 72 | maxdepth=10 # Maximum tree depth 73 | max_energy_error=1000 # Error at witch to count the trajectory as a divergent transition 74 | ) 75 | ``` 76 | 77 | The `target_accept` parameter implicitly controls the step size of the leapfrog 78 | steps in the HMC sampler. During tuning, the sampler will try to choose a step 79 | size, such that the acceptance statistic is `target_accept`. It has to be 80 | between 0 and 1. 81 | 82 | The default is 0.8. Larger values will increase the computational cost, but 83 | might avoid divergences during sampling. In many diverging models increasing 84 | `target_accept` will only make divergences less frequent however, and not solve 85 | the underlying problem. 86 | 87 | Lowering the maximum energy error to, for instance, 10 will often increase the 88 | number of divergences, and make it easier to diagnose their cause. With lower 89 | value the divergences often are reported closer to the critical points in the 90 | parameter space, where the model is most likely to diverge. 91 | 92 | ## Mass Matrix Adaptation 93 | 94 | Nutpie offers several strategies for adapting the mass matrix, which determines 95 | how the sampler navigates the parameter space. 96 | 97 | ### Standard Adaptation 98 | 99 | By setting `use_grad_based_mass_matrix=False`, the sampling algorithm will more 100 | closely resemble the algorithm in Stan and PyMC. Usually, this will result in 101 | less efficient sampling, but the total number of effective samples is sometimes 102 | higher. If this is set to `True` (the default), nutpie will use diagonal mass 103 | matrix estimates that are based on the posterior draws and the scores at those 104 | positions. 105 | 106 | ```python 107 | trace = nutpie.sample( 108 | model, 109 | use_grad_based_mass_matrix=False 110 | ) 111 | ``` 112 | 113 | ### Low-Rank Updates 114 | 115 | For models with strong parameter correlations you can enable a low rank modified 116 | mass matrix. The `mass_matrix_gamma` parameter is a regularization parameter. 117 | More regularization will lead to a smaller effect of the low-rank components, 118 | but might work better for higher dimensional problems. 119 | 120 | `mass_matrix_eigval_cutoff` should be greater than one, and controls how large 121 | an eigenvalue of the full mass matrix has to be, to be included into the 122 | low-rank mass matirx. 123 | 124 | ```python 125 | trace = nutpie.sample( 126 | model, 127 | low_rank_modified_mass_matrix=True, 128 | mass_matrix_eigval_cutoff=3, 129 | mass_matrix_gamma=1e-5 130 | ) 131 | ``` 132 | 133 | ### Experimental Features 134 | 135 | `trasform_adapt` is an experimental feature that allows sampling from many 136 | posteriors, where current methods diverge. It is described in more detail 137 | [here](nf-adapt.qmd). 138 | 139 | ```python 140 | trace = nutpie.sample( 141 | model, 142 | transform_adapt=True # Experimental reparameterization 143 | ) 144 | ``` 145 | 146 | ## Progress Monitoring 147 | 148 | Customize the sampling progress display: 149 | 150 | ```python 151 | trace = nutpie.sample( 152 | model, 153 | progress_bar=True, 154 | progress_rate=500, # Update every 500ms 155 | ) 156 | ``` 157 | -------------------------------------------------------------------------------- /docs/stan-usage.qmd: -------------------------------------------------------------------------------- 1 | # Usage with Stan models 2 | 3 | This document shows how to use `nutpie` with Stan models. We will use the 4 | `nutpie` package to define a simple model and sample from it using Stan. 5 | 6 | ## Installation 7 | 8 | For Stan, it is more common to use `pip` or `uv` to install the necessary 9 | packages. However, `conda` is also an option if you prefer. 10 | 11 | To install using `pip`: 12 | 13 | ```bash 14 | pip install "nutpie[stan]" 15 | ``` 16 | 17 | To install using `uv`: 18 | 19 | ```bash 20 | uv add "nutpie[stan]" 21 | ``` 22 | 23 | To install using `conda`: 24 | 25 | ```bash 26 | conda install -c conda-forge nutpie 27 | ``` 28 | 29 | ## Compiler Toolchain 30 | 31 | Stan requires a compiler toolchain to be installed on your system. This is 32 | necessary for compiling the Stan models. You can find detailed instructions for 33 | setting up the compiler toolchain in the [CmdStan 34 | Guide](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). 35 | 36 | Additionally, since Stan uses Intel's Threading Building Blocks (TBB) for 37 | parallelism, you might need to set the `TBB_CXX_TYPE` environment variable to 38 | specify the compiler type. Depending on your system, you can set it to either 39 | `clang` or `gcc`. For example: 40 | 41 | ```{python} 42 | import os 43 | os.environ["TBB_CXX_TYPE"] = "clang" # or 'gcc' 44 | ``` 45 | 46 | Make sure to set this environment variable before compiling your Stan models to ensure proper configuration. 47 | 48 | ## Defining and Sampling a Simple Model 49 | 50 | We will define a simple Bayesian model using Stan and sample from it using 51 | `nutpie`. 52 | 53 | ### Model Definition 54 | 55 | In your Python script or Jupyter notebook, add the following code: 56 | 57 | ```{python} 58 | import nutpie 59 | 60 | model_code = """ 61 | data { 62 | int N; 63 | vector[N] y; 64 | } 65 | parameters { 66 | real mu; 67 | } 68 | model { 69 | mu ~ normal(0, 1); 70 | y ~ normal(mu, 1); 71 | } 72 | """ 73 | 74 | compiled_model = nutpie.compile_stan_model(code=model_code) 75 | ``` 76 | 77 | ### Sampling 78 | 79 | We can now compile the model and sample from it: 80 | 81 | ```{python} 82 | compiled_model_with_data = compiled_model.with_data(N=3, y=[1, 2, 3]) 83 | trace = nutpie.sample(compiled_model_with_data) 84 | ``` 85 | 86 | ### Using Dimensions 87 | 88 | We'll use the radon model from 89 | [this](https://mc-stan.org/learn-stan/case-studies/radon_cmdstanpy_plotnine.html) 90 | case-study from the stan documentation, to show how we can use coordinates and 91 | dimension names to simplify working with trace objects. 92 | 93 | We follow the same data preparation as in the case-study: 94 | 95 | ```{python} 96 | import pandas as pd 97 | import numpy as np 98 | import arviz as az 99 | import seaborn as sns 100 | 101 | home_data = pd.read_csv( 102 | "https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/srrs2.dat", 103 | index_col="idnum", 104 | ) 105 | county_data = pd.read_csv( 106 | "https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/cty.dat", 107 | ) 108 | 109 | radon_data = ( 110 | home_data 111 | .rename(columns=dict(cntyfips="ctfips")) 112 | .merge( 113 | ( 114 | county_data 115 | .drop_duplicates(['stfips', 'ctfips', 'st', 'cty', 'Uppm']) 116 | .set_index(["ctfips", "stfips"]) 117 | ), 118 | right_index=True, 119 | left_on=["ctfips", "stfips"], 120 | ) 121 | .assign(log_radon=lambda x: np.log(np.clip(x.activity, 0.1, np.inf))) 122 | .assign(log_uranium=lambda x: np.log(np.clip(x["Uppm"], 0.1, np.inf))) 123 | .query("state == 'MN'") 124 | ) 125 | ``` 126 | 127 | And also use the partially pooled model from the case-study: 128 | 129 | ```{python} 130 | model_code = """ 131 | data { 132 | int N; // observations 133 | int J; // counties 134 | array[N] int county; 135 | vector[N] x; 136 | vector[N] y; 137 | } 138 | parameters { 139 | real mu_alpha; 140 | real sigma_alpha; 141 | vector[J] alpha; // non-centered parameterization 142 | real beta; 143 | real sigma; 144 | } 145 | model { 146 | y ~ normal(alpha[county] + beta * x, sigma); 147 | alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling 148 | beta ~ normal(0, 10); 149 | sigma ~ normal(0, 10); 150 | mu_alpha ~ normal(0, 10); 151 | sigma_alpha ~ normal(0, 10); 152 | } 153 | generated quantities { 154 | array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma); 155 | } 156 | """ 157 | ``` 158 | 159 | We collect the dataset in the format that the stan model requires, 160 | and specify the dimensions of each of the non-scalar variables in the model: 161 | 162 | ```{python} 163 | county_idx, counties = pd.factorize(radon_data["county"], use_na_sentinel=False) 164 | observations = radon_data.index 165 | 166 | coords = { 167 | "county": counties, 168 | "observation": observations, 169 | } 170 | 171 | dims = { 172 | "alpha": ["county"], 173 | "y_rep": ["observation"], 174 | } 175 | 176 | data = { 177 | "N": len(observations), 178 | "J": len(counties), 179 | # Stan uses 1-based indexing! 180 | "county": county_idx + 1, 181 | "x": radon_data.log_uranium.values, 182 | "y": radon_data.log_radon.values, 183 | } 184 | ``` 185 | 186 | Then, we compile the model and provide the dimensions, coordinates and the 187 | dataset we just defined: 188 | 189 | ```{python} 190 | compiled_model = ( 191 | nutpie.compile_stan_model(code=model_code) 192 | .with_data(**data) 193 | .with_dims(**dims) 194 | .with_coords(**coords) 195 | ) 196 | ``` 197 | 198 | ```{python} 199 | %%time 200 | trace = nutpie.sample(compiled_model, seed=0) 201 | ``` 202 | 203 | As some basic convergance checking we verify that all Rhat values are smaller 204 | than 1.02, all parameters have at least 500 effective draws and that we have no 205 | divergences: 206 | 207 | ```{python} 208 | assert trace.sample_stats.diverging.sum() == 0 209 | assert az.ess(trace).min().min() > 500 210 | assert az.rhat(trace).max().max() > 1.02 211 | ``` 212 | 213 | Thanks to the coordinates and dimensions we specified, the resulting trace will 214 | now contain labeled data, so that plots based on it have properly set-up labels: 215 | 216 | ```{python} 217 | import arviz as az 218 | import seaborn as sns 219 | import xarray as xr 220 | 221 | sns.catplot( 222 | data=trace.posterior.alpha.to_dataframe().reset_index(), 223 | y="county", 224 | x="alpha", 225 | kind="boxen", 226 | height=13, 227 | aspect=1/2.5, 228 | showfliers=False, 229 | ) 230 | ``` 231 | -------------------------------------------------------------------------------- /docs/styles.css: -------------------------------------------------------------------------------- 1 | /* css styles */ 2 | -------------------------------------------------------------------------------- /notebooks/pytensor_logp.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupyter: 3 | jupytext: 4 | formats: ipynb,md 5 | text_representation: 6 | extension: .md 7 | format_name: markdown 8 | format_version: '1.3' 9 | jupytext_version: 1.13.8 10 | kernelspec: 11 | display_name: pymc4-dev 12 | language: python 13 | name: pymc4-dev 14 | --- 15 | 16 | # Usage example of nutpie 17 | 18 | ```python 19 | # We can control the number cores that are used by an environment variable: 20 | %env RAYON_NUM_THREADS=12 21 | ``` 22 | 23 | ```python 24 | import pytensor 25 | import pytensor.tensor as pt 26 | import pymc as pm 27 | import numpy as np 28 | import nutpie 29 | import arviz 30 | import pandas as pd 31 | import seaborn as sns 32 | import matplotlib.pyplot as plt 33 | ``` 34 | 35 | ## The dataset 36 | 37 | We use the well known radon dataset in this notebook. 38 | 39 | ```python 40 | data = pd.read_csv(pm.get_data("radon.csv")) 41 | data["log_radon"] = data["log_radon"].astype(np.float64) 42 | county_idx, counties = pd.factorize(data.county) 43 | coords = {"county": counties, "obs_id": np.arange(len(county_idx))} 44 | ``` 45 | 46 | ```python 47 | sns.catplot( 48 | data=data, 49 | x="floor", 50 | y="log_radon", 51 | ) 52 | ``` 53 | 54 | ## Use as a sampler for pymc 55 | 56 | ```python 57 | with pm.Model(coords=coords, check_bounds=False) as pymc_model: 58 | intercept = pm.Normal("intercept", sigma=10) 59 | 60 | # County effects 61 | # TODO should be a CenteredNormal 62 | raw = pm.Normal("county_raw", dims="county") 63 | sd = pm.HalfNormal("county_sd") 64 | county_effect = pm.Deterministic("county_effect", raw * sd, dims="county") 65 | 66 | # Global floor effect 67 | floor_effect = pm.Normal("floor_effect", sigma=2) 68 | 69 | # County:floor interaction 70 | # Should also be a CenteredNormal 71 | raw = pm.Normal("county_floor_raw", dims="county") 72 | sd = pm.HalfNormal("county_floor_sd") 73 | county_floor_effect = pm.Deterministic( 74 | "county_floor_effect", raw * sd, dims="county" 75 | ) 76 | 77 | mu = ( 78 | intercept 79 | + county_effect[county_idx] 80 | + floor_effect * data.floor.values 81 | + county_floor_effect[county_idx] * data.floor.values 82 | ) 83 | 84 | sigma = pm.HalfNormal("sigma", sigma=1.5) 85 | pm.Normal( 86 | "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id" 87 | ) 88 | ``` 89 | 90 | ```python 91 | %%time 92 | # The compilation time is pretty bad right now, I think this can be improved a lot though 93 | compiled_model = nutpie.compile_pymc_model(pymc_model) 94 | ``` 95 | 96 | ```python 97 | %%time 98 | trace_pymc = nutpie.sample(compiled_model, chains=10) 99 | ``` 100 | 101 | ```python 102 | sns.catplot( 103 | data=( 104 | (trace_pymc.posterior.county_floor_effect + trace_pymc.posterior.floor_effect) 105 | .isel(county=slice(0, 5)) 106 | .to_dataframe("total_county_floor_effect") 107 | .reset_index() 108 | ), 109 | x="total_county_floor_effect", 110 | y="county", 111 | kind="violin", 112 | orient="h", 113 | ) 114 | plt.axvline(0, color="grey", alpha=0.5, zorder=-100) 115 | ``` 116 | 117 | ## Use nutpie as a sampling backend for stan 118 | 119 | ```python 120 | %%file radon_model.stan 121 | data { 122 | int n_counties; 123 | int n_observed; 124 | array[n_observed] int county_idx; 125 | vector[n_observed] is_floor; 126 | vector[n_observed] log_radon; 127 | } 128 | parameters { 129 | real intercept; 130 | 131 | vector[n_counties] county_raw; 132 | real county_sd; 133 | 134 | real floor_effect; 135 | 136 | vector[n_counties] county_floor_raw; 137 | real county_floor_sd; 138 | 139 | real sigma; 140 | } 141 | transformed parameters { 142 | vector[n_counties] county_effect; 143 | vector[n_counties] county_floor_effect; 144 | vector[n_observed] mu; 145 | 146 | county_effect = county_sd * county_raw; 147 | county_floor_effect = county_floor_sd * county_floor_raw; 148 | 149 | mu = ( 150 | intercept 151 | + county_effect[county_idx] 152 | + floor_effect * is_floor 153 | + county_floor_effect[county_idx] .* is_floor 154 | ); 155 | } 156 | model { 157 | intercept ~ normal(0, 10); 158 | 159 | county_raw ~ normal(0, 1); 160 | county_sd ~ normal(0, 1); 161 | 162 | floor_effect ~ normal(0, 2); 163 | 164 | county_floor_raw ~ normal(0, 1); 165 | county_floor_sd ~ normal(0, 1); 166 | 167 | sigma ~ normal(0, 1.5); 168 | 169 | log_radon ~ normal(mu, sigma); 170 | } 171 | ``` 172 | 173 | ```python 174 | data_stan = { 175 | "n_counties": len(counties), 176 | "n_observed": len(data), 177 | "county_idx": county_idx + 1, 178 | "is_floor": data.floor.values, 179 | "log_radon": data.log_radon.values, 180 | } 181 | 182 | coords_stan = { 183 | "county": counties, 184 | } 185 | 186 | dims_stan = { 187 | "county_raw": ("county",), 188 | "county_floor_raw": ("county",), 189 | "county_effect": ("county",), 190 | "county_floor_effect": ("county",), 191 | "mu": ("observation",), 192 | } 193 | ``` 194 | 195 | ```python 196 | %%time 197 | stan_model = nutpie.compile_stan_model( 198 | data_stan, 199 | filename="radon_model.stan", 200 | coords=coords_stan, 201 | dims=dims_stan, 202 | cache=False 203 | ) 204 | ``` 205 | 206 | ```python 207 | %%time 208 | trace_stan = nutpie.sample(stan_model, chains=10) 209 | ``` 210 | 211 | ## Comparison with pystan 212 | 213 | ```python 214 | import stan 215 | import nest_asyncio 216 | 217 | nest_asyncio.apply() 218 | ``` 219 | 220 | ```python 221 | %%time 222 | with open("radon_model.stan", "r") as file: 223 | model = stan.build(file.read(), data=data_stan) 224 | ``` 225 | 226 | ```python 227 | %%time 228 | trace_pystan = model.sample(num_chains=10, save_warmup=True) 229 | ``` 230 | 231 | ```python 232 | trace_pystan = arviz.from_pystan(trace_pystan, save_warmup=True) 233 | ``` 234 | 235 | ## Comparison to the pymc sampler 236 | 237 | ```python 238 | %%time 239 | with pymc_model: 240 | trace_py = pm.sample( 241 | init="jitter+adapt_diag_grad", 242 | draws=1000, 243 | chains=10, 244 | cores=10, 245 | idata_kwargs={"log_likelihood": False}, 246 | compute_convergence_checks=False, 247 | target_accept=0.8, 248 | discard_tuned_samples=False, 249 | ) 250 | ``` 251 | 252 | ## Early convergance speed 253 | 254 | ```python 255 | plt.plot((trace_pymc.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum("draw").T, np.log(trace_pymc.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T)); 256 | plt.xlim(0, 10000) 257 | plt.ylabel("log-energy") 258 | plt.xlabel("gradient evaluations"); 259 | ``` 260 | 261 | ```python 262 | trace_cmdstan = arviz.from_cmdstan("output_*.csv", save_warmup=True) 263 | ``` 264 | 265 | ```python 266 | plt.plot((trace_cmdstan.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum("draw").T, np.log(trace_cmdstan.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T)); 267 | plt.xlim(0, 10000) 268 | plt.ylabel("log-energy") 269 | plt.xlabel("gradient evaluations"); 270 | ``` 271 | 272 | The new implementation only use about a third of gradient evaluations during tuning 273 | 274 | ```python 275 | trace_cmdstan.warmup_sample_stats.n_steps.sum() 276 | ``` 277 | 278 | ```python 279 | trace_stan.warmup_sample_stats.n_steps.sum() 280 | ``` 281 | 282 | ## Comparison to cmdstan 283 | 284 | 285 | Run on the commandline: 286 | ``` 287 | env STAN_THREADS=1 cmdstan_model radon_model.stan 288 | ``` 289 | 290 | ```python 291 | import json 292 | ``` 293 | 294 | ```python 295 | stan.common.simdjson 296 | ``` 297 | 298 | ```python 299 | type({name: int(val) if isinstance(val, int) else list(val) for name, val in data_stan.items()}["county_idx"][0]) 300 | ``` 301 | 302 | ```python 303 | data_json = {} 304 | for name, val in data_stan.items(): 305 | if isinstance(val, int): 306 | data_json[name] = int(val) 307 | continue 308 | 309 | if val.dtype == np.int64: 310 | data_json[name] = list(int(x) for x in val) 311 | continue 312 | 313 | data_json[name] = list(val) 314 | 315 | with open("radon.json", "w") as file: 316 | json.dump(data_json, file) 317 | ``` 318 | 319 | ```python 320 | %%time 321 | out = !./radon_model sample num_chains=10 save_warmup=1 data file=radon.json num_threads=10 322 | ``` 323 | 324 | ```python 325 | trace_cmdstan = arviz.from_cmdstan("output_*.csv", save_warmup=True) 326 | ``` 327 | 328 | ## Gradient evals per effective sample 329 | 330 | nutpie uses fewer gradient evaluations per effective sample in this model. 331 | 332 | ```python 333 | trace_cmdstan.sample_stats.n_steps.sum() / arviz.ess(trace_cmdstan).min() 334 | ``` 335 | 336 | ```python 337 | trace_stan.sample_stats.n_steps.sum() / arviz.ess(trace_stan).min() 338 | ``` 339 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.1,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "nutpie" 7 | description = "Sample Stan or PyMC models" 8 | authors = [{ name = "PyMC Developers", email = "pymc.devs@gmail.com" }] 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = { text = "MIT" } 12 | classifiers = [ 13 | "Programming Language :: Rust", 14 | "Programming Language :: Python :: Implementation :: CPython", 15 | "Programming Language :: Python :: Implementation :: PyPy", 16 | ] 17 | 18 | dependencies = [ 19 | "pyarrow >= 12.0.0", 20 | "pandas >= 2.0", 21 | "xarray >= 2025.01.2", 22 | "arviz >= 0.20.0", 23 | ] 24 | dynamic = ["version"] 25 | 26 | [project.urls] 27 | Homepage = "https://pymc-devs.github.io/nutpie/" 28 | Repository = "https://github.com/pymc-devs/nutpie" 29 | 30 | [project.optional-dependencies] 31 | stan = ["bridgestan >= 2.6.1", "stanio >= 0.5.1"] 32 | pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] 33 | pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] 34 | nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] 35 | dev = [ 36 | "bridgestan >= 2.6.1", 37 | "stanio >= 0.5.1", 38 | "pymc >= 5.20.1", 39 | "numba >= 0.60.0", 40 | "jax >= 0.4.27", 41 | "flowjax >= 17.0.2", 42 | "pytest", 43 | "pytest-timeout", 44 | "pytest-arraydiff", 45 | ] 46 | all = [ 47 | "bridgestan >= 2.6.1", 48 | "stanio >= 0.5.1", 49 | "pymc >= 5.20.1", 50 | "numba >= 0.60.0", 51 | "jax >= 0.4.27", 52 | "flowjax >= 17.1.0", 53 | "equinox >= 0.11.12", 54 | ] 55 | 56 | [tool.ruff] 57 | line-length = 88 58 | target-version = "py310" 59 | show-fixes = true 60 | output-format = "full" 61 | 62 | [tool.ruff.lint.flake8-tidy-imports] 63 | ban-relative-imports = "all" 64 | 65 | [tool.ruff.lint.isort] 66 | known-first-party = ["nutpie"] 67 | 68 | [tool.pyright] 69 | venvPath = ".pixi/envs/" 70 | venv = "default" 71 | 72 | [tool.maturin] 73 | module-name = "nutpie._lib" 74 | python-source = "python" 75 | features = ["pyo3/extension-module"] 76 | 77 | [tool.pytest.ini_options] 78 | markers = [ 79 | "flow: tests for normalizing flows", 80 | "stan: tests for Stan models", 81 | "pymc: tests for PyMC models", 82 | ] 83 | -------------------------------------------------------------------------------- /python/nutpie/__init__.py: -------------------------------------------------------------------------------- 1 | from nutpie import _lib 2 | from nutpie.compile_pymc import compile_pymc_model 3 | from nutpie.compile_stan import compile_stan_model 4 | from nutpie.sample import sample 5 | 6 | __version__: str = _lib.__version__ 7 | __all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"] 8 | -------------------------------------------------------------------------------- /python/nutpie/compile_stan.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from dataclasses import dataclass, replace 3 | from importlib.util import find_spec 4 | from pathlib import Path 5 | from typing import Any, Optional 6 | 7 | import pandas as pd 8 | from numpy.typing import NDArray 9 | 10 | from nutpie import _lib 11 | from nutpie.sample import CompiledModel 12 | 13 | 14 | @dataclass(frozen=True) 15 | class CompiledStanModel(CompiledModel): 16 | _coords: Optional[dict[str, Any]] 17 | code: str 18 | data: Optional[dict[str, NDArray]] 19 | library: Any 20 | model: Any 21 | model_name: Optional[str] = None 22 | _transform_adapt_args: dict | None = None 23 | 24 | def with_data(self, *, seed=None, **updates): 25 | if self.data is None: 26 | data = {} 27 | else: 28 | data = self.data.copy() 29 | 30 | data.update(updates) 31 | 32 | if data is not None: 33 | if find_spec("stanio") is None: 34 | raise ImportError( 35 | "stanio is not installed in the current environment. " 36 | "Please install it with something like " 37 | "'pip install stanio' or 'pip install nutpie[stan]'." 38 | ) 39 | 40 | import stanio 41 | 42 | data_json = stanio.dump_stan_json(data) 43 | else: 44 | data_json = None 45 | 46 | outer_kwargs = self._transform_adapt_args 47 | if outer_kwargs is None: 48 | outer_kwargs = {} 49 | 50 | def make_adapter(*args, **kwargs): 51 | from nutpie.transform_adapter import make_transform_adapter 52 | 53 | return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None) 54 | 55 | model = _lib.StanModel(self.library, seed, data_json, make_adapter) 56 | coords = self._coords 57 | if coords is None: 58 | coords = {} 59 | else: 60 | coords = coords.copy() 61 | coords["unconstrained_parameter"] = pd.Index(model.param_unc_names()) 62 | 63 | return CompiledStanModel( 64 | _coords=coords, 65 | data=data, 66 | code=self.code, 67 | library=self.library, 68 | dims=self.dims, 69 | model=model, 70 | ) 71 | 72 | def with_coords(self, **coords): 73 | if self.coords is None: 74 | coords_new = {} 75 | else: 76 | coords_new = self.coords.copy() 77 | coords_new.update(coords) 78 | return replace(self, _coords=coords_new) 79 | 80 | def with_dims(self, **dims): 81 | if self.dims is None: 82 | dims_new = {} 83 | else: 84 | dims_new = self.dims.copy() 85 | dims_new.update(dims) 86 | return replace(self, dims=dims_new) 87 | 88 | def with_transform_adapt(self, **kwargs): 89 | return replace(self, _transform_adapt_args=kwargs).with_data() 90 | 91 | def _make_model(self, init_mean): 92 | if self.model is None: 93 | return self.with_data().model 94 | return self.model 95 | 96 | def _make_sampler(self, settings, init_mean, cores, progress_type): 97 | model = self._make_model(init_mean) 98 | return _lib.PySampler.from_stan( 99 | settings, 100 | cores, 101 | model, 102 | progress_type, 103 | ) 104 | 105 | @property 106 | def n_dim(self): 107 | if self.model is None: 108 | return self.with_data().n_dim 109 | return self.model.ndim() 110 | 111 | @property 112 | def shapes(self): 113 | if self.model is None: 114 | return self.with_data().shapes 115 | return {name: var.shape for name, var in self.model.variables().items()} 116 | 117 | @property 118 | def coords(self): 119 | if self.model is None: 120 | return self.with_data().coords 121 | return self._coords 122 | 123 | 124 | def compile_stan_model( 125 | *, 126 | code: Optional[str] = None, 127 | filename: Optional[str] = None, 128 | extra_compile_args: Optional[list[str]] = None, 129 | extra_stanc_args: Optional[list[str]] = None, 130 | dims: Optional[dict[str, int]] = None, 131 | coords: Optional[dict[str, Any]] = None, 132 | model_name: Optional[str] = None, 133 | cleanup: bool = True, 134 | ) -> CompiledStanModel: 135 | if find_spec("bridgestan") is None: 136 | raise ImportError( 137 | "BridgeStan is not installed in the current environment. " 138 | "Please install it with something like " 139 | "'pip install bridgestan' or 'pip install nutpie[stan]'." 140 | ) 141 | 142 | import bridgestan 143 | 144 | if dims is None: 145 | dims = {} 146 | if coords is None: 147 | coords = {} 148 | 149 | if code is not None and filename is not None: 150 | raise ValueError("Specify exactly one of `code` and `filename`") 151 | if code is None: 152 | if filename is None: 153 | raise ValueError("Either code or filename have to be specified") 154 | with Path(filename).open() as file: 155 | code = file.read() 156 | 157 | if model_name is None: 158 | model_name = "model" 159 | 160 | basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) 161 | try: 162 | model_path = ( 163 | Path(basedir.name) 164 | .joinpath("name") 165 | .with_name(model_name) # This verifies that it is a valid filename 166 | .with_suffix(".stan") 167 | ) 168 | model_path.write_text(code) 169 | make_args = ["STAN_THREADS=true"] 170 | if extra_compile_args: 171 | make_args.extend(extra_compile_args) 172 | stanc_args = [] 173 | if extra_stanc_args: 174 | stanc_args.extend(extra_stanc_args) 175 | so_path = bridgestan.compile_model( 176 | model_path, make_args=make_args, stanc_args=stanc_args 177 | ) 178 | # Set necessary library loading paths 179 | bridgestan.compile.windows_dll_path_setup() 180 | library = _lib.StanLibrary(so_path) 181 | finally: 182 | try: 183 | if cleanup: 184 | basedir.cleanup() 185 | except Exception: # noqa: BLE001 186 | pass 187 | 188 | return CompiledStanModel( 189 | code=code, 190 | library=library, 191 | dims=dims, 192 | _coords=coords, 193 | model_name=model_name, 194 | model=None, 195 | data=None, 196 | ) 197 | -------------------------------------------------------------------------------- /python/nutpie/compiled_pyfunc.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass 3 | from functools import partial 4 | from typing import Any, Callable 5 | 6 | import numpy as np 7 | 8 | from nutpie import _lib # type: ignore 9 | from nutpie.sample import CompiledModel 10 | 11 | SeedType = int 12 | 13 | 14 | @dataclass(frozen=True) 15 | class PyFuncModel(CompiledModel): 16 | _make_logp_func: Callable 17 | _make_expand_func: Callable 18 | _make_initial_points: Callable[[SeedType], np.ndarray] | None 19 | _shared_data: dict[str, Any] 20 | _n_dim: int 21 | _variables: list[_lib.PyVariable] 22 | _coords: dict[str, Any] 23 | _raw_logp_fn: Callable | None 24 | _transform_adapt_args: dict | None = None 25 | 26 | @property 27 | def shapes(self) -> dict[str, tuple[int, ...]]: 28 | return {var.name: tuple(var.dtype.shape) for var in self._variables} 29 | 30 | @property 31 | def coords(self): 32 | return self._coords 33 | 34 | @property 35 | def n_dim(self): 36 | return self._n_dim 37 | 38 | def with_data(self, **updates): 39 | for name in updates: 40 | if name not in self._shared_data: 41 | raise ValueError(f"Unknown data variable: {name}") 42 | 43 | updated = self._shared_data.copy() 44 | updated.update(**updates) 45 | return dataclasses.replace(self, _shared_data=updated) 46 | 47 | def with_transform_adapt(self, **kwargs): 48 | return dataclasses.replace(self, _transform_adapt_args=kwargs) 49 | 50 | def _make_sampler(self, settings, init_mean, cores, progress_type): 51 | model = self._make_model(init_mean) 52 | return _lib.PySampler.from_pyfunc( 53 | settings, 54 | cores, 55 | model, 56 | progress_type, 57 | ) 58 | 59 | def _make_model(self, init_mean): 60 | def make_logp_func(): 61 | logp_fn = self._make_logp_func() 62 | return partial(logp_fn, **self._shared_data) 63 | 64 | def make_expand_func(seed1, seed2, chain): 65 | expand_fn = self._make_expand_func(seed1, seed2, chain) 66 | return partial(expand_fn, **self._shared_data) 67 | 68 | if self._raw_logp_fn is not None: 69 | outer_kwargs = self._transform_adapt_args 70 | if outer_kwargs is None: 71 | outer_kwargs = {} 72 | 73 | def make_adapter(*args, **kwargs): 74 | from nutpie.transform_adapter import make_transform_adapter 75 | 76 | return make_transform_adapter(**outer_kwargs)( 77 | *args, **kwargs, logp_fn=self._raw_logp_fn 78 | ) 79 | 80 | else: 81 | make_adapter = None 82 | 83 | return _lib.PyModel( 84 | make_logp_func, 85 | make_expand_func, 86 | self._variables, 87 | self.n_dim, 88 | init_point_func=self._make_initial_points, 89 | transform_adapter=make_adapter, 90 | ) 91 | 92 | 93 | def from_pyfunc( 94 | ndim: int, 95 | make_logp_fn: Callable, 96 | make_expand_fn: Callable, 97 | expanded_dtypes: list[np.dtype], 98 | expanded_shapes: list[tuple[int, ...]], 99 | expanded_names: list[str], 100 | *, 101 | coords: dict[str, Any] | None = None, 102 | dims: dict[str, tuple[str, ...]] | None = None, 103 | shared_data: dict[str, Any] | None = None, 104 | make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, 105 | make_transform_adapter=None, 106 | raw_logp_fn=None, 107 | ): 108 | variables = [] 109 | for name, shape, dtype in zip( 110 | expanded_names, expanded_shapes, expanded_dtypes, strict=True 111 | ): 112 | shape = _lib.TensorShape(list(shape)) 113 | if dtype == np.float64: 114 | dtype = _lib.ExpandDtype.float64_array(shape) 115 | elif dtype == np.float32: 116 | dtype = _lib.ExpandDtype.float32_array(shape) 117 | elif dtype == np.int64: 118 | dtype = _lib.ExpandDtype.int64_array(shape) 119 | variables.append(_lib.PyVariable(name, dtype)) 120 | 121 | if coords is None: 122 | coords = {} 123 | if dims is None: 124 | dims = {} 125 | if shared_data is None: 126 | shared_data = {} 127 | 128 | return PyFuncModel( 129 | _n_dim=ndim, 130 | dims=dims, 131 | _coords=coords, 132 | _make_logp_func=make_logp_fn, 133 | _make_expand_func=make_expand_fn, 134 | _make_initial_points=make_initial_point_fn, 135 | _variables=variables, 136 | _shared_data=shared_data, 137 | _raw_logp_fn=raw_logp_fn, 138 | ) 139 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod progress; 2 | mod pyfunc; 3 | mod pymc; 4 | mod stan; 5 | mod wrapper; 6 | 7 | pub use wrapper::_lib; 8 | -------------------------------------------------------------------------------- /src/progress.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::BTreeMap, sync::Arc, time::Duration}; 2 | 3 | use anyhow::{Context, Result}; 4 | use indicatif::ProgressBar; 5 | use nuts_rs::{ChainProgress, ProgressCallback}; 6 | use pyo3::{Py, PyAny, Python}; 7 | use time_humanize::{Accuracy, Tense}; 8 | use upon::{Engine, Value}; 9 | 10 | pub struct ProgressHandler { 11 | engine: Engine<'static>, 12 | template: String, 13 | callback: Arc>, 14 | rate: Duration, 15 | n_cores: usize, 16 | } 17 | 18 | impl ProgressHandler { 19 | pub fn new(callback: Arc>, rate: Duration, template: String, n_cores: usize) -> Self { 20 | let engine = Engine::new(); 21 | Self { 22 | engine, 23 | callback, 24 | rate, 25 | template, 26 | n_cores, 27 | } 28 | } 29 | 30 | pub fn into_callback(self) -> Result { 31 | let template = self 32 | .engine 33 | .compile(self.template) 34 | .context("Could not compile progress template")?; 35 | 36 | let mut finished = false; 37 | let mut progress_update_count = 0; 38 | 39 | let callback = move |time_sampling, progress: Box<[ChainProgress]>| { 40 | if finished { 41 | return; 42 | } 43 | if progress 44 | .iter() 45 | .all(|chain| chain.finished_draws == chain.total_draws) 46 | { 47 | finished = true; 48 | } 49 | let progress = 50 | progress_to_value(progress_update_count, self.n_cores, time_sampling, progress); 51 | let rendered = template.render_from(&self.engine, &progress).to_string(); 52 | let rendered = rendered.unwrap_or_else(|err| format!("{err}")); 53 | let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,))); 54 | progress_update_count += 1; 55 | }; 56 | 57 | Ok(ProgressCallback { 58 | callback: Box::new(callback), 59 | rate: self.rate, 60 | }) 61 | } 62 | } 63 | 64 | fn progress_to_value( 65 | progress_update_count: usize, 66 | n_cores: usize, 67 | time_sampling: Duration, 68 | progress: Box<[ChainProgress]>, 69 | ) -> Value { 70 | let chains: Vec<_> = progress 71 | .iter() 72 | .enumerate() 73 | .map(|(i, chain)| { 74 | let mut values = BTreeMap::new(); 75 | values.insert("chain_index".into(), Value::Integer(i as i64)); 76 | values.insert( 77 | "finished_draws".into(), 78 | Value::Integer(chain.finished_draws as i64), 79 | ); 80 | values.insert( 81 | "total_draws".into(), 82 | Value::Integer(chain.total_draws as i64), 83 | ); 84 | values.insert( 85 | "divergences".into(), 86 | Value::Integer(chain.divergences as i64), 87 | ); 88 | values.insert("tuning".into(), Value::Bool(chain.tuning)); 89 | values.insert("started".into(), Value::Bool(chain.started)); 90 | values.insert( 91 | "finished".into(), 92 | Value::Bool(chain.total_draws == chain.finished_draws), 93 | ); 94 | values.insert( 95 | "latest_num_steps".into(), 96 | Value::Integer(chain.latest_num_steps as i64), 97 | ); 98 | values.insert( 99 | "total_num_steps".into(), 100 | Value::Integer(chain.total_num_steps as i64), 101 | ); 102 | values.insert( 103 | "step_size".into(), 104 | Value::String(format!("{:.2}", chain.step_size)), 105 | ); 106 | values.insert( 107 | "divergent_draws".into(), 108 | Value::List( 109 | chain 110 | .divergent_draws 111 | .iter() 112 | .map(|&idx| Value::Integer(idx as _)) 113 | .collect(), 114 | ), 115 | ); 116 | upon::Value::Map(values) 117 | }) 118 | .collect(); 119 | 120 | let mut map = BTreeMap::new(); 121 | map.insert("chains".into(), Value::List(chains)); 122 | map.insert( 123 | "total_draws".into(), 124 | Value::Integer( 125 | progress 126 | .iter() 127 | .map(|chain| chain.total_draws) 128 | .sum::() as i64, 129 | ), 130 | ); 131 | map.insert( 132 | "total_finished_draws".into(), 133 | Value::Integer( 134 | progress 135 | .iter() 136 | .map(|chain| chain.finished_draws) 137 | .sum::() as i64, 138 | ), 139 | ); 140 | map.insert( 141 | "time_sampling".into(), 142 | Value::String( 143 | time_humanize::HumanTime::from(time_sampling) 144 | .to_text_en(Accuracy::Rough, Tense::Present), 145 | ), 146 | ); 147 | 148 | let remaining = estimate_remaining_time(n_cores, time_sampling, &progress); 149 | map.insert( 150 | "time_remaining_estimate".into(), 151 | match remaining { 152 | Some(remaining) => Value::String( 153 | time_humanize::HumanTime::from(remaining) 154 | .to_text_en(Accuracy::Rough, Tense::Present), 155 | ), 156 | None => Value::None, 157 | }, 158 | ); 159 | 160 | map.insert("num_cores".into(), Value::Integer(n_cores as _)); 161 | 162 | let finished_chains = progress 163 | .iter() 164 | .map(|chain| (chain.finished_draws == chain.total_draws) as u64) 165 | .sum::(); 166 | map.insert( 167 | "finished_chains".into(), 168 | Value::Integer(finished_chains as _), 169 | ); 170 | map.insert( 171 | "running_chains".into(), 172 | Value::Integer( 173 | progress 174 | .iter() 175 | .map(|chain| (chain.started & (chain.finished_draws < chain.total_draws)) as u64) 176 | .sum::() as i64, 177 | ), 178 | ); 179 | map.insert("num_chains".into(), Value::Integer(progress.len() as _)); 180 | map.insert( 181 | "finished".into(), 182 | Value::Bool(progress.len() == finished_chains as usize), 183 | ); 184 | map.insert( 185 | "progress_update_count".into(), 186 | Value::Integer(progress_update_count as i64), 187 | ); 188 | 189 | Value::Map(map) 190 | } 191 | 192 | fn estimate_remaining_time( 193 | n_cores: usize, 194 | time_sampling: Duration, 195 | progress: &[ChainProgress], 196 | ) -> Option { 197 | let finished_draws: u64 = progress 198 | .iter() 199 | .map(|chain| chain.finished_draws as u64) 200 | .sum(); 201 | if finished_draws == 0 { 202 | return None; 203 | } 204 | 205 | let finished_draws = finished_draws as f64; 206 | 207 | // TODO this assumes that so far all cores were used all the time 208 | let time_per_draw = time_sampling.mul_f64((n_cores as f64) / finished_draws); 209 | 210 | let mut core_times = vec![Duration::ZERO; n_cores]; 211 | 212 | progress 213 | .iter() 214 | .map(|chain| time_per_draw.mul_f64((chain.total_draws - chain.finished_draws) as f64)) 215 | .for_each(|time| { 216 | let min_index = core_times 217 | .iter() 218 | .enumerate() 219 | .min_by_key(|&(_, v)| v) 220 | .unwrap() 221 | .0; 222 | core_times[min_index] += time; 223 | }); 224 | 225 | Some(core_times.into_iter().max().unwrap_or(Duration::ZERO)) 226 | } 227 | 228 | pub struct IndicatifHandler { 229 | rate: Duration, 230 | } 231 | 232 | impl IndicatifHandler { 233 | pub fn new(rate: Duration) -> Self { 234 | Self { rate } 235 | } 236 | 237 | pub fn into_callback(self) -> Result { 238 | let mut finished = false; 239 | let mut last_draws = 0; 240 | let mut bar = None; 241 | 242 | let callback = move |_time_sampling, progress: Box<[ChainProgress]>| { 243 | let total: u64 = progress.iter().map(|chain| chain.total_draws as u64).sum(); 244 | 245 | if bar.is_none() { 246 | bar = Some(ProgressBar::new(total)); 247 | } 248 | 249 | let Some(ref bar) = bar else { unreachable!() }; 250 | 251 | if finished { 252 | return; 253 | } 254 | if progress 255 | .iter() 256 | .all(|chain| chain.finished_draws == chain.total_draws) 257 | { 258 | finished = true; 259 | bar.set_position(total); 260 | bar.finish(); 261 | } 262 | 263 | let finished_draws: u64 = progress 264 | .iter() 265 | .map(|chain| chain.finished_draws as u64) 266 | .sum(); 267 | 268 | let delta = finished_draws.saturating_sub(last_draws); 269 | if delta > 0 { 270 | bar.set_position(finished_draws); 271 | last_draws = finished_draws; 272 | } 273 | }; 274 | 275 | Ok(ProgressCallback { 276 | callback: Box::new(callback), 277 | rate: self.rate, 278 | }) 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /src/pymc.rs: -------------------------------------------------------------------------------- 1 | use std::{ffi::c_void, fmt::Display, sync::Arc}; 2 | 3 | use anyhow::{bail, Context, Result}; 4 | use arrow::{ 5 | array::{Array, Float64Array, LargeListArray, StructArray}, 6 | buffer::OffsetBuffer, 7 | datatypes::{DataType, Field, Fields}, 8 | }; 9 | use itertools::{izip, Itertools}; 10 | use numpy::PyReadonlyArray1; 11 | use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; 12 | use pyo3::{ 13 | pyclass, pymethods, 14 | types::{PyAnyMethods, PyList}, 15 | Bound, Py, PyAny, PyObject, PyResult, Python, 16 | }; 17 | 18 | use rand_distr::num_traits::CheckedEuclid; 19 | use thiserror::Error; 20 | 21 | type UserData = *const std::ffi::c_void; 22 | 23 | type RawLogpFunc = unsafe extern "C" fn( 24 | usize, 25 | *const f64, 26 | *mut f64, 27 | *mut f64, 28 | *const std::ffi::c_void, 29 | ) -> std::os::raw::c_int; 30 | 31 | type RawExpandFunc = unsafe extern "C" fn( 32 | usize, 33 | usize, 34 | *const f64, 35 | *mut f64, 36 | *const std::ffi::c_void, 37 | ) -> std::os::raw::c_int; 38 | 39 | #[pyclass] 40 | #[derive(Clone)] 41 | pub(crate) struct LogpFunc { 42 | func: RawLogpFunc, 43 | _keep_alive: Arc, 44 | user_data_ptr: UserData, 45 | dim: usize, 46 | } 47 | 48 | unsafe impl Send for LogpFunc {} 49 | unsafe impl Sync for LogpFunc {} 50 | 51 | #[pymethods] 52 | impl LogpFunc { 53 | #[new] 54 | fn new(dim: usize, ptr: usize, user_data_ptr: usize, keep_alive: PyObject) -> Self { 55 | let func = 56 | unsafe { std::mem::transmute::<*const c_void, RawLogpFunc>(ptr as *const c_void) }; 57 | Self { 58 | func, 59 | _keep_alive: Arc::new(keep_alive), 60 | user_data_ptr: user_data_ptr as UserData, 61 | dim, 62 | } 63 | } 64 | } 65 | 66 | #[pyclass] 67 | #[derive(Clone)] 68 | pub(crate) struct ExpandFunc { 69 | func: RawExpandFunc, 70 | _keep_alive: Arc, 71 | user_data_ptr: UserData, 72 | dim: usize, 73 | expanded_dim: usize, 74 | } 75 | 76 | #[pymethods] 77 | impl ExpandFunc { 78 | #[new] 79 | fn new( 80 | dim: usize, 81 | expanded_dim: usize, 82 | ptr: usize, 83 | user_data_ptr: usize, 84 | keep_alive: PyObject, 85 | ) -> Self { 86 | let func = 87 | unsafe { std::mem::transmute::<*const c_void, RawExpandFunc>(ptr as *const c_void) }; 88 | Self { 89 | dim, 90 | expanded_dim, 91 | _keep_alive: Arc::new(keep_alive), 92 | user_data_ptr: user_data_ptr as UserData, 93 | func, 94 | } 95 | } 96 | } 97 | 98 | unsafe impl Send for ExpandFunc {} 99 | unsafe impl Sync for ExpandFunc {} 100 | 101 | #[derive(Error, Debug)] 102 | pub(crate) struct ErrorCode(std::os::raw::c_int); 103 | 104 | impl Display for ErrorCode { 105 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 106 | write!(f, "Logp function returned error code {}", self.0) 107 | } 108 | } 109 | 110 | impl LogpError for ErrorCode { 111 | fn is_recoverable(&self) -> bool { 112 | self.0 > 0 113 | } 114 | } 115 | 116 | impl CpuLogpFunc for &LogpFunc { 117 | type LogpError = ErrorCode; 118 | type TransformParams = (); 119 | 120 | fn dim(&self) -> usize { 121 | self.dim 122 | } 123 | 124 | fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { 125 | let mut logp = 0f64; 126 | let logp_ptr = (&mut logp) as *mut f64; 127 | assert!(position.len() == self.dim); 128 | assert!(gradient.len() == self.dim); 129 | let retcode = unsafe { 130 | (self.func)( 131 | self.dim, 132 | position.as_ptr(), 133 | gradient.as_mut_ptr(), 134 | logp_ptr, 135 | self.user_data_ptr, 136 | ) 137 | }; 138 | if retcode == 0 { 139 | return Ok(logp); 140 | } 141 | Err(ErrorCode(retcode)) 142 | } 143 | } 144 | 145 | #[derive(Clone)] 146 | pub(crate) struct PyMcTrace<'model> { 147 | dim: usize, 148 | data: Vec>, 149 | var_sizes: Vec, 150 | var_names: Vec, 151 | expand: &'model ExpandFunc, 152 | count: usize, 153 | } 154 | 155 | impl<'model> DrawStorage for PyMcTrace<'model> { 156 | fn append_value(&mut self, point: &[f64]) -> Result<()> { 157 | assert!(point.len() == self.dim); 158 | 159 | let point = self 160 | .expand_draw(point) 161 | .context("Could not compute deterministic variables")?; 162 | 163 | let mut start: usize = 0; 164 | for (&size, data) in self.var_sizes.iter().zip_eq(self.data.iter_mut()) { 165 | let end = start.checked_add(size).unwrap(); 166 | let vals = &point[start..end]; 167 | data.extend_from_slice(vals); 168 | start = end; 169 | } 170 | self.count += 1; 171 | 172 | Ok(()) 173 | } 174 | 175 | fn finalize(self) -> Result> { 176 | let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes) 177 | .map(|(data, name, size)| { 178 | let (num_arrays, rem) = data 179 | .len() 180 | .checked_div_rem_euclid(&size) 181 | .unwrap_or((self.count, 0)); 182 | assert!(rem == 0); 183 | assert!(num_arrays == self.count); 184 | let data = Float64Array::from(data); 185 | let item_field = Arc::new(Field::new("item", DataType::Float64, false)); 186 | let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size)); 187 | let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None); 188 | let field = Field::new(name, DataType::LargeList(item_field), false); 189 | (Arc::new(field), Arc::new(array) as Arc) 190 | }) 191 | .unzip(); 192 | 193 | let fields = Fields::from(fields); 194 | Ok(Arc::new( 195 | StructArray::try_new(fields, arrays, None).context("Could not create arrow struct")?, 196 | )) 197 | } 198 | 199 | fn inspect(&self) -> Result> { 200 | self.clone().finalize() 201 | } 202 | } 203 | 204 | impl<'model> PyMcTrace<'model> { 205 | fn new(model: &'model PyMcModel, settings: &impl Settings) -> Self { 206 | let draws = settings.hint_num_draws() + settings.hint_num_tune(); 207 | Self { 208 | dim: model.dim, 209 | data: model 210 | .var_sizes 211 | .iter() 212 | .map(|&size| Vec::with_capacity(size * draws)) 213 | .collect(), 214 | var_sizes: model.var_sizes.clone(), 215 | var_names: model.var_names.clone(), 216 | expand: &model.expand, 217 | count: 0, 218 | } 219 | } 220 | 221 | fn expand_draw(&mut self, point: &[f64]) -> Result> { 222 | let mut out = vec![0f64; self.expand.expanded_dim].into_boxed_slice(); 223 | let retcode = unsafe { 224 | (self.expand.func)( 225 | self.expand.dim, 226 | self.expand.expanded_dim, 227 | point.as_ptr(), 228 | out.as_mut_ptr(), 229 | self.expand.user_data_ptr, 230 | ) 231 | }; 232 | if retcode == 0 { 233 | Ok(out) 234 | } else { 235 | Err(anyhow::Error::msg("Failed to expand a draw.")) 236 | } 237 | } 238 | } 239 | 240 | #[pyclass] 241 | #[derive(Clone)] 242 | pub(crate) struct PyMcModel { 243 | dim: usize, 244 | density: LogpFunc, 245 | expand: ExpandFunc, 246 | init_func: Arc>, 247 | var_sizes: Vec, 248 | var_names: Vec, 249 | } 250 | 251 | #[pymethods] 252 | impl PyMcModel { 253 | #[new] 254 | fn new<'py>( 255 | dim: usize, 256 | density: LogpFunc, 257 | expand: ExpandFunc, 258 | init_func: Py, 259 | var_sizes: &Bound<'py, PyList>, 260 | var_names: &Bound<'py, PyList>, 261 | ) -> PyResult { 262 | Ok(Self { 263 | dim, 264 | density, 265 | expand, 266 | init_func: init_func.into(), 267 | var_names: var_names.extract()?, 268 | var_sizes: var_sizes.extract()?, 269 | }) 270 | } 271 | 272 | /* 273 | fn benchmark_logp<'py>( 274 | &self, 275 | py: Python<'py>, 276 | point: PyReadonlyArray1<'py, f64>, 277 | cores: usize, 278 | evals: usize, 279 | ) -> PyResult<&'py PyList> { 280 | let point = point.to_vec()?; 281 | let durations = py.allow_threads(|| Model::benchmark_logp(self, &point, cores, evals))?; 282 | let out = PyList::new( 283 | py, 284 | durations 285 | .into_iter() 286 | .map(|inner| PyList::new(py, inner.into_iter().map(|d| d.as_secs_f64()))), 287 | ); 288 | Ok(out) 289 | } 290 | */ 291 | } 292 | 293 | impl Model for PyMcModel { 294 | type Math<'model> = CpuMath<&'model LogpFunc>; 295 | 296 | type DrawStorage<'model, S: Settings> = PyMcTrace<'model>; 297 | 298 | fn math(&self) -> Result> { 299 | Ok(CpuMath::new(&self.density)) 300 | } 301 | 302 | fn init_position( 303 | &self, 304 | rng: &mut R, 305 | position: &mut [f64], 306 | ) -> Result<()> { 307 | let seed = rng.next_u64(); 308 | 309 | Python::with_gil(|py| { 310 | let init_point = self 311 | .init_func 312 | .call1(py, (seed,)) 313 | .context("Failed to initialize point")?; 314 | 315 | let init_point: PyReadonlyArray1 = init_point 316 | .extract(py) 317 | .context("Initializition array returned incorrect argument")?; 318 | 319 | let init_point = init_point 320 | .as_slice() 321 | .context("Initial point must be contiguous")?; 322 | 323 | if init_point.len() != position.len() { 324 | bail!("Initial point has incorrect length"); 325 | } 326 | 327 | position.copy_from_slice(init_point); 328 | Ok(()) 329 | })?; 330 | Ok(()) 331 | } 332 | 333 | fn new_trace<'model, S: Settings, R: rand::prelude::Rng + ?Sized>( 334 | &'model self, 335 | _rng: &mut R, 336 | _chain_id: u64, 337 | settings: &'model S, 338 | ) -> Result> { 339 | Ok(PyMcTrace::new(self, settings)) 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /tests/reference/test_deterministic_sampling_jax.txt: -------------------------------------------------------------------------------- 1 | 0.941959 2 | 0.559649 3 | 0.534203 4 | 0.561444 5 | 0.561444 6 | 0.418685 7 | 0.827896 8 | 0.847014 9 | 0.738508 10 | 0.961291 11 | 0.923931 12 | 1.00584 13 | 1.16386 14 | 1.10065 15 | 1.6348 16 | 1.13139 17 | 0.993458 18 | 0.993458 19 | 0.966241 20 | 1.10922 21 | 1.10922 22 | 1.05723 23 | 1.05723 24 | 2.32492 25 | 0.0700824 26 | 0.0860656 27 | 1.36431 28 | 0.829624 29 | 0.584658 30 | 0.531506 31 | 0.507961 32 | 0.543701 33 | 0.510104 34 | 2.46898 35 | 0.820341 36 | 0.490474 37 | 0.343958 38 | 0.300549 39 | 2.60267 40 | 0.588131 41 | 0.430013 42 | 0.618032 43 | 1.27527 44 | 1.80449 45 | 1.80449 46 | 0.855217 47 | 0.556106 48 | 1.77619 49 | 2.03761 50 | 1.02106 51 | 0.774811 52 | 1.78438 53 | 1.61398 54 | 0.712683 55 | 1.04966 56 | 1.17936 57 | 1.5425 58 | 1.5425 59 | 1.26262 60 | 1.39659 61 | 0.337024 62 | 0.177694 63 | 0.0424286 64 | 0.180403 65 | 0.140553 66 | 0.367095 67 | 0.348732 68 | 0.341436 69 | 1.82764 70 | 0.692738 71 | 0.629186 72 | 0.245706 73 | 0.732305 74 | 0.56873 75 | 0.498757 76 | 0.204131 77 | 0.417031 78 | 0.184895 79 | 0.208768 80 | 0.238139 81 | 1.95089 82 | 1.95089 83 | 0.593379 84 | 0.593379 85 | 0.750063 86 | 0.69929 87 | 0.490359 88 | 0.478709 89 | 0.361632 90 | 0.346159 91 | 0.728965 92 | 1.58228 93 | 0.985676 94 | 1.58468 95 | 0.709012 96 | 0.700483 97 | 0.805006 98 | 1.70347 99 | 1.26293 100 | 1.24837 101 | 0.23989 102 | 0.881025 103 | 1.39084 104 | 1.37812 105 | 0.969265 106 | 0.969265 107 | 0.938487 108 | 0.846447 109 | 1.61945 110 | 0.108473 111 | 0.173496 112 | 0.897353 113 | 0.455899 114 | 0.571886 115 | 0.891672 116 | 0.891672 117 | 0.864419 118 | 0.739099 119 | 1.49009 120 | 1.49009 121 | 0.385499 122 | 0.228701 123 | 1.83156 124 | 1.83156 125 | 0.947635 126 | 0.805623 127 | 0.714762 128 | 0.853477 129 | 1.45906 130 | 0.908818 131 | 0.540951 132 | 1.40995 133 | 1.22564 134 | 0.26496 135 | 0.159994 136 | 0.423836 137 | 0.350158 138 | 0.388884 139 | 1.39507 140 | 0.727701 141 | 1.80674 142 | 0.466389 143 | 1.61574 144 | 1.61574 145 | 0.42774 146 | 0.217983 147 | 0.14579 148 | 1.01321 149 | 1.01321 150 | 1.19713 151 | 0.390791 152 | 0.223687 153 | 0.149019 154 | 0.103866 155 | 0.153768 156 | 0.12942 157 | 0.346371 158 | 0.814553 159 | 2.41042 160 | 0.42739 161 | 0.322291 162 | 0.248911 163 | 0.854404 164 | 1.35372 165 | 1.35372 166 | 2.00546 167 | 0.0457881 168 | 0.0415644 169 | 0.0797551 170 | 0.0913076 171 | 0.070948 172 | 0.00993872 173 | 0.421448 174 | 0.550377 175 | 0.609387 176 | 0.490487 177 | 2.6607 178 | 0.32804 179 | 0.385999 180 | 0.497294 181 | 1.67109 182 | 1.14328 183 | 1.14328 184 | 0.903063 185 | 0.903063 186 | 0.903063 187 | 0.691269 188 | 2.00151 189 | 0.587672 190 | 0.79679 191 | 1.35563 192 | 0.598471 193 | 0.681826 194 | 0.818296 195 | 1.14265 196 | 0.113094 197 | 0.250861 198 | 0.284491 199 | 0.00420445 200 | 0.00566936 201 | -------------------------------------------------------------------------------- /tests/reference/test_deterministic_sampling_numba.txt: -------------------------------------------------------------------------------- 1 | 0.862203 2 | 0.743827 3 | 0.985284 4 | 0.864159 5 | 1.11537 6 | 1.46228 7 | 1.46228 8 | 0.731645 9 | 0.618394 10 | 0.70658 11 | 1.58816 12 | 1.58816 13 | 1.58816 14 | 1.58816 15 | 1.02597 16 | 1.02597 17 | 2.38965 18 | 0.0442154 19 | 0.0556998 20 | 1.20147 21 | 0.878239 22 | 0.595919 23 | 0.542086 24 | 0.520452 25 | 0.56279 26 | 0.539904 27 | 0.129453 28 | 0.136407 29 | 0.408806 30 | 0.34263 31 | 0.929525 32 | 0.947864 33 | 0.947864 34 | 1.94444 35 | 0.911973 36 | 0.429576 37 | 0.776378 38 | 0.452981 39 | 0.985476 40 | 1.74745 41 | 1.74095 42 | 1.74095 43 | 0.9855 44 | 0.886535 45 | 0.617313 46 | 0.86405 47 | 2.00577 48 | 0.839407 49 | 0.745118 50 | 1.49611 51 | 1.74491 52 | 1.40854 53 | 0.631877 54 | 1.95302 55 | 1.01379 56 | 1.1063 57 | 0.930275 58 | 0.315935 59 | 0.225544 60 | 0.136821 61 | 0.180021 62 | 0.498635 63 | 0.462448 64 | 0.445633 65 | 0.0878991 66 | 0.105731 67 | 0.355683 68 | 0.750934 69 | 0.750934 70 | 0.874486 71 | 1.15119 72 | 0.657067 73 | 0.500027 74 | 1.28332 75 | 1.28332 76 | 0.919994 77 | 1.09658 78 | 1.73803 79 | 1.13439 80 | 1.21956 81 | 0.643106 82 | 0.329788 83 | 0.456239 84 | 0.596018 85 | 0.180103 86 | 0.388767 87 | 1.03772 88 | 1.03192 89 | 1.03192 90 | 1.04759 91 | 1.04759 92 | 1.13558 93 | 0.673716 94 | 0.871073 95 | 0.50739 96 | 0.625146 97 | 0.999657 98 | 1.00779 99 | 2.06182 100 | 0.707917 101 | 0.107437 102 | 0.0772623 103 | 0.10719 104 | 0.36616 105 | 0.14863 106 | 0.0333724 107 | 0.0295763 108 | 0.0205304 109 | 0.127619 110 | 0.164319 111 | 0.241143 112 | 0.376838 113 | 0.87369 114 | 1.64165 115 | 0.106128 116 | 0.170459 117 | 0.916833 118 | 0.458599 119 | 0.575215 120 | 0.894488 121 | 0.894488 122 | 0.865427 123 | 0.739365 124 | 0.681649 125 | 0.72888 126 | 1.38352 127 | 1.38352 128 | 2.28238 129 | 2.28238 130 | 2.28238 131 | 0.567775 132 | 0.41864 133 | 1.41709 134 | 1.41709 135 | 1.41709 136 | 1.41709 137 | 0.600311 138 | 0.598689 139 | 0.627731 140 | 0.460137 141 | 1.86219 142 | 1.81783 143 | 1.78092 144 | 1.78092 145 | 1.78092 146 | 0.492732 147 | 1.37953 148 | 1.16762 149 | 0.597573 150 | 0.627465 151 | 0.617661 152 | 0.649115 153 | 0.608255 154 | 0.685365 155 | 0.685365 156 | 0.685365 157 | 0.685365 158 | 0.685365 159 | 2.2227 160 | 0.971606 161 | 0.4219 162 | 0.879055 163 | 0.74434 164 | 2.08679 165 | 1.34952 166 | 1.34952 167 | 1.34952 168 | 1.34952 169 | 0.513284 170 | 0.16734 171 | 0.174037 172 | 0.626756 173 | 0.913504 174 | 0.271423 175 | 0.200176 176 | 0.132462 177 | 0.465497 178 | 0.406755 179 | 0.493296 180 | 0.0175891 181 | 0.0234891 182 | 0.0220327 183 | 0.132404 184 | 0.0788943 185 | 0.0949265 186 | 0.103031 187 | 0.0760492 188 | 0.377155 189 | 1.90599 190 | 1.58063 191 | 1.58063 192 | 1.17038 193 | 0.556726 194 | 0.55085 195 | 0.24632 196 | 0.375951 197 | 0.339243 198 | 0.747524 199 | 1.82921 200 | 0.794344 201 | -------------------------------------------------------------------------------- /tests/reference/test_deterministic_sampling_stan.txt: -------------------------------------------------------------------------------- 1 | 1.21572 1.03376 1.60518 1.60518 1.59553 1.35023 0.761056 1.41688 1.41688 1.41688 2 | 0.252389 0.999663 0.999663 0.999663 0.740026 0.387763 0.944247 0.289785 1.52909 0.683129 3 | -------------------------------------------------------------------------------- /tests/reference/test_normalizing_flow.txt: -------------------------------------------------------------------------------- 1 | 0.324871 2 | 1.16777 3 | 0.102039 4 | 0.0579082 5 | 0.985197 6 | 0.550663 7 | 0.929168 8 | 0.543959 9 | 0.166275 10 | 0.359855 11 | 0.764495 12 | 1.77769 13 | 0.462447 14 | 0.984399 15 | 0.490158 16 | 1.02799 17 | 0.702622 18 | 0.473246 19 | 0.0127807 20 | 1.13249 21 | 1.0929 22 | 0.357403 23 | 1.27519 24 | 0.842248 25 | 1.00152 26 | 2.38541 27 | 0.854202 28 | 0.00735577 29 | 0.218296 30 | 2.20921 31 | 1.74756 32 | 0.119245 33 | 1.74756 34 | 0.119245 35 | 0.381874 36 | 1.82536 37 | 1.29837 38 | 2.52243 39 | 0.86956 40 | 0.0373546 41 | 0.105311 42 | 0.706774 43 | 0.217778 44 | 0.700421 45 | 0.858623 46 | 1.65319 47 | 0.499877 48 | 0.832728 49 | 2.06511 50 | 1.26955 51 | 0.334041 52 | 0.681329 53 | 1.12741 54 | 0.21517 55 | 1.04719 56 | 0.269313 57 | 0.512924 58 | 2.31191 59 | 0.374169 60 | 0.633086 61 | 0.374169 62 | 0.633086 63 | 0.43021 64 | 1.19342 65 | 0.101336 66 | 0.323738 67 | 0.147408 68 | 1.16919 69 | 0.0614138 70 | 1.33695 71 | 1.54323 72 | 0.199492 73 | 0.351604 74 | 0.396807 75 | 1.05526 76 | 1.62499 77 | 0.266035 78 | 0.54486 79 | 0.861217 80 | 0.621417 81 | 0.416124 82 | 0.64435 83 | 0.430111 84 | 0.634412 85 | 0.614077 86 | 0.986153 87 | 0.76649 88 | 0.549664 89 | 0.353417 90 | 1.19541 91 | 0.467103 92 | 1.8358 93 | 1.04574 94 | 0.438734 95 | 0.641016 96 | 0.699005 97 | 1.69749 98 | 0.317435 99 | 0.175226 100 | 0.739452 101 | -------------------------------------------------------------------------------- /tests/test_pymc.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | import time 3 | import pytest 4 | 5 | if find_spec("pymc") is None: 6 | pytest.skip("Skip pymc tests", allow_module_level=True) 7 | 8 | import numpy as np 9 | import pymc as pm 10 | import pytest 11 | 12 | import nutpie 13 | import nutpie.compile_pymc 14 | 15 | parameterize_backends = pytest.mark.parametrize( 16 | "backend, gradient_backend", 17 | [("numba", None), ("jax", "pytensor"), ("jax", "jax")], 18 | ) 19 | 20 | 21 | @pytest.mark.pymc 22 | @parameterize_backends 23 | def test_pymc_model(backend, gradient_backend): 24 | with pm.Model() as model: 25 | pm.Normal("a") 26 | 27 | compiled = nutpie.compile_pymc_model( 28 | model, backend=backend, gradient_backend=gradient_backend 29 | ) 30 | trace = nutpie.sample(compiled, chains=1) 31 | trace.posterior.a # noqa: B018 32 | 33 | 34 | @pytest.mark.pymc 35 | @parameterize_backends 36 | def test_name_x(backend, gradient_backend): 37 | with pm.Model() as model: 38 | x = pm.Data("x", 1.0) 39 | a = pm.Normal("a", mu=x) 40 | pm.Deterministic("z", x * a) 41 | 42 | compiled = nutpie.compile_pymc_model( 43 | model, backend=backend, gradient_backend=gradient_backend, freeze_model=False 44 | ) 45 | trace = nutpie.sample(compiled, chains=1) 46 | trace.posterior.a # noqa: B018 47 | 48 | 49 | @pytest.mark.pymc 50 | def test_order_shared(): 51 | a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) 52 | with pm.Model() as model: 53 | a = pm.Data("a", np.copy(a_val, order="C")) 54 | b = pm.Normal("b", shape=(2, 5)) 55 | pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1)) 56 | 57 | compiled = nutpie.compile_pymc_model(model, backend="numba") 58 | trace = nutpie.sample(compiled) 59 | np.testing.assert_allclose( 60 | ( 61 | trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :] 62 | ).sum(-1), 63 | trace.posterior.c.values, 64 | ) 65 | 66 | with pm.Model() as model: 67 | a = pm.Data("a", np.copy(a_val, order="F")) 68 | b = pm.Normal("b", shape=(2, 5)) 69 | pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1)) 70 | 71 | compiled = nutpie.compile_pymc_model(model, backend="numba") 72 | trace = nutpie.sample(compiled) 73 | np.testing.assert_allclose( 74 | ( 75 | trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :] 76 | ).sum(-1), 77 | trace.posterior.c.values, 78 | ) 79 | 80 | 81 | @pytest.mark.pymc 82 | @parameterize_backends 83 | def test_low_rank(backend, gradient_backend): 84 | with pm.Model() as model: 85 | pm.Normal("a") 86 | 87 | compiled = nutpie.compile_pymc_model( 88 | model, backend=backend, gradient_backend=gradient_backend 89 | ) 90 | trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True) 91 | trace.posterior.a # noqa: B018 92 | 93 | 94 | @pytest.mark.pymc 95 | @parameterize_backends 96 | def test_low_rank_half_normal(backend, gradient_backend): 97 | with pm.Model() as model: 98 | pm.HalfNormal("a", shape=13) 99 | 100 | compiled = nutpie.compile_pymc_model( 101 | model, backend=backend, gradient_backend=gradient_backend 102 | ) 103 | trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True) 104 | trace.posterior.a # noqa: B018 105 | 106 | 107 | @pytest.mark.pymc 108 | @parameterize_backends 109 | def test_zero_size(backend, gradient_backend): 110 | import pytensor.tensor as pt 111 | 112 | with pm.Model() as model: 113 | a = pm.Normal("a", shape=(0, 0, 10)) 114 | pm.Deterministic("b", pt.exp(a)) 115 | 116 | compiled = nutpie.compile_pymc_model( 117 | model, backend=backend, gradient_backend=gradient_backend 118 | ) 119 | trace = nutpie.sample(compiled, chains=1, draws=17, tune=100) 120 | assert trace.posterior.a.shape == (1, 17, 0, 0, 10) 121 | assert trace.posterior.b.shape == (1, 17, 0, 0, 10) 122 | 123 | 124 | @pytest.mark.pymc 125 | @parameterize_backends 126 | def test_pymc_model_float32(backend, gradient_backend): 127 | import pytensor 128 | 129 | with pytensor.config.change_flags(floatX="float32"): 130 | with pm.Model() as model: 131 | pm.Normal("a") 132 | 133 | compiled = nutpie.compile_pymc_model( 134 | model, backend=backend, gradient_backend=gradient_backend 135 | ) 136 | trace = nutpie.sample(compiled, chains=1) 137 | trace.posterior.a # noqa: B018 138 | 139 | 140 | @pytest.mark.pymc 141 | @parameterize_backends 142 | def test_pymc_model_no_prior(backend, gradient_backend): 143 | with pm.Model() as model: 144 | a = pm.Flat("a") 145 | pm.Normal("b", mu=a, observed=0.0) 146 | 147 | compiled = nutpie.compile_pymc_model( 148 | model, backend=backend, gradient_backend=gradient_backend 149 | ) 150 | trace = nutpie.sample(compiled, chains=1) 151 | trace.posterior.a # noqa: B018 152 | 153 | 154 | @pytest.mark.pymc 155 | @parameterize_backends 156 | def test_blocking(backend, gradient_backend): 157 | with pm.Model() as model: 158 | pm.Normal("a") 159 | 160 | compiled = nutpie.compile_pymc_model( 161 | model, backend=backend, gradient_backend=gradient_backend 162 | ) 163 | sampler = nutpie.sample(compiled, chains=1, blocking=False) 164 | trace = sampler.wait() 165 | trace.posterior.a # noqa: B018 166 | 167 | 168 | @pytest.mark.pymc 169 | @parameterize_backends 170 | @pytest.mark.timeout(20) 171 | def test_wait_timeout(backend, gradient_backend): 172 | with pm.Model() as model: 173 | pm.Normal("a", shape=100_000) 174 | compiled = nutpie.compile_pymc_model( 175 | model, backend=backend, gradient_backend=gradient_backend 176 | ) 177 | start = time.time() 178 | sampler = nutpie.sample(compiled, chains=1, blocking=False) 179 | with pytest.raises(TimeoutError): 180 | sampler.wait(timeout=0.1) 181 | sampler.cancel() 182 | assert start - time.time() < 5 183 | 184 | 185 | @pytest.mark.pymc 186 | @parameterize_backends 187 | @pytest.mark.timeout(20) 188 | def test_pause(backend, gradient_backend): 189 | with pm.Model() as model: 190 | pm.Normal("a", shape=100_000) 191 | compiled = nutpie.compile_pymc_model( 192 | model, backend=backend, gradient_backend=gradient_backend 193 | ) 194 | start = time.time() 195 | sampler = nutpie.sample(compiled, chains=1, blocking=False) 196 | sampler.pause() 197 | sampler.resume() 198 | sampler.cancel() 199 | assert start - time.time() < 5 200 | 201 | 202 | @pytest.mark.pymc 203 | @parameterize_backends 204 | def test_pymc_model_with_coordinate(backend, gradient_backend): 205 | with pm.Model() as model: 206 | model.add_coord("foo", length=5) 207 | pm.Normal("a", dims="foo") 208 | 209 | compiled = nutpie.compile_pymc_model( 210 | model, backend=backend, gradient_backend=gradient_backend 211 | ) 212 | trace = nutpie.sample(compiled, chains=1) 213 | trace.posterior.a # noqa: B018 214 | 215 | 216 | @pytest.mark.pymc 217 | @parameterize_backends 218 | def test_pymc_model_store_extra(backend, gradient_backend): 219 | with pm.Model() as model: 220 | model.add_coord("foo", length=5) 221 | pm.Normal("a", dims="foo") 222 | 223 | compiled = nutpie.compile_pymc_model( 224 | model, backend=backend, gradient_backend=gradient_backend 225 | ) 226 | trace = nutpie.sample( 227 | compiled, 228 | chains=1, 229 | store_mass_matrix=True, 230 | store_divergences=True, 231 | store_unconstrained=True, 232 | store_gradient=True, 233 | ) 234 | trace.posterior.a # noqa: B018 235 | _ = trace.sample_stats.unconstrained_draw 236 | _ = trace.sample_stats.gradient 237 | _ = trace.sample_stats.divergence_start 238 | _ = trace.sample_stats.mass_matrix_inv 239 | 240 | 241 | @pytest.mark.pymc 242 | @parameterize_backends 243 | def test_trafo(backend, gradient_backend): 244 | with pm.Model() as model: 245 | pm.Uniform("a") 246 | 247 | compiled = nutpie.compile_pymc_model( 248 | model, backend=backend, gradient_backend=gradient_backend 249 | ) 250 | trace = nutpie.sample(compiled, chains=1) 251 | trace.posterior.a # noqa: B018 252 | 253 | 254 | @pytest.mark.pymc 255 | @parameterize_backends 256 | def test_det(backend, gradient_backend): 257 | with pm.Model() as model: 258 | a = pm.Uniform("a", shape=2) 259 | pm.Deterministic("b", 2 * a) 260 | 261 | compiled = nutpie.compile_pymc_model( 262 | model, backend=backend, gradient_backend=gradient_backend 263 | ) 264 | trace = nutpie.sample(compiled, chains=1) 265 | assert trace.posterior.a.shape[-1] == 2 266 | assert trace.posterior.b.shape[-1] == 2 267 | 268 | 269 | @pytest.mark.pymc 270 | @parameterize_backends 271 | def test_non_identifier_names(backend, gradient_backend): 272 | with pm.Model() as model: 273 | a = pm.Uniform("a/b", shape=2) 274 | with pm.Model("foo"): 275 | c = pm.Data("c", np.array([2.0, 3.0])) 276 | pm.Deterministic("b", c * a) 277 | 278 | compiled = nutpie.compile_pymc_model( 279 | model, backend=backend, gradient_backend=gradient_backend 280 | ) 281 | trace = nutpie.sample(compiled, chains=1) 282 | assert trace.posterior["a/b"].shape[-1] == 2 283 | assert trace.posterior["foo::b"].shape[-1] == 2 284 | 285 | 286 | @pytest.mark.pymc 287 | @parameterize_backends 288 | def test_pymc_model_shared(backend, gradient_backend): 289 | with pm.Model() as model: 290 | mu = pm.Data("mu", -0.1) 291 | sigma = pm.Data("sigma", np.ones(3)) 292 | pm.Normal("a", mu=mu, sigma=sigma, shape=3) 293 | 294 | compiled = nutpie.compile_pymc_model( 295 | model, 296 | backend=backend, 297 | gradient_backend=gradient_backend, 298 | freeze_model=False, 299 | ) 300 | trace = nutpie.sample(compiled, chains=1, seed=1) 301 | np.testing.assert_allclose(trace.posterior.a.mean().values, -0.1, atol=0.05) 302 | 303 | compiled2 = compiled.with_data(mu=10.0, sigma=3 * np.ones(3)) 304 | trace2 = nutpie.sample(compiled2, chains=1, seed=1) 305 | np.testing.assert_allclose(trace2.posterior.a.mean().values, 10.0, atol=0.5) 306 | 307 | compiled3 = compiled.with_data(mu=0.5, sigma=3 * np.ones(4)) 308 | with pytest.raises(RuntimeError): 309 | nutpie.sample(compiled3, chains=1) 310 | 311 | 312 | @pytest.mark.pymc 313 | @parameterize_backends 314 | def test_pymc_var_names(backend, gradient_backend): 315 | with pm.Model() as model: 316 | mu = pm.Data("mu", -0.1) 317 | sigma = pm.Data("sigma", np.ones(3)) 318 | a = pm.Normal("a", mu=mu, sigma=sigma, shape=3) 319 | 320 | b = pm.Deterministic("b", mu * a) 321 | pm.Deterministic("c", mu * b) 322 | 323 | compiled = nutpie.compile_pymc_model( 324 | model, 325 | backend=backend, 326 | gradient_backend=gradient_backend, 327 | var_names=None, 328 | ) 329 | trace = nutpie.sample(compiled, chains=1, seed=1) 330 | 331 | # Check that variables are stored 332 | assert hasattr(trace.posterior, "b") 333 | assert hasattr(trace.posterior, "c") 334 | 335 | compiled = nutpie.compile_pymc_model( 336 | model, 337 | backend=backend, 338 | gradient_backend=gradient_backend, 339 | var_names=[], 340 | ) 341 | trace = nutpie.sample(compiled, chains=1, seed=1) 342 | 343 | # Check that variables are stored 344 | assert not hasattr(trace.posterior, "b") 345 | assert not hasattr(trace.posterior, "c") 346 | 347 | compiled = nutpie.compile_pymc_model( 348 | model, 349 | backend=backend, 350 | gradient_backend=gradient_backend, 351 | var_names=["b"], 352 | ) 353 | trace = nutpie.sample(compiled, chains=1, seed=1) 354 | 355 | # Check that variables are stored 356 | assert hasattr(trace.posterior, "b") 357 | assert not hasattr(trace.posterior, "c") 358 | 359 | 360 | # TODO For some reason, the sampling results with jax are 361 | # not reproducible accross operating systems. Figure this 362 | # out and add the array_compare marker. 363 | # @pytest.mark.array_compare 364 | @pytest.mark.pymc 365 | @pytest.mark.flow 366 | def test_normalizing_flow(): 367 | with pm.Model() as model: 368 | pm.HalfNormal("x", shape=2) 369 | 370 | compiled = nutpie.compile_pymc_model( 371 | model, backend="jax", gradient_backend="jax" 372 | ).with_transform_adapt( 373 | verbose=True, 374 | num_layers=2, 375 | ) 376 | trace = nutpie.sample( 377 | compiled, 378 | chains=1, 379 | transform_adapt=True, 380 | window_switch_freq=128, 381 | seed=1, 382 | draws=500, 383 | ) 384 | assert float(trace.sample_stats.fisher_distance.mean()) < 0.1 385 | # return trace.posterior.x.isel(draw=slice(-50, None)).values.ravel() 386 | 387 | 388 | @pytest.mark.pymc 389 | @pytest.mark.parametrize( 390 | ("backend", "gradient_backend"), 391 | [ 392 | ("numba", None), 393 | pytest.param( 394 | "jax", 395 | "pytensor", 396 | marks=pytest.mark.xfail( 397 | reason="https://github.com/pymc-devs/pytensor/issues/853" 398 | ), 399 | ), 400 | pytest.param( 401 | "jax", 402 | "jax", 403 | marks=pytest.mark.xfail( 404 | reason="https://github.com/pymc-devs/pytensor/issues/853" 405 | ), 406 | ), 407 | ], 408 | ) 409 | def test_missing(backend, gradient_backend): 410 | with pm.Model(coords={"obs": range(4)}) as model: 411 | mu = pm.Normal("mu") 412 | y = pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs") 413 | pm.Deterministic("y2", 2 * y, dims="obs") 414 | 415 | compiled = nutpie.compile_pymc_model( 416 | model, backend=backend, gradient_backend=gradient_backend 417 | ) 418 | tr = nutpie.sample(compiled, chains=1, seed=1) 419 | print(tr.posterior) 420 | assert hasattr(tr.posterior, "y_unobserved") 421 | 422 | 423 | @pytest.mark.pymc 424 | @pytest.mark.array_compare 425 | def test_deterministic_sampling_numba(): 426 | with pm.Model() as model: 427 | pm.HalfNormal("a") 428 | 429 | compiled = nutpie.compile_pymc_model(model, backend="numba") 430 | trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) 431 | return trace.posterior.a.values.ravel() 432 | 433 | 434 | @pytest.mark.pymc 435 | @pytest.mark.array_compare 436 | def test_deterministic_sampling_jax(): 437 | with pm.Model() as model: 438 | pm.HalfNormal("a") 439 | 440 | compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax") 441 | trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) 442 | return trace.posterior.a.values.ravel() 443 | -------------------------------------------------------------------------------- /tests/test_stan.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | import pytest 3 | 4 | if find_spec("bridgestan") is None: 5 | pytest.skip("Skip stan tests", allow_module_level=True) 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | import nutpie 11 | 12 | 13 | @pytest.mark.stan 14 | def test_stan_model(): 15 | model = """ 16 | data {} 17 | parameters { 18 | real a; 19 | } 20 | model { 21 | a ~ normal(0, 1); 22 | } 23 | """ 24 | 25 | compiled_model = nutpie.compile_stan_model(code=model) 26 | trace = nutpie.sample(compiled_model) 27 | trace.posterior.a # noqa: B018 28 | 29 | 30 | @pytest.mark.stan 31 | def test_stan_model_low_rank(): 32 | model = """ 33 | data {} 34 | parameters { 35 | real a; 36 | } 37 | model { 38 | a ~ normal(0, 1); 39 | } 40 | """ 41 | 42 | compiled_model = nutpie.compile_stan_model(code=model) 43 | trace = nutpie.sample(compiled_model, low_rank_modified_mass_matrix=True) 44 | trace.posterior.a # noqa: B018 45 | 46 | 47 | @pytest.mark.stan 48 | def test_empty(): 49 | model = """ 50 | data {} 51 | parameters { 52 | array[0] real a; 53 | } 54 | model { 55 | a ~ normal(0, 1); 56 | } 57 | """ 58 | 59 | compiled_model = nutpie.compile_stan_model(code=model) 60 | nutpie.sample(compiled_model) 61 | # TODO: Variable `a` is missing because of this bridgestan issue: 62 | # https://github.com/roualdes/bridgestan/issues/278 63 | # assert trace.posterior.a.shape == (0, 1000) 64 | 65 | 66 | @pytest.mark.stan 67 | def test_seed(): 68 | model = """ 69 | data {} 70 | parameters { 71 | real a; 72 | } 73 | model { 74 | a ~ normal(0, 1); 75 | } 76 | generated quantities { 77 | real b = normal_rng(0, 1); 78 | } 79 | """ 80 | 81 | compiled_model = nutpie.compile_stan_model(code=model) 82 | trace = nutpie.sample(compiled_model, seed=42) 83 | trace2 = nutpie.sample(compiled_model, seed=42) 84 | trace3 = nutpie.sample(compiled_model, seed=43) 85 | 86 | assert np.allclose(trace.posterior.a, trace2.posterior.a) 87 | assert np.allclose(trace.posterior.b, trace2.posterior.b) 88 | 89 | assert not np.allclose(trace.posterior.a, trace3.posterior.a) 90 | assert not np.allclose(trace.posterior.b, trace3.posterior.b) 91 | # Check that all chains are pairwise different 92 | for i in range(len(trace.posterior.a)): 93 | for j in range(i + 1, len(trace.posterior.a)): 94 | assert not np.allclose(trace.posterior.a[i], trace.posterior.a[j]) 95 | assert not np.allclose(trace.posterior.b[i], trace.posterior.b[j]) 96 | # Check that all chains are pairwise different between seeds 97 | for i in range(len(trace.posterior.a)): 98 | for j in range(len(trace3.posterior.a)): 99 | assert not np.allclose(trace.posterior.a[i], trace3.posterior.a[j]) 100 | assert not np.allclose(trace.posterior.b[i], trace3.posterior.b[j]) 101 | 102 | 103 | @pytest.mark.stan 104 | def test_nested(): 105 | # Adapted from 106 | # https://github.com/stan-dev/stanio/blob/main/test/data/tuples/output.stan 107 | model = """ 108 | parameters { 109 | real a; 110 | } 111 | model { 112 | a ~ normal(0, 1); 113 | } 114 | generated quantities { 115 | real base = normal_rng(0, 1); 116 | int base_i = to_int(normal_rng(10, 10)); 117 | 118 | tuple(real, real) pair = (base, base * 2); 119 | 120 | tuple(real, tuple(int, complex)) nested = (base * 3, (base_i, base * 4.0i)); 121 | array[2] tuple(real, real) arr_pair = {pair, (base * 5, base * 6)}; 122 | 123 | array[3] tuple(tuple(real, tuple(int, complex)), real) arr_very_nested 124 | = {(nested, base*7), ((base*8, (base_i*2, base*9.0i)), base * 10), (nested, base*11)}; 125 | 126 | array[3,2] tuple(real, real) arr_2d_pair = {{(base * 12, base * 13), (base * 14, base * 15)}, 127 | {(base * 16, base * 17), (base * 18, base * 19)}, 128 | {(base * 20, base * 21), (base * 22, base * 23)}}; 129 | 130 | real basep1 = base + 1, basep2 = base + 2; 131 | real basep3 = base + 3, basep4 = base + 4, basep5 = base + 5; 132 | array[2,3] tuple(array[2] tuple(real, vector[2]), matrix[4,5]) ultimate = 133 | { 134 | {( 135 | {(base, [base *2, base *3]'), (base *4, [base*5, base*6]')}, 136 | to_matrix(linspaced_vector(20, 7, 11), 4, 5) * base 137 | ), 138 | ( 139 | {(basep1, [basep1 *2, basep1 *3]'), (basep1 *4, [basep1*5, basep1*6]')}, 140 | to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep1 141 | ), 142 | ( 143 | {(basep2, [basep2 *2, basep2 *3]'), (basep2 *4, [basep2*5, basep2*6]')}, 144 | to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep2 145 | ) 146 | }, 147 | {( 148 | {(basep3, [basep3 *2, basep3 *3]'), (basep3 *4, [basep3*5, basep3*6]')}, 149 | to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep3 150 | ), 151 | ( 152 | {(basep4, [basep4 *2, basep4 *3]'), (basep4 *4, [basep4*5, basep4*6]')}, 153 | to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep4 154 | ), 155 | ( 156 | {(basep5, [basep5 *2, basep5 *3]'), (basep5 *4, [basep5*5, basep5*6]')}, 157 | to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep5 158 | ) 159 | }}; 160 | } 161 | """ 162 | 163 | compiled = nutpie.compile_stan_model(code=model) 164 | tr = nutpie.sample(compiled, chains=6) 165 | base = tr.posterior.base 166 | 167 | assert np.allclose(tr.posterior["nested:2:2.imag"], 4 * base) 168 | assert np.allclose(tr.posterior["nested:2:2.real"], 0.0) 169 | 170 | assert np.allclose(tr.posterior["ultimate.1.1:1.1:1"], base) 171 | assert np.allclose(tr.posterior["ultimate.1.2:1.1:1"], base + 1) 172 | assert np.allclose(tr.posterior["ultimate.1.3:1.1:1"], base + 2) 173 | assert np.allclose(tr.posterior["ultimate.2.1:1.1:1"], base + 3) 174 | assert np.allclose(tr.posterior["ultimate.2.2:1.1:1"], base + 4) 175 | assert np.allclose(tr.posterior["ultimate.2.3:1.1:1"], base + 5) 176 | 177 | assert tr.posterior["ultimate.2.1:1.1:2"].shape == (6, 1000, 2) 178 | assert np.allclose( 179 | tr.posterior["ultimate.2.3:1.1:2"].values[:, :, 0], 2 * (base + 5) 180 | ) 181 | assert np.allclose( 182 | tr.posterior["ultimate.2.3:1.1:2"].values[:, :, 1], 3 * (base + 5) 183 | ) 184 | assert np.allclose(tr.posterior["base_i"], tr.posterior.base_i.astype(int)) 185 | 186 | 187 | @pytest.mark.stan 188 | def test_stan_model_data(): 189 | model = """ 190 | data { 191 | complex x; 192 | } 193 | parameters { 194 | real a; 195 | } 196 | model { 197 | a ~ normal(0, 1); 198 | } 199 | """ 200 | 201 | compiled_model = nutpie.compile_stan_model(code=model) 202 | with pytest.raises(RuntimeError): 203 | trace = nutpie.sample(compiled_model) 204 | trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0j))) 205 | trace.posterior.a # noqa: B018 206 | 207 | 208 | @pytest.mark.stan 209 | def test_stan_memory_order(): 210 | model = """ 211 | data { 212 | real x; 213 | } 214 | parameters { 215 | real a; 216 | } 217 | model { 218 | a ~ normal(0, 1); 219 | } 220 | generated quantities { 221 | array[2, 3] matrix[5, 7] b; 222 | real count = 0; 223 | for (i in 1:2) 224 | for (j in 1:3) { 225 | for (k in 1:5) { 226 | for (n in 1:7) { 227 | b[i, j][k, n] = count; 228 | count = count + 1; 229 | } 230 | } 231 | } 232 | } 233 | """ 234 | 235 | compiled_model = nutpie.compile_stan_model(code=model) 236 | with pytest.raises(RuntimeError): 237 | trace = nutpie.sample(compiled_model) 238 | trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0))) 239 | trace.posterior.a # noqa: B018 240 | assert trace.posterior.b.shape == (6, 1000, 2, 3, 5, 7) 241 | b = trace.posterior.b.isel(chain=0, draw=0) 242 | count = 0 243 | for i in range(2): 244 | for j in range(3): 245 | for k in range(5): 246 | for n in range(7): 247 | assert float(b[i, j, k, n]) == count 248 | count += 1 249 | 250 | 251 | @pytest.mark.flow 252 | @pytest.mark.stan 253 | def test_stan_flow(): 254 | model = """ 255 | parameters { 256 | array[5] real a; 257 | real b; 258 | } 259 | model { 260 | a ~ normal(0, 1); 261 | b ~ normal(0, 1); 262 | } 263 | """ 264 | import jax 265 | 266 | old = jax.config.update("jax_enable_x64", True) 267 | try: 268 | compiled_model = nutpie.compile_stan_model(code=model).with_transform_adapt( 269 | num_layers=2, 270 | nn_width=4, 271 | ) 272 | trace = nutpie.sample(compiled_model, transform_adapt=True, tune=2000, chains=1) 273 | assert float(trace.sample_stats.fisher_distance.mean()) < 0.1 274 | trace.posterior.a # noqa: B018 275 | finally: 276 | jax.config.update("jax_enable_x64", old) 277 | 278 | 279 | # TODO: There are small numerical differences between linux and windows. 280 | # We should figure out if they originate in stan or in nutpie. 281 | @pytest.mark.array_compare(atol=1e-4) 282 | @pytest.mark.stan 283 | def test_deterministic_sampling_stan(): 284 | model = """ 285 | parameters { 286 | real a; 287 | } 288 | model { 289 | a ~ normal(0, 1); 290 | } 291 | generated quantities { 292 | real b = normal_rng(0, 1) + a; 293 | } 294 | """ 295 | 296 | compiled_model = nutpie.compile_stan_model(code=model) 297 | trace = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) 298 | trace2 = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) 299 | np.testing.assert_allclose(trace.posterior.a.values, trace2.posterior.a.values) 300 | np.testing.assert_allclose(trace.posterior.b.values, trace2.posterior.b.values) 301 | return trace.posterior.a.isel(draw=slice(None, 10)).values 302 | --------------------------------------------------------------------------------