├── .dockerignore ├── .git-blame-ignore-revs ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ ├── developer.yml │ ├── documentation.yml │ └── feature-request.yml ├── PULL_REQUEST_TEMPLATE.md ├── config.yml ├── dependabot.yml ├── release.yml └── workflows │ ├── devcontainer-docker-image.yml │ ├── docker-image.yml │ ├── mypy.yml │ ├── pr-auto-label.yml │ ├── release.yml │ ├── rtd-link-preview.yml │ ├── slash_dispatch.yml │ ├── tests.yml │ └── zizmor.yml ├── .gitignore ├── .gitpod.yml ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── ARCHITECTURE.md ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GOVERNANCE.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── benchmarks ├── asv.conf.json └── benchmarks │ ├── __init__.py │ └── benchmarks.py ├── binder ├── apt.txt ├── requirements.txt └── trigger_binder.sh ├── codecov.yml ├── conda-envs ├── environment-alternative-backends.yml ├── environment-dev.yml ├── environment-docs.yml ├── environment-test.yml ├── windows-environment-dev.yml └── windows-environment-test.yml ├── docs ├── Architecture.png ├── community_diagram.png ├── logos │ ├── PyMC.ai │ ├── PyMC.eps │ ├── PyMC.ico │ ├── PyMC.jpg │ ├── PyMC.pdf │ ├── PyMC.png │ ├── sponsors │ │ ├── numfocus.png │ │ ├── odsc.png │ │ └── pymc-labs.png │ └── svg │ │ ├── PyMC_banner.svg │ │ ├── PyMC_circle.svg │ │ └── PyMC_square.svg └── source │ ├── 404.md │ ├── _templates │ ├── autosummary │ │ └── class.rst │ └── distribution.rst │ ├── api.rst │ ├── api │ ├── backends.rst │ ├── data.rst │ ├── distributions.rst │ ├── distributions │ │ ├── censored.rst │ │ ├── continuous.rst │ │ ├── custom.rst │ │ ├── discrete.rst │ │ ├── mixture.rst │ │ ├── multivariate.rst │ │ ├── simulator.rst │ │ ├── timeseries.rst │ │ ├── transforms.rst │ │ ├── truncated.rst │ │ └── utilities.rst │ ├── gp.rst │ ├── gp │ │ ├── cov.rst │ │ ├── implementations.rst │ │ ├── mean.rst │ │ └── util.rst │ ├── logprob.rst │ ├── math.rst │ ├── misc.rst │ ├── model.rst │ ├── model │ │ ├── conditioning.rst │ │ ├── core.rst │ │ ├── fgraph.rst │ │ └── optimization.rst │ ├── ode.rst │ ├── pytensorf.rst │ ├── samplers.rst │ ├── shape_utils.rst │ ├── smc.rst │ ├── testing.rst │ ├── tuning.rst │ └── vi.rst │ ├── conf.py │ ├── contributing │ ├── build_docs.md │ ├── developer_guide.md │ ├── docker_container.md │ ├── gitpod │ │ ├── gitpod_integration.png │ │ └── gitpod_workspace.png │ ├── implementing_distribution.md │ ├── index.md │ ├── jupyter_style.md │ ├── pr_checklist.md │ ├── pr_tutorial.md │ ├── python_style.md │ ├── release_checklist.md │ ├── review_pr_pymc_examples.md │ ├── running_the_test_suite.md │ ├── using_gitpod.md │ └── versioning_schemes_explanation.md │ ├── glossary.md │ ├── guides │ ├── Gaussian_Processes.rst │ └── Probability_Distributions.rst │ ├── images │ ├── forestplot.png │ └── model_to_graphviz.png │ ├── index.md │ ├── installation.md │ ├── learn.md │ └── learn │ ├── books.md │ ├── consulting.md │ ├── core_notebooks │ ├── GLM_linear.ipynb │ ├── Gaussian_Processes.rst │ ├── dimensionality.ipynb │ ├── index.md │ ├── model_comparison.ipynb │ ├── posterior_predictive.ipynb │ ├── pymc_overview.ipynb │ └── pymc_pytensor.ipynb │ ├── usage_overview.rst │ └── videos_and_podcasts.md ├── pymc ├── __init__.py ├── _version.py ├── backends │ ├── __init__.py │ ├── arviz.py │ ├── base.py │ ├── mcbackend.py │ ├── ndarray.py │ ├── report.py │ └── zarr.py ├── blocking.py ├── data.py ├── distributions │ ├── __init__.py │ ├── censored.py │ ├── continuous.py │ ├── custom.py │ ├── discrete.py │ ├── dist_math.py │ ├── distribution.py │ ├── mixture.py │ ├── moments │ │ ├── __init__.py │ │ └── means.py │ ├── multivariate.py │ ├── shape_utils.py │ ├── simulator.py │ ├── timeseries.py │ ├── transforms.py │ └── truncated.py ├── exceptions.py ├── func_utils.py ├── gp │ ├── __init__.py │ ├── cov.py │ ├── gp.py │ ├── hsgp_approx.py │ ├── mean.py │ └── util.py ├── initial_point.py ├── logprob │ ├── LICENSE_AEPPL.txt │ ├── __init__.py │ ├── abstract.py │ ├── basic.py │ ├── binary.py │ ├── censoring.py │ ├── checks.py │ ├── cumsum.py │ ├── linalg.py │ ├── mixture.py │ ├── order.py │ ├── rewriting.py │ ├── scan.py │ ├── tensor.py │ ├── transform_value.py │ ├── transforms.py │ └── utils.py ├── math.py ├── model │ ├── __init__.py │ ├── core.py │ ├── fgraph.py │ └── transform │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── conditioning.py │ │ └── optimization.py ├── model_graph.py ├── ode │ ├── __init__.py │ ├── ode.py │ └── utils.py ├── plots │ └── __init__.py ├── printing.py ├── pytensorf.py ├── sampling │ ├── __init__.py │ ├── deterministic.py │ ├── forward.py │ ├── jax.py │ ├── mcmc.py │ ├── parallel.py │ └── population.py ├── smc │ ├── __init__.py │ ├── kernels.py │ └── sampling.py ├── stats │ ├── __init__.py │ ├── convergence.py │ └── log_density.py ├── step_methods │ ├── __init__.py │ ├── arraystep.py │ ├── compound.py │ ├── hmc │ │ ├── __init__.py │ │ ├── base_hmc.py │ │ ├── hmc.py │ │ ├── integration.py │ │ ├── nuts.py │ │ └── quadpotential.py │ ├── metropolis.py │ ├── slicer.py │ ├── state.py │ └── step_sizes.py ├── testing.py ├── tuning │ ├── __init__.py │ ├── scaling.py │ └── starting.py ├── util.py ├── variational │ ├── __init__.py │ ├── approximations.py │ ├── callbacks.py │ ├── inference.py │ ├── minibatch_rv.py │ ├── operators.py │ ├── opvi.py │ ├── stein.py │ ├── test_functions.py │ └── updates.py └── vartypes.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts ├── Dockerfile ├── check_all_tests_are_covered.py ├── dev.Dockerfile ├── docker_container.sh ├── generate_pip_deps_from_conda.py ├── run_mypy.py ├── slowest_tests │ ├── extract-slow-tests.py │ └── update-slowest-times-issue.sh └── test.sh ├── setup.py └── tests ├── __init__.py ├── backends ├── __init__.py ├── fixtures.py ├── test_arviz.py ├── test_base.py ├── test_mcbackend.py ├── test_ndarray.py └── test_zarr.py ├── conftest.py ├── distributions ├── __init__.py ├── moments │ ├── __init__.py │ └── test_means.py ├── test_censored.py ├── test_continuous.py ├── test_custom.py ├── test_discrete.py ├── test_dist_math.py ├── test_distribution.py ├── test_mixture.py ├── test_multivariate.py ├── test_random_alternative_backends.py ├── test_shape_utils.py ├── test_simulator.py ├── test_timeseries.py ├── test_transform.py └── test_truncated.py ├── gp ├── __init__.py ├── test_cov.py ├── test_gp.py ├── test_hsgp_approx.py ├── test_mean.py └── test_util.py ├── helpers.py ├── logprob ├── __init__.py ├── test_abstract.py ├── test_basic.py ├── test_binary.py ├── test_censoring.py ├── test_checks.py ├── test_composite_logprob.py ├── test_cumsum.py ├── test_linalg.py ├── test_mixture.py ├── test_order.py ├── test_rewriting.py ├── test_scan.py ├── test_tensor.py ├── test_transform_value.py ├── test_transforms.py ├── test_utils.py └── utils.py ├── model ├── __init__.py ├── test_core.py ├── test_fgraph.py └── transform │ ├── __init__.py │ ├── test_basic.py │ ├── test_conditioning.py │ └── test_optimization.py ├── models.py ├── ode ├── __init__.py ├── test_ode.py └── test_utils.py ├── sampler_fixtures.py ├── sampling ├── __init__.py ├── test_deterministic.py ├── test_forward.py ├── test_jax.py ├── test_mcmc.py ├── test_mcmc_external.py ├── test_parallel.py └── test_population.py ├── smc ├── __init__.py └── test_smc.py ├── stats ├── __init__.py ├── test_convergence.py └── test_log_density.py ├── step_methods ├── __init__.py ├── hmc │ ├── __init__.py │ ├── test_hmc.py │ ├── test_nuts.py │ └── test_quadpotential.py ├── test_compound.py ├── test_metropolis.py ├── test_slicer.py └── test_state.py ├── test_data.py ├── test_func_utils.py ├── test_initial_point.py ├── test_math.py ├── test_model_graph.py ├── test_printing.py ├── test_pytensorf.py ├── test_testing.py ├── test_util.py ├── tuning ├── __init__.py ├── test_scaling.py └── test_starting.py └── variational ├── __init__.py ├── test_approximations.py ├── test_callbacks.py ├── test_inference.py ├── test_minibatch_rv.py ├── test_opvi.py └── test_updates.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .mypy_cache 3 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Updated Apache v2 license header for year 2023 on entire codebase 2 | 3ea470ac964f4bd5c7207ce08a583a2e6aa7ae8a 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | pymc/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | # This issue template was adapted from the NumPy project 2 | # under the BSD 3-Clause "New" or "Revised" License. 3 | # Copyright (c) 2005-2022, NumPy Developers. 4 | # All rights reserved. 5 | 6 | name: Bug report 7 | description: Report a bug. For security vulnerabilities see Report a security vulnerability in the templates. 8 | title: "BUG: " 9 | labels: [bug] 10 | 11 | body: 12 | - type: markdown 13 | attributes: 14 | value: > 15 | Thank you for taking the time to file a bug report. Before creating a new 16 | issue, please make sure to take a few minutes to check the issue tracker 17 | for existing issues about the bug. 18 | 19 | - type: textarea 20 | attributes: 21 | label: "Describe the issue:" 22 | validations: 23 | required: true 24 | 25 | - type: textarea 26 | attributes: 27 | label: "Reproduceable code example:" 28 | description: > 29 | A short code example that reproduces the problem/missing feature. It 30 | should be self-contained, i.e., can be copy-pasted into the Python 31 | interpreter or run as-is via `python myproblem.py`. 32 | placeholder: | 33 | import pymc as pm 34 | << your code here >> 35 | render: python 36 | validations: 37 | required: true 38 | 39 | - type: textarea 40 | attributes: 41 | label: "Error message:" 42 | description: > 43 | Please include full error message, if any. 44 | placeholder: | 45 |
46 | Full traceback starting from `Traceback: ...` 47 |
48 | render: shell 49 | 50 | - type: textarea 51 | attributes: 52 | label: "PyMC version information:" 53 | description: > 54 | PyMC/PyMC3 Version: 55 | PyTensor/Aesara Version: 56 | Python Version: 57 | Operating system: 58 | How did you install PyMC/PyMC3: (conda/pip) 59 | placeholder: | 60 |
61 | configuration information 62 |
63 | validations: 64 | required: true 65 | 66 | - type: textarea 67 | attributes: 68 | label: "Context for the issue:" 69 | description: | 70 | Please explain how this issue affects your work or why it should be prioritized. 71 | placeholder: | 72 | << your explanation here >> 73 | validations: 74 | required: false 75 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: PyMC Discourse 4 | url: https://discourse.pymc.io/ 5 | about: Ask installation and usage questions about PyMC/PyMC3 6 | - name: Example notebook error report 7 | url: https://github.com/pymc-devs/pymc-examples/issues 8 | about: Please report errors or desired extensions to the tutorials and examples here. 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/developer.yml: -------------------------------------------------------------------------------- 1 | name: Developer issue 2 | description: This template is for developers only! 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: Description 8 | validations: 9 | required: true 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | # This issue template was adapted from the NumPy project 2 | # under the BSD 3-Clause "New" or "Revised" License. 3 | # Copyright (c) 2005-2022, NumPy Developers. 4 | # All rights reserved. 5 | 6 | 7 | name: Documentation 8 | description: Report an issue related to the PyMC documentation. 9 | title: "DOC: " 10 | labels: [docs] 11 | 12 | body: 13 | - type: textarea 14 | attributes: 15 | label: "Issue with current documentation:" 16 | description: > 17 | Please make sure to leave a reference to the document/code you're 18 | referring to. You can also check the development version of the 19 | documentation and see if this issue has already been addressed at 20 | https://www.pymc.io/projects/docs/en/latest/api.html. 21 | 22 | - type: textarea 23 | attributes: 24 | label: "Idea or request for content:" 25 | description: > 26 | Please describe as clearly as possible what topics you think are missing 27 | from the current documentation. 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | # This issue template was adapted from the NumPy project 2 | # under the BSD 3-Clause "New" or "Revised" License. 3 | # Copyright (c) 2005-2022, NumPy Developers. 4 | # All rights reserved. 5 | 6 | name: Feature request 7 | description: Make a specific, well-motivated proposal for a feature. 8 | title: "ENH: " 9 | labels: [feature request] 10 | 11 | 12 | body: 13 | - type: markdown 14 | attributes: 15 | value: > 16 | If you're looking to request a new feature or change in functionality, 17 | including adding or changing the meaning of arguments to an existing 18 | function, please post your idea first as a [Discussion](https://github.com/pymc-devs/pymc/discussions) 19 | to validate it and bring attention to it. After validation, 20 | you can open this issue for a more technical developer discussion. 21 | Check the [Contributor Guide](https://github.com/pymc-devs/pymc/blob/main/CONTRIBUTING.md) 22 | if you need more information. 23 | 24 | - type: textarea 25 | attributes: 26 | label: "Before" 27 | description: > 28 | Please fill the code snippet: How did you workaround your problem or frequent use? 29 | Leave empty if you found no workaround. 30 | render: python 31 | validations: 32 | required: false 33 | 34 | - type: textarea 35 | attributes: 36 | label: "After" 37 | description: > 38 | How you see it implemented with a high level API without going into details 39 | render: python 40 | validations: 41 | required: false 42 | 43 | - type: textarea 44 | attributes: 45 | label: "Context for the issue:" 46 | description: | 47 | Please explain how this issue affects your work, why it should be prioritized 48 | or add any information that did not fit Before After template. 49 | placeholder: | 50 | << your explanation here >> 51 | validations: 52 | required: false 53 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | ## Description 7 | 8 | 9 | ## Related Issue 10 | 11 | 12 | - [ ] Closes # 13 | - [ ] Related to # 14 | 15 | ## Checklist 16 | 17 | 18 | - [ ] Checked that [the pre-commit linting/style checks pass](https://docs.pymc.io/en/latest/contributing/python_style.html) 19 | - [ ] Included tests that prove the fix is effective or that the new feature works 20 | - [ ] Added necessary documentation (docstrings and/or example notebooks) 21 | - [ ] If you are a pro: each commit corresponds to a [relevant logical change](https://wiki.openstack.org/wiki/GitCommitMessages#Structural_split_of_changes) 22 | 23 | 24 | ## Type of change 25 | 26 | - [ ] New feature / enhancement 27 | - [ ] Bug fix 28 | - [ ] Documentation 29 | - [ ] Maintenance 30 | - [ ] Other (please specify): 31 | 32 | -------------------------------------------------------------------------------- /.github/config.yml: -------------------------------------------------------------------------------- 1 | # Comment to be posted to on first time issues 2 | newIssueWelcomeComment: > 3 | ![Welcome Banner](https://raw.githubusercontent.com/pymc-devs/brand/main/welcome-bot/BannerWelcome.jpg)] 4 | 5 | :tada: Welcome to _PyMC_! :tada: 6 | We're really excited to have your input into the project! :sparkling_heart: 7 | 8 |
If you haven't done so already, please make sure you check out our [Contributing Guidelines](https://www.pymc.io/projects/docs/en/latest/contributing/index.html) and [Code of Conduct](https://github.com/pymc-devs/pymc/blob/main/CODE_OF_CONDUCT.md). 9 | 10 | 11 | # Comment to be posted to on PRs from first time contributors in your repository 12 | newPRWelcomeComment: > 13 | ![Thank You Banner](https://raw.githubusercontent.com/pymc-devs/brand/main/welcome-bot/BannerThanks.jpg)] 14 | 15 | :sparkling_heart: Thanks for opening this pull request! :sparkling_heart: 16 | The _PyMC_ community really appreciates your time and effort to contribute to the project. 17 | Please make sure you have read our [Contributing Guidelines](https://www.pymc.io/projects/docs/en/latest/contributing/index.html) and filled in our pull request template to the best of your ability. 18 | 19 | 20 | # Comment to be posted to on pull requests merged by a first time user 21 | firstPRMergeComment: > 22 | ![Congratulations Banner](https://raw.githubusercontent.com/pymc-devs/brand/main/welcome-bot/BannerCongratulations.jpg)] 23 | 24 | Congrats on merging your first pull request! :tada: 25 | We here at _PyMC_ are proud of you! :sparkling_heart: 26 | Thank you so much for your contribution :gift: 27 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | labels: 9 | - "Github CI/CD" 10 | - "no releasenotes" 11 | -------------------------------------------------------------------------------- /.github/release.yml: -------------------------------------------------------------------------------- 1 | # This file contains configuration for the automatic generation of release notes in GitHub. 2 | # It's not perfect, but it makes it a little less laborious to write informative release notes. 3 | # Also see https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes 4 | changelog: 5 | exclude: 6 | labels: 7 | - no releasenotes 8 | categories: 9 | - title: Major Changes 🛠 10 | labels: 11 | - major 12 | - title: New Features 🎉 13 | labels: 14 | - enhancements 15 | - feature request 16 | - title: Bugfixes 🪲 17 | labels: 18 | - bug 19 | - title: Documentation 📖 20 | labels: 21 | - docs 22 | - title: Maintenance 🔧 23 | labels: 24 | - "*" 25 | -------------------------------------------------------------------------------- /.github/workflows/devcontainer-docker-image.yml: -------------------------------------------------------------------------------- 1 | name: devcontainer-docker-image 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "48 19 * * 5" # Fridays at 19:48 UTC 7 | release: 8 | types: [published] 9 | 10 | env: 11 | REGISTRY: ghcr.io 12 | IMAGE_NAME: ${{ github.repository }}-devcontainer # pymc-devs/pymc-devcontainer 13 | 14 | jobs: 15 | build-container: 16 | runs-on: ubuntu-latest 17 | 18 | # Set permissions for GitHub token 19 | # 20 | permissions: 21 | contents: read 22 | packages: write 23 | 24 | steps: 25 | - name: Checkout source 26 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 27 | with: 28 | persist-credentials: false 29 | 30 | - name: Setup Docker buildx 31 | uses: docker/setup-buildx-action@v3.10.0 32 | 33 | - name: Prepare metadata 34 | id: meta 35 | uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 36 | with: 37 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 38 | tags: | 39 | type=sha,enable=true,prefix=git- 40 | type=raw,value=latest 41 | 42 | - name: Log into registry ${{ env.REGISTRY }} 43 | uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 44 | with: 45 | registry: ${{ env.REGISTRY }} 46 | username: ${{ github.actor }} 47 | password: ${{ secrets.GITHUB_TOKEN }} 48 | 49 | - name: Build and push Docker image 50 | id: docker_build 51 | uses: docker/build-push-action@14487ce63c7a62a4a324b0bfb37086795e31c6c1 52 | with: 53 | context: . 54 | file: scripts/dev.Dockerfile 55 | platforms: linux/amd64 # ,linux/arm64 56 | push: true 57 | tags: ${{ steps.meta.outputs.tags }} 58 | labels: ${{ steps.meta.outputs.labels }} 59 | cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest 60 | cache-to: type=inline 61 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: docker-image 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | env: 9 | CONTAINER_NAME: build-test 10 | 11 | jobs: 12 | build-container: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 17 | with: 18 | persist-credentials: false 19 | 20 | - name: Login to Docker Hub 21 | uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 22 | with: 23 | username: ${{ secrets.DOCKERHUB_USERNAME }} 24 | password: ${{ secrets.DOCKERHUB_TOKEN }} 25 | 26 | - name: Extract metadata (tags, labels) for Docker 27 | id: meta 28 | uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 29 | with: 30 | images: | 31 | name=pymc/pymc,enable=true 32 | tags: | 33 | type=ref,event=branch 34 | type=ref,event=pr 35 | type=semver,pattern={{version}} 36 | type=semver,pattern={{major}}.{{minor}} 37 | 38 | - name: Build and load image 39 | uses: docker/build-push-action@14487ce63c7a62a4a324b0bfb37086795e31c6c1 40 | with: 41 | context: . 42 | file: scripts/Dockerfile 43 | load: true 44 | tags: ${{ env.CONTAINER_NAME }} 45 | 46 | - name: Test importing pymc 47 | run: | 48 | docker run --rm ${{ env.CONTAINER_NAME }} conda run -n pymc-dev python -c 'import pymc;print(pymc.__version__)' 49 | 50 | - name: Build and push 51 | uses: docker/build-push-action@14487ce63c7a62a4a324b0bfb37086795e31c6c1 52 | with: 53 | context: . 54 | push: true 55 | file: scripts/Dockerfile 56 | tags: ${{ steps.meta.outputs.tags }} 57 | labels: ${{ steps.meta.outputs.labels }} 58 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | defaults: 9 | run: 10 | shell: bash -leo pipefail {0} 11 | jobs: 12 | mypy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 16 | with: 17 | persist-credentials: false 18 | - uses: mamba-org/setup-micromamba@v2 19 | with: 20 | environment-file: conda-envs/environment-test.yml 21 | create-args: >- 22 | python=3.10 23 | environment-name: pymc-test 24 | init-shell: bash 25 | cache-environment: true 26 | - name: Install-pymc and mypy dependencies 27 | run: | 28 | pip install -e . 29 | python --version 30 | - name: Run mypy 31 | run: | 32 | python ./scripts/run_mypy.py --verbose 33 | -------------------------------------------------------------------------------- /.github/workflows/pr-auto-label.yml: -------------------------------------------------------------------------------- 1 | name: "Pull Request Labeler" 2 | on: 3 | # The labeler doesn't execute any contributed code, so it should be fairly safe. 4 | - pull_request_target # zizmor: ignore[dangerous-triggers] 5 | 6 | jobs: 7 | sync: 8 | permissions: 9 | pull-requests: write 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Sync labels with closing issues 13 | uses: williambdean/closing-labels@v0.0.4 14 | with: 15 | exclude: "help wanted,needs info,beginner friendly" 16 | env: 17 | GH_TOKEN: ${{ github.token }} 18 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release-pipeline 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | release: 8 | types: 9 | - published 10 | 11 | jobs: 12 | build-package: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | # write attestations and id-token are necessary for attest-build-provenance-github 16 | attestations: write 17 | id-token: write 18 | steps: 19 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 20 | with: 21 | fetch-depth: 0 22 | persist-credentials: false 23 | - uses: hynek/build-and-inspect-python-package@b5076c307dc91924a82ad150cdd1533b444d3310 # v2.12.0 24 | with: 25 | # Prove that the packages were built in the context of this workflow. 26 | attest-build-provenance-github: true 27 | 28 | publish-package: 29 | # Don't publish from forks 30 | if: github.repository_owner == 'pymc-devs' && github.event_name == 'release' && github.event.action == 'published' 31 | # Use the `release` GitHub environment to protect the Trusted Publishing (OIDC) 32 | # workflow by requiring signoff from a maintainer. 33 | environment: release 34 | needs: build-package 35 | runs-on: ubuntu-latest 36 | permissions: 37 | # write id-token is necessary for trusted publishing (OIDC) 38 | id-token: write 39 | steps: 40 | - name: Download Distribution Artifacts 41 | uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 42 | with: 43 | # The build-and-inspect-python-package action invokes upload-artifact. 44 | # These are the correct arguments from that action. 45 | name: Packages 46 | path: dist 47 | - name: Publish Package to PyPI 48 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 49 | # Implicitly attests that the packages were uploaded in the context of this workflow. 50 | -------------------------------------------------------------------------------- /.github/workflows/rtd-link-preview.yml: -------------------------------------------------------------------------------- 1 | name: Read the Docs Pull Request Preview 2 | on: 3 | # See 4 | pull_request_target: # zizmor: ignore[dangerous-triggers] 5 | types: 6 | - opened 7 | 8 | jobs: 9 | documentation-links: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | pull-requests: write 13 | steps: 14 | - uses: readthedocs/actions/preview@v1 15 | with: 16 | project-slug: "pymc" 17 | -------------------------------------------------------------------------------- /.github/workflows/slash_dispatch.yml: -------------------------------------------------------------------------------- 1 | name: Slash Command Dispatch 2 | on: 3 | issue_comment: 4 | types: [created] 5 | jobs: 6 | slashCommandDispatch: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Slash Command Dispatch 10 | uses: peter-evans/slash-command-dispatch@v4 11 | with: 12 | token: ${{ secrets.ACTION_TRIGGER_TOKEN }} 13 | issue-type: pull-request 14 | commands: | 15 | pre-commit-run 16 | -------------------------------------------------------------------------------- /.github/workflows/zizmor.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/woodruffw/zizmor 2 | name: zizmor GHA analysis 3 | 4 | on: 5 | push: 6 | branches: ["main"] 7 | pull_request: 8 | branches: ["**"] 9 | 10 | jobs: 11 | zizmor: 12 | name: zizmor latest via PyPI 13 | runs-on: ubuntu-latest 14 | permissions: 15 | security-events: write 16 | steps: 17 | - name: Checkout repository 18 | uses: actions/checkout@v4 19 | with: 20 | persist-credentials: false 21 | 22 | - uses: hynek/setup-cached-uv@v2 23 | 24 | - name: Run zizmor 🌈 25 | run: uvx zizmor --format sarif . > results.sarif 26 | env: 27 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 28 | 29 | - name: Upload SARIF file 30 | uses: github/codeql-action/upload-sarif@v3 31 | with: 32 | # Path to SARIF file relative to the root of the repository 33 | sarif_file: results.sarif 34 | # Optional category for the results 35 | # Used to differentiate multiple results for one commit 36 | category: zizmor 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sw[op] 3 | examples/*.png 4 | nb_examples/ 5 | nb_tutorials/ 6 | build/* 7 | dist/* 8 | *.egg-info/ 9 | .ipynb_checkpoints 10 | tmtags 11 | tags 12 | .DS_Store 13 | .cache 14 | # IntelliJ IDE 15 | .idea 16 | *.iml 17 | 18 | # Sphinx 19 | _build 20 | docs/_build 21 | docs/build 22 | docs/jupyter_execute 23 | docs/.jupyter_cache 24 | docs/**/generated/* 25 | 26 | # Merge tool 27 | *.orig 28 | 29 | # Docker development 30 | # notebooks/ 31 | 32 | # air speed velocity (asv) 33 | benchmarks/env/ 34 | benchmarks/html/ 35 | benchmarks/results/ 36 | .pytest_cache/ 37 | 38 | # Visual Studio / VSCode 39 | .vs/ 40 | .vscode/ 41 | .mypy_cache 42 | 43 | pytestdebug.log 44 | .dir-locals.el 45 | .pycheckers 46 | 47 | # Codespaces 48 | pythonenv* 49 | env/ 50 | venv/ 51 | .venv/ 52 | pixi.toml 53 | pixi.lock 54 | .pixi/ 55 | 56 | .jupyter/ 57 | .claude/ 58 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: ghcr.io/pymc-devs/pymc-devcontainer:latest 2 | tasks: 3 | - name: initialize 4 | init: | 5 | # General devcontainer initialization, e.g. pre-commit 6 | _dev-init.sh 7 | 8 | # Create an empty object for .vscode/settings.json if the file doesn't exist 9 | mkdir -p .vscode 10 | [ -f .vscode/settings.json ] || echo "{}" > .vscode/settings.json 11 | 12 | # Add vscode settings 13 | jq ' 14 | .["python.defaultInterpreterPath"] = "/opt/conda/bin/python" 15 | ' .vscode/settings.json | sponge .vscode/settings.json 16 | jq ' 17 | .["terminal.integrated.defaultProfile.linux"] = "bash" 18 | ' .vscode/settings.json | sponge .vscode/settings.json 19 | jq ' 20 | .["git.autofetch"] = true 21 | ' .vscode/settings.json | sponge .vscode/settings.json 22 | 23 | # Install dependencies 24 | sudo chown "$(id -u):$(id -g)" /opt/conda/conda-meta/history 25 | (micromamba install --yes --name base --file conda-envs/environment-dev.yml; pip install -e .) &> /tmp/install-init.log & 26 | 27 | command: | 28 | # Reinitialize devcontainer for good measure 29 | _dev-init.sh 30 | 31 | # Install the pre-commit hooks in the background if not already installed 32 | pre-commit install --install-hooks &> /tmp/pre-commit-init-output.log & 33 | 34 | vscode: 35 | extensions: 36 | - eamodio.gitlens 37 | - ms-python.python 38 | - ms-pyright.pyright 39 | - ms-toolsai.jupyter 40 | - donjayamanne.githistory 41 | 42 | github: 43 | prebuilds: 44 | # enable for master branch 45 | master: true 46 | # enable for other branches (defaults to false) 47 | branches: true 48 | # enable for pull requests coming from this repo (defaults to true) 49 | pullRequests: true 50 | # enable for pull requests coming from forks (defaults to false) 51 | pullRequestsFromForks: false 52 | # add a check to pull requests (defaults to true) 53 | addCheck: true 54 | # add a "Review in Gitpod" button as a comment to pull requests (defaults to false) 55 | addComment: false 56 | # add a "Review in Gitpod" button to the pull request's description (defaults to false) 57 | addBadge: false 58 | # add a label once the prebuild is ready to pull requests (defaults to false) 59 | addLabel: false 60 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | 6 | python: 7 | install: 8 | - method: pip 9 | path: . 10 | 11 | conda: 12 | environment: "conda-envs/environment-docs.yml" 13 | 14 | build: 15 | os: "ubuntu-22.04" 16 | tools: 17 | python: "mambaforge-4.10" 18 | 19 | search: 20 | ranking: 21 | _sources/*: -10 22 | _modules/*: -5 23 | genindex.html: -9 24 | '*__init__.html': -3 25 | '*dist.html': -3 26 | 27 | ignore: 28 | - 404.html 29 | - search.html 30 | - index.html 31 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: If you use this software, please cite it using the metadata from this file. 3 | title: PyMC 4 | authors: 5 | - name: PyMC-Devs 6 | repository-code: "https://github.com/pymc-devs/pymc" 7 | url: "https://www.pymc.io" 8 | abstract: Bayesian Modeling and Probabilistic Programming in Python 9 | license: Apache-2.0 10 | 11 | preferred-citation: 12 | type: article 13 | title: "PyMC: a modern, and comprehensive probabilistic programming framework in Python" 14 | journal: PeerJ Comput. Sci. 15 | database: peerj.com 16 | issn: 2376-5992 17 | languages: 18 | - en 19 | pages: e1516 20 | volume: 9 21 | url: "https://peerj.com/articles/cs-1516" 22 | date-published: 2023-09-01 23 | doi: 10.7717/peerj-cs.1516 24 | authors: 25 | - family-names: Abril-Pla 26 | given-names: Oriol 27 | - family-names: Andreani 28 | given-names: Virgile 29 | - family-names: Carroll 30 | given-names: Colin 31 | - family-names: Dong 32 | given-names: Larry 33 | - family-names: Fonnesbeck 34 | given-names: Christopher J. 35 | - family-names: Kochurov 36 | given-names: Maxim 37 | - family-names: Kumar 38 | given-names: Ravin 39 | - family-names: Lao 40 | given-names: Junpeng 41 | - family-names: Luhmann 42 | given-names: Christian C. 43 | - family-names: Martin 44 | given-names: Osvaldo A. 45 | - family-names: Osthege 46 | given-names: Michael 47 | - family-names: Vieira 48 | given-names: Ricardo 49 | - family-names: Wiecki 50 | given-names: Thomas 51 | - family-names: Zinkov 52 | given-names: Robert 53 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # PyMC Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting PyMC developer Christopher Fonnesbeck via email 59 | (fonnesbeck@gmail.com) or phone (615-955-0380). Alternatively, you 60 | may also contact NumFOCUS Executive Director Leah Silen (512-222-5449), as PyMC 61 | is a member of NumFOCUS and subscribes to their code of conduct as a 62 | precondition for continued membership. All complaints will be reviewed and 63 | investigated and will result in a response that is deemed necessary and 64 | appropriate to the circumstances. The project team is obligated to maintain 65 | confidentiality with regard to the reporter of an incident. Further details of 66 | specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Guidelines for Contributing 2 | 3 | Thank you for being interested in contributing to PyMC. PyMC is an open source, collective effort, and everyone is welcome to contribute. There are many ways in which you can help make it better. Please check the latest information for contributing to the PyMC project on [this guidelines](https://docs.pymc.io/en/latest/contributing/index.html). 4 | 5 | Quick links 6 | ----------- 7 | 8 | * [Pull request (PR) step-by-step ](https://docs.pymc.io/en/latest/contributing/pr_tutorial.html) 9 | * [Pull request (PR) checklist](https://docs.pymc.io/en/latest/contributing/pr_checklist.html) 10 | * [Python style guide with pre-commit](https://docs.pymc.io/en/latest/contributing/python_style.html) 11 | * [Running the test suite](https://docs.pymc.io/en/latest/contributing/running_the_test_suite.html) 12 | * [Running PyMC in Docker](https://docs.pymc.io/en/latest/contributing/docker_container.html) 13 | * [Submitting a bug report or feature request](https://github.com/pymc-devs/pymc/issues) 14 | 15 | For a complete list visit [the contributing section of the documentation](https://docs.pymc.io/en/latest/contributing/index.html). 16 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include *.md *.rst 3 | include scripts/*.sh 4 | include LICENSE 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = docs/source 7 | BUILDDIR = docs/build 8 | 9 | rtd: export READTHEDOCS=true 10 | 11 | # User-friendly check for sphinx-build 12 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 13 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 14 | endif 15 | 16 | .PHONY: help clean html rtd view 17 | 18 | help: 19 | @echo "Please use \`make ' where is one of" 20 | @echo " html to make standalone HTML files" 21 | @echo " rtd to build the website without any cache" 22 | @echo " clean to clean cache and intermediate files" 23 | @echo " view to open the built html files" 24 | 25 | clean: 26 | rm -rf $(BUILDDIR)/* 27 | rm -rf $(SOURCEDIR)/api/generated 28 | rm -rf $(SOURCEDIR)/api/**/generated 29 | rm -rf $(SOURCEDIR)/api/**/classmethods 30 | rm -rf docs/jupyter_execute 31 | 32 | html: 33 | $(SPHINXBUILD) $(SOURCEDIR) $(BUILDDIR) -b html 34 | @echo 35 | @echo "Build finished. The HTML pages are in $(BUILDDIR)." 36 | 37 | rtd: clean 38 | $(SPHINXBUILD) $(SOURCEDIR) $(BUILDDIR) -b html -E 39 | @echo 40 | @echo "Build finished. The HTML pages are in $(BUILDDIR)." 41 | 42 | view: 43 | python -m webbrowser $(BUILDDIR)/index.html 44 | -------------------------------------------------------------------------------- /benchmarks/asv.conf.json: -------------------------------------------------------------------------------- 1 | { 2 | // The version of the config file format. Do not change, unless 3 | // you know what you are doing. 4 | "version": 1, 5 | 6 | // The name of the project being benchmarked 7 | "project": "pymc", 8 | 9 | // The project's homepage 10 | "project_url": "https://pymc-devs.github.io/pymc/", 11 | 12 | // The URL or local path of the source code repository for the 13 | // project being benchmarked 14 | "repo": "..", 15 | 16 | // List of branches to benchmark. If not provided, defaults to "main" 17 | // (for git) or "tip" (for mercurial). 18 | "branches": ["main"], 19 | 20 | // The DVCS being used. If not set, it will be automatically 21 | // determined from "repo" by looking at the protocol in the URL 22 | // (if remote), or by looking for special directories, such as 23 | // ".git" (if local). 24 | "dvcs": "git", 25 | 26 | // The tool to use to create environments. May be "conda", 27 | // "virtualenv" or other value depending on the plugins in use. 28 | // If missing or the empty string, the tool will be automatically 29 | // determined by looking for tools on the PATH environment 30 | // variable. 31 | "environment_type": "conda", 32 | 33 | // the base URL to show a commit for the project. 34 | "show_commit_url": "https://github.com/pymc-devs/pymc/commit/", 35 | 36 | // The Pythons you'd like to test against. If not provided, defaults 37 | // to the current version of Python used to run `asv`. 38 | "pythons": ["3.6"], 39 | 40 | // The matrix of dependencies to test. Each key is the name of a 41 | // package (in PyPI) and the values are version numbers. An empty 42 | // list indicates to just test against the default (latest) 43 | // version. 44 | "matrix": {}, 45 | 46 | // The directory (relative to the current directory) that benchmarks are 47 | // stored in. If not provided, defaults to "benchmarks" 48 | "benchmark_dir": "benchmarks", 49 | 50 | // The directory (relative to the current directory) to cache the Python 51 | // environments in. If not provided, defaults to "env" 52 | "env_dir": "env", 53 | 54 | // The directory (relative to the current directory) that raw benchmark 55 | // results are stored in. If not provided, defaults to "results". 56 | "results_dir": "results", 57 | 58 | // The directory (relative to the current directory) that the html tree 59 | // should be written to. If not provided, defaults to "html". 60 | "html_dir": "html", 61 | 62 | // The number of characters to retain in the commit hashes. 63 | // "hash_length": 8, 64 | 65 | // `asv` will cache wheels of the recent builds in each 66 | // environment, making them faster to install next time. This is 67 | // number of builds to keep, per environment. 68 | "build_cache_size": 2, 69 | 70 | // The commits after which the regression search in `asv publish` 71 | // should start looking for regressions. Dictionary whose keys are 72 | // regexps matching to benchmark names, and values corresponding to 73 | // the commit (exclusive) after which to start looking for 74 | // regressions. The default is to start from the first commit 75 | // with results. If the commit is `null`, regression detection is 76 | // skipped for the matching benchmark. 77 | // 78 | // "regressions_first_commits": { 79 | // "some_benchmark": "352cdf", // Consider regressions only after this commit 80 | // "another_benchmark": null, // Skip regression detection altogether 81 | // } 82 | } 83 | -------------------------------------------------------------------------------- /benchmarks/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Benchmarks for PyMC.""" 16 | -------------------------------------------------------------------------------- /binder/apt.txt: -------------------------------------------------------------------------------- 1 | graphviz 2 | -------------------------------------------------------------------------------- /binder/requirements.txt: -------------------------------------------------------------------------------- 1 | -r ../requirements-dev.txt 2 | # this installs pymc itself. it is funny that this is an absolute path, 3 | # but reqirements-dev.txt is relative. 4 | . 5 | -------------------------------------------------------------------------------- /binder/trigger_binder.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function trigger_binder() { 4 | local URL="${1}" 5 | 6 | curl -L --connect-timeout 10 --max-time 30 "${URL}" 7 | curl_return=$? 8 | 9 | # Return code 28 is when the --max-time is reached 10 | if [ "${curl_return}" -eq 0 ] || [ "${curl_return}" -eq 28 ]; then 11 | if [[ "${curl_return}" -eq 28 ]]; then 12 | printf "\nBinder build started.\nCheck back soon.\n" 13 | fi 14 | else 15 | return "${curl_return}" 16 | fi 17 | 18 | return 0 19 | } 20 | 21 | function main() { 22 | # 1: the Binder build API URL to curl 23 | trigger_binder $1 24 | } 25 | 26 | main "$@" || exit 1 27 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | notify: 4 | after_n_builds: 15 # This should be updated if number of test jobs changes 5 | 6 | coverage: 7 | precision: 2 8 | round: down 9 | range: "70...100" 10 | status: 11 | project: 12 | default: 13 | # basic 14 | target: auto 15 | threshold: 1% 16 | base: auto 17 | patch: 18 | default: 19 | # basic 20 | target: 50% 21 | threshold: 1% 22 | base: auto 23 | 24 | ignore: 25 | - "tests/*" 26 | - "pymc/_version.py" 27 | 28 | comment: 29 | layout: "reach, diff, flags, files" 30 | behavior: default 31 | require_changes: false # if true: only post the comment if coverage changes 32 | require_base: no # [yes :: must have a base report to post] 33 | require_head: yes # [yes :: must have a head report to post] 34 | branches: null # branch names that can post comment 35 | -------------------------------------------------------------------------------- /conda-envs/environment-alternative-backends.yml: -------------------------------------------------------------------------------- 1 | # "test" conda envs are used to set up our CI environment in GitHub actions 2 | name: pymc-test 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base dependencies 8 | - arviz>=0.13.0 9 | - blas 10 | - cachetools>=4.2.1 11 | - cloudpickle 12 | - zarr>=2.5.0,<3 13 | - numba 14 | - nutpie >= 0.13.4 15 | # Jaxlib version must not be greater than jax version! 16 | - blackjax>=1.2.2 17 | - jax>=0.4.28 18 | - jaxlib>=0.4.28 19 | - libblas=*=*mkl 20 | - mkl-service 21 | - numpy>=1.25.0 22 | - numpyro>=0.8.0 23 | - pandas>=0.24.0 24 | - pip 25 | - pytensor>=2.31.2,<2.32 26 | - python-graphviz 27 | - networkx 28 | - rich>=13.7.1 29 | - threadpoolctl>=3.1.0 30 | # JAX is only compatible with Scipy 1.13.0 from >=0.4.26 31 | - scipy>=1.13.0 32 | - typing-extensions>=3.7.4 33 | # Extra dependencies for testing 34 | - ipython>=7.16 35 | - pre-commit>=2.8.0 36 | - pytest-cov>=2.5 37 | - pytest>=3.0 38 | - mypy=1.15.0 39 | - types-cachetools 40 | - pip: 41 | - numdifftools>=0.9.40 42 | - mcbackend>=0.4.0 43 | -------------------------------------------------------------------------------- /conda-envs/environment-dev.yml: -------------------------------------------------------------------------------- 1 | # "dev" conda envs are to be used by devs in setting their local environments 2 | name: pymc-dev 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base dependencies 8 | - arviz>=0.13.0 9 | - blas 10 | - cachetools>=4.2.1 11 | - cloudpickle 12 | - numpy>=1.25.0 13 | - pandas>=0.24.0 14 | - pip 15 | - pytensor>=2.31.2,<2.32 16 | - python-graphviz 17 | - networkx 18 | - scipy>=1.4.1 19 | - typing-extensions>=3.7.4 20 | - threadpoolctl>=3.1.0 21 | - zarr>=2.5.0,<3 22 | # Extra dependencies for dev, testing and docs build 23 | - ipython>=7.16 24 | - jax 25 | - jupyter-sphinx 26 | - myst-nb<=1.0.0 27 | - numpydoc 28 | - pre-commit>=2.8.0 29 | - polyagamma 30 | - pytest-cov>=2.5 31 | - pytest>=3.0 32 | - rich>=13.7.1 33 | - sphinx-copybutton 34 | - sphinx-design 35 | - sphinx-notfound-page 36 | - sphinx>=1.5 37 | - sphinxext-rediraffe 38 | - watermark 39 | - sphinx-remove-toctrees 40 | - mypy=1.15.0 41 | - types-cachetools 42 | - pip: 43 | - git+https://github.com/pymc-devs/pymc-sphinx-theme 44 | - numdifftools>=0.9.40 45 | - mcbackend>=0.4.0 46 | -------------------------------------------------------------------------------- /conda-envs/environment-docs.yml: -------------------------------------------------------------------------------- 1 | # "dev" conda envs are to be used by devs in setting their local environments 2 | name: pymc-docs 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base dependencies 8 | - arviz>=0.13.0 9 | - cachetools>=4.2.1 10 | - cloudpickle 11 | - numpy>=1.25.0 12 | - pandas>=0.24.0 13 | - pip 14 | - pytensor>=2.31.2,<2.32 15 | - python-graphviz 16 | - rich>=13.7.1 17 | - scipy>=1.4.1 18 | - typing-extensions>=3.7.4 19 | - threadpoolctl>=3.1.0 20 | - zarr>=2.5.0,<3 21 | # Extra dependencies for docs build 22 | - ipython>=7.16 23 | - jax 24 | - jupyter-sphinx 25 | - myst-nb<=1.0.0 26 | - numpydoc 27 | - polyagamma 28 | - pre-commit>=2.8.0 29 | - pymc-sphinx-theme>=0.16 30 | - sphinx-copybutton 31 | - sphinx-design 32 | - sphinx-notfound-page 33 | - sphinx-sitemap 34 | - sphinx>=5 35 | - sphinxext-rediraffe 36 | - watermark 37 | - sphinx-remove-toctrees 38 | - pip: 39 | - numdifftools>=0.9.40 40 | -------------------------------------------------------------------------------- /conda-envs/environment-test.yml: -------------------------------------------------------------------------------- 1 | # "test" conda envs are used to set up our CI environment in GitHub actions 2 | name: pymc-test 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base dependencies 8 | - arviz>=0.13.0 9 | - blas 10 | - cachetools>=4.2.1 11 | - cloudpickle 12 | - jax 13 | - numpy>=1.25.0 14 | - pandas>=0.24.0 15 | - pip 16 | - polyagamma 17 | - pytensor>=2.31.2,<2.32 18 | - python-graphviz 19 | - networkx 20 | - rich>=13.7.1 21 | - scipy>=1.4.1 22 | - typing-extensions>=3.7.4 23 | - threadpoolctl>=3.1.0 24 | - zarr>=2.5.0,<3 25 | # Extra dependencies for testing 26 | - ipython>=7.16 27 | - pre-commit>=2.8.0 28 | - pytest-cov>=2.5 29 | - pytest>=3.0 30 | - mypy=1.15.0 31 | - types-cachetools 32 | - pip: 33 | - numdifftools>=0.9.40 34 | - mcbackend>=0.4.0 35 | -------------------------------------------------------------------------------- /conda-envs/windows-environment-dev.yml: -------------------------------------------------------------------------------- 1 | # "dev" conda envs are to be used by devs in setting their local environments 2 | name: pymc-dev 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base dependencies (see install guide for Windows) 8 | - arviz>=0.13.0 9 | - blas 10 | - cachetools>=4.2.1 11 | - cloudpickle 12 | - numpy>=1.25.0 13 | - pandas>=0.24.0 14 | - pip 15 | - pytensor>=2.31.2,<2.32 16 | - python-graphviz 17 | - networkx 18 | - rich>=13.7.1 19 | - scipy>=1.4.1 20 | - typing-extensions>=3.7.4 21 | - threadpoolctl>=3.1.0 22 | - zarr>=2.5.0,<3 23 | # Extra dependencies for dev, testing and docs build 24 | - ipython>=7.16 25 | - myst-nb<=1.0.0 26 | - numpydoc 27 | - polyagamma 28 | - pre-commit>=2.8.0 29 | - pytest-cov>=2.5 30 | - pytest>=3.0 31 | - sphinx-autobuild>=0.7 32 | - sphinx-copybutton 33 | - sphinx-design 34 | - sphinx-notfound-page 35 | - sphinx>=1.5 36 | - watermark 37 | - sphinx-remove-toctrees 38 | - mypy=1.15.0 39 | - types-cachetools 40 | - pip: 41 | - git+https://github.com/pymc-devs/pymc-sphinx-theme 42 | - numdifftools>=0.9.40 43 | - mcbackend>=0.4.0 44 | -------------------------------------------------------------------------------- /conda-envs/windows-environment-test.yml: -------------------------------------------------------------------------------- 1 | # "test" conda envs are used to set up our CI environment in GitHub actions 2 | name: pymc-test 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # Base dependencies (see install guide for Windows) 8 | - arviz>=0.13.0 9 | - blas 10 | - cachetools>=4.2.1 11 | - cloudpickle 12 | - libpython 13 | - mkl-service>=2.3.0 14 | - numpy>=1.25.0 15 | - pandas>=0.24.0 16 | - pip 17 | - polyagamma 18 | - pytensor>=2.31.2,<2.32 19 | - python-graphviz 20 | - networkx 21 | - rich>=13.7.1 22 | - scipy>=1.4.1 23 | - typing-extensions>=3.7.4 24 | - threadpoolctl>=3.1.0 25 | - zarr>=2.5.0,<3 26 | # Extra dependencies for testing 27 | - ipython>=7.16 28 | - pre-commit>=2.8.0 29 | - pytest-cov>=2.5 30 | - pytest>=3.0 31 | - mypy=1.15.0 32 | - types-cachetools 33 | - pip: 34 | - numdifftools>=0.9.40 35 | - mcbackend>=0.4.0 36 | -------------------------------------------------------------------------------- /docs/Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/Architecture.png -------------------------------------------------------------------------------- /docs/community_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/community_diagram.png -------------------------------------------------------------------------------- /docs/logos/PyMC.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/PyMC.ai -------------------------------------------------------------------------------- /docs/logos/PyMC.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/PyMC.ico -------------------------------------------------------------------------------- /docs/logos/PyMC.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/PyMC.jpg -------------------------------------------------------------------------------- /docs/logos/PyMC.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/PyMC.pdf -------------------------------------------------------------------------------- /docs/logos/PyMC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/PyMC.png -------------------------------------------------------------------------------- /docs/logos/sponsors/numfocus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/sponsors/numfocus.png -------------------------------------------------------------------------------- /docs/logos/sponsors/odsc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/sponsors/odsc.png -------------------------------------------------------------------------------- /docs/logos/sponsors/pymc-labs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/logos/sponsors/pymc-labs.png -------------------------------------------------------------------------------- /docs/source/404.md: -------------------------------------------------------------------------------- 1 | --- 2 | orphan: true 3 | --- 4 | 5 | # Page not found 6 | 7 | **Sorry, we could not find this page** 8 | 9 | Click on the navigation bar on top of the page to go to the right section 10 | of the default docs, or alternatively: 11 | 12 | * Go to the current [PyMC website homepage](https://www.pymc.io/) 13 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | {% if methods %} 9 | 10 | .. rubric:: Methods 11 | 12 | .. autosummary:: 13 | :toctree: classmethods 14 | 15 | {% for item in methods %} 16 | {{ objname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block attributes %} 22 | {% if attributes %} 23 | .. rubric:: Attributes 24 | 25 | .. autosummary:: 26 | {% for item in attributes %} 27 | ~{{ name }}.{{ item }} 28 | {%- endfor %} 29 | {% endif %} 30 | {% endblock %} 31 | -------------------------------------------------------------------------------- /docs/source/_templates/distribution.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | {% if objtype == "class" %} 6 | .. autoclass:: {{ objname }} 7 | 8 | .. rubric:: {{ _('Methods') }} 9 | 10 | .. autosummary:: 11 | :toctree: classmethods 12 | 13 | {{ objname }}.dist 14 | {% else %} 15 | .. autofunction:: {{ objname }} 16 | {% endif %} 17 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | *** 4 | API 5 | *** 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | api/distributions 11 | api/gp 12 | api/model 13 | api/samplers 14 | api/vi 15 | api/smc 16 | api/data 17 | api/ode 18 | api/logprob 19 | api/tuning 20 | api/math 21 | api/pytensorf 22 | api/shape_utils 23 | api/backends 24 | api/misc 25 | api/testing 26 | 27 | ------------------ 28 | Dimensionality 29 | ------------------ 30 | PyMC provides numerous methods, and syntactic sugar, to easily specify the dimensionality of 31 | Random Variables in modeling. Refer to :ref:`dimensionality` notebook to see examples 32 | demonstrating the functionality. 33 | 34 | -------------- 35 | API extensions 36 | -------------- 37 | 38 | Plots, stats and diagnostics 39 | ---------------------------- 40 | Plots, stats and diagnostics are delegated to the 41 | :doc:`ArviZ `. 42 | library, a general purpose library for 43 | "exploratory analysis of Bayesian models". 44 | 45 | * Functions from the ``arviz.plots`` module are available through ``pymc.`` or ``pymc.plots.``, 46 | but for their API documentation please refer to the :ref:`ArviZ documentation `. 47 | 48 | * Functions from the ``arviz.stats`` module are available through ``pymc.`` or ``pymc.stats.``, 49 | but for their API documentation please refer to the :ref:`ArviZ documentation `. 50 | 51 | ArviZ is a dependency of PyMC and so, in addition to the locations described above, 52 | importing ArviZ and using ``arviz.`` will also work without any extra installation. 53 | 54 | Generalized Linear Models (GLMs) 55 | -------------------------------- 56 | 57 | Generalized Linear Models are delegated to the 58 | `Bambi `_. 59 | library, a high-level Bayesian model-building 60 | interface built on top of PyMC. 61 | 62 | Bambi is not a dependency of PyMC and should be installed in addition to PyMC 63 | to use it to generate PyMC models via formula syntax. 64 | -------------------------------------------------------------------------------- /docs/source/api/backends.rst: -------------------------------------------------------------------------------- 1 | Storage backends 2 | **************** 3 | 4 | .. currentmodule:: pymc 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | to_inference_data 10 | predictions_to_inference_data 11 | 12 | Internal structures 13 | ------------------- 14 | 15 | .. automodule:: pymc.backends 16 | 17 | .. autosummary:: 18 | :toctree: generated/ 19 | 20 | NDArray 21 | base.BaseTrace 22 | base.MultiTrace 23 | zarr.ZarrTrace 24 | zarr.ZarrChain 25 | -------------------------------------------------------------------------------- /docs/source/api/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | **** 3 | 4 | .. currentmodule:: pymc 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | ConstantData 10 | MutableData 11 | get_data 12 | Data 13 | Minibatch 14 | -------------------------------------------------------------------------------- /docs/source/api/distributions.rst: -------------------------------------------------------------------------------- 1 | .. _api_distributions: 2 | 3 | ************* 4 | Distributions 5 | ************* 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | distributions/continuous 11 | distributions/discrete 12 | distributions/multivariate 13 | distributions/mixture 14 | distributions/timeseries 15 | distributions/truncated 16 | distributions/censored 17 | distributions/custom 18 | distributions/simulator 19 | distributions/transforms 20 | distributions/utilities 21 | -------------------------------------------------------------------------------- /docs/source/api/distributions/censored.rst: -------------------------------------------------------------------------------- 1 | ******** 2 | Censored 3 | ******** 4 | 5 | .. 6 | Manually follow the template in _templates/distribution.rst. 7 | If at any point, multiple objects are listed here, 8 | the pattern should instead be modified to that of the 9 | other API files such as api/distributions/continuous.rst 10 | 11 | .. currentmodule:: pymc 12 | 13 | .. autoclass:: Censored 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | :toctree: classmethods 19 | 20 | Censored.dist 21 | -------------------------------------------------------------------------------- /docs/source/api/distributions/continuous.rst: -------------------------------------------------------------------------------- 1 | ********** 2 | Continuous 3 | ********** 4 | 5 | .. currentmodule:: pymc 6 | .. autosummary:: 7 | :toctree: generated/ 8 | :template: distribution.rst 9 | 10 | AsymmetricLaplace 11 | Beta 12 | Cauchy 13 | ChiSquared 14 | ExGaussian 15 | Exponential 16 | Flat 17 | Gamma 18 | Gumbel 19 | HalfCauchy 20 | HalfFlat 21 | HalfNormal 22 | HalfStudentT 23 | Interpolated 24 | InverseGamma 25 | Kumaraswamy 26 | Laplace 27 | Logistic 28 | LogitNormal 29 | LogNormal 30 | Moyal 31 | Normal 32 | Pareto 33 | PolyaGamma 34 | Rice 35 | SkewNormal 36 | SkewStudentT 37 | StudentT 38 | Triangular 39 | TruncatedNormal 40 | Uniform 41 | VonMises 42 | Wald 43 | Weibull 44 | -------------------------------------------------------------------------------- /docs/source/api/distributions/custom.rst: -------------------------------------------------------------------------------- 1 | ********** 2 | CustomDist 3 | ********** 4 | 5 | .. 6 | Manually follow the template in _templates/distribution.rst. 7 | If at any point, multiple objects are listed here, 8 | the pattern should instead be modified to that of the 9 | other API files such as api/distributions/continuous.rst 10 | 11 | .. currentmodule:: pymc 12 | 13 | .. autoclass:: CustomDist 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | :toctree: classmethods 19 | 20 | CustomDist.dist 21 | -------------------------------------------------------------------------------- /docs/source/api/distributions/discrete.rst: -------------------------------------------------------------------------------- 1 | ******** 2 | Discrete 3 | ******** 4 | 5 | .. currentmodule:: pymc 6 | .. autosummary:: 7 | :toctree: generated 8 | :template: distribution.rst 9 | 10 | Bernoulli 11 | BetaBinomial 12 | Binomial 13 | Categorical 14 | DiscreteUniform 15 | DiscreteWeibull 16 | Geometric 17 | HyperGeometric 18 | NegativeBinomial 19 | OrderedLogistic 20 | OrderedProbit 21 | Poisson 22 | 23 | .. note:: 24 | 25 | **OrderedLogistic and OrderedProbit:** 26 | The ``OrderedLogistic`` and ``OrderedProbit`` distributions expect the observed values to be 0-based, i.e., they should range from ``0`` to ``K-1``. Using 1-based indexing (like ``1, 2, 3,...K``) can result in errors. 27 | -------------------------------------------------------------------------------- /docs/source/api/distributions/mixture.rst: -------------------------------------------------------------------------------- 1 | ******* 2 | Mixture 3 | ******* 4 | 5 | .. currentmodule:: pymc 6 | .. autosummary:: 7 | :toctree: generated 8 | :template: distribution.rst 9 | 10 | Mixture 11 | NormalMixture 12 | ZeroInflatedBinomial 13 | ZeroInflatedNegativeBinomial 14 | ZeroInflatedPoisson 15 | HurdlePoisson 16 | HurdleNegativeBinomial 17 | HurdleGamma 18 | HurdleLogNormal 19 | -------------------------------------------------------------------------------- /docs/source/api/distributions/multivariate.rst: -------------------------------------------------------------------------------- 1 | ************ 2 | Multivariate 3 | ************ 4 | 5 | .. currentmodule:: pymc 6 | .. autosummary:: 7 | :toctree: generated 8 | :template: distribution.rst 9 | 10 | CAR 11 | Dirichlet 12 | DirichletMultinomial 13 | ICAR 14 | KroneckerNormal 15 | LKJCholeskyCov 16 | LKJCorr 17 | MatrixNormal 18 | Multinomial 19 | MvNormal 20 | MvStudentT 21 | OrderedMultinomial 22 | StickBreakingWeights 23 | Wishart 24 | WishartBartlett 25 | ZeroSumNormal 26 | -------------------------------------------------------------------------------- /docs/source/api/distributions/simulator.rst: -------------------------------------------------------------------------------- 1 | ********* 2 | Simulator 3 | ********* 4 | 5 | .. 6 | Manually follow the template in _templates/distribution.rst. 7 | If at any point, multiple objects are listed here, 8 | the pattern should instead be modified to that of the 9 | other API files such as api/distributions/continuous.rst 10 | 11 | .. currentmodule:: pymc 12 | 13 | .. autoclass:: Simulator 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | :toctree: classmethods 19 | 20 | Simulator.dist 21 | -------------------------------------------------------------------------------- /docs/source/api/distributions/timeseries.rst: -------------------------------------------------------------------------------- 1 | ********** 2 | Timeseries 3 | ********** 4 | 5 | .. currentmodule:: pymc 6 | .. autosummary:: 7 | :toctree: generated 8 | :template: distribution.rst 9 | 10 | AR 11 | EulerMaruyama 12 | GARCH11 13 | GaussianRandomWalk 14 | MvGaussianRandomWalk 15 | MvStudentTRandomWalk 16 | -------------------------------------------------------------------------------- /docs/source/api/distributions/truncated.rst: -------------------------------------------------------------------------------- 1 | ********* 2 | Truncated 3 | ********* 4 | 5 | .. 6 | Manually follow the template in _templates/distribution.rst. 7 | If at any point, multiple objects are listed here, 8 | the pattern should instead be modified to that of the 9 | other API files such as api/distributions/continuous.rst 10 | 11 | .. currentmodule:: pymc 12 | 13 | .. autoclass:: Truncated 14 | 15 | .. rubric:: Methods 16 | 17 | .. autosummary:: 18 | :toctree: classmethods 19 | 20 | Truncated.dist 21 | -------------------------------------------------------------------------------- /docs/source/api/distributions/utilities.rst: -------------------------------------------------------------------------------- 1 | ********************** 2 | Distribution utilities 3 | ********************** 4 | 5 | .. currentmodule:: pymc 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | Continuous 10 | Discrete 11 | Distribution 12 | SymbolicRandomVariable 13 | DiracDelta 14 | -------------------------------------------------------------------------------- /docs/source/api/gp.rst: -------------------------------------------------------------------------------- 1 | Gaussian Processes 2 | ------------------ 3 | 4 | .. automodule:: pymc.gp 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | gp/implementations 10 | gp/mean 11 | gp/cov 12 | gp/util 13 | -------------------------------------------------------------------------------- /docs/source/api/gp/cov.rst: -------------------------------------------------------------------------------- 1 | ******************** 2 | Covariance Functions 3 | ******************** 4 | 5 | .. automodule:: pymc.gp.cov 6 | .. autosummary:: 7 | :toctree: generated 8 | 9 | Constant 10 | WhiteNoise 11 | ExpQuad 12 | RatQuad 13 | Exponential 14 | Matern52 15 | Matern32 16 | Linear 17 | Polynomial 18 | Cosine 19 | Periodic 20 | WarpedInput 21 | Gibbs 22 | Coregion 23 | ScaledCov 24 | Kron 25 | -------------------------------------------------------------------------------- /docs/source/api/gp/implementations.rst: -------------------------------------------------------------------------------- 1 | *************** 2 | Implementations 3 | *************** 4 | 5 | .. currentmodule:: pymc.gp 6 | .. autosummary:: 7 | :toctree: generated 8 | 9 | HSGP 10 | HSGPPeriodic 11 | Latent 12 | LatentKron 13 | Marginal 14 | MarginalKron 15 | MarginalApprox 16 | TP 17 | -------------------------------------------------------------------------------- /docs/source/api/gp/mean.rst: -------------------------------------------------------------------------------- 1 | ************** 2 | Mean Functions 3 | ************** 4 | 5 | .. automodule:: pymc.gp.mean 6 | .. autosummary:: 7 | :toctree: generated 8 | 9 | Zero 10 | Constant 11 | Linear 12 | -------------------------------------------------------------------------------- /docs/source/api/gp/util.rst: -------------------------------------------------------------------------------- 1 | ************ 2 | GP Utilities 3 | ************ 4 | 5 | .. automodule:: pymc.gp.util 6 | .. autosummary:: 7 | :toctree: generated 8 | 9 | plot_gp_dist 10 | -------------------------------------------------------------------------------- /docs/source/api/logprob.rst: -------------------------------------------------------------------------------- 1 | *********** 2 | Probability 3 | *********** 4 | 5 | .. currentmodule:: pymc 6 | 7 | .. autosummary:: 8 | :toctree: generated/ 9 | 10 | logp 11 | logcdf 12 | icdf 13 | 14 | Conditional probability 15 | ----------------------- 16 | 17 | .. currentmodule:: pymc.logprob 18 | 19 | .. autosummary:: 20 | :toctree: generated/ 21 | 22 | conditional_logp 23 | transformed_conditional_logp 24 | -------------------------------------------------------------------------------- /docs/source/api/math.rst: -------------------------------------------------------------------------------- 1 | ==== 2 | Math 3 | ==== 4 | 5 | This submodule contains various mathematical functions. Most of them are imported directly 6 | from pytensor.tensor (see there for more details). Doing any kind of math with PyMC random 7 | variables, or defining custom likelihoods or priors requires you to use these PyTensor 8 | expressions rather than NumPy or Python code. 9 | 10 | .. currentmodule:: pymc 11 | 12 | Functions exposed in pymc namespace 13 | ----------------------------------- 14 | .. autosummary:: 15 | :toctree: generated/ 16 | 17 | expand_packed_triangular 18 | logit 19 | invlogit 20 | probit 21 | invprobit 22 | logaddexp 23 | logsumexp 24 | 25 | 26 | Functions exposed in pymc.math 27 | ------------------------------ 28 | 29 | .. automodule:: pymc.math 30 | .. autosummary:: 31 | :toctree: generated/ 32 | 33 | abs 34 | prod 35 | dot 36 | eq 37 | neq 38 | ge 39 | gt 40 | le 41 | lt 42 | exp 43 | log 44 | sgn 45 | sqr 46 | sqrt 47 | sum 48 | ceil 49 | floor 50 | sin 51 | sinh 52 | arcsin 53 | arcsinh 54 | cos 55 | cosh 56 | arccos 57 | arccosh 58 | tan 59 | tanh 60 | arctan 61 | arctanh 62 | cumprod 63 | cumsum 64 | matmul 65 | and_ 66 | broadcast_to 67 | clip 68 | concatenate 69 | flatten 70 | or_ 71 | stack 72 | switch 73 | where 74 | flatten_list 75 | constant 76 | max 77 | maximum 78 | mean 79 | min 80 | minimum 81 | round 82 | erf 83 | erfc 84 | erfcinv 85 | erfinv 86 | log1pexp 87 | log1mexp 88 | logaddexp 89 | logsumexp 90 | logdiffexp 91 | logit 92 | invlogit 93 | probit 94 | invprobit 95 | sigmoid 96 | softmax 97 | log_softmax 98 | logbern 99 | full 100 | full_like 101 | ones 102 | ones_like 103 | zeros 104 | zeros_like 105 | kronecker 106 | cartesian 107 | kron_dot 108 | kron_solve_lower 109 | kron_solve_upper 110 | kron_diag 111 | flat_outer 112 | expand_packed_triangular 113 | batched_diag 114 | block_diagonal 115 | matrix_inverse 116 | logdet 117 | -------------------------------------------------------------------------------- /docs/source/api/misc.rst: -------------------------------------------------------------------------------- 1 | Other utils 2 | *********** 3 | 4 | .. currentmodule:: pymc 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | compute_log_likelihood 10 | compute_log_prior 11 | find_constrained_prior 12 | DictToArrayBijection 13 | -------------------------------------------------------------------------------- /docs/source/api/model.rst: -------------------------------------------------------------------------------- 1 | Model 2 | ------ 3 | 4 | .. automodule:: pymc.model 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | model/core 10 | model/conditioning 11 | model/optimization 12 | model/fgraph 13 | -------------------------------------------------------------------------------- /docs/source/api/model/conditioning.rst: -------------------------------------------------------------------------------- 1 | Model Conditioning 2 | ------------------ 3 | 4 | .. currentmodule:: pymc.model.transform.conditioning 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | do 9 | observe 10 | change_value_transforms 11 | remove_value_transforms 12 | -------------------------------------------------------------------------------- /docs/source/api/model/core.rst: -------------------------------------------------------------------------------- 1 | Model creation and inspection 2 | ----------------------------- 3 | 4 | .. currentmodule:: pymc.model.core 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | Model 9 | modelcontext 10 | 11 | Others 12 | ------ 13 | 14 | .. currentmodule:: pymc.model.core 15 | .. autosummary:: 16 | :toctree: generated/ 17 | 18 | Deterministic 19 | Potential 20 | set_data 21 | Point 22 | compile_fn 23 | 24 | 25 | Graph visualization 26 | ------------------- 27 | 28 | .. currentmodule:: pymc.model_graph 29 | .. autosummary:: 30 | :toctree: generated/ 31 | 32 | model_to_networkx 33 | model_to_graphviz 34 | -------------------------------------------------------------------------------- /docs/source/api/model/fgraph.rst: -------------------------------------------------------------------------------- 1 | FunctionGraph 2 | ------------- 3 | 4 | .. currentmodule:: pymc.model.fgraph 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | clone_model 9 | fgraph_from_model 10 | model_from_fgraph 11 | -------------------------------------------------------------------------------- /docs/source/api/model/optimization.rst: -------------------------------------------------------------------------------- 1 | Model Optimization 2 | ------------------ 3 | .. currentmodule:: pymc.model.transform.optimization 4 | .. autosummary:: 5 | :toctree: generated/ 6 | 7 | freeze_dims_and_data 8 | -------------------------------------------------------------------------------- /docs/source/api/ode.rst: -------------------------------------------------------------------------------- 1 | ************************************** 2 | Ordinary differential equations (ODEs) 3 | ************************************** 4 | 5 | 6 | .. automodule:: pymc.ode 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | DifferentialEquation 12 | -------------------------------------------------------------------------------- /docs/source/api/pytensorf.rst: -------------------------------------------------------------------------------- 1 | PyTensor utils 2 | ************** 3 | 4 | .. currentmodule:: pymc.pytensorf 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | compile 10 | gradient 11 | hessian 12 | hessian_diag 13 | jacobian 14 | inputvars 15 | cont_inputs 16 | floatX 17 | intX 18 | constant_fold 19 | CallableTensor 20 | join_nonshared_inputs 21 | make_shared_replacements 22 | convert_data 23 | -------------------------------------------------------------------------------- /docs/source/api/samplers.rst: -------------------------------------------------------------------------------- 1 | Samplers 2 | ======== 3 | 4 | This submodule contains functions for MCMC and forward sampling. 5 | 6 | 7 | .. currentmodule:: pymc 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | 12 | sample 13 | sample_prior_predictive 14 | sample_posterior_predictive 15 | draw 16 | compute_deterministics 17 | init_nuts 18 | sampling.jax.sample_blackjax_nuts 19 | sampling.jax.sample_numpyro_nuts 20 | 21 | 22 | Step methods 23 | ************ 24 | 25 | HMC family 26 | ---------- 27 | .. currentmodule:: pymc.step_methods.hmc 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | 32 | NUTS 33 | HamiltonianMC 34 | 35 | Metropolis family 36 | ----------------- 37 | .. currentmodule:: pymc.step_methods 38 | 39 | .. autosummary:: 40 | :toctree: generated/ 41 | 42 | BinaryGibbsMetropolis 43 | BinaryMetropolis 44 | CategoricalGibbsMetropolis 45 | CauchyProposal 46 | DEMetropolis 47 | DEMetropolisZ 48 | LaplaceProposal 49 | Metropolis 50 | MultivariateNormalProposal 51 | NormalProposal 52 | PoissonProposal 53 | UniformProposal 54 | 55 | Other step methods 56 | ------------------ 57 | .. currentmodule:: pymc.step_methods 58 | 59 | .. autosummary:: 60 | :toctree: generated/ 61 | 62 | CompoundStep 63 | Slice 64 | -------------------------------------------------------------------------------- /docs/source/api/shape_utils.rst: -------------------------------------------------------------------------------- 1 | *********** 2 | shape_utils 3 | *********** 4 | 5 | This submodule contains various functions that apply numpy's broadcasting rules to shape tuples, and also to samples drawn from probability distributions. 6 | 7 | The main challenge when broadcasting samples drawn from a generative model, is that each random variate has a core shape. When we draw many i.i.d samples from a given RV, for example if we ask for ``size_tuple`` i.i.d draws, the result usually is a ``size_tuple + RV_core_shape``. In the generative model's hierarchy, the downstream RVs that are conditionally dependent on our above sampled values, will get an array with a shape that is inconsistent with the core shape they expect to see for their parameters. This is a problem sometimes because it prevents regular broadcasting in complex hierarchical models, and thus make prior and posterior predictive sampling difficult. 8 | 9 | This module introduces functions that are made aware of the requested ``size_tuple`` of i.i.d samples, and does the broadcasting on the core shapes, transparently ignoring or moving the i.i.d ``size_tuple`` prepended axes around. 10 | 11 | .. currentmodule:: pymc.distributions.shape_utils 12 | 13 | .. autosummary:: 14 | :toctree: generated/ 15 | 16 | to_tuple 17 | rv_size_is_none 18 | change_dist_size 19 | -------------------------------------------------------------------------------- /docs/source/api/smc.rst: -------------------------------------------------------------------------------- 1 | Sequential Monte Carlo 2 | ********************** 3 | 4 | .. automodule:: pymc.smc 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | sample_smc 10 | 11 | .. _smc_kernels: 12 | 13 | SMC kernels 14 | ----------- 15 | 16 | .. currentmodule:: pymc.smc.kernels 17 | .. autosummary:: 18 | :toctree: generated/ 19 | 20 | SMC_KERNEL 21 | IMH 22 | MH 23 | -------------------------------------------------------------------------------- /docs/source/api/testing.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Testing 3 | ======= 4 | 5 | This submodule contains tools to help with testing PyMC code. 6 | 7 | 8 | .. currentmodule:: pymc.testing 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | mock_sample 14 | mock_sample_setup_and_teardown 15 | -------------------------------------------------------------------------------- /docs/source/api/tuning.rst: -------------------------------------------------------------------------------- 1 | Tuning 2 | ------ 3 | 4 | .. currentmodule:: pymc 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | find_hessian 10 | find_MAP 11 | -------------------------------------------------------------------------------- /docs/source/api/vi.rst: -------------------------------------------------------------------------------- 1 | ********************* 2 | Variational Inference 3 | ********************* 4 | 5 | .. currentmodule:: pymc 6 | 7 | .. autosummary:: 8 | :toctree: generated/ 9 | 10 | ADVI 11 | ASVGD 12 | SVGD 13 | FullRankADVI 14 | ImplicitGradient 15 | Inference 16 | KLqp 17 | fit 18 | 19 | Approximations 20 | -------------- 21 | 22 | .. autosummary:: 23 | :toctree: generated/ 24 | 25 | Empirical 26 | FullRank 27 | MeanField 28 | sample_approx 29 | 30 | OPVI 31 | ---- 32 | 33 | .. autosummary:: 34 | :toctree: generated/ 35 | 36 | Approximation 37 | Group 38 | 39 | Operators 40 | --------- 41 | 42 | .. automodule:: pymc.variational.operators 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | KL 47 | KSD 48 | 49 | Special 50 | ------- 51 | .. currentmodule:: pymc 52 | .. autosummary:: 53 | :toctree: generated/ 54 | 55 | Stein 56 | adadelta 57 | adagrad 58 | adagrad_window 59 | adam 60 | adamax 61 | apply_momentum 62 | apply_nesterov_momentum 63 | momentum 64 | nesterov_momentum 65 | norm_constraint 66 | rmsprop 67 | sgd 68 | total_norm_constraint 69 | -------------------------------------------------------------------------------- /docs/source/contributing/build_docs.md: -------------------------------------------------------------------------------- 1 | # Build documentation locally 2 | 3 | :::{warning} 4 | Docs build is not supported on Windows. 5 | To build docs on Windows we recommend running inside a Docker container. 6 | ::: 7 | 8 | To build the docs, run these commands at PyMC repository root: 9 | 10 | ## Installing dependencies 11 | 12 | ```shell 13 | conda install -f conda-envs/environment-docs.yml # or make sure all dependencies listed here are installed 14 | pip install -e . # Install local pymc version as installable package 15 | ``` 16 | 17 | ## Building the documentation 18 | There is a `Makefile` in the pymc repo to help with the doc building process. 19 | 20 | ```shell 21 | make clean 22 | make html 23 | ``` 24 | 25 | `make html` is the command that builds the documentation with `sphinx-build`. 26 | `make clean` deletes caches and intermediate files. 27 | 28 | The `make clean` step is not always necessary, if you are working on a specific page 29 | for example, you can rebuild the docs without the clean step and everything should 30 | work fine. If you are restructuring the content or editing toctrees, then you'll need 31 | to execute `make clean`. 32 | 33 | A good approach is to generally skip the `make clean`, which makes 34 | the `make html` faster and see how everything looks. 35 | If something looks strange, run `make clean` and `make html` one after the other 36 | to see if it fixes the issue before checking anything else. 37 | 38 | ### Emulate building on readthedocs 39 | The target `rtd` is also available to chain `make clean` with `sphinx-build` 40 | setting also some extra options and environment variables to indicate 41 | sphinx to simulate as much as possible a readthedocs build. 42 | 43 | ```shell 44 | make rtd 45 | ``` 46 | 47 | :::{important} 48 | This won't reinstall or update any dependencies, unlike on readthedocs where 49 | all dependencies are installed in a clean env before each build. 50 | 51 | But it will execute all notebooks inside the `core_notebooks` folder, 52 | which by default are not executed. Executing the notebooks will add several minutes 53 | to the doc build, as there are 6 notebooks which take between 20s to 5 minutes 54 | to run. 55 | ::: 56 | 57 | ## View the generated docs 58 | 59 | ```shell 60 | make view 61 | ``` 62 | 63 | This will use Python's `webbrowser` module to open the generated website on your browser. 64 | The generated website is static, so there is no need to set a server to preview it. 65 | -------------------------------------------------------------------------------- /docs/source/contributing/docker_container.md: -------------------------------------------------------------------------------- 1 | (docker_container)= 2 | # Running PyMC in Docker 3 | 4 | We have provided a Dockerfile which helps for isolating build problems, and local development. 5 | Install [Docker](https://www.docker.com/) for your operating system, clone this repo, then 6 | run the following commands to build a `pymc` docker image. 7 | 8 | ```bash 9 | cd pymc 10 | bash scripts/docker_container.sh build 11 | ``` 12 | 13 | After successfully building the docker image, you can start a local docker container called `pymc` either from `bash` or from [`jupyter`](http://jupyter.org/) notebook server running on port 8888. 14 | 15 | ```bash 16 | bash scripts/docker_container.sh bash # running the container with bash 17 | bash scripts/docker_container.sh jupyter # running the container with jupyter notebook 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/source/contributing/gitpod/gitpod_integration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/source/contributing/gitpod/gitpod_integration.png -------------------------------------------------------------------------------- /docs/source/contributing/gitpod/gitpod_workspace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/source/contributing/gitpod/gitpod_workspace.png -------------------------------------------------------------------------------- /docs/source/contributing/pr_checklist.md: -------------------------------------------------------------------------------- 1 | (pr_checklist)= 2 | # Pull request checklist 3 | 4 | We recommended that your contribution complies with the following guidelines before you submit a pull request: 5 | 6 | * If your pull request addresses an issue, use the pull request title to describe the issue and mention the issue number in the pull request _description_. 7 | This will make sure a link back to the original issue is created. 8 | 9 | :::{caution} 10 | Adding the related issue in the PR title generates no link and is therefore 11 | not useful as nobody knows issue numbers. Please mention all related 12 | issues in the PR but do so only in the PR description. 13 | ::: 14 | 15 | * All public methods must have informative docstrings with sample usage when appropriate. 16 | Docstrings should follow the [numpydoc style](https://numpydoc.readthedocs.io/en/latest/format.html) 17 | 18 | * Please select "Create draft pull request" in the dropdown menu when opening your pull request to indicate a work in progress. This is to avoid duplicated work, to get early input on implementation details or API/functionality, or to seek collaborators. 19 | 20 | * Documentation and high-coverage tests are necessary for enhancements to be accepted. 21 | * When adding additional functionality, consider adding also one example notebook at [pymc-examples](https://github.com/pymc-devs/pymc-examples). 22 | Open a [proposal issue](https://github.com/pymc-devs/pymc-examples/issues/new/choose) in the example repo to discuss the specific scope of the notebook. 23 | 24 | * Run any of the pre-existing examples in [pymc-examples](https://github.com/pymc-devs/pymc-examples) that contain analyses that would be affected by your changes to ensure that nothing breaks. This is a useful opportunity to not only check your work for bugs that might not be revealed by unit test, but also to show how your contribution improves PyMC for end users. 25 | 26 | * **No `pre-commit` errors:** see the {ref}`python_style` and {ref}`jupyter_style` page on how to install and run it. 27 | 28 | * All other tests pass when everything is rebuilt from scratch. See {ref}`running_the_test_suite` 29 | -------------------------------------------------------------------------------- /docs/source/contributing/python_style.md: -------------------------------------------------------------------------------- 1 | (python_style)= 2 | # Python style guide 3 | 4 | ## Pre commit checks 5 | 6 | Some code-quality checks are performed during continuous integration. The easiest way to check that they pass locally, 7 | before submitting your pull request, is by using [pre-commit](https://pre-commit.com/). 8 | 9 | Steps to get set up are (run these within your virtual environment): 10 | 11 | 1. install: 12 | 13 | ```bash 14 | pip install pre-commit 15 | ``` 16 | 17 | 2. enable: 18 | 19 | ```bash 20 | pre-commit install 21 | ``` 22 | 23 | Now, whenever you stage some file, when you run `git commit -m ""`, `pre-commit` will run 24 | the checks defined in `.pre-commit-config.yaml` and will block your commit if any of them fail. If any hook fails, you 25 | should fix it (if necessary), run `git add ` again, and then re-run `git commit -m ""`. 26 | 27 | You can skip `pre-commit` using `--no-verify`, e.g. 28 | 29 | ```bash 30 | git commit -m "wip lol" --no-verify 31 | ``` 32 | 33 | To skip one particular hook, you can set the `SKIP` environment variable. E.g. (on Linux): 34 | 35 | ```bash 36 | SKIP=ruff git commit -m "" 37 | ``` 38 | 39 | You can manually run all `pre-commit` hooks on all files with 40 | 41 | ```bash 42 | pre-commit run --all-files 43 | ``` 44 | 45 | or, if you just want to manually run them on a subset of files, 46 | 47 | ```bash 48 | pre-commit run --files ... 49 | ``` 50 | 51 | ## Gotchas & Troubleshooting 52 | __Pre-commit runs on staged files__ 53 | 54 | If you have some `git` changes staged and other unstaged, the `pre-commit` will only run on the staged files. 55 | 56 | __Pre-commit repeatedly complains about the same formatting changes__ 57 | 58 | Check the unstaged changes (see previous point). 59 | 60 | __Whitespace changes in the `environment-dev.yml` files__ 61 | 62 | On Windows, there are some bugs in pre-commit hooks that can lead to changes in some environment YAML files. 63 | Until this is fixed upstream, you should __ignore these changes__. 64 | To actually make the commit, deactivate the automated `pre-commit` with `pre-commit uninstall` and make sure to run it manually with `pre-commit run --all`. 65 | 66 | __Failures in the `mypy` step__ 67 | 68 | We are running static type checks with `mypy` to continuously improve the reliability and type safety of the PyMC codebase. 69 | However, there are many files with unresolved type problems, which is why we are allowing some files to fail the `mypy` check. 70 | 71 | If you are seeing the `mypy` step complain, chances are that you are in one of the following two situations: 72 | * 😕 Your changes introduced type problems in a file that was previously free of type problems. 73 | * 🥳 Your changes fixed type problems. 74 | 75 | In any case __read the logging output of the `mypy` hook__, because it contains the instructions how to proceed. 76 | 77 | You can also run the `mypy` check manually with `python scripts/run_mypy.py [--verbose]`. 78 | -------------------------------------------------------------------------------- /docs/source/contributing/release_checklist.md: -------------------------------------------------------------------------------- 1 | # PyMC Release workflow 2 | + Track all relevant issues and PRs via a **version-specific [milestone](https://github.com/pymc-devs/pymc/milestones)** 3 | + Make sure that there are no major known bugs that should not be released 4 | + Make a PR to **bump the version number** in `__init__.py` and edit the `RELEASE-NOTES.md`. 5 | + :::{important} 6 | Please don't name it after the release itself, and remember to push to your own fork like an ordinary citizen. 7 | ::: 8 | + Create a new "vNext" section at the top 9 | + Edit the header with the release version and date 10 | + Add a line to credit the release manager like in previous releases 11 | + After merging the PR, check that the CI pipelines on master are all ✔ 12 | + Create a Release with the Tag as ´v1.2.3´ and a human-readable title like the ones on previous releases 13 | 14 | After the last step, the [GitHub Action "release-pipeline"](https://github.com/pymc-devs/pymc/blob/master/.github/workflows/release.yml) triggers and automatically builds and publishes the new version to PyPI. 15 | 16 | ## Troubleshooting 17 | + If for some reason, the release must be "unpublished", this is possible by manually deleting it on PyPI and GitHub. HOWEVER, PyPI will not accept another release with the same version number! 18 | + The `release-pipeline` has a `test-install-job`, which can fail if the PyPI index did not update fast enough. 19 | 20 | ## Post-release steps 21 | + Head over to [Zenodo](https://zenodo.org/record/4603970) and copy the version specific DOI-bade into the [release notes](https://github.com/pymc-devs/pymc/releases) 22 | + Rename and close the release milestone and open a new "vNext" milestone 23 | + Monitor the update the [conda-forge/pymc-feedstock](https://github.com/conda-forge/pymc-feedstock) repository for new PRs. The bots should automatically pick up the new version and open a PR to update it. Manual intervention may be required though (see the repos PR history for examples). 24 | + Re-run notebooks with the new release (see https://github.com/pymc-devs/pymc-examples) 25 | + Make sure the new version appears at the website and that [`docs.pymc.io/en/stable`](https://docs.pymc.io/en/stable) points to it. 26 | -------------------------------------------------------------------------------- /docs/source/contributing/running_the_test_suite.md: -------------------------------------------------------------------------------- 1 | (running_the_test_suite)= 2 | # Running the test suite 3 | The first step to run tests is the installation of additional dependencies that are needed for testing: 4 | 5 | ```bash 6 | pip install -r requirements-dev.txt 7 | ``` 8 | 9 | The PyMC test suite uses `pytest` as the testing framework. 10 | If you are unfamiliar with `pytest`, check out [this short video series](https://calmcode.io/pytest/introduction.html). 11 | 12 | With the optional dependencies installed, you can start running tests. 13 | Below are some example of how you might want to run certain parts of the test suite. 14 | 15 | ```{attention} 16 | Running the entire test suite will take hours. 17 | Therefore, we recommend to run just specific tests that target the parts of the codebase you're working on. 18 | ``` 19 | 20 | To run all tests from a single file: 21 | ```bash 22 | pytest -v tests/model/test_core.py 23 | ``` 24 | 25 | ```{tip} 26 | The `-v` flag is short-hand for `--verbose` and prints the names of the test cases that are currently running. 27 | ``` 28 | 29 | Often, you'll want to focus on just a few test cases first. 30 | By using the `-k` flag, you can filter for test cases that match a certain pattern. 31 | For example, the following command runs all test cases from `test_core.py` that have "coord" in their name: 32 | 33 | ```bash 34 | pytest -v tests/model/test_core.py -k coord 35 | ``` 36 | 37 | 38 | To get a coverage report, you can pass `--cov=pymc`, optionally with `--cov-report term-missing` to get a printout of the line numbers that were visited by the invoked tests. 39 | Note that because you are not running the entire test suite, the coverage will be terrible. 40 | But you can still watch for specific line numbers of the code that you're working on. 41 | 42 | ```bash 43 | pytest -v --cov=pymc --cov-report term-missing tests/.py 44 | ``` 45 | 46 | When you are reasonably confident about the changes you made, you can push the changes and open a pull request. 47 | Our GitHub Actions pipeline will run the entire test suite and if there are failures you can go back and run these tests on your local machine. 48 | -------------------------------------------------------------------------------- /docs/source/images/forestplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/source/images/forestplot.png -------------------------------------------------------------------------------- /docs/source/images/model_to_graphviz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pymc/360cb6edde9ccba306c0e046d9576c936fa4e571/docs/source/images/model_to_graphviz.png -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | sd_hide_title: true 3 | --- 4 | 5 | # PyMC versioned Documentation 6 | 7 | :::{toctree} 8 | :hidden: 9 | 10 | Home 11 | Examples 12 | Learn 13 | api 14 | Community 15 | contributing/index 16 | ::: 17 | -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | (installation)= 2 | # Installation 3 | 4 | We recommend using [Anaconda](https://www.anaconda.com/) (or [Miniforge](https://github.com/conda-forge/miniforge)) to install Python on your local machine, which allows for packages to be installed using its `conda` utility. 5 | 6 | Once you have installed one of the above, PyMC can be installed into a new conda environment as follows: 7 | 8 | ```console 9 | conda create -c conda-forge -n pymc_env "pymc>=5" 10 | conda activate pymc_env 11 | ``` 12 | If you like, replace the name `pymc_env` with whatever environment name you prefer. 13 | 14 | :::{seealso} 15 | The [conda-forge tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#using-multiple-channels) page to avoid installation 16 | issues when using multiple conda channels (e.g. defaults and conda-forge). 17 | ::: 18 | 19 | ## JAX sampling 20 | 21 | If you wish to enable sampling using the JAX backend via NumPyro, 22 | you need to install it manually as it is an optional dependency: 23 | 24 | ```console 25 | conda install numpyro 26 | ``` 27 | 28 | Similarly, to use BlackJAX sampler instead: 29 | 30 | ```console 31 | conda install blackjax 32 | ``` 33 | 34 | ## Nutpie sampling 35 | 36 | You can also enable sampling with [nutpie](https://github.com/pymc-devs/nutpie). 37 | Nutpie uses numba as the compiler and a sampler written in Rust for faster performance. 38 | 39 | ```console 40 | conda install -c conda-forge nutpie 41 | ``` 42 | -------------------------------------------------------------------------------- /docs/source/learn.md: -------------------------------------------------------------------------------- 1 | (learn)= 2 | # Learn PyMC & Bayesian modeling 3 | 4 | :::{toctree} 5 | :maxdepth: 1 6 | installation 7 | learn/core_notebooks/index 8 | learn/books 9 | learn/videos_and_podcasts 10 | learn/consulting 11 | glossary 12 | ::: 13 | 14 | ## At a glance 15 | ### Beginner 16 | - Book: [Bayesian Analysis with Python](http://bap.com.ar/) 17 | - Book: [Bayesian Methods for Hackers](https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers) 18 | 19 | 20 | ### Intermediate 21 | - {ref}`pymc_overview` shows PyMC 4.0 code in action 22 | - Example notebooks: {doc}`nb:gallery` 23 | - {ref}`GLM_linear` 24 | - {ref}`posterior_predictive` 25 | - Comparing models: {ref}`model_comparison` 26 | - Shapes and dimensionality {ref}`dimensionality` 27 | - {ref}`videos_and_podcasts` 28 | - Book: [Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/welcome.html) 29 | 30 | ### Advanced 31 | - {octicon}`plug;1em;sd-text-info` Experimental and cutting edge functionality: {doc}`pmx:index` library 32 | - {octicon}`gear;1em;sd-text-info` PyMC internals guides (To be outlined and referenced here once [pymc#5538](https://github.com/pymc-devs/pymc/issues/5538) 33 | is addressed) 34 | -------------------------------------------------------------------------------- /docs/source/learn/books.md: -------------------------------------------------------------------------------- 1 | (books)= 2 | # Books 3 | :::::{container} full-width 4 | ::::{grid} 1 2 2 3 5 | :gutter: 3 6 | 7 | :::{grid-item-card} Bayesian Modeling and Computation in Python 8 | :img-top: https://bayesiancomputationbook.com/_images/Cover.jpg 9 | 10 | By Osvaldo Martin, Ravin Kumar and Junpeng Lao 11 | 12 | Hands on approach with PyMC and ArviZ focusing on the practice of applied statistics. 13 | 14 | [Website + code](https://bayesiancomputationbook.com/welcome.html) 15 | 16 | ::: 17 | 18 | :::{grid-item-card} Bayesian Methods for Hackers 19 | :img-top: https://www.pearson.com/hipassets/assets/hip/images/bigcovers/0133902838.jpg 20 | By Cameron Davidson-Pilon 21 | 22 | The "hacker" in the title means learn-as-you-code. This hands-on introduction teaches intuitive definitions of the Bayesian approach to statistics, worklflow and decision-making by applying them using PyMC. 23 | 24 | [Github repo](https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers) 25 | 26 | [Project homepage](http://camdavidsonpilon.github.io/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers/) 27 | 28 | ::: 29 | 30 | :::{grid-item-card} Bayesian Analysis with Python 31 | :img-top: https://aloctavodia.github.io/img/BAP.png 32 | 33 | By Osvaldo Martin 34 | 35 | 36 | A great introductory book written by a maintainer of PyMC. It provides a hands-on introduction to the main concepts of Bayesian statistics using synthetic and real data sets. Mastering the concepts in this book is a great foundation to pursue more advanced knowledge. 37 | 38 | [Book website](https://www.packtpub.com/big-data-and-business-intelligence/bayesian-analysis-python-second-edition) 39 | 40 | [Code and errata in PyMC 3.x](https://github.com/aloctavodia/BAP) 41 | ::: 42 | 43 | :::{grid-item-card} Doing Bayesian Data Analysis 44 | :img-top: https://jkkweb.sitehost.iu.edu/DoingBayesianDataAnalysis/DBDA2Ecover.png 45 | 46 | By John K. Kruschke 47 | 48 | 49 | Principled introduction to Bayesian data analysis, with practical exercises. The book's original examples are coded in R, but notebooks with a PyMC port of the code are available through the links below. 50 | 51 | [Book website](https://sites.google.com/site/doingbayesiandataanalysis/home) 52 | 53 | [PyMC port of the second edition's code](https://github.com/cluhmann/DBDA-python) 54 | 55 | ::: 56 | 57 | :::{grid-item-card} Statistical Rethinking 58 | :img-top: https://xcelab.net/rm/sr2edcover-1-187x300.png 59 | 60 | By Richard McElreath 61 | 62 | A Bayesian Course with Examples in R and Stan. 63 | 64 | [Book website](http://xcelab.net/rm/statistical-rethinking/) 65 | 66 | [PyMC 3.x port of the code](https://github.com/pymc-devs/resources/tree/master/Rethinking) 67 | 68 | ::: 69 | 70 | :::{grid-item-card} Bayesian Cognitive Modeling: A Practical Course 71 | :img-top: https://images-na.ssl-images-amazon.com/images/I/51K33XI2I8L._SX330_BO1,204,203,200_.jpg 72 | 73 | By Michael Lee and Eric-Jan Wagenmakers 74 | 75 | Focused on using Bayesian statistics in cognitive modeling. 76 | 77 | [Book website](https://bayesmodels.com/) 78 | 79 | [PyMC 3.x implementations](https://github.com/pymc-devs/resources/tree/master/BCM) 80 | ::: 81 | 82 | :::{grid-item-card} Bayesian Data Analysis 83 | :img-top: https://www.stat.columbia.edu/~gelman/book/bda_cover.png 84 | 85 | By Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin 86 | 87 | A comprehensive, standard, and wonderful textbook on Bayesian methods. 88 | 89 | [Book website](https://www.stat.columbia.edu/~gelman/book/) 90 | 91 | [Examples and exercises implemented in PyMC 3.x](https://github.com/pymc-devs/resources/tree/master/BDA3) 92 | 93 | :::: 94 | ::::: 95 | -------------------------------------------------------------------------------- /docs/source/learn/consulting.md: -------------------------------------------------------------------------------- 1 | (consulting)= 2 | # Consulting 3 | 4 |
5 |
6 | 7 | If you need professional help with your PyMC model, [PyMC Labs](https://www.pymc-labs.io) is a Bayesian consultancy consisting of [members of the PyMC core development team](https://www.pymc-labs.io/team/). Work we typically do includes: 8 | * Model speed-ups (reparameterizations, JAX, [GPU sampling](https://www.pymc-labs.io/blog-posts/pymc-stan-benchmark/)) 9 | * Improving models (adding hierarchy, time-series structure etc) 10 | * Building new models to solve applied business problems 11 | * [Bayesian Media Mix Models](https://www.pymc-labs.io/blog-posts/bayesian-media-mix-modeling-for-marketing-optimization/) for marketing attribution 12 | 13 | Interested? Send us an email at [info@pymc-labs.io](mailto:info@pymc-labs.io). 14 | -------------------------------------------------------------------------------- /docs/source/learn/core_notebooks/index.md: -------------------------------------------------------------------------------- 1 | (core_notebooks)= 2 | # Notebooks on core features 3 | 4 | :::{toctree} 5 | :maxdepth: 1 6 | 7 | pymc_overview 8 | GLM_linear 9 | model_comparison 10 | posterior_predictive 11 | dimensionality 12 | pymc_pytensor 13 | Gaussian_Processes 14 | ::: 15 | 16 | :::{note} 17 | The notebooks above are executed with each version of the library 18 | (available on the navigation bar). In addition, a much larger gallery 19 | of example notebooks is available at the {doc}`"Examples" tab `. 20 | These are executed more sparsely and independently. 21 | They include a watermark to show which versions were used to run them. 22 | ::: 23 | -------------------------------------------------------------------------------- /docs/source/learn/usage_overview.rst: -------------------------------------------------------------------------------- 1 | TODO: incorporate the useful bits of this page into the learning section 2 | 3 | ************** 4 | Usage Overview 5 | ************** 6 | 7 | For a detailed overview of building models in PyMC, please read the appropriate sections in the rest of the documentation. For a flavor of what PyMC models look like, here is a quick example. 8 | 9 | First, let's import PyMC and :doc:`ArviZ ` (which handles plotting and diagnostics): 10 | 11 | :: 12 | 13 | import arviz as az 14 | import numpy as np 15 | import pymc as pm 16 | 17 | Models are defined using a context manager (``with`` statement). The model is specified declaratively inside the context manager, instantiating model variables and transforming them as necessary. Here is an example of a model for a bioassay experiment: 18 | 19 | :: 20 | 21 | # Set style 22 | az.style.use("arviz-darkgrid") 23 | 24 | # Data 25 | n = np.ones(4)*5 26 | y = np.array([0, 1, 3, 5]) 27 | dose = np.array([-.86,-.3,-.05,.73]) 28 | 29 | with pm.Model() as bioassay_model: 30 | 31 | # Prior distributions for latent variables 32 | alpha = pm.Normal('alpha', 0, sigma=10) 33 | beta = pm.Normal('beta', 0, sigma=1) 34 | 35 | # Linear combination of parameters 36 | theta = pm.invlogit(alpha + beta * dose) 37 | 38 | # Model likelihood 39 | deaths = pm.Binomial('deaths', n=n, p=theta, observed=y) 40 | 41 | Save this file, then from a python shell (or another file in the same directory), call: 42 | 43 | :: 44 | 45 | with bioassay_model: 46 | 47 | # Draw samples 48 | idata = pm.sample(1000, tune=2000, cores=2) 49 | # Plot two parameters 50 | az.plot_forest(idata, var_names=['alpha', 'beta'], r_hat=True) 51 | 52 | This example will generate 1000 posterior samples on each of two cores using the NUTS algorithm, preceded by 2000 tuning samples (these are good default numbers for most models). 53 | 54 | :: 55 | 56 | Auto-assigning NUTS sampler... 57 | Initializing NUTS using jitter+adapt_diag... 58 | Multiprocess sampling (2 chains in 2 jobs) 59 | NUTS: [beta, alpha] 60 | |██████████████████████████████████████| 100.00% [6000/6000 00:04<00:00 Sampling 2 chains, 0 divergences] 61 | 62 | The sample is returned as arrays inside a ``MultiTrace`` object, which is then passed to the plotting function. The resulting graph shows a forest plot of the random variables in the model, along with a convergence diagnostic (R-hat) that indicates our model has converged. 63 | 64 | .. image:: ./images/forestplot.png 65 | :width: 1000px 66 | 67 | See also 68 | ======== 69 | 70 | * `Tutorials `__ 71 | * `Examples `__ 72 | 73 | 74 | .. |NumFOCUS| image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png 75 | :target: http://www.numfocus.org/ 76 | :height: 120px 77 | .. |PyMCLabs| image:: https://raw.githubusercontent.com/pymc-devs/pymc/main/docs/pymc-labs-logo.png 78 | :target: https://pymc-labs.io 79 | :height: 120px 80 | -------------------------------------------------------------------------------- /docs/source/learn/videos_and_podcasts.md: -------------------------------------------------------------------------------- 1 | (videos_and_podcasts)= 2 | # Videos and Podcasts 3 | 4 | :::{card} PyMC Developers Youtube channel 5 | 6 | [See all videos here](https://www.youtube.com/c/PyMCDevelopers/videos) 7 | ::: 8 | 9 | :::{card} PyMC talks 10 | 11 | Actively curated [YouTube playlist](https://www.youtube.com/playlist?list=PL1Ma_1DBbE82OVW8Fz_6Ts1oOeyOAiovy) of PyMC talks 12 | ::: 13 | 14 | :::{card} PyMC Labs Youtube channel 15 | 16 | [See all videos here](https://www.youtube.com/c/PyMCLabs/videos) 17 | ::: 18 | 19 | :::{card} PyMCon 2020 talks 20 | 21 | [See all videos here](https://www.youtube.com/playlist?list=PLD1x-BW9UdeG68AQj6rDRfGiFFrpZ3cgu) 22 | ::: 23 | 24 | :::{card} Learning Bayesian Statistics podcast 25 | 26 | [See all videos here](https://www.youtube.com/channel/UCAwVseuhVrpJFfik_cMHrhQ/videos) 27 | ::: 28 | -------------------------------------------------------------------------------- /pymc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """PyMC: Bayesian Modeling and Probabilistic Programming in Python.""" 17 | 18 | import logging 19 | 20 | _log = logging.getLogger(__name__) 21 | 22 | if not logging.root.handlers: 23 | _log.setLevel(logging.INFO) 24 | if len(_log.handlers) == 0: 25 | handler = logging.StreamHandler() 26 | _log.addHandler(handler) 27 | 28 | 29 | def __set_compiler_flags(): 30 | # Workarounds for PyTensor compiler problems on various platforms 31 | import pytensor 32 | 33 | current = pytensor.config.gcc__cxxflags 34 | augmented = f"{current} -Wno-c++11-narrowing" 35 | 36 | # Work around compiler bug in GCC < 8.4 related to structured exception 37 | # handling registers on Windows. 38 | # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65782 for details. 39 | # First disable C++ exception handling altogether since it's not needed 40 | # for the C extensions that we generate. 41 | augmented = f"{augmented} -fno-exceptions" 42 | # Now disable the generation of stack unwinding tables. 43 | augmented = f"{augmented} -fno-unwind-tables -fno-asynchronous-unwind-tables" 44 | 45 | pytensor.config.gcc__cxxflags = augmented 46 | 47 | 48 | __set_compiler_flags() 49 | 50 | from pymc import _version, gp, ode, sampling 51 | from pymc.backends import * 52 | from pymc.blocking import * 53 | from pymc.data import * 54 | from pymc.distributions import * 55 | from pymc.exceptions import * 56 | from pymc.func_utils import find_constrained_prior 57 | from pymc.logprob import * 58 | from pymc.math import ( 59 | expand_packed_triangular, 60 | invlogit, 61 | invprobit, 62 | logaddexp, 63 | logit, 64 | logsumexp, 65 | probit, 66 | ) 67 | from pymc.model.core import * 68 | from pymc.model.transform.conditioning import do, observe 69 | from pymc.model_graph import model_to_graphviz, model_to_networkx 70 | from pymc.plots import * 71 | from pymc.printing import * 72 | from pymc.pytensorf import * 73 | from pymc.sampling import * 74 | from pymc.smc import * 75 | from pymc.stats import * 76 | from pymc.step_methods import * 77 | from pymc.tuning import * 78 | from pymc.util import drop_warning_stat 79 | from pymc.variational import * 80 | from pymc.vartypes import * 81 | 82 | __version__ = _version.get_versions()["version"] 83 | -------------------------------------------------------------------------------- /pymc/backends/report.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import itertools 17 | import logging 18 | 19 | from pymc.stats.convergence import _LEVELS, SamplerWarning 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class SamplerReport: 25 | """Bundle warnings, convergence stats and metadata of a sampling run.""" 26 | 27 | def __init__(self) -> None: 28 | self._chain_warnings: dict[int, list[SamplerWarning]] = {} 29 | self._global_warnings: list[SamplerWarning] = [] 30 | self._n_tune = None 31 | self._n_draws = None 32 | self._t_sampling = None 33 | 34 | @property 35 | def _warnings(self): 36 | chains = list(itertools.chain.from_iterable(self._chain_warnings.values())) 37 | return chains + self._global_warnings 38 | 39 | @property 40 | def ok(self): 41 | """Whether the automatic convergence checks found serious problems.""" 42 | return all(_LEVELS[warn.level] < _LEVELS["warn"] for warn in self._warnings) 43 | 44 | @property 45 | def n_tune(self) -> int | None: 46 | """Number of tune iterations - not necessarily kept in trace.""" 47 | return self._n_tune 48 | 49 | @property 50 | def n_draws(self) -> int | None: 51 | """Number of draw iterations.""" 52 | return self._n_draws 53 | 54 | @property 55 | def t_sampling(self) -> float | None: 56 | """ 57 | Number of seconds that the sampling procedure took. 58 | 59 | (Includes parallelization overhead.) 60 | """ 61 | return self._t_sampling 62 | 63 | def raise_ok(self, level="error"): 64 | errors = [warn for warn in self._warnings if _LEVELS[warn.level] >= _LEVELS[level]] 65 | if errors: 66 | raise ValueError("Serious convergence issues during sampling.") 67 | 68 | def _add_warnings(self, warnings, chain=None): 69 | if chain is None: 70 | warn_list = self._global_warnings 71 | else: 72 | warn_list = self._chain_warnings.setdefault(chain, []) 73 | warn_list.extend(warnings) 74 | 75 | def _slice(self, start, stop, step): 76 | report = SamplerReport() 77 | 78 | def filter_warns(warnings): 79 | filtered = [] 80 | for warn in warnings: 81 | if warn.step is None: 82 | filtered.append(warn) 83 | elif start <= warn.step < stop and (warn.step - start) % step == 0: 84 | warn = dataclasses.replace(warn, step=warn.step - start) 85 | filtered.append(warn) 86 | return filtered 87 | 88 | report._add_warnings(filter_warns(self._global_warnings)) 89 | for chain in self._chain_warnings: 90 | report._add_warnings(filter_warns(self._chain_warnings[chain]), chain) 91 | 92 | return report 93 | -------------------------------------------------------------------------------- /pymc/distributions/moments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Moments dispatchers for pymc random variables.""" 16 | 17 | from pymc.distributions.moments.means import mean 18 | 19 | __all__ = ["mean"] 20 | -------------------------------------------------------------------------------- /pymc/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = [ 16 | "ImputationWarning", 17 | "IncorrectArgumentsError", 18 | "SamplingError", 19 | "ShapeError", 20 | "ShapeWarning", 21 | "TraceDirectoryError", 22 | ] 23 | 24 | 25 | class SamplingError(RuntimeError): 26 | pass 27 | 28 | 29 | class IncorrectArgumentsError(ValueError): 30 | pass 31 | 32 | 33 | class TraceDirectoryError(ValueError): 34 | """Error from trying to load a trace from an incorrectly-structured directory.""" 35 | 36 | pass 37 | 38 | 39 | class ImputationWarning(UserWarning): 40 | """Warning that there are missing values that will be imputed.""" 41 | 42 | pass 43 | 44 | 45 | class ShapeWarning(UserWarning): 46 | """Something that could lead to shape problems down the line.""" 47 | 48 | pass 49 | 50 | 51 | class ShapeError(Exception): 52 | """Error that the shape of a variable is incorrect.""" 53 | 54 | def __init__(self, message, actual=None, expected=None): 55 | if actual is not None and expected is not None: 56 | super().__init__(f"{message} (actual {actual} != expected {expected})") 57 | elif actual is not None and expected is None: 58 | super().__init__(f"{message} (actual {actual})") 59 | elif actual is None and expected is not None: 60 | super().__init__(f"{message} (expected {expected})") 61 | else: 62 | super().__init__(message) 63 | 64 | 65 | class DtypeError(TypeError): 66 | """Error that the dtype of a variable is incorrect.""" 67 | 68 | def __init__(self, message, actual=None, expected=None): 69 | if actual is not None and expected is not None: 70 | super().__init__(f"{message} (actual {actual} != expected {expected})") 71 | elif actual is not None and expected is None: 72 | super().__init__(f"{message} (actual {actual})") 73 | elif actual is None and expected is not None: 74 | super().__init__(f"{message} (expected {expected})") 75 | else: 76 | super().__init__(message) 77 | 78 | 79 | class TruncationError(RuntimeError): 80 | """Exception for errors generated from truncated graphs.""" 81 | 82 | 83 | class NotConstantValueError(ValueError): 84 | pass 85 | 86 | 87 | class BlockModelAccessError(RuntimeError): 88 | pass 89 | 90 | 91 | class UndefinedMomentException(Exception): 92 | pass 93 | -------------------------------------------------------------------------------- /pymc/gp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gaussian Processes.""" 16 | 17 | from pymc.gp import cov, mean, util 18 | from pymc.gp.gp import ( 19 | TP, 20 | Latent, 21 | LatentKron, 22 | Marginal, 23 | MarginalApprox, 24 | MarginalKron, 25 | MarginalSparse, 26 | ) 27 | from pymc.gp.hsgp_approx import HSGP, HSGPPeriodic 28 | -------------------------------------------------------------------------------- /pymc/gp/mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytensor.tensor as pt 16 | 17 | __all__ = ["Constant", "Linear", "Zero"] 18 | 19 | 20 | class Mean: 21 | """Base class for mean functions.""" 22 | 23 | def __call__(self, X): 24 | R""" 25 | Evaluate the mean function. 26 | 27 | Parameters 28 | ---------- 29 | X: The training inputs to the mean function. 30 | """ 31 | raise NotImplementedError 32 | 33 | def __add__(self, other): 34 | return Add(self, other) 35 | 36 | def __mul__(self, other): 37 | return Prod(self, other) 38 | 39 | 40 | class Zero(Mean): 41 | """Zero mean function for Gaussian process.""" 42 | 43 | def __call__(self, X): 44 | return pt.alloc(0.0, X.shape[0]) 45 | 46 | 47 | class Constant(Mean): 48 | """ 49 | Constant mean function for Gaussian process. 50 | 51 | Parameters 52 | ---------- 53 | c: variable, array or integer 54 | Constant mean value 55 | """ 56 | 57 | def __init__(self, c=0): 58 | super().__init__() 59 | self.c = c 60 | 61 | def __call__(self, X): 62 | return pt.alloc(1.0, X.shape[0]) * self.c 63 | 64 | 65 | class Linear(Mean): 66 | """ 67 | Linear mean function for Gaussian process. 68 | 69 | Parameters 70 | ---------- 71 | coeffs: variables 72 | Linear coefficients 73 | intercept: variable, array or integer 74 | Intercept for linear function (Defaults to zero) 75 | """ 76 | 77 | def __init__(self, coeffs, intercept=0): 78 | super().__init__() 79 | self.b = intercept 80 | self.A = coeffs 81 | 82 | def __call__(self, X): 83 | return pt.squeeze(pt.dot(X, self.A) + self.b) 84 | 85 | 86 | class Add(Mean): 87 | def __init__(self, first_mean, second_mean): 88 | super().__init__() 89 | self.m1 = first_mean 90 | self.m2 = second_mean 91 | 92 | def __call__(self, X): 93 | return pt.add(self.m1(X), self.m2(X)) 94 | 95 | 96 | class Prod(Mean): 97 | def __init__(self, first_mean, second_mean): 98 | super().__init__() 99 | self.m1 = first_mean 100 | self.m2 = second_mean 101 | 102 | def __call__(self, X): 103 | return pt.mul(self.m1(X), self.m2(X)) 104 | -------------------------------------------------------------------------------- /pymc/logprob/LICENSE_AEPPL.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-2022 aesara-devs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pymc/logprob/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # MIT License 16 | # 17 | # Copyright (c) 2021-2022 aesara-devs 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | 37 | """Conversion of PyMC graphs into logp graphs.""" 38 | 39 | from pymc.logprob.basic import ( 40 | conditional_logp, 41 | icdf, 42 | logcdf, 43 | logp, 44 | transformed_conditional_logp, 45 | ) 46 | 47 | # Add rewrites to the DBs 48 | import pymc.logprob.binary 49 | import pymc.logprob.censoring 50 | import pymc.logprob.cumsum 51 | import pymc.logprob.checks 52 | import pymc.logprob.linalg 53 | import pymc.logprob.mixture 54 | import pymc.logprob.order 55 | import pymc.logprob.scan 56 | import pymc.logprob.tensor 57 | import pymc.logprob.transforms 58 | 59 | 60 | __all__ = ( 61 | "icdf", 62 | "logcdf", 63 | "logp", 64 | ) 65 | -------------------------------------------------------------------------------- /pymc/logprob/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytensor.tensor as pt 15 | 16 | from pytensor.graph.rewriting.basic import node_rewriter 17 | from pytensor.tensor.math import _matrix_matrix_matmul 18 | 19 | from pymc.logprob.abstract import MeasurableBlockwise, MeasurableOp, _logprob, _logprob_helper 20 | from pymc.logprob.rewriting import measurable_ir_rewrites_db 21 | from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables 22 | 23 | 24 | class MeasurableMatMul(MeasurableBlockwise): 25 | """Measurable matrix multiplication operation.""" 26 | 27 | right_measurable: bool 28 | 29 | def __init__(self, measurable_right: bool, **kwargs): 30 | self.right_measurable = measurable_right 31 | super().__init__(**kwargs) 32 | 33 | 34 | @_logprob.register(MeasurableMatMul) 35 | def logprob_measurable_matmul(op, values, l, r): # noqa: E741 36 | [y_value] = values 37 | if op.right_measurable: 38 | A, x = l, r 39 | x_value = pt.linalg.solve(A, y_value) 40 | else: 41 | x, A = l, r 42 | x_value = pt.linalg.solve(A.mT, y_value.mT).mT 43 | 44 | x_logp = _logprob_helper(x, x_value) 45 | 46 | # The operation has a support dimensionality of 2 47 | # We need to reduce it if it's still present in the base logp 48 | if x_logp.type.ndim == x_value.type.ndim: 49 | x_logp = pt.sum(x_logp, axis=(-1, -2)) 50 | elif x_logp.type.ndim == x_value.type.ndim - 1: 51 | x_logp = pt.sum(x_logp, axis=-1) 52 | 53 | _, log_abs_jac_det = pt.linalg.slogdet(A) 54 | 55 | return x_logp - log_abs_jac_det 56 | 57 | 58 | @node_rewriter(tracks=[_matrix_matrix_matmul]) 59 | def find_measurable_matmul(fgraph, node): 60 | """Find measurable matrix-matrix multiplication operations.""" 61 | if isinstance(node.op, MeasurableOp): 62 | return None 63 | 64 | [out] = node.outputs 65 | [l, r] = node.inputs # noqa: E741 66 | 67 | # Check that not both a and r are measurable 68 | measurable_inputs = filter_measurable_variables([l, r]) 69 | if len(measurable_inputs) != 1: 70 | return None 71 | 72 | [measurable_input] = measurable_inputs 73 | 74 | # Check the measurable input is not broadcasted 75 | if measurable_input.type.broadcastable[:-2] != out.type.broadcastable[:-2]: 76 | return None 77 | 78 | measurable_right = measurable_input is r 79 | A = l if measurable_right else r 80 | 81 | # Check if the static shape already reveals a non-square matrix, 82 | if ( 83 | A.type.shape[-1] is not None 84 | and A.type.shape[-2] is not None 85 | and A.type.shape[-1] != A.type.shape[-2] 86 | ): 87 | return None 88 | 89 | # Check the other input is not potentially measurable 90 | if check_potential_measurability([A]): 91 | return None 92 | 93 | measurable_matmul = MeasurableMatMul(measurable_right=measurable_right, **node.op._props_dict()) 94 | return [measurable_matmul(l, r)] 95 | 96 | 97 | measurable_ir_rewrites_db.register( 98 | find_measurable_matmul.__name__, 99 | find_measurable_matmul, 100 | "basic", 101 | "linalg", 102 | ) 103 | -------------------------------------------------------------------------------- /pymc/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model object.""" 16 | 17 | from pymc.model.core import * 18 | from pymc.model.core import ValueGradFunction 19 | -------------------------------------------------------------------------------- /pymc/model/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model transforms.""" 16 | -------------------------------------------------------------------------------- /pymc/model/transform/basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from collections.abc import Sequence 15 | 16 | from pytensor import Variable, clone_replace 17 | from pytensor.graph import ancestors 18 | from pytensor.graph.fg import FunctionGraph 19 | 20 | from pymc.data import MinibatchOp 21 | from pymc.model.core import Model 22 | from pymc.model.fgraph import ( 23 | ModelObservedRV, 24 | ModelVar, 25 | fgraph_from_model, 26 | model_from_fgraph, 27 | ) 28 | 29 | ModelVariable = Variable | str 30 | 31 | 32 | def prune_vars_detached_from_observed(model: Model) -> Model: 33 | """Prune model variables that are not related to any observed variable in the Model.""" 34 | # Potentials are ambiguous as whether they correspond to likelihood or prior terms, 35 | # We simply raise for now 36 | if model.potentials: 37 | raise NotImplementedError("Pruning not implemented for models with Potentials") 38 | 39 | fgraph, _ = fgraph_from_model(model, inlined_views=True) 40 | observed_vars = ( 41 | out 42 | for node in fgraph.apply_nodes 43 | if isinstance(node.op, ModelObservedRV) 44 | for out in node.outputs 45 | ) 46 | ancestor_nodes = {var.owner for var in ancestors(observed_vars)} 47 | nodes_to_remove = { 48 | node 49 | for node in fgraph.apply_nodes 50 | if isinstance(node.op, ModelVar) and node not in ancestor_nodes 51 | } 52 | for node_to_remove in nodes_to_remove: 53 | fgraph.remove_node(node_to_remove) 54 | return model_from_fgraph(fgraph, mutate_fgraph=True) 55 | 56 | 57 | def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> list[Variable]: 58 | if isinstance(vars, list | tuple): 59 | vars_seq = vars 60 | else: 61 | vars_seq = (vars,) 62 | return [model[var] if isinstance(var, str) else var for var in vars_seq] 63 | 64 | 65 | def remove_minibatched_nodes(model: Model) -> Model: 66 | """Remove all uses of pm.Minibatch in the Model.""" 67 | fgraph, _ = fgraph_from_model(model) 68 | 69 | replacements = {} 70 | for var in fgraph.apply_nodes: 71 | if isinstance(var.op, MinibatchOp): 72 | for inp, out in zip(var.inputs, var.outputs): 73 | replacements[out] = inp 74 | 75 | old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] 76 | # Using `rebuild_strict=False` means all coords, names, and dim information is lost 77 | # So we need to restore it from the old fgraph 78 | new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] 79 | for old_out, new_out in zip(old_outs, new_outs): 80 | new_out.name = old_out.name 81 | fgraph = FunctionGraph(outputs=new_outs, clone=False) 82 | fgraph._coords = old_coords # type: ignore[attr-defined] 83 | fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] 84 | return model_from_fgraph(fgraph, mutate_fgraph=True) 85 | -------------------------------------------------------------------------------- /pymc/ode/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains tools used to perform inference on ordinary differential equations. 16 | 17 | Due to the nature of the model (as well as included solvers), ODE solution may perform slowly. 18 | Another library based on PyMC--sunode--has implemented Adams' method and BDF (backward differentation formula) using the very fast SUNDIALS suite of ODE and PDE solvers. 19 | It is much faster than the ``pm.ode`` implementation. 20 | More information about ``sunode`` is available at: https://github.com/aseyboldt/sunode. 21 | """ 22 | # Copyright 2020 The PyMC Developers 23 | # 24 | # Licensed under the Apache License, Version 2.0 (the "License"); 25 | # you may not use this file except in compliance with the License. 26 | # You may obtain a copy of the License at 27 | # 28 | # http://www.apache.org/licenses/LICENSE-2.0 29 | # 30 | # Unless required by applicable law or agreed to in writing, software 31 | # distributed under the License is distributed on an "AS IS" BASIS, 32 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 33 | # See the License for the specific language governing permissions and 34 | # limitations under the License. 35 | 36 | from pymc.ode import utils 37 | from pymc.ode.ode import DifferentialEquation 38 | -------------------------------------------------------------------------------- /pymc/plots/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Alias for the `plots` submodule from ArviZ. 16 | 17 | Plots are delegated to the ArviZ library, a general purpose library for 18 | "exploratory analysis of Bayesian models." 19 | See https://arviz-devs.github.io/arviz/ for details on plots. 20 | """ 21 | 22 | import functools 23 | import sys 24 | import warnings 25 | 26 | import arviz as az 27 | 28 | # Makes this module as identical to arviz.plots as possible 29 | for attr in az.plots.__all__: 30 | obj = getattr(az.plots, attr) 31 | if not attr.startswith("__"): 32 | setattr(sys.modules[__name__], attr, obj) 33 | 34 | 35 | def alias_deprecation(func, alias: str): 36 | original = func.__name__ 37 | 38 | @functools.wraps(func) 39 | def wrapped(*args, **kwargs): 40 | raise FutureWarning( 41 | f"The function `{alias}` from PyMC was an alias for `{original}` from ArviZ. " 42 | "It was removed in PyMC 4.0. " 43 | f"Switch to `pymc.{original}` or `arviz.{original}`." 44 | ) 45 | 46 | return wrapped 47 | 48 | 49 | # Aliases of ArviZ functions 50 | autocorrplot = alias_deprecation(az.plot_autocorr, alias="autocorrplot") 51 | forestplot = alias_deprecation(az.plot_forest, alias="forestplot") 52 | kdeplot = alias_deprecation(az.plot_kde, alias="kdeplot") 53 | energyplot = alias_deprecation(az.plot_energy, alias="energyplot") 54 | densityplot = alias_deprecation(az.plot_density, alias="densityplot") 55 | pairplot = alias_deprecation(az.plot_pair, alias="pairplot") 56 | traceplot = alias_deprecation(az.plot_trace, alias="traceplot") 57 | compareplot = alias_deprecation(az.plot_compare, alias="compareplot") 58 | 59 | 60 | __all__ = ( 61 | *az.plots.__all__, 62 | "autocorrplot", 63 | "compareplot", 64 | "forestplot", 65 | "kdeplot", 66 | "traceplot", 67 | "energyplot", 68 | "densityplot", 69 | "pairplot", 70 | ) 71 | -------------------------------------------------------------------------------- /pymc/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MCMC samplers.""" 16 | 17 | from pymc.sampling.deterministic import compute_deterministics 18 | from pymc.sampling.forward import * 19 | from pymc.sampling.mcmc import * 20 | -------------------------------------------------------------------------------- /pymc/smc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sequential Monte Carlo samplers.""" 16 | 17 | from pymc.smc.kernels import IMH, MH 18 | from pymc.smc.sampling import sample_smc 19 | 20 | __all__ = ("sample_smc",) 21 | -------------------------------------------------------------------------------- /pymc/stats/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Alias for the `stats` submodule from ArviZ. 16 | 17 | Diagnostics and auxiliary statistical functions are delegated to the ArviZ library, a general 18 | purpose library for "exploratory analysis of Bayesian models." 19 | See https://arviz-devs.github.io/arviz/ for details. 20 | """ 21 | 22 | import sys 23 | 24 | import arviz as az 25 | 26 | for attr in az.stats.__all__: 27 | obj = getattr(az.stats, attr) 28 | if not attr.startswith("__"): 29 | setattr(sys.modules[__name__], attr, obj) 30 | 31 | from pymc.stats.log_density import compute_log_likelihood, compute_log_prior 32 | 33 | __all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__) 34 | -------------------------------------------------------------------------------- /pymc/step_methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Step methods.""" 16 | 17 | from pymc.step_methods.compound import BlockedStep, CompoundStep 18 | from pymc.step_methods.hmc import NUTS, HamiltonianMC 19 | from pymc.step_methods.metropolis import ( 20 | BinaryGibbsMetropolis, 21 | BinaryMetropolis, 22 | CategoricalGibbsMetropolis, 23 | CauchyProposal, 24 | DEMetropolis, 25 | DEMetropolisZ, 26 | LaplaceProposal, 27 | Metropolis, 28 | MultivariateNormalProposal, 29 | NormalProposal, 30 | PoissonProposal, 31 | UniformProposal, 32 | ) 33 | from pymc.step_methods.slicer import Slice 34 | 35 | # Other step methods can be added by appending to this list 36 | STEP_METHODS: list[type[BlockedStep]] = [ 37 | NUTS, 38 | HamiltonianMC, 39 | Metropolis, 40 | BinaryMetropolis, 41 | BinaryGibbsMetropolis, 42 | Slice, 43 | CategoricalGibbsMetropolis, 44 | ] 45 | -------------------------------------------------------------------------------- /pymc/step_methods/hmc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hamiltonian Monte Carlo.""" 16 | 17 | from pymc.step_methods.hmc.hmc import HamiltonianMC 18 | from pymc.step_methods.hmc.nuts import NUTS 19 | -------------------------------------------------------------------------------- /pymc/tuning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tuning phase.""" 16 | 17 | from pymc.tuning.scaling import find_hessian, guess_scaling, trace_cov 18 | from pymc.tuning.starting import find_MAP 19 | -------------------------------------------------------------------------------- /pymc/variational/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Variational Monte Carlo.""" 16 | 17 | # commonly used 18 | from pymc.variational import ( 19 | approximations, 20 | callbacks, 21 | inference, 22 | operators, 23 | opvi, 24 | test_functions, 25 | updates, 26 | ) 27 | from pymc.variational.approximations import ( 28 | Empirical, 29 | FullRank, 30 | MeanField, 31 | sample_approx, 32 | ) 33 | from pymc.variational.inference import ( 34 | ADVI, 35 | ASVGD, 36 | SVGD, 37 | FullRankADVI, 38 | ImplicitGradient, 39 | Inference, 40 | KLqp, 41 | fit, 42 | ) 43 | from pymc.variational.opvi import Approximation, Group 44 | 45 | # special 46 | from pymc.variational.stein import Stein 47 | from pymc.variational.updates import ( 48 | adadelta, 49 | adagrad, 50 | adagrad_window, 51 | adam, 52 | adamax, 53 | apply_momentum, 54 | apply_nesterov_momentum, 55 | momentum, 56 | nesterov_momentum, 57 | norm_constraint, 58 | rmsprop, 59 | sgd, 60 | total_norm_constraint, 61 | ) 62 | -------------------------------------------------------------------------------- /pymc/variational/stein.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytensor.tensor as pt 16 | 17 | from pytensor.graph.replace import graph_replace 18 | 19 | from pymc.pytensorf import floatX 20 | from pymc.util import WithMemoization, locally_cachedmethod 21 | from pymc.variational.opvi import node_property 22 | from pymc.variational.test_functions import rbf 23 | 24 | __all__ = ["Stein"] 25 | 26 | 27 | class Stein(WithMemoization): 28 | def __init__(self, approx, kernel=rbf, use_histogram=True, temperature=1): 29 | self.approx = approx 30 | self.temperature = floatX(temperature) 31 | self._kernel_f = kernel 32 | self.use_histogram = use_histogram 33 | 34 | @property 35 | def input_joint_matrix(self): 36 | if self.use_histogram: 37 | return self.approx.joint_histogram 38 | else: 39 | return self.approx.symbolic_random 40 | 41 | @node_property 42 | def approx_symbolic_matrices(self): 43 | if self.use_histogram: 44 | return self.approx.collect("histogram") 45 | else: 46 | return self.approx.symbolic_randoms 47 | 48 | @node_property 49 | def dlogp(self): 50 | logp = self.logp_norm.sum() 51 | grad = pt.grad(logp, self.approx_symbolic_matrices) 52 | 53 | def flatten2(tensor): 54 | return tensor.flatten(2) 55 | 56 | return pt.concatenate(list(map(flatten2, grad)), -1) 57 | 58 | @node_property 59 | def grad(self): 60 | n = floatX(self.input_joint_matrix.shape[0]) 61 | temperature = self.temperature 62 | svgd_grad = self.density_part_grad / temperature + self.repulsive_part_grad 63 | return svgd_grad / n 64 | 65 | @node_property 66 | def density_part_grad(self): 67 | Kxy = self.Kxy 68 | dlogpdx = self.dlogp 69 | return pt.dot(Kxy, dlogpdx) 70 | 71 | @node_property 72 | def repulsive_part_grad(self): 73 | t = self.approx.symbolic_normalizing_constant 74 | dxkxy = self.dxkxy 75 | return dxkxy / t 76 | 77 | @property 78 | def Kxy(self): 79 | return self._kernel()[0] 80 | 81 | @property 82 | def dxkxy(self): 83 | return self._kernel()[1] 84 | 85 | @node_property 86 | def logp_norm(self): 87 | sized_symbolic_logp = self.approx.sized_symbolic_logp 88 | if self.use_histogram: 89 | sized_symbolic_logp = graph_replace( 90 | sized_symbolic_logp, 91 | dict(zip(self.approx.symbolic_randoms, self.approx.collect("histogram"))), 92 | strict=False, 93 | ) 94 | return sized_symbolic_logp / self.approx.symbolic_normalizing_constant 95 | 96 | @locally_cachedmethod 97 | def _kernel(self): 98 | return self._kernel_f(self.input_joint_matrix) 99 | -------------------------------------------------------------------------------- /pymc/variational/test_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pytensor import tensor as pt 16 | 17 | from pymc.pytensorf import floatX 18 | from pymc.variational.opvi import TestFunction 19 | 20 | __all__ = ["rbf"] 21 | 22 | 23 | class Kernel(TestFunction): 24 | r""" 25 | Dummy base class for kernel SVGD in case we implement more. 26 | 27 | .. math:: 28 | 29 | f(x) -> (k(x,.), \nabla_x k(x,.)) 30 | 31 | """ 32 | 33 | 34 | class RBF(Kernel): 35 | def __call__(self, X): 36 | XY = X.dot(X.T) 37 | x2 = pt.sum(X**2, axis=1).dimshuffle(0, "x") 38 | X2e = pt.repeat(x2, X.shape[0], axis=1) 39 | H = X2e + X2e.T - 2.0 * XY 40 | 41 | V = pt.sort(H.flatten()) 42 | length = V.shape[0] 43 | # median distance 44 | m = pt.switch( 45 | pt.eq((length % 2), 0), 46 | # if even vector 47 | pt.mean(V[((length // 2) - 1) : ((length // 2) + 1)]), 48 | # if odd vector 49 | V[length // 2], 50 | ) 51 | 52 | h = 0.5 * m / pt.log(floatX(H.shape[0]) + floatX(1)) 53 | 54 | # RBF 55 | Kxy = pt.exp(-H / h / 2.0) 56 | 57 | # Derivative 58 | dxkxy = -pt.dot(Kxy, X) 59 | sumkxy = pt.sum(Kxy, axis=-1, keepdims=True) 60 | dxkxy = pt.add(dxkxy, pt.mul(X, sumkxy)) / h 61 | 62 | return Kxy, dxkxy 63 | 64 | 65 | rbf = RBF() 66 | -------------------------------------------------------------------------------- /pymc/vartypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = [ 16 | "bool_types", 17 | "complex_types", 18 | "continuous_types", 19 | "discrete_types", 20 | "float_types", 21 | "int_types", 22 | "isgenerator", 23 | "typefilter", 24 | ] 25 | 26 | bool_types = {"int8"} 27 | 28 | int_types = {"int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"} 29 | float_types = {"float32", "float64"} 30 | complex_types = {"complex64", "complex128"} 31 | continuous_types = float_types | complex_types 32 | discrete_types = bool_types | int_types 33 | 34 | string_types = str 35 | 36 | 37 | def typefilter(vars, types): 38 | # Returns variables of type `types` from `vars` 39 | return [v for v in vars if v.dtype in types] 40 | 41 | 42 | def isgenerator(obj): 43 | return hasattr(obj, "__next__") 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "versioneer[toml]==0.29"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | testpaths = ["tests"] 7 | minversion = "6.0" 8 | xfail_strict = true 9 | addopts = ["--color=yes"] 10 | 11 | [tool.versioneer] 12 | VCS = "git" 13 | style = "pep440" 14 | versionfile_source = "pymc/_version.py" 15 | versionfile_build = "pymc/_version.py" 16 | tag_prefix = "v" 17 | 18 | [tool.mypy] 19 | python_version = "3.10" 20 | no_implicit_optional = false 21 | strict_optional = true 22 | warn_redundant_casts = false 23 | check_untyped_defs = false 24 | disallow_untyped_calls = false 25 | disallow_incomplete_defs = false 26 | disallow_untyped_defs = false 27 | disallow_untyped_decorators = false 28 | ignore_missing_imports = true 29 | warn_unused_ignores = false 30 | 31 | [tool.ruff] 32 | line-length = 100 33 | target-version = "py310" 34 | extend-exclude = ["_version.py"] 35 | 36 | [tool.ruff.format] 37 | docstring-code-format = true 38 | 39 | [tool.ruff.lint] 40 | select = ["C4", "D", "E", "F", "I", "UP", "W", "RUF", "T20", "TID"] 41 | ignore = [ 42 | "E501", 43 | "F841", # Local variable name is assigned to but never used 44 | "RUF001", # String contains ambiguous character (such as Greek letters) 45 | "RUF002", # Docstring contains ambiguous character (such as Greek letters) 46 | "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` 47 | "D100", # Missing docstring in public module 48 | "D101", # Missing docstring in public class 49 | "D102", # Missing docstring in public method 50 | "D103", # Missing docstring in public function 51 | "D105", # Missing docstring in magic method 52 | ] 53 | 54 | [tool.ruff.lint.pydocstyle] 55 | convention = "numpy" 56 | 57 | [tool.ruff.lint.isort] 58 | lines-between-types = 1 59 | 60 | [tool.ruff.lint.extend-per-file-ignores] 61 | "__init__.py" = [ 62 | "F401", # Module imported but unused 63 | "F403", # 'from module import *' used; unable to detect undefined names 64 | ] 65 | "docs/source/*" = ["D"] 66 | "pymc/__init__.py" = [ 67 | "E402", # Module level import not at top of file 68 | ] 69 | "pymc/stats/__init__.py" = [ 70 | "E402", # Module level import not at top of file 71 | ] 72 | "pymc/logprob/__init__.py" = [ 73 | "I001", # Import block is un-sorted or un-formatted 74 | ] 75 | "tests/*" = ["D"] 76 | "scripts/run_mypy.py" = [ 77 | "T201", # No print statements 78 | ] 79 | "*.ipynb" = [ 80 | "T201", # No print statements 81 | ] 82 | 83 | [tool.coverage.report] 84 | exclude_lines = [ 85 | "pragma: nocover", 86 | "raise NotImplementedError", 87 | "if TYPE_CHECKING:", 88 | ] 89 | 90 | [tool.coverage.run] 91 | omit = ["*examples*"] 92 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. 2 | # See that file for comments about the need/usage of each dependency. 3 | 4 | arviz>=0.13.0 5 | cachetools>=4.2.1 6 | cloudpickle 7 | git+https://github.com/pymc-devs/pymc-sphinx-theme 8 | ipython>=7.16 9 | jupyter-sphinx 10 | mcbackend>=0.4.0 11 | mypy==1.15.0 12 | myst-nb<=1.0.0 13 | numdifftools>=0.9.40 14 | numpy>=1.25.0 15 | numpydoc 16 | pandas>=0.24.0 17 | polyagamma 18 | pre-commit>=2.8.0 19 | pytensor>=2.31.2,<2.32 20 | pytest-cov>=2.5 21 | pytest>=3.0 22 | rich>=13.7.1 23 | scipy>=1.4.1 24 | sphinx-copybutton 25 | sphinx-design 26 | sphinx-notfound-page 27 | sphinx-remove-toctrees 28 | sphinx>=1.5 29 | sphinxext-rediraffe 30 | threadpoolctl>=3.1.0 31 | types-cachetools 32 | typing-extensions>=3.7.4 33 | watermark 34 | zarr>=2.5.0,<3 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | arviz>=0.13.0 2 | cachetools>=4.2.1 3 | cloudpickle 4 | numpy>=1.25.0 5 | pandas>=0.24.0 6 | pytensor>=2.31.2,<2.32 7 | rich>=13.7.1 8 | scipy>=1.4.1 9 | threadpoolctl>=3.1.0,<4.0.0 10 | typing-extensions>=3.7.4 11 | -------------------------------------------------------------------------------- /scripts/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM jupyter/base-notebook:python-3.9.12 2 | 3 | LABEL name="pymc" 4 | LABEL description="Environment for PyMC version 4" 5 | 6 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 7 | 8 | # Switch to jovyan to avoid container runs as root 9 | USER $NB_UID 10 | 11 | COPY /conda-envs/environment-dev.yml . 12 | RUN mamba env create -f environment-dev.yml && \ 13 | /bin/bash -c ". activate pymc-dev && \ 14 | mamba install -c conda-forge -y pymc" && \ 15 | conda clean --all -f -y 16 | 17 | # Fix PkgResourcesDeprecationWarning 18 | RUN pip install --upgrade --user setuptools==58.3.0 19 | 20 | #Setup working folder 21 | WORKDIR /home/jovyan/work 22 | 23 | # For running from bash 24 | SHELL ["/bin/bash","-c"] 25 | RUN echo "conda activate pymc-dev" >> ~/.bashrc && \ 26 | source ~/.bashrc 27 | 28 | # For running from jupyter notebook 29 | EXPOSE 8888 30 | CMD ["conda", "run", "--no-capture-output", "-n", "pymc-dev", "jupyter","notebook","--ip=0.0.0.0","--port=8888","--no-browser"] 31 | -------------------------------------------------------------------------------- /scripts/dev.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/mamba-org/micromamba-devcontainer:latest 2 | 3 | COPY --chown=${MAMBA_USER}:${MAMBA_USER} conda-envs/environment-dev.yml /tmp/environment-dev.yml 4 | RUN : \ 5 | && micromamba install --yes --name base --file /tmp/environment-dev.yml \ 6 | && micromamba clean --all --yes \ 7 | && rm /tmp/environment-dev.yml \ 8 | && sudo chmod -R a+rwx /opt/conda \ 9 | ; 10 | 11 | # Run subsequent commands in an activated Conda environment 12 | ARG MAMBA_DOCKERFILE_ACTIVATE=1 13 | -------------------------------------------------------------------------------- /scripts/docker_container.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | COMMAND="${1:-jupyter}" 4 | SRC_DIR=${SRC_DIR:-`pwd`} 5 | CONTAINER_NAME=${CONTAINER_NAME:-pymc} 6 | PORT=${PORT:-8888} 7 | 8 | # stop and remove previous instances of the pymc container to avoid naming conflicts 9 | if [[ $(docker ps -aq -f name=${CONTAINER_NAME}) ]]; then 10 | echo "Shutting down and removing previous instance of ${CONTAINER_NAME} container..." 11 | docker rm -f ${CONTAINER_NAME} 12 | fi 13 | 14 | # $COMMAND can be either `build` or `bash` or `jupyter` 15 | if [[ $COMMAND = 'build' ]]; then 16 | docker build \ 17 | -t ${CONTAINER_NAME} \ 18 | -f $SRC_DIR/scripts/Dockerfile $SRC_DIR 19 | 20 | elif [[ $COMMAND = 'bash' ]]; then 21 | docker run -it -v $SRC_DIR:/home/jovyan/work --rm --name ${CONTAINER_NAME} ${CONTAINER_NAME} bash 22 | else 23 | docker run -it -p $PORT:8888 -v $SRC_DIR:/home/jovyan/work --rm --name ${CONTAINER_NAME} ${CONTAINER_NAME} 24 | fi 25 | -------------------------------------------------------------------------------- /scripts/slowest_tests/extract-slow-tests.py: -------------------------------------------------------------------------------- 1 | """Parse the GitHub action log for test times. 2 | 3 | Taken from https://github.com/pymc-labs/pymc-marketing/tree/main/scripts/slowest_tests/extract-slow-tests.py 4 | 5 | """ 6 | 7 | import re 8 | import sys 9 | 10 | from pathlib import Path 11 | 12 | start_pattern = re.compile(r"==== slow") 13 | separator_pattern = re.compile(r"====") 14 | time_pattern = re.compile(r"(\d+\.\d+)s ") 15 | 16 | 17 | def extract_lines(lines: list[str]) -> list[str]: 18 | times = [] 19 | 20 | in_section = False 21 | for line in lines: 22 | detect_start = start_pattern.search(line) 23 | detect_end = separator_pattern.search(line) 24 | 25 | if detect_start: 26 | in_section = True 27 | 28 | if in_section: 29 | times.append(line) 30 | 31 | if not detect_start and in_section and detect_end: 32 | break 33 | 34 | return times 35 | 36 | 37 | def trim_up_to_match(pattern, string: str) -> str: 38 | match = pattern.search(string) 39 | if not match: 40 | return "" 41 | 42 | return string[match.start() :] 43 | 44 | 45 | def trim(pattern, lines: list[str]) -> list[str]: 46 | return [trim_up_to_match(pattern, line) for line in lines] 47 | 48 | 49 | def strip_ansi(text: str) -> str: 50 | ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") 51 | return ansi_escape.sub("", text) 52 | 53 | 54 | def format_times(times: list[str]) -> list[str]: 55 | return ( 56 | trim(separator_pattern, times[:1]) 57 | + trim(time_pattern, times[1:-1]) 58 | + [strip_ansi(line) for line in trim(separator_pattern, times[-1:])] 59 | ) 60 | 61 | 62 | def read_lines_from_stdin(): 63 | return sys.stdin.read().splitlines() 64 | 65 | 66 | def read_from_file(file: Path): 67 | """For testing purposes.""" 68 | return file.read_text().splitlines() 69 | 70 | 71 | def main(read_lines): 72 | lines = read_lines() 73 | times = extract_lines(lines) 74 | parsed_times = format_times(times) 75 | print("\n".join(parsed_times)) # noqa: T201 76 | 77 | 78 | if __name__ == "__main__": 79 | read_lines = read_lines_from_stdin 80 | main(read_lines) 81 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | _FLOATX=${FLOATX:=float64} 6 | PYTENSOR_FLAGS="floatX=${_FLOATX},gcc__cxxflags='-march=core2'" pytest -v --cov=pymc --cov-report=xml "$@" --cov-report term 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from codecs import open 17 | from os.path import dirname, join, realpath 18 | 19 | import versioneer 20 | 21 | from setuptools import find_packages, setup 22 | 23 | DESCRIPTION = "Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with PyTensor" 24 | AUTHOR = "PyMC Developers" 25 | AUTHOR_EMAIL = "pymc.devs@gmail.com" 26 | URL = "http://github.com/pymc-devs/pymc" 27 | LICENSE = "Apache License, Version 2.0" 28 | 29 | classifiers = [ 30 | "Development Status :: 5 - Production/Stable", 31 | "Programming Language :: Python", 32 | "Programming Language :: Python :: 3", 33 | "Programming Language :: Python :: 3.10", 34 | "Programming Language :: Python :: 3.11", 35 | "Programming Language :: Python :: 3.12", 36 | "Programming Language :: Python :: 3.13", 37 | "License :: OSI Approved :: Apache Software License", 38 | "Intended Audience :: Science/Research", 39 | "Topic :: Scientific/Engineering", 40 | "Topic :: Scientific/Engineering :: Mathematics", 41 | "Operating System :: OS Independent", 42 | ] 43 | 44 | PROJECT_ROOT = dirname(realpath(__file__)) 45 | 46 | # Get the long description from the README file 47 | with open(join(PROJECT_ROOT, "README.rst"), encoding="utf-8") as buff: 48 | LONG_DESCRIPTION = buff.read() 49 | 50 | REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt") 51 | 52 | with open(REQUIREMENTS_FILE) as f: 53 | install_reqs = f.read().splitlines() 54 | 55 | test_reqs = ["pytest", "pytest-cov"] 56 | 57 | if __name__ == "__main__": 58 | setup( 59 | name="pymc", 60 | version=versioneer.get_version(), 61 | cmdclass=versioneer.get_cmdclass(), 62 | maintainer=AUTHOR, 63 | maintainer_email=AUTHOR_EMAIL, 64 | description=DESCRIPTION, 65 | license=LICENSE, 66 | url=URL, 67 | long_description=LONG_DESCRIPTION, 68 | long_description_content_type="text/x-rst", 69 | packages=find_packages(exclude=["tests*"]), 70 | # because of an upload-size limit by PyPI, we're temporarily removing docs from the tarball. 71 | # Also see MANIFEST.in 72 | # package_data={'docs': ['*']}, 73 | classifiers=classifiers, 74 | python_requires=">=3.10", 75 | install_requires=install_reqs, 76 | tests_require=test_reqs, 77 | ) 78 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pymc as pm 16 | 17 | _log = pm._log 18 | -------------------------------------------------------------------------------- /tests/backends/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/backends/test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import pytest 16 | 17 | import pymc as pm 18 | 19 | from pymc.backends import _init_trace 20 | from pymc.backends.base import _choose_chains 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "n_points, tune, expected_length, expected_n_traces", 25 | [ 26 | ((5, 2, 2), 0, 2, 3), 27 | ((6, 1, 1), 1, 6, 1), 28 | ], 29 | ) 30 | def test_choose_chains(n_points, tune, expected_length, expected_n_traces): 31 | trace_0 = np.arange(n_points[0]) 32 | trace_1 = np.arange(n_points[1]) 33 | trace_2 = np.arange(n_points[2]) 34 | traces, length = _choose_chains([trace_0, trace_1, trace_2], tune=tune) 35 | assert length == expected_length 36 | assert expected_n_traces == len(traces) 37 | 38 | 39 | class TestInitTrace: 40 | def test_init_trace_continuation_unsupported(self): 41 | with pm.Model() as pmodel: 42 | A = pm.Normal("A") 43 | B = pm.Uniform("B") 44 | strace = pm.backends.ndarray.NDArray(vars=[A, B]) 45 | strace.setup(10, 0) 46 | strace.record({"A": 2, "B_interval__": 0.1}) 47 | assert len(strace) == 1 48 | with pytest.raises(ValueError, match="Continuation of traces"): 49 | _init_trace( 50 | expected_length=20, 51 | stats_dtypes=pm.Metropolis().stats_dtypes, 52 | chain_number=0, 53 | trace=strace, 54 | model=pmodel, 55 | ) 56 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | 16 | import numpy as np 17 | import pytensor 18 | import pytest 19 | 20 | 21 | @pytest.fixture(scope="function", autouse=True) 22 | def pytensor_config(): 23 | config = pytensor.config.change_flags(on_opt_error="raise") 24 | with config: 25 | yield 26 | 27 | 28 | @pytest.fixture(scope="function", autouse=True) 29 | def exception_verbosity(): 30 | config = pytensor.config.change_flags(exception_verbosity="high") 31 | with config: 32 | yield 33 | 34 | 35 | @pytest.fixture(scope="function", autouse=False) 36 | def strict_float32(): 37 | if pytensor.config.floatX == "float32": 38 | config = pytensor.config.change_flags(warn_float64="raise") 39 | with config: 40 | yield 41 | else: 42 | yield 43 | 44 | 45 | @pytest.fixture(scope="function", autouse=False) 46 | def seeded_test(): 47 | np.random.seed(20160911) 48 | 49 | 50 | @pytest.fixture 51 | def fail_on_warning(): 52 | with warnings.catch_warnings(): 53 | warnings.simplefilter("error") 54 | yield 55 | -------------------------------------------------------------------------------- /tests/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/distributions/moments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/distributions/test_random_alternative_backends.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from contextlib import nullcontext 15 | 16 | import numpy as np 17 | import pytest 18 | 19 | import pymc as pm 20 | 21 | from pymc import DirichletMultinomial, MvStudentT 22 | from pymc.model.transform.optimization import freeze_dims_and_data 23 | 24 | 25 | @pytest.fixture(params=["FAST_RUN", "JAX", "NUMBA"]) 26 | def mode(request): 27 | mode_param = request.param 28 | if mode_param != "FAST_RUN": 29 | pytest.importorskip(mode_param.lower()) 30 | return mode_param 31 | 32 | 33 | def test_dirichlet_multinomial(mode): 34 | """Test we can draw from a DM in the JAX backend if the shape is constant.""" 35 | dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) 36 | dm_draws = pm.draw(dm, mode=mode) 37 | np.testing.assert_equal(dm_draws, np.eye(3) * 5) 38 | 39 | 40 | def test_dirichlet_multinomial_dims(mode): 41 | """Test we can draw from a DM with a shape defined by dims in the JAX backend, 42 | after freezing those dims. 43 | """ 44 | with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: 45 | dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) 46 | 47 | # JAX does not allow us to JIT a function with dynamic shape 48 | expected_ctxt = pytest.raises(TypeError) if mode == "JAX" else nullcontext() 49 | with expected_ctxt: 50 | pm.draw(dm, mode=mode) 51 | 52 | # Should be fine after freezing the dims that specify the shape 53 | frozen_dm = freeze_dims_and_data(m)["dm"] 54 | dm_draws = pm.draw(frozen_dm, mode=mode) 55 | np.testing.assert_equal(dm_draws, np.eye(3) * 5) 56 | 57 | 58 | def test_mvstudentt(mode): 59 | mvt = MvStudentT.dist(nu=100, mu=[1, 2, 3], scale=np.eye(3) * [0.01, 1, 100], shape=(10_000, 3)) 60 | draws = pm.draw(mvt, mode=mode) 61 | np.testing.assert_allclose(draws.mean(0), [1, 2, 3], rtol=0.1) 62 | np.testing.assert_allclose(draws.std(0), np.sqrt([0.01, 1, 100]), rtol=0.1) 63 | 64 | 65 | def test_repeated_arguments(mode): 66 | # Regression test for a failure in Numba mode when a RV had repeated arguments 67 | v = 0.5 * 1e5 68 | x = pm.Beta.dist(v, v) 69 | x_draw = pm.draw(x, mode=mode) 70 | np.testing.assert_allclose(x_draw, 0.5, rtol=0.01) 71 | -------------------------------------------------------------------------------- /tests/gp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/gp/test_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import numpy.testing as npt 17 | 18 | import pymc as pm 19 | 20 | 21 | class TestZeroMean: 22 | def test_value(self): 23 | X = np.linspace(0, 1, 10)[:, None] 24 | with pm.Model() as model: 25 | zero_mean = pm.gp.mean.Zero() 26 | M = zero_mean(X).eval() 27 | assert np.all(M == 0) 28 | assert M.shape == (10,) 29 | 30 | 31 | class TestConstantMean: 32 | def test_value(self): 33 | X = np.linspace(0, 1, 10)[:, None] 34 | with pm.Model() as model: 35 | const_mean = pm.gp.mean.Constant(6) 36 | M = const_mean(X).eval() 37 | assert np.all(M == 6) 38 | assert M.shape == (10,) 39 | 40 | 41 | class TestLinearMean: 42 | def test_value(self): 43 | X = np.linspace(0, 1, 10)[:, None] 44 | with pm.Model() as model: 45 | linear_mean = pm.gp.mean.Linear(2, 0.5) 46 | M = linear_mean(X).eval() 47 | npt.assert_allclose(M[1], 0.7222, atol=1e-3) 48 | assert M.shape == (10,) 49 | 50 | 51 | class TestAddProdMean: 52 | def test_add(self): 53 | X = np.linspace(0, 1, 10)[:, None] 54 | with pm.Model() as model: 55 | mean1 = pm.gp.mean.Linear(coeffs=2, intercept=0.5) 56 | mean2 = pm.gp.mean.Constant(2) 57 | mean = mean1 + mean2 + mean2 58 | M = mean(X).eval() 59 | npt.assert_allclose(M[1], 0.7222 + 2 + 2, atol=1e-3) 60 | 61 | def test_prod(self): 62 | X = np.linspace(0, 1, 10)[:, None] 63 | with pm.Model() as model: 64 | mean1 = pm.gp.mean.Linear(coeffs=2, intercept=0.5) 65 | mean2 = pm.gp.mean.Constant(2) 66 | mean = mean1 * mean2 * mean2 67 | M = mean(X).eval() 68 | npt.assert_allclose(M[1], 0.7222 * 2 * 2, atol=1e-3) 69 | 70 | def test_add_multid(self): 71 | X = np.linspace(0, 1, 30).reshape(10, 3) 72 | A = np.array([1, 2, 3]) 73 | b = 10 74 | with pm.Model() as model: 75 | mean1 = pm.gp.mean.Linear(coeffs=A, intercept=b) 76 | mean2 = pm.gp.mean.Constant(2) 77 | mean = mean1 + mean2 + mean2 78 | M = mean(X).eval() 79 | npt.assert_allclose(M[1], 10.8965 + 2 + 2, atol=1e-3) 80 | 81 | def test_prod_multid(self): 82 | X = np.linspace(0, 1, 30).reshape(10, 3) 83 | A = np.array([1, 2, 3]) 84 | b = 10 85 | with pm.Model() as model: 86 | mean1 = pm.gp.mean.Linear(coeffs=A, intercept=b) 87 | mean2 = pm.gp.mean.Constant(2) 88 | mean = mean1 * mean2 * mean2 89 | M = mean(X).eval() 90 | npt.assert_allclose(M[1], 10.8965 * 2 * 2, atol=1e-3) 91 | -------------------------------------------------------------------------------- /tests/gp/test_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import numpy.testing as npt 17 | import pytensor.tensor as pt 18 | import pytest 19 | 20 | import pymc as pm 21 | 22 | 23 | class TestPlotGP: 24 | def test_plot_gp_dist(self): 25 | """Test that the plotting helper works with the stated input shapes.""" 26 | import matplotlib.pyplot as plt 27 | 28 | X = 100 29 | S = 500 30 | fig, ax = plt.subplots() 31 | pm.gp.util.plot_gp_dist( 32 | ax, x=np.linspace(0, 50, X), samples=np.random.normal(np.arange(X), size=(S, X)) 33 | ) 34 | plt.close() 35 | pass 36 | 37 | def test_plot_gp_dist_warn_nan(self): 38 | """Test that the plotting helper works with the stated input shapes.""" 39 | import matplotlib.pyplot as plt 40 | 41 | X = 100 42 | S = 500 43 | samples = np.random.normal(np.arange(X), size=(S, X)) 44 | samples[15, 3] = np.nan 45 | fig, ax = plt.subplots() 46 | with pytest.warns(UserWarning): 47 | pm.gp.util.plot_gp_dist(ax, x=np.linspace(0, 50, X), samples=samples) 48 | plt.close() 49 | pass 50 | 51 | 52 | class TestKmeansInducing: 53 | def setup_method(self): 54 | self.centers = (-5, 5) 55 | self.x = np.concatenate( 56 | (self.centers[0] + np.random.randn(500), self.centers[1] + np.random.randn(500)) 57 | ) 58 | 59 | def test_kmeans(self): 60 | X = self.x[:, None] 61 | Xu = pm.gp.util.kmeans_inducing_points(2, X).flatten() 62 | npt.assert_allclose(np.asarray(self.centers), np.sort(Xu), rtol=0.05) 63 | 64 | X = pt.as_tensor_variable(self.x[:, None]) 65 | Xu = pm.gp.util.kmeans_inducing_points(2, X).flatten() 66 | npt.assert_allclose(np.asarray(self.centers), np.sort(Xu), rtol=0.05) 67 | 68 | def test_kmeans_raises(self): 69 | with pytest.raises(TypeError): 70 | Xu = pm.gp.util.kmeans_inducing_points(2, "str is the wrong type").flatten() 71 | 72 | 73 | class TestReplaceWithValues: 74 | def test_basic_replace(self): 75 | with pm.Model() as model: 76 | a = pm.Normal("a") 77 | b = pm.Normal("b", mu=a) 78 | c = a * b 79 | 80 | (c_val,) = pm.gp.util.replace_with_values( 81 | [c], replacements={"a": 2, "b": 3, "x": 100}, model=model 82 | ) 83 | assert c_val == np.array(6.0) 84 | 85 | def test_replace_no_inputs_needed(self): 86 | with pm.Model() as model: 87 | a = pt.as_tensor_variable(2.0) 88 | b = 1.0 + a 89 | c = a * b 90 | (c_val,) = pm.gp.util.replace_with_values([c], replacements={"x": 100}) 91 | assert c_val == np.array(6.0) 92 | 93 | def test_missing_input(self): 94 | with pm.Model() as model: 95 | a = pm.Normal("a") 96 | b = pm.Normal("b", mu=a) 97 | c = a * b 98 | 99 | with pytest.raises(ValueError): 100 | (c_val,) = pm.gp.util.replace_with_values( 101 | [c], replacements={"a": 2, "x": 100}, model=model 102 | ) 103 | -------------------------------------------------------------------------------- /tests/logprob/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/logprob/test_linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import pytest 16 | 17 | from pytensor.tensor.type import tensor 18 | 19 | from pymc.distributions import MatrixNormal, MvNormal, Normal 20 | from pymc.logprob.basic import logp 21 | 22 | 23 | @pytest.mark.parametrize("univariate", [True, False]) 24 | @pytest.mark.parametrize("batch_shape", [(), (3,)]) 25 | def test_matrix_vector_transform(univariate, batch_shape): 26 | rng = np.random.default_rng(755) 27 | 28 | μ = rng.normal(size=(*batch_shape, 2)) 29 | if univariate: 30 | σ = np.abs(rng.normal(size=(*batch_shape, 2))) 31 | Σ = np.eye(2) * (σ**2)[..., None] 32 | x = Normal.dist(mu=μ, sigma=σ) 33 | else: 34 | A = rng.normal(size=(*batch_shape, 2, 2)) 35 | Σ = np.swapaxes(A, -1, -2) @ A 36 | x = MvNormal.dist(mu=μ, cov=Σ) 37 | 38 | c = rng.normal(size=(*batch_shape, 2)) 39 | B = rng.normal(size=(*batch_shape, 2, 2)) 40 | y = c + (B @ x[..., None]).squeeze(-1) 41 | 42 | # An affine transformed MvNormal is still a MvNormal 43 | # https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Affine_transformation 44 | ref_dist = MvNormal.dist( 45 | mu=c + (B @ μ[..., None]).squeeze(-1), cov=B @ Σ @ np.swapaxes(B, -1, -2) 46 | ) 47 | test_y = rng.normal(size=(*batch_shape, 2)) 48 | np.testing.assert_allclose( 49 | logp(y, test_y).eval(), 50 | logp(ref_dist, test_y).eval(), 51 | ) 52 | 53 | 54 | def test_matrix_matrix_transform(): 55 | rng = np.random.default_rng(46) 56 | 57 | n, p = 2, 3 58 | M = rng.normal(size=(n, p)) 59 | A = rng.normal(size=(n, n)) * 0.1 60 | U = A.T @ A 61 | B = rng.normal(size=(p, p)) * 0.1 62 | V = B.T @ B 63 | X = MatrixNormal.dist(mu=M, rowcov=U, colcov=V) 64 | 65 | D = rng.normal(size=(n, n)) 66 | C = rng.normal(size=(p, p)) 67 | Y = D @ X @ C 68 | 69 | # A linearly transformed MatrixNormal is still a MatrixNormal 70 | # https://en.wikipedia.org/wiki/Matrix_normal_distribution#Transformation 71 | ref_dist = MatrixNormal.dist(mu=D @ M @ C, rowcov=D @ U @ D.T, colcov=C.T @ V @ C) 72 | test_Y = rng.normal(size=(n, p)) 73 | np.testing.assert_allclose( 74 | logp(Y, test_Y).eval(), 75 | logp(ref_dist, test_Y).eval(), 76 | rtol=1e-5, 77 | ) 78 | 79 | 80 | def test_broadcasted_matmul_fails(): 81 | x = Normal.dist(size=(3, 2)) 82 | A = tensor("A", shape=(4, 3, 3)) 83 | y = A @ x 84 | with pytest.raises(NotImplementedError): 85 | logp(y, y.type()) 86 | -------------------------------------------------------------------------------- /tests/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/model/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/model/transform/test_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | 16 | import pymc as pm 17 | 18 | from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes 19 | 20 | 21 | def test_prune_vars_detached_from_observed(): 22 | with pm.Model() as m: 23 | obs_data = pm.Data("obs_data", 0) 24 | a0 = pm.Data("a0", 0) 25 | a1 = pm.Normal("a1", a0) 26 | a2 = pm.Normal("a2", a1) 27 | pm.Normal("obs", a2, observed=obs_data) 28 | 29 | d0 = pm.Data("d0", 0) 30 | d1 = pm.Normal("d1", d0) 31 | 32 | assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} 33 | pruned_m = prune_vars_detached_from_observed(m) 34 | assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} 35 | 36 | 37 | def test_remove_minibatches(): 38 | data_size = 100 39 | data = np.zeros((data_size,)) 40 | batch_size = 10 41 | with pm.Model(coords={"d": range(5)}) as m1: 42 | mb = pm.Minibatch(data, batch_size=batch_size) 43 | mu = pm.Normal("mu", dims="d") 44 | x = pm.Normal("x") 45 | y = pm.Normal("y", x, observed=mb, total_size=100) 46 | 47 | m2 = remove_minibatched_nodes(m1) 48 | assert m1.y.shape[0].eval() == batch_size 49 | assert m2.y.shape[0].eval() == data_size 50 | assert m1.coords == m2.coords 51 | assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval() 52 | -------------------------------------------------------------------------------- /tests/ode/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/ode/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import scipy.integrate as ode 17 | 18 | from pymc.ode.utils import augment_system 19 | 20 | 21 | def test_gradients(): 22 | """Tests the computation of the sensitivities from the PyTensor computation graph""" 23 | 24 | # ODE system for which to compute gradients 25 | def ode_func(y, t, p): 26 | return np.exp(-t) - p[0] * y[0] 27 | 28 | # Computation of graidients with PyTensor 29 | augmented_ode_func = augment_system(ode_func, n_states=1, n_theta=1) 30 | 31 | # This is the new system, ODE + Sensitivities, which will be integrated 32 | def augmented_system(Y, t, p): 33 | dydt, ddt_dydp = augmented_ode_func(Y[:1], t, p, Y[1:]) 34 | derivatives = np.concatenate([dydt, ddt_dydp]) 35 | return derivatives 36 | 37 | # Create real sensitivities 38 | y0 = 0.0 39 | t = np.arange(0, 12, 0.25).reshape(-1, 1) 40 | a = 0.472 41 | p = np.array([y0, a]) 42 | 43 | # Derivatives of the analytic solution with respect to y0 and alpha 44 | # Treat y0 like a parameter and solve analytically. Then differentiate. 45 | # I used CAS to get these derivatives 46 | y0_sensitivity = np.exp(-a * t) 47 | a_sensitivity = ( 48 | -(np.exp(t * (a - 1)) - 1 + (a - 1) * (y0 * a - y0 - 1) * t) * np.exp(-a * t) / (a - 1) ** 2 49 | ) 50 | 51 | sensitivity = np.c_[y0_sensitivity, a_sensitivity] 52 | 53 | integrated_solutions = ode.odeint(func=augmented_system, y0=[y0, 1, 0], t=t.ravel(), args=(p,)) 54 | simulated_sensitivity = integrated_solutions[:, 1:] 55 | 56 | np.testing.assert_allclose(sensitivity, simulated_sensitivity, rtol=1e-5) 57 | -------------------------------------------------------------------------------- /tests/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/sampling/test_deterministic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import pytest 16 | 17 | from numpy.testing import assert_allclose 18 | 19 | from pymc.distributions import Normal 20 | from pymc.model.core import Deterministic, Model 21 | from pymc.sampling.deterministic import compute_deterministics 22 | from pymc.sampling.forward import sample_prior_predictive 23 | 24 | # Turn all warnings into errors for this module 25 | pytestmark = pytest.mark.filterwarnings("error") 26 | 27 | 28 | def test_compute_deterministics(): 29 | with Model(coords={"group": (0, 2, 4)}) as m: 30 | mu_raw = Normal("mu_raw", 0, 1, dims="group") 31 | mu = Deterministic("mu", mu_raw.cumsum(), dims="group") 32 | 33 | sigma_raw = Normal("sigma_raw", 0, 1) 34 | sigma = Deterministic("sigma", sigma_raw.exp()) 35 | 36 | dataset = sample_prior_predictive( 37 | draws=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22 38 | ).prior 39 | 40 | # Test default 41 | with m: 42 | all_dets = compute_deterministics(dataset) 43 | assert set(all_dets.data_vars.variables) == {"mu", "sigma"} 44 | assert all_dets["mu"].dims == ("chain", "draw", "group") 45 | assert all_dets["sigma"].dims == ("chain", "draw") 46 | assert_allclose(all_dets["mu"], dataset["mu_raw"].cumsum("group")) 47 | assert_allclose(all_dets["sigma"], np.exp(dataset["sigma_raw"])) 48 | 49 | # Test custom arguments 50 | extended_with_mu = compute_deterministics( 51 | dataset, 52 | var_names=["mu"], 53 | merge_dataset=True, 54 | model=m, 55 | compile_kwargs={"mode": "FAST_COMPILE"}, 56 | progressbar=False, 57 | ) 58 | assert set(extended_with_mu.data_vars.variables) == {"mu_raw", "sigma_raw", "mu"} 59 | assert extended_with_mu["mu"].dims == ("chain", "draw", "group") 60 | assert_allclose(extended_with_mu["mu"], dataset["mu_raw"].cumsum("group")) 61 | 62 | only_sigma = compute_deterministics(dataset, var_names=["sigma"], model=m, progressbar=False) 63 | assert set(only_sigma.data_vars.variables) == {"sigma"} 64 | assert only_sigma["sigma"].dims == ("chain", "draw") 65 | assert_allclose(only_sigma["sigma"], np.exp(dataset["sigma_raw"])) 66 | 67 | 68 | def test_docstring_example(): 69 | import pymc as pm 70 | 71 | with pm.Model(coords={"group": (0, 2, 4)}) as m: 72 | mu_raw = pm.Normal("mu_raw", 0, 1, dims="group") 73 | mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group") 74 | 75 | trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5, draws=5) 76 | 77 | assert "mu" not in trace.posterior 78 | 79 | with m: 80 | trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True) 81 | 82 | assert "mu" in trace.posterior 83 | -------------------------------------------------------------------------------- /tests/sampling/test_mcmc_external.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import numpy.testing as npt 17 | import pytest 18 | 19 | from pymc import Data, Model, Normal, sample 20 | 21 | 22 | @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) 23 | def test_external_nuts_sampler(recwarn, nuts_sampler): 24 | if nuts_sampler != "pymc": 25 | pytest.importorskip(nuts_sampler) 26 | 27 | with Model(): 28 | x = Normal("x", 100, 5) 29 | y = Data("y", [1, 2, 3, 4]) 30 | Data("z", [100, 190, 310, 405]) 31 | 32 | Normal("L", mu=x, sigma=0.1, observed=y) 33 | 34 | kwargs = { 35 | "nuts_sampler": nuts_sampler, 36 | "random_seed": 123, 37 | "chains": 2, 38 | "tune": 500, 39 | "draws": 500, 40 | "progressbar": False, 41 | "initvals": {"x": 0.0}, 42 | } 43 | 44 | idata1 = sample(**kwargs) 45 | idata2 = sample(**kwargs) 46 | 47 | reference_kwargs = kwargs.copy() 48 | reference_kwargs["nuts_sampler"] = "pymc" 49 | idata_reference = sample(**reference_kwargs) 50 | 51 | warns = { 52 | (warn.category, warn.message.args[0]) 53 | for warn in recwarn 54 | if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) 55 | } 56 | expected = set() 57 | if nuts_sampler == "nutpie": 58 | expected.add( 59 | ( 60 | UserWarning, 61 | "`initvals` are currently not passed to nutpie sampler. " 62 | "Use `init_mean` kwarg following nutpie specification instead.", 63 | ) 64 | ) 65 | assert warns == expected 66 | assert "y" in idata1.constant_data 67 | assert "z" in idata1.constant_data 68 | assert "L" in idata1.observed_data 69 | assert idata1.posterior.chain.size == 2 70 | assert idata1.posterior.draw.size == 500 71 | assert idata1.posterior.tuning_steps == 500 72 | np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) 73 | 74 | assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() 75 | 76 | 77 | def test_step_args(): 78 | with Model() as model: 79 | a = Normal("a") 80 | idata = sample( 81 | nuts_sampler="numpyro", 82 | target_accept=0.5, 83 | nuts={"max_treedepth": 10}, 84 | random_seed=1411, 85 | progressbar=False, 86 | ) 87 | 88 | npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) 89 | -------------------------------------------------------------------------------- /tests/sampling/test_population.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | import pymc as pm 18 | 19 | from pymc.step_methods.metropolis import DEMetropolis 20 | 21 | 22 | class TestPopulationSamplers: 23 | steppers = [DEMetropolis] 24 | 25 | def test_checks_population_size(self): 26 | """Test that population samplers check the population size.""" 27 | with pm.Model() as model: 28 | n = pm.Normal("n", mu=0, sigma=1) 29 | for stepper in TestPopulationSamplers.steppers: 30 | step = stepper() 31 | with pytest.raises(ValueError, match="requires at least 3 chains"): 32 | pm.sample(draws=10, tune=10, chains=1, cores=1, step=step) 33 | # don't parallelize to make test faster 34 | pm.sample( 35 | draws=10, 36 | tune=10, 37 | chains=4, 38 | cores=1, 39 | step=step, 40 | compute_convergence_checks=False, 41 | ) 42 | 43 | def test_demcmc_warning_on_small_populations(self): 44 | """Test that a warning is raised when n_chains <= n_dims""" 45 | with pm.Model() as model: 46 | pm.Normal("n", mu=0, sigma=1, size=(2, 3)) 47 | with pytest.warns(UserWarning, match="more chains than dimensions"): 48 | pm.sample( 49 | draws=5, 50 | tune=5, 51 | chains=6, 52 | step=DEMetropolis(), 53 | # make tests faster by not parallelizing; disable convergence warning 54 | cores=1, 55 | compute_convergence_checks=False, 56 | ) 57 | 58 | def test_nonparallelized_chains_are_random(self): 59 | with pm.Model() as model: 60 | x = pm.Normal("x", 0, 1) 61 | for stepper in TestPopulationSamplers.steppers: 62 | step = stepper() 63 | idata = pm.sample( 64 | chains=4, 65 | cores=1, 66 | draws=20, 67 | tune=0, 68 | step=step, 69 | compute_convergence_checks=False, 70 | ) 71 | samples = idata.posterior["x"].values[:, 5] 72 | 73 | assert len(set(samples)) == 4, f"Parallelized {stepper} chains are identical." 74 | 75 | def test_parallelized_chains_are_random(self): 76 | with pm.Model() as model: 77 | x = pm.Normal("x", 0, 1) 78 | for stepper in TestPopulationSamplers.steppers: 79 | step = stepper() 80 | idata = pm.sample( 81 | chains=4, 82 | cores=4, 83 | draws=20, 84 | tune=0, 85 | step=step, 86 | compute_convergence_checks=False, 87 | ) 88 | samples = idata.posterior["x"].values[:, 5] 89 | 90 | assert len(set(samples)) == 4, f"Parallelized {stepper} chains are identical." 91 | -------------------------------------------------------------------------------- /tests/smc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/stats/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/stats/test_convergence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import arviz 18 | import numpy as np 19 | 20 | from pymc.stats import convergence 21 | 22 | 23 | def test_warn_divergences(): 24 | idata = arviz.from_dict( 25 | sample_stats={ 26 | "diverging": np.array([[1, 0, 1, 0], [0, 0, 0, 0]]).astype(bool), 27 | } 28 | ) 29 | warns = convergence.warn_divergences(idata) 30 | assert len(warns) == 1 31 | assert "2 divergences after tuning" in warns[0].message 32 | 33 | 34 | def test_warn_treedepth(): 35 | idata = arviz.from_dict( 36 | sample_stats={ 37 | "reached_max_treedepth": np.array([[0, 0, 0], [0, 1, 0]]).astype(bool), 38 | } 39 | ) 40 | warns = convergence.warn_treedepth(idata) 41 | assert len(warns) == 1 42 | assert "Chain 1 reached the maximum tree depth" in warns[0].message 43 | 44 | 45 | def test_warn_treedepth_multiple_samplers(): 46 | """Check we handle cases when sampling with multiple NUTS samplers, each of which reports max_treedepth.""" 47 | max_treedepth = np.zeros((3, 2, 2), dtype=bool) 48 | max_treedepth[0, 0, 0] = True 49 | max_treedepth[2, 1, 1] = True 50 | idata = arviz.from_dict( 51 | sample_stats={ 52 | "reached_max_treedepth": max_treedepth, 53 | } 54 | ) 55 | warns = convergence.warn_treedepth(idata) 56 | assert len(warns) == 2 57 | assert "Chain 0 reached the maximum tree depth" in warns[0].message 58 | assert "Chain 2 reached the maximum tree depth" in warns[1].message 59 | 60 | 61 | def test_log_warning_stats(caplog): 62 | s1 = {"warning": "Temperature too low!"} 63 | s2 = {"warning": "Temperature too high!"} 64 | stats = [s1, s2] 65 | 66 | with caplog.at_level(logging.WARNING): 67 | convergence.log_warning_stats(stats) 68 | 69 | # We have a list of stats dicts, because there might be several samplers involved. 70 | assert "too low" in caplog.records[0].message 71 | assert "too high" in caplog.records[1].message 72 | 73 | 74 | def test_log_warning_stats_knows_SamplerWarning(caplog): 75 | """Checks that SamplerWarning "warning" stats get special treatment.""" 76 | warn = convergence.SamplerWarning( 77 | convergence.WarningType.BAD_ENERGY, 78 | "Not that interesting", 79 | "debug", 80 | ) 81 | stats = [{"warning": warn}] 82 | 83 | with caplog.at_level(logging.DEBUG, logger="pymc"): 84 | convergence.log_warning_stats(stats) 85 | 86 | assert "Not that interesting" in caplog.records[0].message 87 | -------------------------------------------------------------------------------- /tests/step_methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/step_methods/hmc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/step_methods/hmc/test_hmc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | import numpy as np 18 | import numpy.testing as npt 19 | import pytest 20 | 21 | import pymc as pm 22 | 23 | from pymc.blocking import DictToArrayBijection, RaveledVars 24 | from pymc.pytensorf import floatX 25 | from pymc.step_methods.hmc import HamiltonianMC 26 | from pymc.step_methods.hmc.base_hmc import BaseHMC 27 | from tests import models 28 | from tests.helpers import RVsAssignmentStepsTester, StepMethodTester 29 | 30 | 31 | class TestStepHamiltonianMC(StepMethodTester): 32 | @pytest.mark.parametrize( 33 | "step_fn, draws", 34 | [ 35 | (lambda C, _: HamiltonianMC(scaling=C, is_cov=True, blocked=False), 1000), 36 | (lambda C, _: HamiltonianMC(scaling=C, is_cov=True), 1000), 37 | ], 38 | ) 39 | def test_step_continuous(self, step_fn, draws): 40 | self.step_continuous(step_fn, draws) 41 | 42 | 43 | class TestRVsAssignmentHamiltonianMC(RVsAssignmentStepsTester): 44 | @pytest.mark.parametrize("step, step_kwargs", [(HamiltonianMC, {})]) 45 | def test_continuous_steps(self, step, step_kwargs): 46 | self.continuous_steps(step, step_kwargs) 47 | 48 | 49 | def test_leapfrog_reversible(): 50 | n = 3 51 | np.random.seed(42) 52 | start, model, _ = models.non_normal(n) 53 | size = sum(start[n.name].size for n in model.value_vars) 54 | scaling = floatX(np.random.rand(size)) 55 | 56 | class HMC(BaseHMC): 57 | def _hamiltonian_step(self, *args, **kwargs): 58 | pass 59 | 60 | step = HMC(vars=model.value_vars, model=model, scaling=scaling) 61 | 62 | astart = DictToArrayBijection.map(start) 63 | p = RaveledVars(floatX(step.potential.random()), astart.point_map_info) 64 | q = floatX(np.random.randn(size)) 65 | start = step.integrator.compute_state(p, q) 66 | for epsilon in [0.01, 0.1]: 67 | for n_steps in [1, 2, 3, 4, 20]: 68 | state = start 69 | for _ in range(n_steps): 70 | state = step.integrator.step(epsilon, state) 71 | for _ in range(n_steps): 72 | state = step.integrator.step(-epsilon, state) 73 | npt.assert_allclose(state.q.data, start.q.data, rtol=1e-5) 74 | npt.assert_allclose(state.p.data, start.p.data, rtol=1e-5) 75 | 76 | 77 | def test_nuts_tuning(): 78 | with pm.Model(): 79 | pm.Normal("mu", mu=0, sigma=1) 80 | step = pm.NUTS() 81 | with warnings.catch_warnings(): 82 | warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) 83 | idata = pm.sample( 84 | 10, step=step, tune=5, discard_tuned_samples=False, progressbar=False, chains=1 85 | ) 86 | 87 | assert not step.tune 88 | ss_tuned = idata.warmup_sample_stats["step_size"][0, -1] 89 | ss_posterior = idata.sample_stats["step_size"][0, :] 90 | np.testing.assert_array_equal(ss_posterior, ss_tuned) 91 | -------------------------------------------------------------------------------- /tests/step_methods/test_slicer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pytest 17 | 18 | from pymc.step_methods.slicer import Slice 19 | from tests import sampler_fixtures as sf 20 | from tests.helpers import RVsAssignmentStepsTester, StepMethodTester 21 | 22 | SEED = 20240920 23 | 24 | 25 | class TestSliceUniform(sf.SliceFixture, sf.UniformFixture): 26 | n_samples = 10000 27 | tune = 1000 28 | burn = 0 29 | chains = 4 30 | min_n_eff = 5000 31 | rtol = 0.1 32 | atol = 0.05 33 | step_args = {"rng": np.random.default_rng(SEED)} 34 | 35 | 36 | class TestStepSlicer(StepMethodTester): 37 | @pytest.mark.parametrize( 38 | "step_fn, draws", 39 | [ 40 | (lambda *_: Slice(), 2000), 41 | (lambda *_: Slice(blocked=True), 2000), 42 | ], 43 | ids=str, 44 | ) 45 | def test_step_continuous(self, step_fn, draws): 46 | self.step_continuous(step_fn, draws) 47 | 48 | 49 | class TestRVsAssignmentSlicer(RVsAssignmentStepsTester): 50 | @pytest.mark.parametrize("step, step_kwargs", [(Slice, {})]) 51 | def test_continuous_steps(self, step, step_kwargs): 52 | self.continuous_steps(step, step_kwargs) 53 | -------------------------------------------------------------------------------- /tests/test_testing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from contextlib import ExitStack as does_not_raise 15 | 16 | import pytest 17 | 18 | import pymc as pm 19 | 20 | from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown 21 | from tests.models import simple_normal 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "values, edges, expectation", 26 | [ 27 | ([], None, pytest.raises(IndexError)), 28 | ([], (0, 0), pytest.raises(ValueError)), 29 | ([0], None, pytest.raises(ValueError)), 30 | ([0], (0, 0), does_not_raise()), 31 | ([-1, 1], None, pytest.raises(ValueError)), 32 | ([-1, 0, 1], None, does_not_raise()), 33 | ], 34 | ) 35 | def test_domain(values, edges, expectation): 36 | with expectation: 37 | Domain(values, edges=edges) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "args, kwargs, expected_size", 42 | [ 43 | pytest.param((), {}, (1, 10), id="default"), 44 | pytest.param((100,), {}, (1, 100), id="positional-draws"), 45 | pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"), 46 | pytest.param((100,), {"chains": 6}, (6, 100), id="chains"), 47 | ], 48 | ) 49 | def test_mock_sample(args, kwargs, expected_size) -> None: 50 | expected_chains, expected_draws = expected_size 51 | _, model, _ = simple_normal(bounded_prior=True) 52 | 53 | with model: 54 | idata = mock_sample(*args, **kwargs) 55 | 56 | assert "posterior" in idata 57 | assert "observed_data" in idata 58 | assert "prior" not in idata 59 | assert "posterior_predictive" not in idata 60 | assert "sample_stats" not in idata 61 | 62 | assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws} 63 | 64 | 65 | mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) 66 | 67 | 68 | @pytest.fixture(scope="function") 69 | def dummy_model() -> pm.Model: 70 | with pm.Model() as model: 71 | pm.Flat("flat") 72 | pm.HalfFlat("half_flat") 73 | 74 | return model 75 | 76 | 77 | def test_fixture(mock_pymc_sample, dummy_model) -> None: 78 | with dummy_model: 79 | idata = pm.sample() 80 | 81 | posterior = idata.posterior 82 | assert posterior.sizes == {"chain": 1, "draw": 10} 83 | assert (posterior["half_flat"] >= 0).all() 84 | -------------------------------------------------------------------------------- /tests/tuning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/tuning/test_scaling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | 16 | import numpy as np 17 | 18 | from pymc.tuning import scaling 19 | from tests import models 20 | 21 | 22 | def test_adjust_precision(): 23 | a = np.array([-10, -0.01, 0, 10, 1e300, -np.inf, np.inf]) 24 | with warnings.catch_warnings(): 25 | warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) 26 | a1 = scaling.adjust_precision(a) 27 | assert all((a1 > 0) & (a1 < 1e200)) 28 | 29 | 30 | def test_guess_scaling(): 31 | start, model, _ = models.non_normal(n=5) 32 | a1 = scaling.guess_scaling(start, model=model) 33 | assert all((a1 > 0) & (a1 < 1e200)) 34 | -------------------------------------------------------------------------------- /tests/variational/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/variational/test_callbacks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pytensor 17 | import pytest 18 | 19 | import pymc as pm 20 | 21 | 22 | @pytest.mark.parametrize("diff", ["relative", "absolute"]) 23 | @pytest.mark.parametrize("ord", [1, 2, np.inf]) 24 | def test_callbacks_convergence(diff, ord): 25 | cb = pm.variational.callbacks.CheckParametersConvergence(every=1, diff=diff, ord=ord) 26 | 27 | class _approx: 28 | params = (pytensor.shared(np.asarray([1, 2, 3])),) 29 | 30 | approx = _approx() 31 | 32 | with pytest.raises(StopIteration): 33 | cb(approx, None, 1) 34 | cb(approx, None, 10) 35 | 36 | 37 | def test_tracker_callback(): 38 | import time 39 | 40 | tracker = pm.callbacks.Tracker( 41 | ints=lambda *t: t[-1], 42 | ints2=lambda ap, h, j: j, 43 | time=time.time, 44 | ) 45 | for i in range(10): 46 | tracker(None, None, i) 47 | assert "time" in tracker.hist 48 | assert "ints" in tracker.hist 49 | assert "ints2" in tracker.hist 50 | assert len(tracker["ints"]) == len(tracker["ints2"]) == len(tracker["time"]) == 10 51 | assert tracker["ints"] == tracker["ints2"] == list(range(10)) 52 | tracker = pm.callbacks.Tracker(bad=lambda t: t) # bad signature 53 | with pytest.raises(TypeError): 54 | tracker(None, None, 1) 55 | -------------------------------------------------------------------------------- /tests/variational/test_updates.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 - present The PyMC Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pytensor 17 | import pytest 18 | 19 | from pymc.variational.updates import ( 20 | adadelta, 21 | adagrad, 22 | adagrad_window, 23 | adam, 24 | adamax, 25 | momentum, 26 | nesterov_momentum, 27 | rmsprop, 28 | sgd, 29 | ) 30 | 31 | _a = pytensor.shared(1.0) 32 | _b = _a * 2 33 | 34 | _m = pytensor.shared(np.empty((10,), pytensor.config.floatX)) 35 | _n = _m.sum() 36 | _m2 = pytensor.shared(np.empty((10, 10, 10), pytensor.config.floatX)) 37 | _n2 = _b + _n + _m2.sum() 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "opt", 42 | [sgd, momentum, nesterov_momentum, adagrad, rmsprop, adadelta, adam, adamax, adagrad_window], 43 | ids=[ 44 | "sgd", 45 | "momentum", 46 | "nesterov_momentum", 47 | "adagrad", 48 | "rmsprop", 49 | "adadelta", 50 | "adam", 51 | "adamax", 52 | "adagrad_window", 53 | ], 54 | ) 55 | @pytest.mark.parametrize( 56 | "getter", 57 | [ 58 | lambda t: t, # all params -> ok 59 | lambda t: (None, t[1]), # missing loss -> fail 60 | lambda t: (t[0], None), # missing params -> fail 61 | lambda t: (None, None), 62 | ], # all missing -> partial 63 | ids=["all_params", "missing_loss", "missing_params", "all_missing"], 64 | ) 65 | @pytest.mark.parametrize("kwargs", [{}, {"learning_rate": 1e-2}], ids=["without_args", "with_args"]) 66 | @pytest.mark.parametrize( 67 | "loss_and_params", 68 | [(_b, [_a]), (_n, [_m]), (_n2, [_a, _m, _m2])], 69 | ids=["scalar", "matrix", "mixed"], 70 | ) 71 | def test_updates_fast(opt, loss_and_params, kwargs, getter): 72 | with pytensor.config.change_flags(compute_test_value="ignore"): 73 | loss, param = getter(loss_and_params) 74 | args = {} 75 | args.update(**kwargs) 76 | args.update({"loss_or_grads": loss, "params": param}) 77 | if loss is None and param is None: 78 | updates = opt(**args) 79 | # Here we should get new callable 80 | assert callable(updates) 81 | # And be able to get updates 82 | updates = opt(_b, [_a]) 83 | assert isinstance(updates, dict) 84 | # case when both are None is above 85 | elif loss is None or param is None: 86 | # Here something goes wrong and user provides not full set of [params + loss_or_grads] 87 | # We raise Value error 88 | with pytest.raises(ValueError): 89 | opt(**args) 90 | else: 91 | # Usual call to optimizer, old behaviour 92 | updates = opt(**args) 93 | assert isinstance(updates, dict) 94 | --------------------------------------------------------------------------------