├── .gitattributes
├── .github
├── dependabot.yml
└── workflows
│ ├── hypothesis.yml
│ ├── post-release.yml
│ ├── publish.yml
│ ├── rtd-preview.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── .readthedocs.yaml
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── References.md
├── docs
├── source
│ ├── _static
│ │ ├── ArviZ.png
│ │ ├── ArviZ_white.png
│ │ ├── bokeh-logo-dark.svg
│ │ ├── bokeh-logo-light.svg
│ │ ├── custom.css
│ │ ├── favicon.ico
│ │ ├── matplotlib-logo-dark.svg
│ │ ├── matplotlib-logo-light.svg
│ │ ├── none-logo-light.png
│ │ ├── plotly-logo-dark.png
│ │ └── plotly-logo-light.png
│ ├── _templates
│ │ └── name.html
│ ├── api
│ │ ├── backend
│ │ │ ├── bokeh.part.rst
│ │ │ ├── index.rst
│ │ │ ├── interface.template.rst
│ │ │ ├── matplotlib.part.rst
│ │ │ ├── none.part.rst
│ │ │ └── plotly.part.rst
│ │ ├── helpers.rst
│ │ ├── index.md
│ │ ├── managers.rst
│ │ ├── plots.rst
│ │ └── visuals.rst
│ ├── conf.py
│ ├── contributing
│ │ ├── docs.md
│ │ ├── new_plot.md
│ │ └── testing.md
│ ├── gallery
│ │ ├── distribution
│ │ │ ├── 00_plot_dist_ecdf.py
│ │ │ ├── 01_plot_dist_hist.py
│ │ │ ├── 02_plot_dist_kde.py
│ │ │ ├── 04_plot_forest.py
│ │ │ ├── 05_plot_forest_shade.py
│ │ │ ├── 06_plot_prior_posterior.py
│ │ │ └── 07_plot_pairs_focus_distribution.py
│ │ ├── inference_diagnostics
│ │ │ ├── 00_plot_rank.py
│ │ │ ├── 01_plot_trace.py
│ │ │ ├── 02_plot_ess_evolution.py
│ │ │ ├── 03_plot_ess_local.py
│ │ │ ├── 04_plot_ess_quantile.py
│ │ │ ├── 05_plot_ess_models.py
│ │ │ ├── 05_plot_mcse.py
│ │ │ ├── 06_plot_convergence_dist.py
│ │ │ ├── 07_plot_autocorr.py
│ │ │ ├── 08_plot_energy.py
│ │ │ └── 09_plot_pairs_focus.py
│ │ ├── mixed
│ │ │ ├── 00_plot_rank_dist.py
│ │ │ ├── 01_plot_trace_dist.py
│ │ │ ├── 02_plot_forest_ess.py
│ │ │ └── 03_combine_plots.py
│ │ ├── model_comparison
│ │ │ ├── 00_plot_compare.py
│ │ │ └── 99_plot_bf.py
│ │ ├── posterior_comparison
│ │ │ ├── 00_plot_dist_models.py
│ │ │ └── 01_plot_forest_models.py
│ │ ├── predictive_checks
│ │ │ ├── 00_plot_ppc_dist.py
│ │ │ ├── 01_plot_ppc_rootogram.py
│ │ │ ├── 03_plot_pava_calibration.py
│ │ │ ├── 04_plot_ppc_pit.py
│ │ │ ├── 05_plot_ppc_coverage.py
│ │ │ ├── 06_plot_loo_pit.py
│ │ │ ├── 07_plot_ppc_tstat.py
│ │ │ └── 99_plot_forest_pp_obs.py
│ │ ├── prior_and_likelihood_sensitivity_checks
│ │ │ ├── 00_plot_psense.py
│ │ │ └── 01_plot_psense_quantities.py
│ │ ├── sbc
│ │ │ ├── 00_plot_ecdf_pit.py
│ │ │ └── 01_plot_ecdf_coverage.py
│ │ └── utils
│ │ │ ├── 00_add_reference_lines.py
│ │ │ └── 01_add_reference_bands.py
│ ├── glossary.md
│ ├── index.md
│ └── tutorials
│ │ ├── compose_own_plot.ipynb
│ │ ├── intro_to_plotcollection.ipynb
│ │ ├── overview.ipynb
│ │ └── plots_intro.ipynb
└── sphinxext
│ └── gallery_generator.py
├── pyproject.toml
├── src
└── arviz_plots
│ ├── __init__.py
│ ├── _version.py
│ ├── backend
│ ├── __init__.py
│ ├── bokeh
│ │ ├── __init__.py
│ │ └── legend.py
│ ├── matplotlib
│ │ ├── __init__.py
│ │ └── legend.py
│ ├── none
│ │ ├── __init__.py
│ │ └── legend.py
│ └── plotly
│ │ ├── __init__.py
│ │ ├── legend.py
│ │ └── templates.py
│ ├── plot_collection.py
│ ├── plot_matrix.py
│ ├── plots
│ ├── __init__.py
│ ├── autocorr_plot.py
│ ├── bf_plot.py
│ ├── combine.py
│ ├── compare_plot.py
│ ├── convergence_dist_plot.py
│ ├── dist_plot.py
│ ├── ecdf_plot.py
│ ├── energy_plot.py
│ ├── ess_plot.py
│ ├── evolution_plot.py
│ ├── forest_plot.py
│ ├── loo_pit_plot.py
│ ├── mcse_plot.py
│ ├── pairs_focus_plot.py
│ ├── pava_calibration_plot.py
│ ├── ppc_dist_plot.py
│ ├── ppc_pit_plot.py
│ ├── ppc_rootogram_plot.py
│ ├── ppc_tstat.py
│ ├── prior_posterior_plot.py
│ ├── psense_dist_plot.py
│ ├── psense_quantities_plot.py
│ ├── rank_dist_plot.py
│ ├── rank_plot.py
│ ├── ridge_plot.py
│ ├── trace_dist_plot.py
│ ├── trace_plot.py
│ └── utils.py
│ ├── py.typed
│ ├── style.py
│ ├── styles
│ ├── arviz-cetrino.mplstyle
│ ├── arviz-cetrino.yml
│ ├── arviz-variat.mplstyle
│ ├── arviz-variat.yml
│ ├── arviz-vibrant.mplstyle
│ └── arviz-vibrant.yml
│ └── visuals
│ └── __init__.py
├── tests
├── __init__.py
├── conftest.py
├── test_fixtures.py
├── test_hypothesis_plots.py
├── test_plot_collection.py
├── test_plot_matrix.py
└── test_plots.py
└── tox.ini
/.gitattributes:
--------------------------------------------------------------------------------
1 | # SCM syntax highlighting & preventing 3-way merges
2 | pixi.lock merge=binary linguist-language=YAML linguist-generated=true
3 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
7 |
--------------------------------------------------------------------------------
/.github/workflows/hypothesis.yml:
--------------------------------------------------------------------------------
1 | name: Run extended tests with hypothesis
2 | on:
3 | schedule:
4 | - cron: '17 5 * * 1'
5 | workflow_dispatch:
6 |
7 | permissions:
8 | issues: write
9 |
10 | jobs:
11 | hypothesis_testing:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.11'
19 | - name: Install arviz-plots
20 | run: |
21 | python -m pip install ".[test]"
22 | - name: Execute tests
23 | run: |
24 | pytest --hypothesis-profile chron -k hypothesis
25 | echo "DATE=$(date +'%Y-%m-%d %H:%M %z')" >> ${GITHUB_ENV}
26 | - name: Comment on issue if failed
27 | if: failure()
28 | uses: peter-evans/create-or-update-comment@v4
29 | with:
30 | issue-number: 43
31 | body: |
32 | The extended tests with hypothesis failed.
33 |
34 | * Branch: ${{ github.ref_name }}
35 | * Date: ${{ env.DATE }}
36 |
37 | See [workflow logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details on which tests failed and why.
38 |
--------------------------------------------------------------------------------
/.github/workflows/post-release.yml:
--------------------------------------------------------------------------------
1 | name: Post-release
2 | on:
3 | release:
4 | types: [published, released]
5 | workflow_dispatch:
6 |
7 | jobs:
8 | changelog:
9 | name: Update changelog
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | with:
14 | ref: main
15 | - uses: rhysd/changelog-from-release/action@v3
16 | with:
17 | file: CHANGELOG.md
18 | github_token: ${{ secrets.GITHUB_TOKEN }}
19 | commit_summary_template: 'update changelog for %s changes'
20 | pull_request: true
21 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish library
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | tags:
8 | # Don't try to be smart about PEP 440 compliance,
9 | # see https://www.python.org/dev/peps/pep-0440/#appendix-b-parsing-version-strings-with-regular-expressions
10 | - v*
11 |
12 | jobs:
13 | build-package:
14 | runs-on: ubuntu-latest
15 | permissions:
16 | # write attestations and id-token are necessary for attest-build-provenance-github
17 | attestations: write
18 | id-token: write
19 | steps:
20 | - uses: actions/checkout@v4
21 | with:
22 | fetch-depth: 0
23 | persist-credentials: false
24 | - uses: hynek/build-and-inspect-python-package@v2
25 | with:
26 | # Prove that the packages were built in the context of this workflow.
27 | attest-build-provenance-github: true
28 | publish:
29 | runs-on: ubuntu-latest
30 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
31 | # Use the `release` GitHub environment to protect the Trusted Publishing (OIDC)
32 | # workflow by requiring signoff from a maintainer.
33 | environment:
34 | name: publish
35 | url: https://pypi.org/p/arviz-plots
36 | needs: build-package
37 | permissions:
38 | # write id-token is necessary for trusted publishing (OIDC)
39 | id-token: write
40 | steps:
41 | - name: Download Distribution Artifacts
42 | uses: actions/download-artifact@v4
43 | with:
44 | # The build-and-inspect-python-package action invokes upload-artifact.
45 | # These are the correct arguments from that action.
46 | name: Packages
47 | path: dist
48 | - name: Publish to PyPI
49 | uses: pypa/gh-action-pypi-publish@release/v1
50 |
--------------------------------------------------------------------------------
/.github/workflows/rtd-preview.yml:
--------------------------------------------------------------------------------
1 | name: Read the Docs Pull Request Preview
2 | on:
3 | pull_request_target:
4 | types:
5 | - opened
6 |
7 | permissions:
8 | pull-requests: write
9 |
10 | jobs:
11 | documentation-links:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: readthedocs/actions/preview@v1
15 | with:
16 | project-slug: "arviz-plots"
17 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Run tests
2 | on:
3 | pull_request:
4 | push:
5 | branches: [main]
6 | paths-ignore:
7 | - "docs/"
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.11", "3.12"]
15 | fail-fast: false
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install tox tox-gh-actions
26 | - name: Test with tox
27 | run: tox
28 | - name: Upload coverage to Codecov
29 | uses: codecov/codecov-action@v5
30 | with:
31 | name: Python ${{ matrix.python-version }}
32 | fail_ci_if_error: false
33 | token: ${{ secrets.CODECOV_TOKEN }}
34 | verbose: true
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks
4 |
5 | ### JupyterNotebooks ###
6 | # gitignore template for Jupyter Notebooks
7 | # website: http://jupyter.org/
8 |
9 | .ipynb_checkpoints
10 | */.ipynb_checkpoints/*
11 | .virtual_documents
12 | */.virtual_documents/*
13 |
14 | # IPython
15 | profile_default/
16 | ipython_config.py
17 |
18 | # Remove previous ipynb_checkpoints
19 | # git rm -r .ipynb_checkpoints/
20 |
21 | ### Python ###
22 | # Byte-compiled / optimized / DLL files
23 | __pycache__/
24 | *.py[cod]
25 | *$py.class
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | wheels/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 | cover/
74 |
75 | # default test image folder
76 | test_images/
77 |
78 | # Translations
79 | *.mo
80 | *.pot
81 |
82 | # Django stuff:
83 | *.log
84 | local_settings.py
85 | db.sqlite3
86 | db.sqlite3-journal
87 |
88 | # Flask stuff:
89 | instance/
90 | .webassets-cache
91 |
92 | # Scrapy stuff:
93 | .scrapy
94 |
95 | # Sphinx documentation
96 | docs/_build/
97 | docs/build
98 | docs/source/api/**/generated
99 | docs/source/gallery/_images
100 | docs/source/gallery/_scripts
101 | docs/source/gallery/*.md
102 | docs/source/gallery/backreferences.json
103 | docs/jupyter_execute
104 | docs/source/api/backend/*.rst
105 | !docs/source/api/backend/*.part.rst
106 | !docs/source/api/backend/index.rst
107 |
108 | # PyBuilder
109 | .pybuilder/
110 | target/
111 |
112 | # Jupyter Notebook
113 |
114 | # IPython
115 |
116 | # pyenv
117 | # For a library or package, you might want to ignore these files since the code is
118 | # intended to run in multiple environments; otherwise, check them in:
119 | # .python-version
120 |
121 | # pipenv
122 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
123 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
124 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
125 | # install all needed dependencies.
126 | #Pipfile.lock
127 |
128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
129 | __pypackages__/
130 |
131 | # Celery stuff
132 | celerybeat-schedule
133 | celerybeat.pid
134 |
135 | # SageMath parsed files
136 | *.sage.py
137 |
138 | # Environments
139 | .env
140 | .venv
141 | env/
142 | venv/
143 | ENV/
144 | env.bak/
145 | venv.bak/
146 |
147 | # Spyder project settings
148 | .spyderproject
149 | .spyproject
150 |
151 | # Rope project settings
152 | .ropeproject
153 |
154 | # mkdocs documentation
155 | /site
156 |
157 | # mypy
158 | .mypy_cache/
159 | .dmypy.json
160 | dmypy.json
161 |
162 | # Pyre type checker
163 | .pyre/
164 |
165 | # pytype static type analyzer
166 | .pytype/
167 |
168 | # Cython debug symbols
169 | cython_debug/
170 |
171 | # macos
172 | .DS_Store
173 |
174 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
175 |
176 | # pixi environments
177 | .pixi
178 | *.egg-info
179 | ## TO REMOVE ONCE WE HAVE THE MONOREPO
180 | pixi.lock
181 | pixi.toml
182 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.3.0
4 | hooks:
5 | - id: check-added-large-files
6 | args: ['--maxkb=1500']
7 | - id: check-merge-conflict
8 |
9 | - repo: https://github.com/PyCQA/isort
10 | rev: 5.12.0
11 | hooks:
12 | - id: isort
13 | exclude: ^src/arviz_base/example_data/
14 |
15 | - repo: https://github.com/psf/black
16 | rev: 23.3.0
17 | hooks:
18 | - id: black
19 | exclude: ^docs/source/gallery/
20 |
21 | - repo: https://github.com/pycqa/pydocstyle
22 | rev: 6.3.0
23 | hooks:
24 | - id: pydocstyle
25 | additional_dependencies: [tomli]
26 | files: ^src/arviz_plots/.+\.py$
27 |
28 | - repo: local
29 | hooks:
30 | - id: pylint
31 | name: pylint
32 | entry: pylint
33 | language: system
34 | types: [python]
35 | args:
36 | [
37 | "-rn", # Only display messages
38 | "-sn", # Don't display the score
39 | ]
40 | exclude: ^docs/source/gallery/
41 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 | version: 2
4 |
5 | build:
6 | os: ubuntu-24.04
7 | tools:
8 | python: "3.12"
9 |
10 | sphinx:
11 | fail_on_warning: True
12 | configuration: docs/source/conf.py
13 |
14 |
15 | python:
16 | install:
17 | - method: pip
18 | path: .
19 | extra_requirements:
20 | - doc
21 | - matplotlib
22 | - bokeh
23 | - plotly
24 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 |
2 | # [v0.5.0](https://github.com/arviz-devs/arviz-plots/releases/tag/v0.5.0) - 2025-03-21
3 |
4 | ## What's Changed
5 | * Randomized pit by [@aloctavodia](https://github.com/aloctavodia) in [#167](https://github.com/arviz-devs/arviz-plots/pull/167)
6 | * Add missing visuals to docs by [@aloctavodia](https://github.com/aloctavodia) in [#168](https://github.com/arviz-devs/arviz-plots/pull/168)
7 | * Make ecdf_line a step line by [@aloctavodia](https://github.com/aloctavodia) in [#169](https://github.com/arviz-devs/arviz-plots/pull/169)
8 | * Properly wrap columns and reduce duplicated code by [@aloctavodia](https://github.com/aloctavodia) in [#170](https://github.com/arviz-devs/arviz-plots/pull/170)
9 | * Added plot_bf() function for bayes_factor in arviz-plots by [@PiyushPanwarFST](https://github.com/PiyushPanwarFST) in [#158](https://github.com/arviz-devs/arviz-plots/pull/158)
10 | * Update utils.py by [@aloctavodia](https://github.com/aloctavodia) in [#171](https://github.com/arviz-devs/arviz-plots/pull/171)
11 | * Add plot_rank by [@aloctavodia](https://github.com/aloctavodia) in [#172](https://github.com/arviz-devs/arviz-plots/pull/172)
12 | * Add plot ecdf pit plot by [@aloctavodia](https://github.com/aloctavodia) in [#173](https://github.com/arviz-devs/arviz-plots/pull/173)
13 | * fix regression bug plot_psense_quantities by [@aloctavodia](https://github.com/aloctavodia) in [#175](https://github.com/arviz-devs/arviz-plots/pull/175)
14 | * Improve Data Generation & Plot Tests: by [@PiyushPanwarFST](https://github.com/PiyushPanwarFST) in [#177](https://github.com/arviz-devs/arviz-plots/pull/177)
15 | * Adds square root scale for bokeh and plotly by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#178](https://github.com/arviz-devs/arviz-plots/pull/178)
16 | * Update plot_compare by [@aloctavodia](https://github.com/aloctavodia) in [#181](https://github.com/arviz-devs/arviz-plots/pull/181)
17 | * plot_converge_dist: Add grouped argument by [@aloctavodia](https://github.com/aloctavodia) in [#182](https://github.com/arviz-devs/arviz-plots/pull/182)
18 | * adds sqrt scale for yaxis in plotly by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#183](https://github.com/arviz-devs/arviz-plots/pull/183)
19 | * Add coverage argument to pit plots by [@aloctavodia](https://github.com/aloctavodia) in [#185](https://github.com/arviz-devs/arviz-plots/pull/185)
20 |
21 |
22 | ## New Contributors
23 | * [@github-actions](https://github.com/github-actions) made their first contribution in [#165](https://github.com/arviz-devs/arviz-plots/pull/165)
24 | * [@PiyushPanwarFST](https://github.com/PiyushPanwarFST) made their first contribution in [#158](https://github.com/arviz-devs/arviz-plots/pull/158)
25 |
26 | **Full Changelog**: https://github.com/arviz-devs/arviz-plots/compare/v0.4.0...v0.5.0
27 |
28 | [Changes][v0.5.0]
29 |
30 |
31 |
32 | # [v0.4.0](https://github.com/arviz-devs/arviz-plots/releases/tag/v0.4.0) - 2025-03-05
33 |
34 | ## What's Changed
35 | * move out new_ds to arviz-stats by [@aloctavodia](https://github.com/aloctavodia) in [#102](https://github.com/arviz-devs/arviz-plots/pull/102)
36 | * update version, dependencies and CI by [@OriolAbril](https://github.com/OriolAbril) in [#110](https://github.com/arviz-devs/arviz-plots/pull/110)
37 | * Use DataTree class from xarray by [@OriolAbril](https://github.com/OriolAbril) in [#111](https://github.com/arviz-devs/arviz-plots/pull/111)
38 | * Update pyproject.toml by [@OriolAbril](https://github.com/OriolAbril) in [#113](https://github.com/arviz-devs/arviz-plots/pull/113)
39 | * Add energy plot by [@aloctavodia](https://github.com/aloctavodia) in [#108](https://github.com/arviz-devs/arviz-plots/pull/108)
40 | * Add plot for distribution of convergence diagnostics by [@aloctavodia](https://github.com/aloctavodia) in [#105](https://github.com/arviz-devs/arviz-plots/pull/105)
41 | * Add separated prior and likelihood groups by [@aloctavodia](https://github.com/aloctavodia) in [#117](https://github.com/arviz-devs/arviz-plots/pull/117)
42 | * Add psense_quantities plot by [@aloctavodia](https://github.com/aloctavodia) in [#119](https://github.com/arviz-devs/arviz-plots/pull/119)
43 | * Rename arviz-clean to arviz-variat by [@aloctavodia](https://github.com/aloctavodia) in [#120](https://github.com/arviz-devs/arviz-plots/pull/120)
44 | * add cetrino and vibrant styles for plotly by [@aloctavodia](https://github.com/aloctavodia) in [#121](https://github.com/arviz-devs/arviz-plots/pull/121)
45 | * psense: fix facetting and add xlabel by [@aloctavodia](https://github.com/aloctavodia) in [#123](https://github.com/arviz-devs/arviz-plots/pull/123)
46 | * Add summary dictionary arguments by [@aloctavodia](https://github.com/aloctavodia) in [#125](https://github.com/arviz-devs/arviz-plots/pull/125)
47 | * plotly: change format of title update in backend by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#124](https://github.com/arviz-devs/arviz-plots/pull/124)
48 | * Update glossary.md by [@aloctavodia](https://github.com/aloctavodia) in [#126](https://github.com/arviz-devs/arviz-plots/pull/126)
49 | * Add PAV-adjusted calibration plot by [@aloctavodia](https://github.com/aloctavodia) in [#127](https://github.com/arviz-devs/arviz-plots/pull/127)
50 | * Upper bound plotly by [@aloctavodia](https://github.com/aloctavodia) in [#128](https://github.com/arviz-devs/arviz-plots/pull/128)
51 | * Use isotonic function that works with datatrees by [@aloctavodia](https://github.com/aloctavodia) in [#131](https://github.com/arviz-devs/arviz-plots/pull/131)
52 | * plot_pava_calibrarion: Add reference, fix xlabel by [@aloctavodia](https://github.com/aloctavodia) in [#132](https://github.com/arviz-devs/arviz-plots/pull/132)
53 | * Fix bug when setting some plot_kwargs to false by [@aloctavodia](https://github.com/aloctavodia) in [#134](https://github.com/arviz-devs/arviz-plots/pull/134)
54 | * Add citations by [@aloctavodia](https://github.com/aloctavodia) in [#135](https://github.com/arviz-devs/arviz-plots/pull/135)
55 | * use <6 version of plotly for documentation and use latest for other purposes by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#136](https://github.com/arviz-devs/arviz-plots/pull/136)
56 | * fix see also pava gallery by [@aloctavodia](https://github.com/aloctavodia) in [#137](https://github.com/arviz-devs/arviz-plots/pull/137)
57 | * Add plot_ppc_dist by [@aloctavodia](https://github.com/aloctavodia) in [#138](https://github.com/arviz-devs/arviz-plots/pull/138)
58 | * Add warning message for discrete data by [@aloctavodia](https://github.com/aloctavodia) in [#139](https://github.com/arviz-devs/arviz-plots/pull/139)
59 | * rename plot_pava and minor fixes by [@aloctavodia](https://github.com/aloctavodia) in [#140](https://github.com/arviz-devs/arviz-plots/pull/140)
60 | * Plotly: Fix excesive margins by [@aloctavodia](https://github.com/aloctavodia) in [#141](https://github.com/arviz-devs/arviz-plots/pull/141)
61 | * add arviz-style for bokeh by [@aloctavodia](https://github.com/aloctavodia) in [#122](https://github.com/arviz-devs/arviz-plots/pull/122)
62 | * Add Plot ppc rootogram by [@aloctavodia](https://github.com/aloctavodia) in [#142](https://github.com/arviz-devs/arviz-plots/pull/142)
63 | * plot_ppc_rootogram: fix examples by [@aloctavodia](https://github.com/aloctavodia) in [#144](https://github.com/arviz-devs/arviz-plots/pull/144)
64 | * Reorganize categories in the gallery by [@aloctavodia](https://github.com/aloctavodia) in [#145](https://github.com/arviz-devs/arviz-plots/pull/145)
65 | * remove plots from titles by [@aloctavodia](https://github.com/aloctavodia) in [#146](https://github.com/arviz-devs/arviz-plots/pull/146)
66 | * Consistence use of data_pairs, remove default markers from pava by [@aloctavodia](https://github.com/aloctavodia) in [#152](https://github.com/arviz-devs/arviz-plots/pull/152)
67 | * added functionality of step histogram for all three backends by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#147](https://github.com/arviz-devs/arviz-plots/pull/147)
68 | * Use continuous outcome for plot_ppc_dist example by [@aloctavodia](https://github.com/aloctavodia) in [#154](https://github.com/arviz-devs/arviz-plots/pull/154)
69 | * add grid visual by [@aloctavodia](https://github.com/aloctavodia) in [#155](https://github.com/arviz-devs/arviz-plots/pull/155)
70 | * Add plot_ppc_pit by [@aloctavodia](https://github.com/aloctavodia) in [#159](https://github.com/arviz-devs/arviz-plots/pull/159)
71 | * plot_ppc_pava: Change default xlabel by [@aloctavodia](https://github.com/aloctavodia) in [#161](https://github.com/arviz-devs/arviz-plots/pull/161)
72 |
73 | ## New Contributors
74 | * [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) made their first contribution in [#124](https://github.com/arviz-devs/arviz-plots/pull/124)
75 |
76 | **Full Changelog**: https://github.com/arviz-devs/arviz-plots/compare/v0.3.0...v0.4.0
77 |
78 | [Changes][v0.4.0]
79 |
80 |
81 | [v0.5.0]: https://github.com/arviz-devs/arviz-plots/compare/v0.4.0...v0.5.0
82 | [v0.4.0]: https://github.com/arviz-devs/arviz-plots/tree/v0.4.0
83 |
84 |
85 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # ArviZ Community Code of Conduct
2 |
3 | ArviZ adopts the NumFOCUS Code of Conduct directly. In other words, we
4 | expect our community to treat others with kindness and understanding.
5 |
6 |
7 | # THE SHORT VERSION
8 | Be kind to others. Do not insult or put down others.
9 | Behave professionally. Remember that harassment and sexist, racist,
10 | or exclusionary jokes are not appropriate.
11 |
12 | All communication should be appropriate for a professional audience
13 | including people of many different backgrounds. Sexual language and
14 | imagery are not appropriate.
15 |
16 | ArviZ is dedicated to providing a harassment-free community for everyone,
17 | regardless of gender, sexual orientation, gender identity, and
18 | expression, disability, physical appearance, body size, race,
19 | or religion. We do not tolerate harassment of community members
20 | in any form.
21 |
22 | Thank you for helping make this a welcoming, friendly community for all.
23 |
24 |
25 | # How to Submit a Report
26 | If you feel that there has been a Code of Conduct violation an anonymous
27 | reporting form is available.
28 | **If you feel your safety is in jeopardy or the situation is an
29 | emergency, we urge you to contact local law enforcement before making
30 | a report. (In the U.S., dial 911.)**
31 |
32 | We are committed to promptly addressing any reported issues.
33 | If you have experienced or witnessed behavior that violates this
34 | Code of Conduct, please complete the form below to
35 | make a report.
36 |
37 | **REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT
38 |
39 | Reports are sent to the NumFOCUS Code of Conduct Enforcement Team
40 | (see below).
41 |
42 | You can view the Privacy Policy and Terms of Service for TypeForm here.
43 | The NumFOCUS Privacy Policy is here:
44 | https://www.numfocus.org/privacy-policy
45 |
46 |
47 | # Full Code of Conduct
48 | The full text of the NumFOCUS/ArviZ Code of Conduct can be found on
49 | NumFOCUS's website
50 | https://numfocus.org/code-of-conduct
51 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing guidelines
2 |
3 | ## Before contributing
4 |
5 | Welcome to arviz-plots! Before contributing to the project,
6 | make sure that you **read our code of conduct** (`CODE_OF_CONDUCT.md`).
7 |
8 | ## Contributing code
9 |
10 | 1. Set up a Python development environment
11 | (advice: use [venv](https://docs.python.org/3/library/venv.html),
12 | [virtualenv](https://virtualenv.pypa.io/), or [miniconda](https://docs.conda.io/en/latest/miniconda.html))
13 | 2. Install tox: `python -m pip install tox`
14 | 3. Clone the repository
15 | 4. Start a new branch off main: `git switch -c new-branch main`
16 | 5. Make your code changes
17 | 6. Check that your code follows the style guidelines of the project: `tox -e check`
18 | 7. (optional) Build the documentation: `tox -e docs`
19 | 8. (optional) Run the tests: `tox -e py310`
20 | (change the version number according to the Python you are using)
21 | 9. Commit, push, and open a pull request!
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # arviz-plots
2 |
3 | [](https://github.com/arviz-devs/arviz-plots/actions/workflows/test.yml)
4 | [](https://codecov.io/gh/arviz-devs/arviz-plots)
5 | [](https://numfocus.org)
6 |
7 | ArviZ plotting elements and static battery included plots
8 |
9 | We are currently working on splitting ArviZ into independent modules.
10 | See https://github.com/arviz-devs/arviz/issues/2088 for more details.
11 |
--------------------------------------------------------------------------------
/References.md:
--------------------------------------------------------------------------------
1 | # References
2 |
3 | This is a list of references for the methods implemented in ArviZ (including base/stats/plots).
4 | The references are organized by diagnosis. The format is the one used for docstrings.
5 | We could have references in BibTeX format, and use https://sphinxcontrib-bibtex.readthedocs.io/en/latest/
6 | to include them in the docstring of each function and in a single webpage.
7 | See https://github.com/arviz-devs/arviz-stats/issues/56
8 |
9 |
10 | ## Psense
11 |
12 | References
13 | ----------
14 | .. [1] Kallioinen et al. *Detecting and diagnosing prior and likelihood sensitivity with
15 | power-scaling*, Stat Comput 34(57) (2024), https://doi.org/10.1007/s11222-023-10366-5
16 |
17 |
18 | ## LOO
19 |
20 | References
21 | ----------
22 | .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
23 | and WAIC*. Statistics and Computing. 27(5) (2017). https://doi.org/10.1007/s11222-016-9696-4.
24 | arXiv preprint https://arxiv.org/abs/1507.04544.
25 |
26 | .. [2] Yao et al. *Using stacking to average Bayesian predictive distributions*
27 | Bayesian Analysis, 13, 3 (2018). https://doi.org/10.1214/17-BA1091
28 | arXiv preprint https://arxiv.org/abs/1704.02030.
29 |
30 | .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*.
31 | Journal of Machine Learning Research, 25(72) (2024). https://jmlr.org/papers/v25/19-556.html.
32 | arXiv preprint https://arxiv.org/abs/1507.02646
33 |
34 | .. [4] Magnusson et al. *Bayesian Leave-One-Out Cross-Validation for Large Data.*
35 | Proceedings of the 36th International Conference on Machine Learning, 97 (2019)
36 | https://proceedings.mlr.press/v97/magnusson19a.html
37 | arXiv preprint https://arxiv.org/abs/1904.10679
38 |
39 |
40 | ## R-hat, ESS, MCSE
41 |
42 | References
43 | ----------
44 | .. [1] Vehtari et al. *Rank-normalization, folding, and localization: An improved Rhat for
45 | assessing convergence of MCMC*. Bayesian Analysis. 16(2) (2021)
46 | https://doi.org/10.1214/20-BA1221. arXiv preprint https://arxiv.org/abs/1903.08008
47 |
48 |
49 | ## PAV-adjusted calibration plot
50 |
51 | References
52 | ----------
53 | .. [1] Dimitriadis et al. *Stable reliability diagrams for probabilistic classifiers*.
54 | PNAS, 118(8) (2021). https://doi.org/10.1073/pnas.2016191118
55 |
56 | .. [2] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*.
57 | (2025) arXiv preprint https://arxiv.org/abs/2503.01509
58 |
59 |
60 | ## Energy plot and divergences
61 |
62 | References
63 | ----------
64 | .. [1] Betancourt. *Diagnosing Suboptimal Cotangent Disintegrations in
65 | Hamiltonian Monte Carlo*. (2016) https://arxiv.org/abs/1604.00695
66 |
67 | ## Rootograms
68 |
69 | References
70 | ----------
71 | .. [1] Kleiber et al. *Visualizing Count Data Regressions Using Rootograms*.
72 | The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590
73 |
74 | .. [2] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*.
75 | (2025) arXiv preprint https://arxiv.org/abs/2503.01509
76 |
77 | ## ECDF-pit
78 |
79 | References
80 | ----------
81 | .. [1] Säilynoja et al. *Graphical test for discrete uniformity and
82 | its applications in goodness-of-fit evaluation and multiple sample comparison*.
83 | Statistics and Computing 32(32). (2022) https://doi.org/10.1007/s11222-022-10090-6
84 |
85 | ## Bayesian R2
86 |
87 | References
88 | ----------
89 | .. [1] Gelman et al. *R-squared for Bayesian regression models*.
90 | The American Statistician, 73(3). (2019) https://doi:10.1080/00031305.2018.1549100
91 | preprint http://www.stat.columbia.edu/~gelman/research/unpublished/bayes_R2_v3.pdf
92 |
93 |
--------------------------------------------------------------------------------
/docs/source/_static/ArviZ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/ArviZ.png
--------------------------------------------------------------------------------
/docs/source/_static/ArviZ_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/ArviZ_white.png
--------------------------------------------------------------------------------
/docs/source/_static/bokeh-logo-dark.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_static/bokeh-logo-light.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_static/custom.css:
--------------------------------------------------------------------------------
1 | /* -------------------- THEME OVERRIDES -------------------- */
2 | html[data-theme="light"] {
3 | --pst-color-primary: rgb(11 117 145);
4 | --pst-color-secondary: rgb(238 144 64);
5 | --pst-color-plot-background: rgb(255, 255, 255);
6 | }
7 |
8 | html[data-theme="dark"] {
9 | --pst-color-primary: rgb(0 192 191);
10 | --pst-color-secondary: rgb(238 144 64);
11 | --pst-color-plot-background: rgb(218, 219, 220);
12 | }
13 |
14 | /* Override for example gallery - remove border around card */
15 | .bd-content div.sd-card.example-gallery {
16 | border: none;
17 | }
18 |
19 | /* -------------------- EXAMPLE GALLERY + (homepage) -------------------- */
20 |
21 | /* Homepage - grid layout */
22 | .home-flex-grid {
23 | display: flex;
24 | flex-flow: row wrap;
25 | justify-content: center;
26 | gap: 10px;
27 | padding: 20px 0px 40px;
28 | }
29 |
30 | /* Homepage + Example Gallery Body - Set dimensions */
31 | .home-img-plot,
32 | .bd-content div.sd-card.example-gallery .sd-card-body,
33 | .home-img-plot-overlay,
34 | .example-img-plot-overlay {
35 | display: flex;
36 | justify-content: center;
37 | align-items: center;
38 | overflow: hidden;
39 | padding: 10px;
40 | }
41 | .home-img-plot,
42 | .home-img-plot-overlay {
43 | width: 235px;
44 | height: 130px;
45 | }
46 | .bd-content div.sd-card.example-gallery .sd-card-body,
47 | .example-img-plot-overlay {
48 | width: 100%;
49 | height: 150px;
50 | }
51 | .home-img-plot img,
52 | .bd-content div.sd-card.example-gallery .sd-card-body img {
53 | /* Images keep aspect ratio and fit in container */
54 | /* To make images stretch/fill container, change to min-width */
55 | max-width: 100%;
56 | max-height: 100%;
57 | }
58 |
59 | /* Homepage + Example Gallery Body - Set color and hover */
60 | .home-img-plot.img-thumbnail,
61 | .bd-content div.sd-card.example-gallery .sd-card-body {
62 | background-color: var(--pst-color-plot-background); /* Same as img-thumbnail from pydata css, adjusted for dark mode */
63 | }
64 | .home-img-plot-overlay,
65 | .example-img-plot-overlay,
66 | .bd-content div.sd-card.example-gallery .sd-card-body {
67 | border: 1px solid #dee2e6; /* Same as img-thumbnail from pydata css */
68 | border-radius: 0.25rem; /* Same as img-thumbnail from pydata css */
69 | }
70 | .home-img-plot-overlay,
71 | .example-img-plot-overlay,
72 | .example-img-plot-overlay p.sd-card-text {
73 | background: var(--pst-color-primary);
74 | position: absolute;
75 | color: var(--pst-color-background);
76 | opacity: 0;
77 | transition: .2s ease;
78 | text-align: center;
79 | padding: 10px;
80 | z-index: 998; /* Make sure overlay is above image...this is here to handle dark mode */
81 | }
82 | .home-img-plot-overlay:hover,
83 | .bd-content div.sd-card.example-gallery:hover .example-img-plot-overlay,
84 | .example-img-plot-overlay p.sd-card-text {
85 | opacity: 90%;
86 | }
87 |
88 | /* Example Gallery Body - Set syntax highlighting for code on hover */
89 | .example-img-plot-overlay .sd-card-text code.code {
90 | background-color: var(--pst-color-background);
91 | }
92 | .example-img-plot-overlay .sd-card-text .code span.pre {
93 | color: var(--pst-color-primary);
94 | font-weight: 700;
95 | }
96 |
97 | /* Example Gallery Footer - Plot titles goes here */
98 | .example-gallery .sd-card-footer {
99 | height: 40px;
100 | padding: 5px;
101 | border-top: none !important;
102 | }
103 | .example-gallery .sd-card-footer p.sd-card-text {
104 | color: var(--pst-color-text-muted);
105 | font-size: 1rem; /* This is font size for plot titles (below the figure) */
106 | font-weight: 700;
107 | }
108 | .sd-card.example-gallery:hover .sd-card-footer p.sd-card-text {
109 | color: var(--pst-color-primary); /* Change text color on hover over card */
110 | }
111 |
112 | /* Overlapping */
113 | .example-gallery a.sd-stretched-link.reference {
114 | z-index: 999; /* Countermeasure where z-index = 998 */
115 | }
116 |
117 |
--------------------------------------------------------------------------------
/docs/source/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/favicon.ico
--------------------------------------------------------------------------------
/docs/source/_static/none-logo-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/none-logo-light.png
--------------------------------------------------------------------------------
/docs/source/_static/plotly-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/plotly-logo-dark.png
--------------------------------------------------------------------------------
/docs/source/_static/plotly-logo-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/plotly-logo-light.png
--------------------------------------------------------------------------------
/docs/source/_templates/name.html:
--------------------------------------------------------------------------------
1 |
6 |
--------------------------------------------------------------------------------
/docs/source/api/backend/bokeh.part.rst:
--------------------------------------------------------------------------------
1 | =============
2 | Bokeh backend
3 | =============
4 |
5 | .. automodule:: arviz_plots.backend.bokeh
6 |
--------------------------------------------------------------------------------
/docs/source/api/backend/index.rst:
--------------------------------------------------------------------------------
1 | ==============================
2 | Interface to plotting backends
3 | ==============================
4 |
5 | ------------------
6 | Available backends
7 | ------------------
8 |
9 | .. grid:: 1 1 2 2
10 |
11 | .. grid-item-card::
12 | :link: matplotlib
13 | :link-type: doc
14 | :link-alt: Matplotlib
15 | :img-background: ../../_static/matplotlib-logo-light.svg
16 | :class-img-bottom: dark-light
17 |
18 | .. grid-item-card::
19 | :link: bokeh
20 | :link-type: doc
21 | :link-alt: Bokeh
22 | :img-background: ../../_static/bokeh-logo-light.svg
23 | :class-img-bottom: dark-light
24 |
25 | .. grid-item-card::
26 | :link: plotly
27 | :link-type: doc
28 | :link-alt: Plotly
29 | :img-background: ../../_static/plotly-logo-light.png
30 | :class-img-bottom: dark-light
31 |
32 | .. grid-item-card::
33 | :link: none
34 | :link-type: doc
35 | :link-alt: None (no plotting, only data processing)
36 | :img-background: ../../_static/none-logo-light.png
37 | :class-img-bottom: dark-light
38 |
39 |
40 | .. toctree::
41 | :maxdepth: 1
42 | :hidden:
43 |
44 | Matplotlib
45 | Bokeh
46 | Plotly
47 | None (only processing, no plotting)
48 |
49 | ---------------------------
50 | Common interface definition
51 | ---------------------------
52 |
53 | .. automodule:: arviz_plots.backend
54 |
55 | .. dropdown:: Keyword arguments
56 | :name: backend_interface_arguments
57 | :open:
58 |
59 | The argument names are defined here to have a comprehensive list of all possibilities.
60 | If relevant, a keyword argument present here should be present in the function,
61 | and converted in each backend to its corresponding argument in that backend.
62 |
63 | This set of arguments doesn't aim to be complete, only to cover basic properties
64 | so the plotting functions can work on multiple backends without duplication.
65 | Advanced customization will be backend specific through ``**kwargs``.
66 |
67 | target
68 | This module is designed mainly in a functional way. Thus, all functions
69 | should take a ``target`` argument which indicates on which object should
70 | the function be applied to.
71 |
72 | color
73 | Color of the visual element. Should also be present whenever ``facecolor``
74 | and ``edgecolor`` are present, setting the default value for both.
75 |
76 | facecolor
77 | Color for filling the visual element.
78 |
79 | edgecolor
80 | Color for the edges of the visual element.
81 |
82 | alpha
83 | Transparency of the visual element.
84 |
85 | width
86 | Width of the visual element itself or of its edges, whichever applies.
87 |
88 | size
89 | Size of the visual element.
90 |
91 | linestyle
92 | Style of the line plotted.
93 |
94 | marker
95 | Marker to be added to the plot.
96 |
97 | vertical_align
98 | Vertical alignment between the visual element and the data coordinates provided.
99 |
100 | horizontal_align
101 | Horizontal alignment between the visual element and the data coordinates provided.
102 |
103 | axis
104 | Data axis (x, y or both) on which to apply the function.
105 |
--------------------------------------------------------------------------------
/docs/source/api/backend/interface.template.rst:
--------------------------------------------------------------------------------
1 | Object creation and I/O
2 | .......................
3 |
4 | .. autosummary::
5 | :toctree: generated/
6 |
7 | create_plotting_grid
8 | show
9 |
10 | Geoms
11 | .....
12 |
13 | .. autosummary::
14 | :toctree: generated/
15 |
16 | line
17 | scatter
18 | text
19 |
20 | Plot appeareance
21 | ................
22 |
23 | .. autosummary::
24 | :toctree: generated/
25 |
26 | title
27 | ylabel
28 | xlabel
29 | xticks
30 | yticks
31 | ticklabel_props
32 | remove_ticks
33 | remove_axis
34 |
35 | Legend
36 | ......
37 |
38 | .. autosummary::
39 | :toctree: generated/
40 |
41 | legend
42 |
--------------------------------------------------------------------------------
/docs/source/api/backend/matplotlib.part.rst:
--------------------------------------------------------------------------------
1 | ==================
2 | Matplotlib backend
3 | ==================
4 |
5 | .. automodule:: arviz_plots.backend.matplotlib
6 |
--------------------------------------------------------------------------------
/docs/source/api/backend/none.part.rst:
--------------------------------------------------------------------------------
1 | ============
2 | None backend
3 | ============
4 |
5 | .. automodule:: arviz_plots.backend.none
6 |
--------------------------------------------------------------------------------
/docs/source/api/backend/plotly.part.rst:
--------------------------------------------------------------------------------
1 | ==============
2 | Plotly backend
3 | ==============
4 |
5 | .. automodule:: arviz_plots.backend.plotly
6 |
--------------------------------------------------------------------------------
/docs/source/api/helpers.rst:
--------------------------------------------------------------------------------
1 | =========================
2 | Helper plotting functions
3 | =========================
4 |
5 | Helper plotting functions are available at the ``arviz_plots`` top level namespace and
6 | provide ways to customize plots by adding elements or modifying existing ones.
7 |
8 | An introductory guide to the ``plot_...`` functions is also available at :ref:`plots_intro`.
9 |
10 | .. currentmodule:: arviz_plots
11 |
12 | .. autosummary::
13 | :toctree: generated/
14 |
15 | add_bands
16 | add_lines
17 |
18 | Style
19 | ...............
20 |
21 | .. autosummary::
22 | :toctree: generated/
23 |
24 | style.available
25 | style.get
26 | style.use
27 |
28 |
--------------------------------------------------------------------------------
/docs/source/api/index.md:
--------------------------------------------------------------------------------
1 | # API reference
2 |
3 | ```{eval-rst}
4 | .. automodule:: arviz_plots
5 | ```
6 |
7 | ```{toctree}
8 | :maxdepth: 1
9 |
10 | plots
11 | helpers
12 | managers
13 | visuals
14 | backend/index
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/source/api/managers.rst:
--------------------------------------------------------------------------------
1 | =============================================
2 | Managers for faceting and aesthetics mapping
3 | =============================================
4 | The classes in this module lay at the core of the library,
5 | and are consequently available at the ``arviz_plots`` top level namespace.
6 |
7 | They abstract all information regarding :term:`faceting` and :term:`aesthetic mapping`
8 | in our :term:`figure` to prevent duplication and ensure coherence between
9 | the different functions.
10 |
11 | .. currentmodule:: arviz_plots
12 |
13 | PlotCollection
14 | ==============
15 |
16 | Object creation
17 | ...............
18 |
19 | .. autosummary::
20 | :toctree: generated/
21 |
22 | PlotCollection
23 | PlotCollection.grid
24 | PlotCollection.wrap
25 |
26 | Plotting
27 | ........
28 |
29 | .. autosummary::
30 | :toctree: generated/
31 |
32 | PlotCollection.add_legend
33 | PlotCollection.map
34 |
35 | Attributes
36 | ..........
37 |
38 | .. autosummary::
39 | :toctree: generated/
40 |
41 | PlotCollection.aes
42 | PlotCollection.viz
43 | PlotCollection.aes_set
44 | PlotCollection.facet_dims
45 | PlotCollection.data
46 |
47 | Faceting and aesthetics mapping
48 | ................................
49 |
50 | .. autosummary::
51 | :toctree: generated/
52 |
53 | PlotCollection.generate_aes_dt
54 | PlotCollection.get_aes_as_dataset
55 | PlotCollection.get_aes_kwargs
56 | PlotCollection.update_aes
57 | PlotCollection.update_aes_from_dataset
58 |
59 | Other
60 | .....
61 |
62 | .. autosummary::
63 | :toctree: generated/
64 |
65 | PlotCollection.allocate_artist
66 | PlotCollection.get_viz
67 | PlotCollection.get_target
68 | PlotCollection.show
69 | PlotCollection.savefig
70 |
71 | PlotMatrix
72 | ==========
73 |
74 | Object creation
75 | ...............
76 |
77 | .. autosummary::
78 | :toctree: generated/
79 |
80 | PlotMatrix
81 |
82 | Plotting
83 | ........
84 |
85 | .. autosummary::
86 | :toctree: generated/
87 |
88 | PlotMatrix.map
89 | PlotMatrix.map_triangle
90 | PlotMatrix.map_lower
91 | PlotMatrix.map_upper
92 |
93 | Attributes
94 | ..........
95 |
96 | .. autosummary::
97 | :toctree: generated/
98 |
99 | PlotMatrix.aes
100 | PlotMatrix.viz
101 | PlotMatrix.aes_set
102 | PlotMatrix.facet_dims
103 | PlotMatrix.data
104 |
--------------------------------------------------------------------------------
/docs/source/api/plots.rst:
--------------------------------------------------------------------------------
1 | .. _plots_api:
2 |
3 | ========================
4 | Batteries-included plots
5 | ========================
6 |
7 | Batteries-included plotting functions are available at the ``arviz_plots``
8 | top level namespace and provide plug and play opinionated solutions
9 | to common tasks within the Bayesian workflow.
10 |
11 | Each of the entries below describe the behaviour of each function, all its arguments
12 | and include a handful of examples of each.
13 | A complementary introduction and guide to ``plot_...`` functions is available at :ref:`plots_intro`.
14 |
15 | .. currentmodule:: arviz_plots
16 |
17 | .. autosummary::
18 | :toctree: generated/
19 |
20 | combine_plots
21 | plot_autocorr
22 | plot_bf
23 | plot_compare
24 | plot_convergence_dist
25 | plot_dist
26 | plot_energy
27 | plot_ecdf_pit
28 | plot_ess
29 | plot_ess_evolution
30 | plot_forest
31 | plot_loo_pit
32 | plot_pairs_focus
33 | plot_ppc_dist
34 | plot_ppc_pava
35 | plot_ppc_pit
36 | plot_ppc_rootogram
37 | plot_prior_posterior
38 | plot_psense_dist
39 | plot_psense_quantities
40 | plot_rank
41 | plot_rank_dist
42 | plot_ridge
43 | plot_trace
44 | plot_trace_dist
--------------------------------------------------------------------------------
/docs/source/api/visuals.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | Visual elements
3 | ===============
4 |
5 | .. automodule:: arviz_plots.visuals
6 |
7 | Data plotting elements
8 | ----------------------
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | ci_line_y
14 | ecdf_line
15 | fill_between_y
16 | hline
17 | hist
18 | hspan
19 | line
20 | line_xy
21 | line_x
22 | scatter_xy
23 | scatter_x
24 | scatter_xy
25 | trace_rug
26 | vline
27 | vspan
28 |
29 |
30 | Data and axis annotating elements
31 | ---------------------------------
32 |
33 | .. autosummary::
34 | :toctree: generated/
35 |
36 | annotate_label
37 | annotate_xy
38 | labelled_title
39 | labelled_x
40 | labelled_y
41 | point_estimate_text
42 |
43 | Plot customization elements
44 | ---------------------------
45 |
46 | .. autosummary::
47 | :toctree: generated/
48 |
49 | remove_axis
50 | remove_ticks
51 | set_xticks
52 | ticklabel_props
53 | grid
54 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=redefined-builtin,invalid-name
2 | import os
3 | import sys
4 | from importlib.metadata import metadata
5 | from pathlib import Path
6 |
7 | # -- Project information
8 |
9 | _metadata = metadata("arviz-plots")
10 |
11 | project = _metadata["Name"]
12 | author = _metadata["Author-email"].split("<", 1)[0].strip()
13 | copyright = f"2022, {author}"
14 |
15 | version = _metadata["Version"]
16 | if os.environ.get("READTHEDOCS", False):
17 | rtd_version = os.environ.get("READTHEDOCS_VERSION", "")
18 | if "." not in rtd_version and rtd_version.lower() != "stable":
19 | version = "dev"
20 | else:
21 | branch_name = os.environ.get("BUILD_SOURCEBRANCHNAME", "")
22 | if branch_name == "main":
23 | version = "dev"
24 | release = version
25 |
26 |
27 | # -- General configuration
28 |
29 | sys.path.insert(0, os.path.abspath("../sphinxext"))
30 |
31 | templates_path = ["_templates"]
32 | exclude_patterns = [
33 | "Thumbs.db",
34 | ".DS_Store",
35 | ".ipynb_checkpoints",
36 | "**/*.template.rst",
37 | "**/*.part.rst",
38 | "**/*.part.md",
39 | ]
40 | skip_gallery = os.environ.get("ARVIZDOCS_NOGALLERY", False)
41 |
42 | extensions = [
43 | "sphinx.ext.intersphinx",
44 | "sphinx.ext.mathjax",
45 | "sphinx.ext.viewcode",
46 | "sphinx.ext.autosummary",
47 | "sphinx.ext.extlinks",
48 | "numpydoc",
49 | "myst_nb",
50 | "sphinx_copybutton",
51 | "sphinx_design",
52 | "jupyter_sphinx",
53 | "matplotlib.sphinxext.plot_directive",
54 | "bokeh.sphinxext.bokeh_plot",
55 | ]
56 | if skip_gallery:
57 | exclude_patterns.append("gallery/*")
58 | else:
59 | extensions.append("gallery_generator")
60 |
61 | suppress_warnings = ["mystnb.unknown_mime_type"]
62 |
63 | backend_modules = ("none", "matplotlib", "bokeh", "plotly")
64 | api_backend_dir = Path(__file__).parent.resolve() / "api" / "backend"
65 | with open(api_backend_dir / "interface.template.rst", "r", encoding="utf-8") as f:
66 | interface_template = f.read()
67 | for file in backend_modules:
68 | with open(api_backend_dir / f"{file}.part.rst", "r", encoding="utf-8") as f:
69 | intro = f.read()
70 | with open(api_backend_dir / f"{file}.rst", "w", encoding="utf-8") as f:
71 | f.write(f"{intro}\n\n{interface_template}")
72 |
73 | # The reST default role (used for this markup: `text`) to use for all documents.
74 | default_role = "autolink"
75 |
76 | # If true, '()' will be appended to :func: etc. cross-reference text.
77 | add_function_parentheses = False
78 |
79 | # -- Options for extensions
80 |
81 | plot_include_source = True
82 | plot_formats = [("png", 90)]
83 | plot_html_show_formats = False
84 | plot_html_show_source_link = False
85 |
86 | extlinks = {
87 | "issue": ("https://github.com/arviz-devs/arviz-plots/issues/%s", "GH#%s"),
88 | "pull": ("https://github.com/arviz-devs/arviz-plots/pull/%s", "PR#%s"),
89 | }
90 |
91 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
92 | copybutton_prompt_is_regexp = True
93 |
94 | nb_execution_mode = "auto"
95 | nb_execution_excludepatterns = ["*.ipynb"]
96 | nb_kernel_rgx_aliases = {".*": "python3"}
97 | myst_enable_extensions = ["colon_fence", "deflist", "dollarmath", "amsmath", "linkify"]
98 |
99 | autosummary_generate = True
100 | autodoc_typehints = "none"
101 | autodoc_default_options = {
102 | "members": False,
103 | }
104 |
105 | numpydoc_show_class_members = False
106 | numpydoc_xref_param_type = True
107 | numpydoc_xref_ignore = {"of", "or", "optional", "scalar", "default"}
108 | singulars = ("int", "list", "dict", "float")
109 | numpydoc_xref_aliases = {
110 | "DataArray": ":class:`xarray.DataArray`",
111 | "Dataset": ":class:`xarray.Dataset`",
112 | "DataTree": ":class:`xarray.DataTree`",
113 | "mapping": ":term:`python:mapping`",
114 | "hashable": ":term:`python:hashable`",
115 | **{f"{singular}s": f":any:`{singular}s <{singular}>`" for singular in singulars},
116 | }
117 |
118 | intersphinx_mapping = {
119 | "arviz_org": ("https://www.arviz.org/en/latest/", None),
120 | "arviz_base": ("https://arviz-base.readthedocs.io/en/latest/", None),
121 | "arviz_stats": ("https://arviz-stats.readthedocs.io/en/latest/", None),
122 | "einstats": ("https://einstats.python.arviz.org/en/latest/", None),
123 | "numpy": ("https://numpy.org/doc/stable/", None),
124 | "python": ("https://docs.python.org/3/", None),
125 | "xarray": ("https://docs.xarray.dev/en/stable/", None),
126 | "matplotlib": ("https://matplotlib.org/stable/", None),
127 | "bokeh": ("https://docs.bokeh.org/en/latest", None),
128 | }
129 |
130 | # -- Options for HTML output
131 | html_theme = "sphinx_book_theme"
132 | html_context = {"default_mode": "light"}
133 | html_theme_options = {
134 | "logo": {
135 | "image_light": "_static/ArviZ.png",
136 | "image_dark": "_static/ArviZ_white.png",
137 | }
138 | }
139 | html_favicon = "_static/favicon.ico"
140 | html_static_path = ["_static"]
141 | html_css_files = ["custom.css"]
142 | html_sidebars = {
143 | "**": [
144 | "navbar-logo.html",
145 | "name.html",
146 | "icon-links.html",
147 | "search-button-field.html",
148 | "sbt-sidebar-nav.html",
149 | ]
150 | }
151 |
--------------------------------------------------------------------------------
/docs/source/contributing/docs.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | ## How to build the documentation locally
4 | Similarly to testing, there are also tox jobs that take care of building the right environment
5 | and running the required commands to build the documentation.
6 |
7 | In general the process should follow these three steps in this order:
8 |
9 | ```console
10 | tox -e cleandocs
11 | tox -e docs # or tox -e nogallerydocs
12 | tox -e viewdocs
13 | ```
14 |
15 | These commands will respectively:
16 |
17 | * Delete all intermediate sphinx files that were generated during the build process
18 | * Run `sphinx-build` command to parse and render the library documentation
19 | * Open the documentation homepage on the default browser with `python -m webbrowser`
20 |
21 | The only required step however is the middle one. In general sphinx uses the intermediate
22 | files only if it detects it hasn't been modified, so when iterating quickly it is recommended
23 | to skip the clean step in order to achieve faster builds. Moreover, if the documentation
24 | page is already open on the browser, there is no need for the viewdocs job because
25 | the documentation is always rendered on the same path; refreshing the page from the browser
26 | is enough.
27 |
28 | The example gallery requires processing the python scripts in order to execute each
29 | once per backend in order to generate the png or html+javascript preview.
30 | Therefore, it is the most time consuming step of generating the documentation.
31 | As very often we'll work on the part of the docs not related to the example gallery,
32 | the command `tox -e nogallerydocs` will generate the documentation without the example gallery,
33 | which allows for much faster iteration when writing documentation.
34 | This also means for example the `minigallery` directive in the docstrings won't work,
35 | and sphinx will output warnings about it when using this option.
36 |
37 |
38 | ## How to add examples to the gallery
39 | Examples in the gallery are written in the form of python scripts.
40 | They are divided between multiple categories,
41 | with each category being a folder within `/docs/source/gallery/`.
42 | Therefore, anything matching this glob `/docs/source/gallery/**/*.py`
43 | will be rendered into the example gallery.
44 | To control the order in which examples appear in the gallery,
45 | all filenames should start with two digits, underscore and then
46 | the unique name of the script.
47 |
48 | The script is divided in two parts, the first is a file level docstring,
49 | the second is the code example itself.
50 |
51 | The docstring part should contain a markdown top level title,
52 | a short description of the example and a seealso directive using MyST syntax.
53 |
54 | The code part should import arviz-plots as `azp`. Later on, it set `backend="none"`
55 | explicitly when calling the plotting functions and
56 | store the generated {class}`~arviz_plots.PlotCollection` as the `pc` variable
57 | so the example can finish with `pc.show()`.
58 |
59 | Here is an example that can be used as template:
60 |
61 | ```python
62 | """
63 | # Posterior ECDFs
64 |
65 | Faceted ECDF plots for 1D marginals of the distribution
66 |
67 | ---
68 |
69 | :::{seealso}
70 | API Documentation: {func}`~arviz_plots.plot_dist`
71 |
72 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz)
73 | :::
74 | """
75 | from arviz_base import load_arviz_data
76 |
77 | import arviz_plots as azp
78 |
79 | azp.style.use("arviz-variat")
80 |
81 | data = load_arviz_data("centered_eight")
82 | pc = azp.plot_dist(
83 | data,
84 | kind="ecdf",
85 | col_wrap=4,
86 | backend="none" # change to preferred backend
87 | )
88 | pc.show()
89 | ```
90 |
91 | ## About arviz-plots documentation
92 | Documentation for arviz-plots is written in both rST and MyST (which can be used from jupyter
93 | notebooks too) and rendered with Sphinx. Docstrings follow the numpydoc style guide.
94 |
95 | ### The gallery generator sphinxext
96 | We have a custom sphinx extension to generate the example gallery, located at
97 | `/docs/sphinxext/gallery_generator.py`.
98 |
99 | This sphinx extension reads the example scripts within `/docs/source/gallery`
100 | and takes care of the following tasks:
101 |
102 | 1. Process the script contents into a MyST source page with proper syntax, tabs, code block,
103 | and relevant links.
104 | 1. Execute the code for all plotting backends to generate the respective previews.
105 | In addition, when executing the matplotlib version, it is stored as a png to use as the
106 | miniature in the gallery page.
107 | 1. Generate the index page for the gallery, with the grid view
108 | 1. Generate a json with references of all the functions used in the different examples.
109 | This supports the `minigallery` directive that allows adding plot specific galleries
110 | in the examples section of the docstring.
111 |
--------------------------------------------------------------------------------
/docs/source/contributing/testing.md:
--------------------------------------------------------------------------------
1 | # Testing
2 |
3 | ## How to run the test suite
4 |
5 | To run the test suite `tox` should be installed.
6 |
7 | ### Run the whole test suite
8 | Tox creates an independent env where all testing dependencies are installed
9 | and then runs pytest to execute all tests on the `/test`:
10 |
11 | ```console
12 | tox -e py311
13 | ```
14 |
15 | The `-e` flag stands for "execute" we want to execute a previously defined job that
16 | takes care of the steps above. The job name is "py" followed by your local python
17 | version without decimal point.
18 |
19 | :::{note}
20 | It is also possible to run `pytest tests/` directly instead of `tox -e py311`,
21 | and all commands covered in this page work either way. However, it is recommended
22 | to use tox to isolate the testing environment and have local testing be as similar
23 | as possible as testing in CI jobs.
24 | :::
25 |
26 | ### Pass arguments to pytest
27 | We can also pass arguments through tox to pytest. With this we can for example
28 | select specific subsets of tests to be executed with the `-k` flag:
29 |
30 | ```console
31 | tox -e py311 -- -k plot_trace_dist
32 | ```
33 |
34 | Would run all tests whose name contains `plot_trace_dist`.
35 | The [pytest documentation](https://docs.pytest.org/en/stable/reference/reference.html#command-line-flags)
36 | lists and describes all available options.
37 |
38 | ### Custom pytest arguments
39 | In addition to built-in pytest arguments, we have also defined a couple extra flags
40 | in `tests/conftest.py` to handle arviz-plots specific situations.
41 |
42 | #### Skip flags
43 | One of the drivers of arviz-plots design is ensuring parity between the different
44 | plotting backends with minimal duplication.
45 |
46 | Therefore, all tests that depend on the plotting backend should be parametrized
47 | with `@pytest.mark.parametrize("backend", backend_list)` to make sure all backends
48 | pass all relevant tests.
49 |
50 | At the same time however, arviz-plots considers all backends optional dependencies,
51 | so not all backends might be installed and consequently, not all tests can be executed.
52 | By default, testing works under the assumption that all backends are installed,
53 | but backend specific tests can be skipped when running the test suite:
54 |
55 | ```console
56 | tox -e py311 -- --skip-bokeh
57 | tox -e py311 -- --skip-mpl
58 | tox -e py311 -- --skip-plotly
59 | ```
60 |
61 | :::{note} It is also possible to use both flags, in which case, only tests
62 | independent of the plotting backend like asethetic generation and the like
63 | will be executed.
64 | :::
65 |
66 | #### Saving matplotlib figures generated by tests
67 | Testing checks plotting functions can be executed and return objects with the
68 | right properties, but there are no checks against the actual generated images.
69 |
70 | As we might often want to check the generated images, there is also a `--save`
71 | flag to indicate pytest to save all figures generated by matplotlib while testing.
72 |
73 | This flag takes one optional argument in case we want to specify the folder where
74 | the images will be saved, otherwise it defaults to `test_images` in the project
75 | home folder.
76 |
77 | ```console
78 | tox -e py311 -- --save
79 | ```
80 |
81 | Generates basically the same output as any test job:
82 |
83 | ```
84 | build: _optional_hooks> python ...
85 | [...]
86 | py311: OK (42.26=setup[1.79]+cmd[40.47] seconds)
87 | congratulations :) (42.31 seconds)
88 | ```
89 |
90 | But if you inspect the project home folder, you should see a `/test_images` folder
91 | with contents similar to:
92 |
93 | ```
94 | 'test_grid[matplotlib].png' 'test_plot_forest_extendable[matplotlib].png' 'test_plot_trace_dist[True-False-matplotlib].png'
95 | 'test_grid_rows_cols[cols-matplotlib].png' 'test_plot_forest[False-matplotlib].png' 'test_plot_trace_dist[True-True-matplotlib].png'
96 | 'test_grid_rows_cols[rows-matplotlib].png' 'test_plot_forest_models[matplotlib].png' 'test_plot_trace[matplotlib].png'
97 | 'test_grid_scalar[matplotlib].png' 'test_plot_forest_sample[matplotlib].png' 'test_plot_trace_sample[matplotlib].png'
98 | 'test_grid_variable[matplotlib].png' 'test_plot_forest[True-matplotlib].png' 'test_wrap[matplotlib].png'
99 | 'test_plot_dist[matplotlib].png' 'test_plot_trace_dist[False-False-matplotlib].png' 'test_wrap_only_variable[matplotlib].png'
100 | 'test_plot_dist_models[matplotlib].png' 'test_plot_trace_dist[False-True-matplotlib].png' 'test_wrap_variable[matplotlib].png'
101 | ```
102 |
103 | ## About arviz-plots testing
104 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/00_plot_dist_ecdf.py:
--------------------------------------------------------------------------------
1 | """
2 | # Posterior ECDFs
3 |
4 | Faceted ECDF plots for 1D marginals of the distribution
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_dist`
10 |
11 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | data = load_arviz_data("centered_eight")
21 | pc = azp.plot_dist(
22 | data,
23 | kind="ecdf",
24 | col_wrap=4,
25 | backend="none" # change to preferred backend
26 | )
27 | pc.show()
28 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/01_plot_dist_hist.py:
--------------------------------------------------------------------------------
1 | """
2 | # Posterior Histograms
3 |
4 | Faceted histogram plots for 1D marginals of the distribution.
5 | The `point_estimate_text` option is set to False to omit that visual from the plot.
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_dist`
10 |
11 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | data = load_arviz_data("centered_eight")
21 | pc = azp.plot_dist(
22 | data,
23 | kind="hist",
24 | visuals={"point_estimate_text": False},
25 | backend="none" # change to preferred backend
26 | )
27 | pc.show()
28 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/02_plot_dist_kde.py:
--------------------------------------------------------------------------------
1 | """
2 | # Posterior KDEs
3 |
4 | KDE plot of the variable `mu` from the centered eight model. The `sample_dims` parameter is
5 | used to restrict the KDE computation along the `draw` dimension only."
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_dist`
10 |
11 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | data = load_arviz_data("centered_eight")
21 | pc = azp.plot_dist(
22 | data,
23 | kind="kde",
24 | var_names=["mu"],
25 | sample_dims=["draw"],
26 | backend="none" # change to preferred backend
27 | )
28 | pc.show()
29 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/04_plot_forest.py:
--------------------------------------------------------------------------------
1 | """
2 | # Forest plot
3 |
4 | Default forest plot with marginal distribution summaries
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_forest`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("rugby")
19 | pc = azp.plot_forest(
20 | data,
21 | var_names=["home", "atts", "defs"],
22 | backend="none" # change to preferred backend
23 | )
24 | pc.show()
25 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/05_plot_forest_shade.py:
--------------------------------------------------------------------------------
1 | """
2 | # Forest plot with shading
3 |
4 | Forest plot marginal summaries with row shading to enhance reading
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_forest`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("rugby")
19 | pc = azp.plot_forest(
20 | data,
21 | var_names=["home", "atts", "defs"],
22 | shade_label="team",
23 | backend="none", # change to preferred backend
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/06_plot_prior_posterior.py:
--------------------------------------------------------------------------------
1 | """
2 | # Plot prior and posterior
3 |
4 | Plot prior and posterior marginal distributions.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_prior_posterior`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("centered_eight")
19 | pc = azp.plot_prior_posterior(
20 | data,
21 | var_names="mu",
22 | kind="hist",
23 | backend="none" # change to preferred backend
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/distribution/07_plot_pairs_focus_distribution.py:
--------------------------------------------------------------------------------
1 | """
2 | # Scatterplot one variable against all others
3 |
4 | Plot one variable against other variables in the dataset.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_pairs_focus`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("centered_eight")
19 | pc = azp.plot_pairs_focus(
20 | data,
21 | var_names=["theta","tau"],
22 | focus_var="mu",
23 | figure_kwargs={"sharex": True},
24 | backend="none", # change to preferred backend
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/00_plot_rank.py:
--------------------------------------------------------------------------------
1 | """
2 | # Rank plot
3 |
4 | faceted plot with fractional ranks for each variable
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_rank`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("centered_eight")
19 | pc = azp.plot_rank(
20 | data,
21 | backend="none" # change to preferred backend
22 | )
23 | pc.show()
24 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/01_plot_trace.py:
--------------------------------------------------------------------------------
1 | """
2 | # Trace plot
3 |
4 | faceted plot with MCMC traces for each variable
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_trace`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("centered_eight")
19 | pc = azp.plot_trace(
20 | data,
21 | backend="none" # change to preferred backend
22 | )
23 | pc.show()
24 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/02_plot_ess_evolution.py:
--------------------------------------------------------------------------------
1 | """
2 | # ESS evolution
3 |
4 | faceted plot with ESS 'bulk' and 'tail' for each variable
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ess_evolution`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("centered_eight")
20 | pc = azp.plot_ess_evolution(data, backend="none") # change to preferred backend
21 | pc.show()
22 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/03_plot_ess_local.py:
--------------------------------------------------------------------------------
1 | """
2 | # ESS local
3 |
4 | faceted local ESS plot
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ess`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("centered_eight")
20 | pc = azp.plot_ess(
21 | data,
22 | kind="local",
23 | backend="none", # change to preferred backend
24 | rug=True,
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/04_plot_ess_quantile.py:
--------------------------------------------------------------------------------
1 | """
2 | # ESS quantile
3 |
4 | faceted quantile ESS plot
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ess`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("centered_eight")
20 | pc = azp.plot_ess(
21 | data,
22 | kind="quantile",
23 | backend="none", # change to preferred backend
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/05_plot_ess_models.py:
--------------------------------------------------------------------------------
1 | """
2 | # ESS comparison
3 |
4 | Full ESS (Either local or quantile) comparison between different models
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ess`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | c = load_arviz_data("centered_eight")
20 | n = load_arviz_data("non_centered_eight")
21 | pc = azp.plot_ess(
22 | {"Centered": c, "Non Centered": n},
23 | backend="none", # change to preferred backend
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/05_plot_mcse.py:
--------------------------------------------------------------------------------
1 | """
2 | # Monte Carlo standard error
3 |
4 | faceted quantile MCSE plot
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ess`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("centered_eight")
20 | pc = azp.plot_mcse(
21 | data,
22 | extra_methods=True,
23 | var_names=["mu"],
24 | backend="none", # change to preferred backend
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/06_plot_convergence_dist.py:
--------------------------------------------------------------------------------
1 | """
2 | # Convergence diagnostics distribution
3 |
4 | Plot the distribution of ESS and R-hat.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_convergence_dist`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("radon")
20 | pc = azp.plot_convergence_dist(
21 | data,
22 | var_names=["za_county"],
23 | backend="none", # change to preferred backend
24 | )
25 |
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/07_plot_autocorr.py:
--------------------------------------------------------------------------------
1 | """
2 | # Autocorrelation Plot
3 |
4 | faceted plot with autocorrelation for each variable
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_autocorr`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("centered_eight")
19 | pc = azp.plot_autocorr(
20 | data,
21 | backend="none" # change to preferred backend
22 | )
23 | pc.show()
24 |
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/08_plot_energy.py:
--------------------------------------------------------------------------------
1 | """
2 | # Energy
3 |
4 | Plot transition and marginal energy distributions
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_energy`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("centered_eight")
19 | pc = azp.plot_energy(
20 | data,
21 | backend="none" # change to preferred backend
22 | )
23 | pc.show()
--------------------------------------------------------------------------------
/docs/source/gallery/inference_diagnostics/09_plot_pairs_focus.py:
--------------------------------------------------------------------------------
1 | """
2 | # Scatter plot with divergences
3 |
4 | Plot one variable against other variables in the dataset.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_pairs_focus`
10 | :::
11 | """
12 | import numpy as np
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | dt = load_arviz_data("centered_eight")
20 | dt.posterior["log_tau"] = np.log(dt.posterior["tau"])
21 |
22 | pc = azp.plot_pairs_focus(
23 | dt,
24 | var_names=["theta"],
25 | focus_var="log_tau",
26 | visuals={"divergence":True},
27 | backend="none", # change to preferred backend
28 | )
29 | pc.show()
30 |
--------------------------------------------------------------------------------
/docs/source/gallery/mixed/00_plot_rank_dist.py:
--------------------------------------------------------------------------------
1 | """
2 | # Rank and distribution plot
3 |
4 | Two column layout with marginal distributions on the left and fractional ranks on the right
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_rank_dist`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("non_centered_eight")
19 | pc = azp.plot_rank_dist(
20 | data,
21 | var_names=["mu", "tau"],
22 | backend="none" # change to preferred backend
23 | )
24 | pc.show()
25 |
--------------------------------------------------------------------------------
/docs/source/gallery/mixed/01_plot_trace_dist.py:
--------------------------------------------------------------------------------
1 | """
2 | # Trace and distribution plot
3 |
4 | Two column layout with marginal distributions on the left and MCMC traces on the right
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_trace_dist`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | data = load_arviz_data("non_centered_eight")
19 | pc = azp.plot_trace_dist(
20 | data,
21 | backend="none" # change to preferred backend
22 | )
23 | pc.show()
24 |
--------------------------------------------------------------------------------
/docs/source/gallery/mixed/02_plot_forest_ess.py:
--------------------------------------------------------------------------------
1 | """
2 | # Forest plot with ESS
3 |
4 | Multiple panel visualization with a forest plot and ESS information
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_forest`
10 | :::
11 | """
12 | import arviz_stats # make azstats accessor available
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | centered = load_arviz_data("centered_eight")
20 |
21 | c_aux = (
22 | centered["posterior"]
23 | .dataset.expand_dims(column=3)
24 | .assign_coords(column=["labels", "forest", "ess"])
25 | )
26 |
27 | pc = azp.plot_forest(
28 | c_aux,
29 | combined=True,
30 | backend="none", # change to preferred backend
31 | )
32 |
33 | pc.map(
34 | azp.visuals.scatter_x,
35 | "ess",
36 | data=centered.posterior.ds.azstats.ess(),
37 | coords={"column": "ess"},
38 | color="gray",
39 | )
40 | pc.show()
41 |
--------------------------------------------------------------------------------
/docs/source/gallery/mixed/03_combine_plots.py:
--------------------------------------------------------------------------------
1 | """
2 | # Custom diagnostic plots combination
3 |
4 | Arrange three diagnostic plots (ESS evolution plot, rank plot and autocorrelation plot)
5 | in a custom column layout.
6 |
7 | ---
8 |
9 | :::{seealso}
10 | API Documentation: {func}`~arviz_plots.combine_plots`
11 | :::
12 | """
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("non_centered_eight")
20 | pc = azp.combine_plots(
21 | data,
22 | [
23 | (azp.plot_ess_evolution, {}),
24 | (azp.plot_rank, {}),
25 | (azp.plot_autocorr, {}),
26 | ],
27 | var_names=["theta", "mu", "tau"],
28 | coords={"school": ["Hotchkiss", "St. Paul's"]},
29 | backend="none", # change to preferred backend
30 | )
31 | pc.show()
32 |
--------------------------------------------------------------------------------
/docs/source/gallery/model_comparison/00_plot_compare.py:
--------------------------------------------------------------------------------
1 | """
2 | # Predictive model comparison
3 |
4 | Compare multiple models using predictive accuracy estimated using PSIS-LOO-CV.
5 | Usually the DataFrame ``cmp_df`` is generated using ArviZ's ```compare` function.
6 |
7 | ---
8 |
9 | :::{seealso}
10 | API Documentation: {func}`~arviz_plots.plot_compare`
11 | :::
12 | """
13 | import pandas as pd
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 |
20 | cmp_df = pd.DataFrame({"elpd": [-4.5, -14.3, -16.2],
21 | "p": [2.6, 2.3, 2.1],
22 | "elpd_diff": [0, 9.7, 11.3],
23 | "weight": [0.9, 0.1, 0],
24 | "se": [2.3, 2.7, 2.3],
25 | "dse": [0, 2.7, 2.3],
26 | "warning": [False, False, False],
27 | },
28 | index=["Model B", "Model A", "Model C"])
29 |
30 | pc = azp.plot_compare(cmp_df,
31 | backend="none", # change to preferred backend
32 | )
33 | pc.show()
--------------------------------------------------------------------------------
/docs/source/gallery/model_comparison/99_plot_bf.py:
--------------------------------------------------------------------------------
1 | """
2 | # Bayes_factor
3 |
4 | Compute Bayes factor using Savage–Dickey ratio.
5 |
6 | We can apply this function when the null model is nested within the alternative.
7 | In other words when the null (``ref_val``) is a particular value of the model we are
8 | building (see [here](https://statproofbook.github.io/P/bf-sddr.html)).
9 |
10 | For others cases computing Bayes factor is not straightforward and requires more complex
11 | methods. Instead, of Bayes factors, we usually recommend Pareto smoothed
12 | importance sampling leave one out cross validation (PSIS-LOO-CV). In ArviZ, you will find
13 | them as functions with ``loo`` in their names.
14 |
15 | ---
16 | :::{seealso}
17 | API Documentation: {func}`~arviz_plots.plot_bf`
18 | :::
19 | """
20 | from arviz_base import load_arviz_data
21 |
22 | import arviz_plots as azp
23 |
24 | azp.style.use("arviz-variat")
25 |
26 | data = load_arviz_data("centered_eight")
27 |
28 | pc = azp.plot_bf(
29 | data,
30 | backend="none", # change to preferred backend
31 | var_names="mu"
32 | )
33 |
34 | pc.show()
--------------------------------------------------------------------------------
/docs/source/gallery/posterior_comparison/00_plot_dist_models.py:
--------------------------------------------------------------------------------
1 | """
2 | # Posterior KDEs for two models
3 |
4 | Full marginal distribution comparison between different models
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_dist`
10 |
11 | Other examples comparing marginal distributions: {ref}`gallery_forest_models`
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | c = load_arviz_data("centered_eight")
21 | n = load_arviz_data("non_centered_eight")
22 | pc = azp.plot_dist(
23 | {"Centered": c, "Non Centered": n},
24 | backend="none" # change to preferred backend
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/posterior_comparison/01_plot_forest_models.py:
--------------------------------------------------------------------------------
1 | """
2 | # Posterior forest for two models
3 |
4 | Forest plot summaries for 1D marginal distributions
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_forest`
10 |
11 | Other examples comparing marginal distributions: {ref}`gallery_dist_models`
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | c = load_arviz_data("centered_eight")
21 | n = load_arviz_data("non_centered_eight")
22 | pc = azp.plot_forest(
23 | {"Centered": c, "Non Centered": n},
24 | backend="none" # change to preferred backend
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/00_plot_ppc_dist.py:
--------------------------------------------------------------------------------
1 | """
2 | # Predictive check with KDEs
3 |
4 | Plot of samples from the posterior predictive and observed data.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ppc_dist`
10 |
11 | EABM chapter on [Posterior predictive checks](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#posterior-predictive-checks)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | dt = load_arviz_data("radon")
21 | pc = azp.plot_ppc_dist(
22 | dt,
23 | backend="none",
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/01_plot_ppc_rootogram.py:
--------------------------------------------------------------------------------
1 | """
2 | # Rootogram
3 |
4 | Rootogram for the posterior predictive and observed data.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ppc_rootogram`
10 |
11 | EABM chapter on [Posterior predictive checks for count data](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#posterior-predictive-checks-for-count-data)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | dt = load_arviz_data("rugby")
21 | pc = azp.plot_ppc_rootogram(
22 | dt,
23 | aes={"color": ["__variable__"]}, # map variable to color
24 | aes_by_visuals={"title": ["color"]}, # change title's color per variable
25 | backend="none",
26 | )
27 | pc.show()
28 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/03_plot_pava_calibration.py:
--------------------------------------------------------------------------------
1 | """
2 | # PAV-adjusted calibration
3 |
4 | PAV-adjusted calibration plot for binary predictions.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ppc_pava`
10 |
11 | EABM chapter on [Posterior predictive checks for binary data](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#posterior-predictive-checks-for-binary-data)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | dt = load_arviz_data("anes")
21 | pc = azp.plot_ppc_pava(
22 | dt,
23 | backend="none",
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/04_plot_ppc_pit.py:
--------------------------------------------------------------------------------
1 | """
2 | # PIT ECDF
3 |
4 | Plot of the probability integral transform of the posterior predictive distribution with respect to the observed data.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ppc_pit`
10 |
11 | EABM chapter on [Posterior predictive checks with PIT-ECDFs](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#pit-ecdfs)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | dt = load_arviz_data("radon")
21 | pc = azp.plot_ppc_pit(
22 | dt,
23 | backend="none",
24 | )
25 | pc.show()
26 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/05_plot_ppc_coverage.py:
--------------------------------------------------------------------------------
1 | """
2 | # Coverage ECDF
3 |
4 | Proportion of true values that fall within a given prediction interval.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ppc_pit`
10 |
11 | EABM chapter on [Posterior predictive checks and coverage](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#coverage)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | dt = load_arviz_data("radon")
21 | pc = azp.plot_ppc_pit(
22 | dt,
23 | coverage=True,
24 | backend="none",
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/06_plot_loo_pit.py:
--------------------------------------------------------------------------------
1 | """
2 | # LOO-PIT ECDF
3 |
4 | Plot of the probability integral transform of the posterior predictive distribution with
5 | respect to the observed data using the leave-one-out (LOO) method.
6 |
7 |
8 | ---
9 |
10 | :::{seealso}
11 | API Documentation: {func}`~arviz_plots.plot_ppc_pit`
12 |
13 | EABM chapter on [Posterior predictive checks with LOO-PIT ECDFs](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#sec-avoid-double-dipping)
14 | :::
15 | """
16 | from arviz_base import load_arviz_data
17 |
18 | import arviz_plots as azp
19 |
20 | azp.style.use("arviz-variat")
21 |
22 | dt = load_arviz_data("radon")
23 | pc = azp.plot_loo_pit(
24 | dt,
25 | backend="none",
26 | )
27 | pc.show()
28 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/07_plot_ppc_tstat.py:
--------------------------------------------------------------------------------
1 | """
2 | # Test statistics
3 |
4 | T-statistic for the observed data and posterior predictive data.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_ppc_tstat`
10 |
11 | EABM chapter on [Posterior predictive checks with summary statistics](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#using-summary-statistics)
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | dt = load_arviz_data("regression1d")
21 | pc = azp.plot_ppc_tstat(
22 | dt,
23 | t_stat="median",
24 | backend="none"
25 | )
26 | pc.show()
27 |
--------------------------------------------------------------------------------
/docs/source/gallery/predictive_checks/99_plot_forest_pp_obs.py:
--------------------------------------------------------------------------------
1 | """
2 | # Posterior predictive forest and observations
3 |
4 | Overlay of forest plot for the posterior predictive samples and the actual observations
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.plot_forest`
10 | :::
11 | """
12 | from arviz_base import load_arviz_data
13 |
14 | import arviz_plots as azp
15 |
16 | azp.style.use("arviz-variat")
17 |
18 | idata = load_arviz_data("non_centered_eight")
19 | pc = azp.plot_forest(
20 | idata,
21 | group="posterior_predictive",
22 | combined=True,
23 | labels=["obs_dim_0"],
24 | backend="none", # change to preferred backend
25 | )
26 |
27 | pc.map(
28 | azp.visuals.scatter_x,
29 | "observations",
30 | data=idata.observed_data.ds,
31 | coords={"column": "forest"},
32 | color="gray",
33 | )
34 |
35 | pc.map(
36 | azp.visuals.labelled_x,
37 | "xlabel",
38 | coords={"column": "forest"},
39 | text="Observations",
40 | ignore_aes="y",
41 | )
42 | pc.show()
43 |
--------------------------------------------------------------------------------
/docs/source/gallery/prior_and_likelihood_sensitivity_checks/00_plot_psense.py:
--------------------------------------------------------------------------------
1 | """
2 | # Sensitivity posterior marginals
3 |
4 | The posterior sensitivity is assessed by power-scaling the prior or likelihood and
5 | visualizing the resulting changes. Sensitivity can then be quantified by considering
6 | how much the perturbed posteriors differ from the base posterior.
7 |
8 | ---
9 |
10 | :::{seealso}
11 | API Documentation: {func}`~arviz_plots.plot_psense_dist`
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | idata = load_arviz_data("rugby")
21 | pc = azp.plot_psense_dist(
22 | idata,
23 | var_names=["defs", "sd_att", "sd_def"],
24 | coords={"team": ["Scotland", "Wales"]},
25 | y=[-2, -1, 0],
26 | backend="none",
27 | )
28 | pc.show()
29 |
--------------------------------------------------------------------------------
/docs/source/gallery/prior_and_likelihood_sensitivity_checks/01_plot_psense_quantities.py:
--------------------------------------------------------------------------------
1 | """
2 | # Sensitivity posterior quantities
3 |
4 | The posterior quantities are computed by power-scaling the prior or likelihood and
5 | visualizing the resulting changes. Sensitivity can then be quantified by considering
6 | how much the perturbed quantities differ from the base quantities.
7 |
8 | ---
9 |
10 | :::{seealso}
11 | API Documentation: {func}`~arviz_plots.plot_psense_quantities`
12 | :::
13 | """
14 | from arviz_base import load_arviz_data
15 |
16 | import arviz_plots as azp
17 |
18 | azp.style.use("arviz-variat")
19 |
20 | idata = load_arviz_data("rugby")
21 | pc = azp.plot_psense_quantities(
22 | idata,
23 | var_names=["sd_att", "sd_def"],
24 | quantities=["mean", "sd", "0.25", "0.75"],
25 | col_wrap=2,
26 | backend="none",
27 | )
28 | pc.show()
29 |
--------------------------------------------------------------------------------
/docs/source/gallery/sbc/00_plot_ecdf_pit.py:
--------------------------------------------------------------------------------
1 | """
2 | # PIT-ECDF
3 |
4 | Faceted plot with PIT Δ-ECDF values for each variable
5 |
6 | The ``plot_ecdf_pit`` function assumes the values passed to it has already been transformed
7 | to PIT values, as in the case of SBC analysis or values from ``arviz_base.loo_pit``.
8 |
9 | The distribution should be uniform if the model is well-calibrated.
10 |
11 | To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between
12 | the expected CDF from the observed ECDF. As small deviations from uniformity are expected,
13 | the plot also shows the credible envelope.
14 |
15 | ---
16 |
17 | :::{seealso}
18 | API Documentation: {func}`~arviz_plots.plot_ecdf_pit`
19 | :::
20 | """
21 | from arviz_base import load_arviz_data
22 |
23 | import arviz_plots as azp
24 |
25 | azp.style.use("arviz-variat")
26 |
27 | data = load_arviz_data("sbc")
28 | pc = azp.plot_ecdf_pit(
29 | data,
30 | backend="none" # change to preferred backend
31 | )
32 | pc.show()
33 |
--------------------------------------------------------------------------------
/docs/source/gallery/sbc/01_plot_ecdf_coverage.py:
--------------------------------------------------------------------------------
1 | """
2 | # Coverage ECDF
3 |
4 | Coverage refers to the proportion of true values that fall within a given prediction interval.
5 | For a well-calibrated model, the coverage should match the intended interval width. For example,
6 | a 95% credible interval should contain the true value 95% of the time.
7 |
8 | The distribution should be uniform if the model is well-calibrated.
9 |
10 | To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between
11 | the expected CDF from the observed ECDF. As small deviations from uniformity are expected,
12 | the plot also shows the credible envelope.
13 |
14 | We can compute the coverage for equal-tailed intervals (ETI) by passing `coverage=True` to the
15 | `plot_ecdf_pit` function. This works because ETI coverage can be obtained by transforming the PIT
16 | values. However, for other interval types, such as HDI, coverage must be computed explicitly and
17 | is not supported by this function.
18 |
19 | ---
20 |
21 | :::{seealso}
22 | API Documentation: {func}`~arviz_plots.plot_ecdf_pit`
23 | :::
24 | """
25 | from arviz_base import load_arviz_data
26 |
27 | import arviz_plots as azp
28 |
29 | azp.style.use("arviz-variat")
30 |
31 | data = load_arviz_data("sbc")
32 | pc = azp.plot_ecdf_pit(
33 | data,
34 | coverage=True,
35 | backend="none" # change to preferred backend
36 | )
37 | pc.show()
38 |
--------------------------------------------------------------------------------
/docs/source/gallery/utils/00_add_reference_lines.py:
--------------------------------------------------------------------------------
1 | """
2 | # Add Lines
3 |
4 | Draw lines on plots to highlight specific thresholds, targets, or important values.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.add_lines`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("centered_eight")
20 | ref_ds = data.posterior.dataset.quantile([0.5, 0.1, 0.9], dim=["chain", "draw"])
21 | pc = azp.plot_dist(
22 | data,
23 | kind="ecdf",
24 | backend="none", # change to preferred backend
25 | )
26 | pc = azp.add_lines(
27 | pc,
28 | values=ref_ds,
29 | ref_dim="quantile",
30 | aes_by_visuals={"ref_line": ["color"]},
31 | color=["black", "gray", "gray"]
32 | )
33 | pc.show()
34 |
--------------------------------------------------------------------------------
/docs/source/gallery/utils/01_add_reference_bands.py:
--------------------------------------------------------------------------------
1 | """
2 | # Add Reference Bands
3 |
4 | Draw reference bands to highlight specific regions.
5 |
6 | ---
7 |
8 | :::{seealso}
9 | API Documentation: {func}`~arviz_plots.add_bands`
10 | :::
11 | """
12 |
13 | from arviz_base import load_arviz_data
14 |
15 | import arviz_plots as azp
16 |
17 | azp.style.use("arviz-variat")
18 |
19 | data = load_arviz_data("centered_eight")
20 | rope = [(-1, 1)]
21 | pc = azp.plot_forest(
22 | data,
23 | backend="none", # change to preferred backend
24 | )
25 | pc.coords = {"column": "forest"}
26 | pc = azp.add_bands(
27 | pc,
28 | values=rope,
29 | visuals={"ref_band":{"color": "#f66d7f"}},
30 | )
31 |
32 | pc.show()
33 |
--------------------------------------------------------------------------------
/docs/source/glossary.md:
--------------------------------------------------------------------------------
1 | # Glossary
2 |
3 |
4 | :::{glossary}
5 | aesthetic
6 | aesthetics
7 | When used as a noun, we use _an aesthetic_ as a graphical property that is
8 | being used to encode data.
9 |
10 | Moreover, within `arviz_plots` _aesthetics_ can actually be any arbitrary
11 | keyword argument accepted by the plotting function being used.
12 |
13 | aesthetic mapping
14 | aesthetic mappings
15 | We use _aesthetic mapping_ to indicate the relation between the {term}`aesthetics`
16 | in our plot and properties in our dataset.
17 |
18 | figure
19 | Highest level data visualization structure. All plotted elements
20 | are contained within a figure or its children.
21 |
22 | EABM
23 | Acronym for Exploratory Analysis of Bayesian Models. We use this concept to
24 | reference all the tasks within a Bayesian workflow outside of
25 | building and fitting or sampling a model. For more details, see the
26 | [EABM virtual book](https://arviz-devs.github.io/EABM/)
27 |
28 | plot
29 | plots
30 | Area (or areas) where the data will be plotted into. A {term}`figure`
31 | can contain multiple {term}`faceted` plots.
32 |
33 | visual
34 | visuals
35 | Graphical component or element added by `arviz-plots`
36 |
37 | faceting
38 | faceted
39 | Generate multiple similar {term}`plot` elements with each of them
40 | referring to a specific property or value of the data.
41 |
42 | :::
43 |
44 | ## Equivalences with library specific objects
45 |
46 | | arviz-plots name | matplotlib | bokeh | plotly |
47 | |------------------|--------------|---------|------------------|
48 | | figure | figure | layout | Figure |
49 | | plot | axes/subplot | figure | -[^plotly_plot] |
50 | | visual | artist | glyph | trace |
51 |
52 | [^plotly_plot]: In plotly there is no specific object to represent a {term}`plot`.
53 |
54 | Instead, when adding {term}`visuals` one can choose to add a visual to all {term}`plots`
55 | in the {term}`figure`, or give the row/col indexes, or specify a subset of {term}`plots`
56 | on which to add the {term}`visual`.
57 |
--------------------------------------------------------------------------------
/docs/source/index.md:
--------------------------------------------------------------------------------
1 | # ArviZ-plots
2 |
3 | Welcome to the ArviZ-plots documentation! This library focuses on visual summaries and diagnostics for exploratory analysis of Bayesian models. It is one of the 3 components of the ArviZ library, the other two being:
4 |
5 | * [arviz-base](https://arviz-base.readthedocs.io/en/latest/) data related functionality, including converters from different PPLs.
6 | * [arviz-stats](https://arviz-stats.readthedocs.io/en/latest/) for statistical functions and diagnostics.
7 |
8 | We recommend most users install and use all three ArviZ components together through the main ArviZ package, as they're designed to work seamlessly as one toolkit. Advanced users may choose to install components individually for finer control over dependencies.
9 |
10 | Note: All plotting functions - whether accessed through the full ArviZ package or directly via ArviZ-plots - are documented here.
11 |
12 |
13 | ## Exploratory Analysis of Bayesian Models
14 |
15 | In Modern Bayesian statistics models are usually build and solve using probabilistic programming languages (PPLs) such as PyMC, Stan, NumPyro, etc. These languages allow users to specify models in a high-level language and perform inference using state-of-the-art algorithms like Markov Chain Monte Carlo (MCMC) or Variational Inference (VI). As a result we usually get a posterior distribution, in the form of samples. The posterior distribution has a central role in Bayesian statistics, but other distributions like the posterior and prior predictive distribution are also of interest. And other quantities may be relevant too.
16 |
17 | The correct visualization, analysis, and interpretation of these computed data is key to properly answer the questions that motivate our analysis.
18 |
19 | When working with Bayesian models there are a series of related tasks that need to be addressed besides inference itself:
20 |
21 | * Diagnoses of the quality of the inference
22 |
23 | * Model criticism, including evaluations of both model assumptions and model predictions
24 |
25 | * Comparison of models, including model selection or model averaging
26 |
27 | * Preparation of the results for a particular audience.
28 |
29 | We call these tasks exploratory analysis of Bayesian models (EABM). Successfully performing such tasks are central to the iterative and interactive modelling process (See Bayesian Workflow). In the words of Persi Diaconis.
30 |
31 | > Exploratory data analysis seeks to reveal structure, or simple descriptions in data. We look at numbers or graphs and try to find patterns. We pursue leads suggested by background information, imagination, patterns perceived, and experience with other data analyses.
32 |
33 | The goal of ArviZ is to provide a unified interface for performing exploratory analysis of Bayesian models in Python, regardless of the PPL used to perform inference. This allows users to focus on the analysis and interpretation of the results, rather than on the details of the implementation.
34 |
35 |
36 |
37 | ## Installation
38 |
39 | For instructions on how to install the full ArviZ package (including `arviz-base`, `arviz-stats` and `arviz-plots`), please refer to the [installation guide](https://python.arviz.org/en/latest/getting_started/Installation.html).
40 |
41 | However, if you are only interested in the plotting functions provided by ArviZ-plots, please follow the instructions below:
42 |
43 | ::::{tab-set}
44 | :::{tab-item} PyPI
45 | :sync: stable
46 |
47 | ```bash
48 | pip install "arviz-plots[]"
49 | ```
50 | :::
51 | :::{tab-item} GitHub
52 | :sync: dev
53 |
54 | ```bash
55 | pip install "arviz-plots[] @ git+https://github.com/arviz-devs/arviz-plots"
56 | ```
57 | :::
58 | ::::
59 |
60 | Note that `arviz-plots` is a minimal package, which only depends on
61 | xarray, numpy, arviz-base and arviz-stats.
62 | None of the possible backends: matplotlib, bokeh or plotly are installed
63 | by default.
64 |
65 | Consequently, it is not recommended to install `arviz-plots` but
66 | instead to choose which backend to use. For example `arviz-plots[bokeh]`
67 | or `arviz-plots[matplotlib, plotly]`, multiple comma separated values are valid too.
68 |
69 | This will ensure all relevant dependencies are installed. For example, to use the plotly backend,
70 | both `plotly>5` and `webcolors` are required.
71 |
72 | ```{toctree}
73 | :hidden:
74 | :caption: User guide
75 |
76 | tutorials/overview
77 | tutorials/plots_intro
78 | tutorials/intro_to_plotcollection
79 | tutorials/compose_own_plot
80 | ```
81 |
82 | ```{toctree}
83 | :hidden:
84 | :caption: Reference
85 |
86 | gallery/index
87 | api/index
88 | glossary
89 | ```
90 | ```{toctree}
91 | :hidden:
92 | :caption: Tutorials
93 |
94 | ArviZ in Context
95 | ```
96 |
97 | ```{toctree}
98 | :hidden:
99 | :caption: Contributing
100 |
101 | contributing/testing
102 | contributing/new_plot
103 | contributing/docs
104 | ```
105 |
106 | ```{toctree}
107 | :caption: About
108 | :hidden:
109 |
110 | BlueSky
111 | Mastodon
112 | GitHub repository
113 | ```
114 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.4,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "arviz-plots"
7 | readme = "README.md"
8 | requires-python = ">=3.11"
9 | license = {file = "LICENSE"}
10 | authors = [
11 | {name = "ArviZ team", email = "arvizdevs@gmail.com"}
12 | ]
13 | classifiers = [
14 | "Development Status :: 3 - Alpha",
15 | "Intended Audience :: Developers",
16 | "Intended Audience :: Science/Research",
17 | "Intended Audience :: Education",
18 | "Framework :: Matplotlib",
19 | "License :: OSI Approved :: Apache Software License",
20 | "Operating System :: OS Independent",
21 | "Programming Language :: Python",
22 | "Programming Language :: Python :: 3",
23 | "Programming Language :: Python :: 3.10",
24 | "Programming Language :: Python :: 3.11",
25 | "Programming Language :: Python :: 3.12",
26 | ]
27 | dynamic = ["version", "description"]
28 | dependencies = [
29 | "arviz-base @ git+https://github.com/arviz-devs/arviz-base",
30 | "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats",
31 | ]
32 |
33 | [tool.flit.module]
34 | name = "arviz_plots"
35 |
36 | [project.urls]
37 | source = "https://github.com/arviz-devs/arviz-plots"
38 | tracker = "https://github.com/arviz-devs/arviz-plots/issues"
39 | documentation = "https://arviz-plots.readthedocs.io"
40 | funding = "https://opencollective.com/arviz"
41 |
42 | [project.optional-dependencies]
43 | matplotlib = ["matplotlib"]
44 | bokeh = ["bokeh"]
45 | plotly = ["plotly", "webcolors"]
46 | test = [
47 | "hypothesis",
48 | "pytest",
49 | "pytest-cov",
50 | "h5netcdf",
51 | "kaleido",
52 | ]
53 | doc = [
54 | "sphinx-book-theme",
55 | "myst-parser[linkify]",
56 | "myst-nb",
57 | "sphinx-copybutton",
58 | "numpydoc",
59 | "sphinx>=6",
60 | "sphinx-design",
61 | "jupyter-sphinx",
62 | "h5netcdf",
63 | "plotly<6",
64 | ]
65 |
66 |
67 | [tool.black]
68 | line-length = 100
69 |
70 | [tool.isort]
71 | profile = "black"
72 | include_trailing_comma = true
73 | use_parentheses = true
74 | multi_line_output = 3
75 | line_length = 100
76 | skip = [
77 | "src/arviz_plots/__init__.py"
78 | ]
79 |
80 | [tool.pydocstyle]
81 | convention = "numpy"
82 | match_dir = "^(?!docs|.tox).*"
83 |
84 | [tool.mypy]
85 | python_version = "3.10"
86 | warn_redundant_casts = true
87 | warn_unused_configs = true
88 | pretty = true
89 | show_error_codes = true
90 | show_error_context = true
91 | show_column_numbers = true
92 |
93 | disallow_any_generics = true
94 | disallow_subclassing_any = true
95 | disallow_untyped_calls = true
96 | disallow_incomplete_defs = true
97 | check_untyped_defs = true
98 | disallow_untyped_decorators = true
99 | no_implicit_optional = true
100 | warn_unused_ignores = true
101 | warn_return_any = true
102 | no_implicit_reexport = true
103 |
104 | # More strict checks for library code
105 | [[tool.mypy.overrides]]
106 | module = "arviz_plots"
107 | disallow_untyped_defs = true
108 |
109 | # Ignore certain missing imports
110 | # [[tool.mypy.overrides]]
111 | # module = "thirdparty.*"
112 | # ignore_missing_imports = true
113 |
114 | [tool.pytest.ini_options]
115 | filterwarnings = ["error"]
116 | addopts = "--durations=10"
117 | testpaths = [
118 | "tests",
119 | ]
120 |
121 | [tool.coverage.run]
122 | source = ["arviz_plots"]
123 |
--------------------------------------------------------------------------------
/src/arviz_plots/_version.py:
--------------------------------------------------------------------------------
1 | """Base ArviZ version."""
2 | __version__ = "0.6.0.dev0"
3 |
--------------------------------------------------------------------------------
/src/arviz_plots/backend/__init__.py:
--------------------------------------------------------------------------------
1 | """Common interface to plotting backends.
2 |
3 | Each submodule within this module defines a common interface layer to different plotting libraries.
4 |
5 | Outside ``arviz_plots.backend`` the corresponding backend module is imported,
6 | but only the common interface layer is used, making no distinctions between plotting backends.
7 | Each submodule inside ``arviz_plots.backend`` is expected to implement the same functions
8 | with the same call signature. Thus, adding a new backend requires only
9 | implementing this common interface for it, with no changes to any of the other modules.
10 | """
11 |
--------------------------------------------------------------------------------
/src/arviz_plots/backend/bokeh/legend.py:
--------------------------------------------------------------------------------
1 | """Bokeh manual legend generation."""
2 | import warnings
3 |
4 | import numpy as np
5 | from bokeh.models import Legend
6 |
7 |
8 | def dealiase_line_kwargs(kwargs):
9 | """Convert arviz common interface properties to bokeh ones."""
10 | prop_map = {"width": "line_width", "linestyle": "line_dash"}
11 | return {prop_map.get(key, key): value for key, value in kwargs.items()}
12 |
13 |
14 | def legend(
15 | target,
16 | kwarg_list,
17 | label_list,
18 | title=None,
19 | artist_type="line",
20 | artist_kwargs=None,
21 | legend_target=None,
22 | side="auto",
23 | legend_placement_threshold=600, # Magic number
24 | **kwargs,
25 | ):
26 | """Generate a legend on a figure given lists of labels and property kwargs.
27 |
28 | Parameters
29 | ----------
30 | legend_target : (int, int), default (0, -1)
31 | Row and colum indicators of the :term:`plot` where the legend will be placed.
32 | Bokeh does not support :term:`figure` level legend.
33 | side : str, optional
34 | Side of the plot on which to place the legend. Use "center" to put the legend
35 | inside the plotting area.
36 | """
37 | if artist_kwargs is None:
38 | artist_kwargs = {}
39 | if legend_target is None:
40 | legend_target = (0, -1)
41 | # TODO: improve selection of Figure object from what is stored as "figure"
42 | children = target.children
43 | if not isinstance(children[0], tuple):
44 | children = children[1].children
45 | plots = [child[0] for child in children]
46 | row_id = np.array([child[1] for child in children], dtype=int)
47 | col_id = np.array([child[2] for child in children], dtype=int)
48 | legend_id = np.argmax(
49 | (row_id == np.unique(row_id)[legend_target[0]])
50 | & (col_id == np.unique(col_id)[legend_target[1]])
51 | )
52 | target_plot = plots[legend_id]
53 | if target_plot.legend:
54 | warnings.warn("This target plot already contains a legend")
55 | glyph_list = []
56 | if artist_type == "line":
57 | artist_fun = target_plot.line
58 | kwarg_list = [dealiase_line_kwargs(kws) for kws in kwarg_list]
59 | else:
60 | raise NotImplementedError("Only line type legends supported for now")
61 |
62 | for kws in kwarg_list:
63 | glyph = artist_fun(**{**artist_kwargs, **kws})
64 | glyph_list.append(glyph)
65 |
66 | if side == "auto":
67 | plot_width = target_plot.width
68 | if plot_width >= legend_placement_threshold:
69 | side = "right"
70 | else:
71 | side = "center"
72 |
73 | leg = Legend(
74 | items=[(str(label), [glyph]) for label, glyph in zip(label_list, glyph_list)],
75 | title=title,
76 | **kwargs,
77 | )
78 | target_plot.add_layout(leg, side)
79 | return leg
80 |
--------------------------------------------------------------------------------
/src/arviz_plots/backend/matplotlib/legend.py:
--------------------------------------------------------------------------------
1 | """Matplotlib manual legend generation."""
2 | from matplotlib.lines import Line2D
3 |
4 |
5 | def dealiase_line_kwargs(kwargs):
6 | """Convert arviz common interface properties to matplotlib ones."""
7 | prop_map = {"width": "linewidth"}
8 | return {prop_map.get(key, key): value for key, value in kwargs.items()}
9 |
10 |
11 | def legend(
12 | target, kwarg_list, label_list, title=None, artist_type="line", artist_kwargs=None, **kwargs
13 | ):
14 | """Generate a legend on a figure given lists of labels and property kwargs."""
15 | if artist_kwargs is None:
16 | artist_kwargs = {}
17 | kwargs.setdefault("loc", "outside right upper")
18 | if artist_type == "line":
19 | artist_fun = Line2D
20 | kwarg_list = [dealiase_line_kwargs(kws) for kws in kwarg_list]
21 | else:
22 | raise NotImplementedError("Only line type legends supported for now")
23 | handles = [artist_fun([], [], **{**artist_kwargs, **kws}) for kws in kwarg_list]
24 | return target.legend(handles, label_list, title=title, **kwargs)
25 |
--------------------------------------------------------------------------------
/src/arviz_plots/backend/none/legend.py:
--------------------------------------------------------------------------------
1 | """None backend manual legend generation.
2 |
3 | For now only used for documentation purposes.
4 | """
5 |
6 |
7 | # pylint: disable=unused-argument
8 | def legend(
9 | target, kwarg_list, label_list, title=None, artist_type="line", artist_kwargs=None, **kwargs
10 | ):
11 | """Generate a legend on a figure given lists of labels and property kwargs.
12 |
13 | Parameters
14 | ----------
15 | target : plot object
16 | kwarg_list : sequence of mapping
17 | label_list : sequence of str
18 | title : str, optional
19 | artist_type : {"line", "scatter", "rectangle"}, default "line"
20 | artist_kwargs : mapping, optional
21 | Passed to all visuals when generating legend miniatures.
22 | **kwargs
23 | Passed to backend legend generating function.
24 | """
25 | return None
26 |
--------------------------------------------------------------------------------
/src/arviz_plots/backend/plotly/legend.py:
--------------------------------------------------------------------------------
1 | """Plotly legend generation."""
2 |
3 |
4 | def dealiase_line_kwargs(kwargs):
5 | """Convert arviz common interface properties to plotly ones."""
6 | prop_map = {"linewidth": "width", "linestyle": "dash"}
7 | return {prop_map.get(key, key): value for key, value in kwargs.items()}
8 |
9 |
10 | def legend(
11 | target, kwarg_list, label_list, title=None, artist_type="line", artist_kwargs=None, **kwargs
12 | ):
13 | """Generate a legend with plotly.
14 |
15 | Parameters
16 | ----------
17 | target : plotly.graph_objects.Figure
18 | The figure to add the legend to
19 | kwarg_list : list
20 | List of style dictionaries for each legend entry
21 | label_list : list
22 | List of labels for each legend entry
23 | title : str, optional
24 | Title of the legend
25 | artist_type : str, optional
26 | Type of visual to use for legend entries. Currently only "line" is supported.
27 | artist_kwargs : dict, optional
28 | Additional kwargs passed to all visuals
29 | **kwargs : dict
30 | Additional kwargs passed to legend configuration
31 |
32 | Returns
33 | -------
34 | None
35 | The legend is added to the target figure inplace
36 | """
37 | if artist_kwargs is None:
38 | artist_kwargs = {}
39 |
40 | if artist_type == "line":
41 | artist_fun = target.add_scatter
42 | kwarg_list = [dealiase_line_kwargs(kws) for kws in kwarg_list]
43 | mode = "lines"
44 | else:
45 | raise NotImplementedError("Only line type legends supported for now")
46 |
47 | for kws, label in zip(kwarg_list, label_list):
48 | artist_fun(
49 | x=[None],
50 | y=[None],
51 | name=str(label),
52 | mode=mode,
53 | line=kws,
54 | showlegend=True,
55 | **artist_kwargs,
56 | )
57 |
58 | target.update_layout(showlegend=True, legend_title_text=title, **kwargs)
59 |
--------------------------------------------------------------------------------
/src/arviz_plots/backend/plotly/templates.py:
--------------------------------------------------------------------------------
1 | """Plotly templates for ArviZ styles."""
2 | import plotly.graph_objects as go
3 |
4 | arviz_variat_template = go.layout.Template()
5 |
6 | arviz_variat_template.layout.paper_bgcolor = "white"
7 | arviz_variat_template.layout.plot_bgcolor = "white"
8 | arviz_variat_template.layout.polar.bgcolor = "white"
9 | arviz_variat_template.layout.ternary.bgcolor = "white"
10 | arviz_variat_template.layout.margin = {"l": 50, "r": 10, "t": 40, "b": 45}
11 | axis_common = {"showgrid": False, "ticks": "outside", "showline": True, "zeroline": False}
12 | arviz_variat_template.layout.xaxis = axis_common
13 | arviz_variat_template.layout.yaxis = axis_common
14 |
15 | arviz_variat_template.layout.colorway = [
16 | "#36acc6",
17 | "#f66d7f",
18 | "#fac364",
19 | "#7c2695",
20 | "#228306",
21 | "#a252f4",
22 | "#63f0ea",
23 | "#000000",
24 | "#6f6f6f",
25 | "#b7b7b7",
26 | ]
27 |
28 |
29 | arviz_cetrino_template = go.layout.Template()
30 |
31 | arviz_cetrino_template.layout.paper_bgcolor = "white"
32 | arviz_cetrino_template.layout.plot_bgcolor = "white"
33 | arviz_cetrino_template.layout.polar.bgcolor = "white"
34 | arviz_cetrino_template.layout.ternary.bgcolor = "white"
35 | arviz_cetrino_template.layout.margin = {"l": 50, "r": 10, "t": 40, "b": 45}
36 | axis_common = {"showgrid": False, "ticks": "outside", "showline": True, "zeroline": False}
37 | arviz_cetrino_template.layout.xaxis = axis_common
38 | arviz_cetrino_template.layout.yaxis = axis_common
39 |
40 | arviz_cetrino_template.layout.colorway = [
41 | "#009988",
42 | "#9238b2",
43 | "#d2225f",
44 | "#ec8f26",
45 | "#fcd026",
46 | "#3cd186",
47 | "#a57119",
48 | "#2f5e14",
49 | "#f225f4",
50 | "#8f9fbf",
51 | ]
52 |
53 |
54 | arviz_vibrant_template = go.layout.Template()
55 |
56 | arviz_vibrant_template.layout.paper_bgcolor = "white"
57 | arviz_vibrant_template.layout.plot_bgcolor = "white"
58 | arviz_vibrant_template.layout.polar.bgcolor = "white"
59 | arviz_vibrant_template.layout.ternary.bgcolor = "white"
60 | arviz_vibrant_template.layout.margin = {"l": 50, "r": 10, "t": 40, "b": 45}
61 | axis_common = {"showgrid": False, "ticks": "outside", "showline": True, "zeroline": False}
62 | arviz_vibrant_template.layout.xaxis = axis_common
63 | arviz_vibrant_template.layout.yaxis = axis_common
64 |
65 | arviz_vibrant_template.layout.colorway = [
66 | "#008b92",
67 | "#f15c58",
68 | "#48cdef",
69 | "#98d81a",
70 | "#997ee5",
71 | "#f5dc9d",
72 | "#c90a4e",
73 | "#145393",
74 | "#323232",
75 | "#616161",
76 | ]
77 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/__init__.py:
--------------------------------------------------------------------------------
1 | """Batteries-included ArviZ plots."""
2 |
3 | from .autocorr_plot import plot_autocorr
4 | from .bf_plot import plot_bf
5 | from .combine import combine_plots
6 | from .compare_plot import plot_compare
7 | from .convergence_dist_plot import plot_convergence_dist
8 | from .dist_plot import plot_dist
9 | from .ecdf_plot import plot_ecdf_pit
10 | from .energy_plot import plot_energy
11 | from .ess_plot import plot_ess
12 | from .evolution_plot import plot_ess_evolution
13 | from .forest_plot import plot_forest
14 | from .loo_pit_plot import plot_loo_pit
15 | from .mcse_plot import plot_mcse
16 | from .pairs_focus_plot import plot_pairs_focus
17 | from .pava_calibration_plot import plot_ppc_pava
18 | from .ppc_dist_plot import plot_ppc_dist
19 | from .ppc_pit_plot import plot_ppc_pit
20 | from .ppc_rootogram_plot import plot_ppc_rootogram
21 | from .ppc_tstat import plot_ppc_tstat
22 | from .prior_posterior_plot import plot_prior_posterior
23 | from .psense_dist_plot import plot_psense_dist
24 | from .psense_quantities_plot import plot_psense_quantities
25 | from .rank_dist_plot import plot_rank_dist
26 | from .rank_plot import plot_rank
27 | from .ridge_plot import plot_ridge
28 | from .trace_dist_plot import plot_trace_dist
29 | from .trace_plot import plot_trace
30 | from .utils import add_bands, add_lines
31 |
32 | __all__ = [
33 | "combine_plots",
34 | "plot_autocorr",
35 | "plot_bf",
36 | "plot_compare",
37 | "plot_convergence_dist",
38 | "plot_dist",
39 | "plot_forest",
40 | "plot_trace",
41 | "plot_trace_dist",
42 | "plot_ecdf_pit",
43 | "plot_energy",
44 | "plot_ess",
45 | "plot_ess_evolution",
46 | "plot_loo_pit",
47 | "plot_mcse",
48 | "plot_ppc_dist",
49 | "plot_ppc_rootogram",
50 | "plot_prior_posterior",
51 | "plot_rank",
52 | "plot_rank_dist",
53 | "plot_ridge",
54 | "plot_ppc_pava",
55 | "plot_ppc_pit",
56 | "plot_ppc_tstat",
57 | "plot_psense_dist",
58 | "plot_psense_quantities",
59 | "plot_pairs_focus",
60 | "add_lines",
61 | "add_bands",
62 | ]
63 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/autocorr_plot.py:
--------------------------------------------------------------------------------
1 | """Autocorrelation plot code."""
2 |
3 | from copy import copy
4 | from importlib import import_module
5 |
6 | import numpy as np
7 | from arviz_base import rcParams
8 | from arviz_base.labels import BaseLabeller
9 |
10 | from arviz_plots.plot_collection import PlotCollection
11 | from arviz_plots.plots.utils import filter_aes, process_group_variables_coords, set_wrap_layout
12 | from arviz_plots.visuals import fill_between_y, labelled_title, labelled_x, line, line_xy
13 |
14 |
15 | def plot_autocorr(
16 | dt,
17 | var_names=None,
18 | filter_vars=None,
19 | group="posterior",
20 | coords=None,
21 | sample_dims=None,
22 | max_lag=None,
23 | plot_collection=None,
24 | backend=None,
25 | labeller=None,
26 | aes_by_visuals=None,
27 | visuals=None,
28 | **pc_kwargs,
29 | ):
30 | """Autocorrelation plots for the given dataset.
31 |
32 | Line plot of the autocorrelation function (ACF)
33 |
34 | The ACF plots can be used as a convergence diagnostic for posteriors from MCMC
35 | samples.
36 |
37 | Parameters
38 | ----------
39 | dt : DataTree
40 | Input data
41 | var_names : str or list of str, optional
42 | One or more variables to be plotted. Currently only one variable is supported.
43 | Prefix the variables by ~ when you want to exclude them from the plot.
44 | filter_vars : {None, “like”, “regex”}, optional, default=None
45 | If None (default), interpret var_names as the real variables names.
46 | If “like”, interpret var_names as substrings of the real variables names.
47 | If “regex”, interpret var_names as regular expressions on the real variables names.
48 | group : str, optional
49 | Which group to use. Defaults to "posterior".
50 | coords : dict, optional
51 | Coordinates to plot.
52 | sample_dims : str or sequence of hashable, optional
53 | Dimensions to reduce unless mapped to an aesthetic.
54 | Defaults to ``rcParams["data.sample_dims"]``
55 | max_lag : int, optional
56 | Maximum lag to compute the ACF. Defaults to 100.
57 | plot_collection : PlotCollection, optional
58 | backend : {"matplotlib", "bokeh", "plotly"}, optional
59 | labeller : labeller, optional
60 | aes_by_visuals : mapping of {str : sequence of str}, optional
61 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
62 | when plotted. Valid keys are the same as for `visuals`.
63 |
64 | visuals : mapping of {str : mapping or False}, optional
65 | Valid keys are:
66 |
67 | * lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
68 | * ref_line -> passed to :func:`~arviz_plots.visuals.line_xy`
69 | * ci -> passed to :func:`~arviz_plots.visuals.fill_between_y`
70 | * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x`
71 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title`
72 |
73 | pc_kwargs : mapping
74 | Passed to :class:`arviz_plots.PlotCollection.grid`
75 |
76 | Returns
77 | -------
78 | PlotCollection
79 |
80 | Examples
81 | --------
82 | Autocorrelation plot for mu variable in the centered eight dataset.
83 |
84 | .. plot::
85 | :context: close-figs
86 |
87 | >>> from arviz_plots import plot_autocorr, style
88 | >>> style.use("arviz-variat")
89 | >>> from arviz_base import load_arviz_data
90 | >>> dt = load_arviz_data('centered_eight')
91 | >>> plot_autocorr(dt, var_names=["mu"])
92 |
93 |
94 | .. minigallery:: plot_autocorr
95 |
96 | """
97 | if sample_dims is None:
98 | sample_dims = rcParams["data.sample_dims"]
99 | if isinstance(sample_dims, str):
100 | sample_dims = [sample_dims]
101 | sample_dims = list(sample_dims)
102 | if visuals is None:
103 | visuals = {}
104 | else:
105 | visuals = visuals.copy()
106 |
107 | if backend is None:
108 | if plot_collection is None:
109 | backend = rcParams["plot.backend"]
110 | else:
111 | backend = plot_collection.backend
112 |
113 | labeller = BaseLabeller()
114 |
115 | # Default max lag to 100
116 | if max_lag is None:
117 | max_lag = 100
118 |
119 | distribution = process_group_variables_coords(
120 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
121 | )
122 |
123 | acf_dataset = distribution.azstats.autocorr(dim=sample_dims).sel(draw=slice(0, max_lag - 1))
124 | c_i = 1.96 / acf_dataset.sizes["draw"] ** 0.5
125 | x_ci = np.arange(0, max_lag).astype(float)
126 |
127 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
128 | default_linestyle = plot_bknd.get_default_aes("linestyle", 2, {})[1]
129 |
130 | if plot_collection is None:
131 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
132 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
133 | pc_kwargs.setdefault("col_wrap", 4)
134 | pc_kwargs.setdefault(
135 | "cols", ["__variable__"] + [dim for dim in acf_dataset.dims if dim not in sample_dims]
136 | )
137 |
138 | if "chain" in distribution:
139 | pc_kwargs["aes"].setdefault("color", ["chain"])
140 | pc_kwargs["aes"].setdefault("overlay", ["chain"])
141 |
142 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
143 | pc_kwargs["figure_kwargs"].setdefault("sharex", True)
144 | pc_kwargs["figure_kwargs"].setdefault("sharey", True)
145 |
146 | plot_collection = PlotCollection.wrap(
147 | distribution,
148 | backend=backend,
149 | **pc_kwargs,
150 | )
151 |
152 | if aes_by_visuals is None:
153 | aes_by_visuals = {}
154 | else:
155 | aes_by_visuals = aes_by_visuals.copy()
156 | aes_by_visuals.setdefault("lines", plot_collection.aes_set)
157 |
158 | ## reference line
159 | ref_ls_kwargs = copy(visuals.get("ref_line", {}))
160 |
161 | if ref_ls_kwargs is not False:
162 | _, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "ref_line", sample_dims)
163 | ref_ls_kwargs.setdefault("color", "gray")
164 | ref_ls_kwargs.setdefault("linestyle", default_linestyle)
165 |
166 | plot_collection.map(
167 | line_xy,
168 | "ref_line",
169 | data=acf_dataset,
170 | x=x_ci,
171 | y=0,
172 | ignore_aes=ac_ls_ignore,
173 | **ref_ls_kwargs,
174 | )
175 |
176 | ## autocorrelation line
177 | acf_ls_kwargs = copy(visuals.get("lines", {}))
178 |
179 | if acf_ls_kwargs is not False:
180 | _, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "lines", sample_dims)
181 |
182 | plot_collection.map(
183 | line,
184 | "lines",
185 | data=acf_dataset,
186 | ignore_aes=ac_ls_ignore,
187 | **acf_ls_kwargs,
188 | )
189 |
190 | # Plot confidence intervals
191 | ci_kwargs = copy(visuals.get("ci", {}))
192 | _, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "ci", "draw")
193 | if ci_kwargs is not False:
194 | ci_kwargs.setdefault("color", "black")
195 | ci_kwargs.setdefault("alpha", 0.1)
196 |
197 | plot_collection.map(
198 | fill_between_y,
199 | "ci",
200 | data=acf_dataset,
201 | x=x_ci,
202 | y=0,
203 | y_bottom=-c_i,
204 | y_top=c_i,
205 | ignore_aes=ci_ignore,
206 | **ci_kwargs,
207 | )
208 |
209 | # set xlabel
210 | _, xlabels_aes, xlabels_ignore = filter_aes(
211 | plot_collection, aes_by_visuals, "xlabel", sample_dims
212 | )
213 | xlabel_kwargs = copy(visuals.get("xlabel", {}))
214 | if xlabel_kwargs is not False:
215 | if "color" not in xlabels_aes:
216 | xlabel_kwargs.setdefault("color", "black")
217 |
218 | xlabel_kwargs.setdefault("text", "Lag")
219 | plot_collection.map(
220 | labelled_x,
221 | "xlabel",
222 | ignore_aes=xlabels_ignore,
223 | subset_info=True,
224 | **xlabel_kwargs,
225 | )
226 |
227 | # title
228 | title_kwargs = copy(visuals.get("title", {}))
229 | _, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims)
230 |
231 | if title_kwargs is not False:
232 | plot_collection.map(
233 | labelled_title,
234 | "title",
235 | ignore_aes=title_ignore,
236 | subset_info=True,
237 | labeller=labeller,
238 | **title_kwargs,
239 | )
240 |
241 | return plot_collection
242 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/bf_plot.py:
--------------------------------------------------------------------------------
1 | """Contain functions for Bayes Factor plotting."""
2 |
3 | from copy import copy
4 |
5 | import xarray as xr
6 | from arviz_stats.bayes_factor import bayes_factor
7 |
8 | from arviz_plots.plots.prior_posterior_plot import plot_prior_posterior
9 | from arviz_plots.plots.utils import add_lines, filter_aes
10 |
11 |
12 | def plot_bf(
13 | dt,
14 | var_names,
15 | ref_val=0,
16 | kind=None,
17 | sample_dims=None,
18 | plot_collection=None,
19 | backend=None,
20 | labeller=None,
21 | aes_by_visuals=None,
22 | visuals=None,
23 | stats=None,
24 | **pc_kwargs,
25 | ):
26 | r"""Bayes Factor for comparing hypothesis of two nested models.
27 |
28 | The Bayes factor is estimated by comparing a model (H1) against a model
29 | in which the parameter of interest has been restricted to be a point-null (H0)
30 | This computation assumes H0 is a special case of H1. For more details see here
31 | https://arviz-devs.github.io/EABM/Chapters/Model_comparison.html#savagedickey-ratio
32 |
33 | Parameters
34 | ----------
35 | dt : DataTree or dict of {str : DataTree}
36 | Input data. In case of dictionary input, the keys are taken to be model names.
37 | In such cases, a dimension "model" is generated and can be used to map to aesthetics.
38 | var_names : str, optional
39 | Variables for which the bayes factor will be computed and the prior and
40 | posterior will be plotted.
41 | ref_val : int or float, default 0
42 | Reference (point-null) value for Bayes factor estimation.
43 | kind : {"kde", "hist", "dot", "ecdf"}, optional
44 | How to represent the marginal density.
45 | Defaults to ``rcParams["plot.density_kind"]``
46 | sample_dims : str or sequence of hashable, optional
47 | Dimensions to reduce unless mapped to an aesthetic.
48 | Defaults to ``rcParams["data.sample_dims"]``
49 | plot_collection : PlotCollection, optional
50 | backend : {"matplotlib", "bokeh", "plotly"}, optional
51 | labeller : labeller, optional
52 | aes_by_visuals : mapping of {str : sequence of str}, optional
53 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
54 | when plotted. Valid keys are the same as for `visuals`.
55 | visuals : mapping of {str : mapping or False}, optional
56 | Valid keys are:
57 |
58 | * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument.
59 |
60 | * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
61 | * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
62 | * "hist" -> passed to :func: `~arviz_plots.visuals.hist`
63 |
64 | * "ref_line -> passed to :func: `~arviz_plots.visuals.vline`
65 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title`
66 |
67 | stats : mapping, optional
68 | Valid keys are:
69 |
70 | * density -> passed to kde, ecdf, ...
71 |
72 | pc_kwargs : mapping
73 | Passed to :class:`arviz_plots.PlotCollection.wrap`
74 |
75 | Returns
76 | -------
77 | PlotCollection
78 |
79 | Examples
80 | --------
81 | Select one variable.
82 |
83 | .. plot::
84 | :context: close-figs
85 |
86 | >>> from arviz_plots import plot_bf, style
87 | >>> style.use("arviz-variat")
88 | >>> from arviz_base import load_arviz_data
89 | >>> dt = load_arviz_data('centered_eight')
90 | >>> plot_bf(dt, var_names="mu", kind="hist")
91 |
92 | .. minigallery:: plot_bf
93 | """
94 | if visuals is None:
95 | visuals = {}
96 | else:
97 | visuals = visuals.copy()
98 | if aes_by_visuals is None:
99 | aes_by_visuals = {}
100 | else:
101 | aes_by_visuals = aes_by_visuals.copy()
102 |
103 | bf, _ = bayes_factor(dt, var_names, ref_val, return_ref_vals=True)
104 |
105 | if isinstance(var_names, str):
106 | var_names = [var_names]
107 | bf_aes_ds = xr.Dataset(
108 | {
109 | var: xr.DataArray(
110 | None,
111 | coords={"BF_type": [f"BF01:{bf[var]['BF01']:.2f}"]},
112 | dims=["BF_type"],
113 | )
114 | for var in var_names
115 | }
116 | )
117 |
118 | plot_collection = plot_prior_posterior(
119 | dt,
120 | var_names=var_names,
121 | coords=None,
122 | sample_dims=sample_dims,
123 | kind=kind,
124 | plot_collection=plot_collection,
125 | backend=backend,
126 | labeller=labeller,
127 | visuals=visuals,
128 | stats=stats,
129 | **pc_kwargs,
130 | )
131 |
132 | plot_collection.update_aes_from_dataset("bf_aes", bf_aes_ds)
133 |
134 | ref_line_kwargs = copy(visuals.get("ref_line", {}))
135 | if ref_line_kwargs is False:
136 | raise ValueError(
137 | "visuals['ref_line'] can't be False, use ref_val=False to remove this element"
138 | )
139 |
140 | if ref_val is not False:
141 | _, ref_aes, _ = filter_aes(plot_collection, aes_by_visuals, "ref_line", "sample")
142 | if "color" not in ref_aes:
143 | ref_line_kwargs.setdefault("color", "black")
144 | if "alpha" not in ref_aes:
145 | ref_line_kwargs.setdefault("alpha", 0.5)
146 | add_lines(
147 | plot_collection,
148 | ref_val,
149 | aes_by_visuals=aes_by_visuals,
150 | visuals={"ref_line": ref_line_kwargs},
151 | )
152 |
153 | if backend == "matplotlib": ## remove this when we have a better way to handle legends
154 | plot_collection.add_legend(
155 | ["__variable__", "BF_type"], loc="upper left", fontsize=10, text_only=True
156 | )
157 |
158 | return plot_collection
159 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/combine.py:
--------------------------------------------------------------------------------
1 | """Elements to combine multiple batteries-included plots into a single figure."""
2 | import re
3 | from importlib import import_module
4 |
5 | from arviz_base import rcParams
6 |
7 | from arviz_plots import PlotCollection
8 | from arviz_plots.plot_collection import backend_from_object
9 | from arviz_plots.plots.utils import process_group_variables_coords, set_grid_layout
10 |
11 |
12 | def get_valid_arg(key, value, backend):
13 | """Convert none backend aesthetic argument indicator to a valid value for the given backend.
14 |
15 | Parameters
16 | ----------
17 | key : str
18 | The keyword part of the :ref:`backend-interface-arguments` for which `value` should
19 | be valid.
20 | value : any
21 | The current value for `key`. It might be an indicator from the none backend such as
22 | "color_0" or "linestyle_3" which gets processed or something else in which case
23 | it is assumed to be a valid argument already and returned as is.
24 | backend : str
25 | The backend for which `value` should be valid.
26 |
27 | Returns
28 | -------
29 | valid_value : any
30 | """
31 | plot_backend = import_module(f"arviz_plots.backend.{backend}")
32 | key_matcher = "color" if key in {"facecolor", "edgecolor"} else key
33 | if isinstance(value, str):
34 | match = re.match(key_matcher + "_([0-9]+)", value)
35 | if match:
36 | index = int(match.groups()[0])
37 | return plot_backend.get_default_aes(key, index + 1)[index]
38 | return value
39 |
40 |
41 | def backendize_kwargs(kwargs, backend):
42 | """Process the visual description dictionary from the none backend to valid kwargs."""
43 | return {
44 | key: get_valid_arg(key, value, backend)
45 | for key, value in kwargs.items()
46 | if key != "function"
47 | }
48 |
49 |
50 | def render(da, target, **kwargs):
51 | """Render visual descriptions from the none backend with a plotting backend."""
52 | backend = backend_from_object(target, return_module=False)
53 | plot_backend = import_module(f"arviz_plots.backend.{backend}")
54 | visuals = da.item()
55 | plot_fun_name = visuals["function"]
56 | visuals = backendize_kwargs(visuals, backend)
57 | kwargs = backendize_kwargs(kwargs, backend)
58 | return getattr(plot_backend, plot_fun_name)(target=target, **{**visuals, **kwargs})
59 |
60 |
61 | def combine_plots(
62 | dt,
63 | plots,
64 | var_names=None,
65 | filter_vars=None,
66 | group="posterior",
67 | coords=None,
68 | sample_dims=None,
69 | expand="column",
70 | plot_names=None,
71 | backend=None,
72 | **pc_kwargs,
73 | ):
74 | """Arrange multiple batteries-included plots in a customizable column or row layout.
75 |
76 | Parameters
77 | ----------
78 | dt : DataTree of dict of {str : DataTree}
79 | Input data. In case of dictionary input, the keys are taken to be model names.
80 | In such cases, a dimension "model" is generated and can be used to map to aesthetics.
81 |
82 | Note that not all batteries included functions accept dictionary input, so it will
83 | only work when all plotting functions requested in `plots` are compatible with it.
84 | plots : list of tuple of (callable, mapping)
85 | List of all the plotting functions to be combined. Each element in this list
86 | is a tuple with two elements. The first is the function to be called, the second
87 | is a dictionary with any keyword arguments that should be used when calling that function.
88 | var_names : str or sequence of str, optional
89 | One or more variables to be plotted.
90 | Prefix the variables by ~ when you want to exclude them from the plot.
91 | filter_vars : {None, “like”, “regex”}, default None
92 | If None, interpret `var_names` as the real variables names.
93 | If “like”, interpret `var_names` as substrings of the real variables names.
94 | If “regex”, interpret `var_names` as regular expressions on the real variables names.
95 | group : str, default "posterior"
96 | Group to be plotted.
97 | coords : dict, optional
98 | sample_dims : str or sequence of hashable, optional
99 | Dimensions to reduce unless mapped to an aesthetic.
100 | Defaults to ``rcParams["data.sample_dims"]``
101 | expand : {"column", "row"}, default "column"
102 | How to combine the different plotting functions. If "column", each plotting function
103 | will be added as a new column, if "row" it will be a new row instead.
104 | plot_names : list of str, optional
105 | List of the same length as `plots` with the plot names to use as coordinate values
106 | in the returned :class:`~arviz_plots.PlotCollection`.
107 | backend : {"matplotlib", "bokeh", "plotly"}, optional
108 | Plotting backend to use. Defaults to ``rcParams["plot.backend"]``.
109 | pc_kwargs : mapping, optional
110 | Passed to :class:`arviz_plots.PlotCollection.grid`
111 |
112 | Returns
113 | -------
114 | PlotCollection
115 |
116 | Examples
117 | --------
118 | Customize the names of the plots in the returned :class:`PlotCollection`
119 |
120 | .. plot::
121 | :context: close-figs
122 |
123 | >>> import arviz_plots as azp
124 | >>> azp.style.use("arviz-variat")
125 | >>> from arviz_base import load_arviz_data
126 | >>> rugby = load_arviz_data('rugby')
127 | >>> pc = azp.combine_plots(
128 | >>> rugby,
129 | >>> plots=[
130 | >>> (azp.plot_ppc_pit, {}),
131 | >>> (azp.plot_ppc_rootogram, {}),
132 | >>> ],
133 | >>> group="posterior_predictive",
134 | >>> plot_names=["pit", "rootogram"],
135 | >>> )
136 |
137 | Now if we inspect the ``pc.viz`` attribute, we can see it has a ``column`` dimension
138 | with the requested coordinate values:
139 |
140 | .. plot::
141 | :context: close-figs
142 |
143 | >>> pc.viz
144 |
145 | .. minigallery:: combine_plots
146 | """
147 | if plot_names is None:
148 | plot_names = [
149 | getattr(elem[0], "__name__") + f"_{idx:02d}" for idx, elem in enumerate(plots)
150 | ]
151 | if sample_dims is None:
152 | sample_dims = rcParams["data.sample_dims"]
153 | if isinstance(sample_dims, str):
154 | sample_dims = [sample_dims]
155 | if backend is None:
156 | backend = rcParams["plot.backend"]
157 |
158 | distribution = process_group_variables_coords(
159 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
160 | )
161 | facet_dims = ["__variable__"] + (
162 | []
163 | if "predictive" in group
164 | else [dim for dim in distribution.dims if dim not in sample_dims]
165 | )
166 |
167 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
168 | if expand == "column":
169 | pc_kwargs.setdefault("cols", ["column"])
170 | pc_kwargs.setdefault("rows", facet_dims)
171 | expand_kwargs = {"column": len(plots)}
172 | elif expand == "row":
173 | pc_kwargs.setdefault("cols", facet_dims)
174 | pc_kwargs.setdefault("rows", ["row"])
175 | expand_kwargs = {"row": len(plots)}
176 | else:
177 | raise ValueError(f"`expand` must be 'row' or 'column' but got '{expand}'")
178 | distribution = distribution.expand_dims(**expand_kwargs).assign_coords({expand: plot_names})
179 |
180 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
181 | pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, distribution)
182 |
183 | pc = PlotCollection.grid(
184 | distribution,
185 | backend=backend,
186 | **pc_kwargs,
187 | )
188 |
189 | for name, (plot, kwargs) in zip(plot_names, plots):
190 | pc_i = plot(
191 | dt,
192 | backend="none",
193 | group=group,
194 | var_names=var_names,
195 | filter_vars=filter_vars,
196 | coords=coords,
197 | sample_dims=sample_dims,
198 | **kwargs,
199 | )
200 | pc.coords = None
201 | pc.aes = pc_i.aes
202 | pc.coords = {expand: name}
203 | for viz_group, ds in pc_i.viz.children.items():
204 | if viz_group in {"plot", "row_index", "col_index"}:
205 | continue
206 | attrs = ds.attrs
207 | pc.map(
208 | render,
209 | fun_label=f"{viz_group}_{name}",
210 | data=ds.dataset,
211 | ignore_aes=attrs.get("ignore_aes", frozenset()),
212 | )
213 | pc.coords = None
214 | # TODO: at some point all `pc_i.aes` objects should be merged
215 | # and stored into the `pc.aes` attribute
216 |
217 | return pc
218 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/compare_plot.py:
--------------------------------------------------------------------------------
1 | """Compare plot code."""
2 | from importlib import import_module
3 |
4 | import numpy as np
5 | from arviz_base import rcParams
6 | from xarray import Dataset, DataTree
7 |
8 | from arviz_plots.plot_collection import PlotCollection
9 |
10 |
11 | def plot_compare(
12 | cmp_df,
13 | similar_shade=True,
14 | relative_scale=False,
15 | backend=None,
16 | visuals=None,
17 | **pc_kwargs,
18 | ):
19 | r"""Summary plot for model comparison.
20 |
21 | Models are compared based on their expected log pointwise predictive density (ELPD).
22 |
23 | The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
24 | cross-validation (LOO). Details are presented in [1]_ and [2]_.
25 |
26 | Parameters
27 | ----------
28 | comp_df : pandas.DataFrame
29 | Result of the :func:`arviz_stats.compare` method.
30 | similar_shade : bool, optional
31 | If True, a shade is drawn to indicate models with similar
32 | predictive performance to the best model. Defaults to True.
33 | relative_scale : bool, optional.
34 | If True scale the ELPD values relative to the best model.
35 | Defaults to False.
36 | backend : {"bokeh", "matplotlib", "plotly"}
37 | Select plotting backend. Defaults to rcParams["plot.backend"].
38 | figsize : (float, float), optional
39 | If `None`, size is (10, num of models) inches.
40 | visuals : mapping of {str : mapping or False}, optional
41 | Valid keys are:
42 |
43 | * point_estimate -> passed to :func:`~.backend.scatter`
44 | * error_bar -> passed to :func:`~.backend.line`
45 | * ref_line -> passed to :func:`~.backend.line`
46 | * shade -> passed to :func:`~.backend.fill_between_y`
47 | * labels -> passed to :func:`~.backend.xticks` and :func:`~.backend.yticks`
48 | * title -> passed to :func:`~.backend.title`
49 | * ticklabels -> passed to :func:`~.backend.yticks`
50 |
51 | pc_kwargs : mapping
52 | Passed to :class:`arviz_plots.PlotCollection`
53 |
54 | Returns
55 | -------
56 | axes :bokeh figure, matplotlib axes or plotly figure
57 |
58 | See Also
59 | --------
60 | :func:`arviz_stats.compare`: Summary plot for model comparison.
61 | :func:`arviz_stats.loo` : Compute the ELPD using Pareto smoothed importance sampling
62 | Leave-one-out cross-validation method.
63 |
64 | References
65 | ----------
66 | .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
67 | and WAIC*. Statistics and Computing. 27(5) (2017).
68 | https://doi.org/10.1007/s11222-016-9696-4. arXiv preprint https://arxiv.org/abs/1507.04544.
69 |
70 | .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
71 | Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
72 | arXiv preprint https://arxiv.org/abs/1507.02646
73 | """
74 | # Check if cmp_df contains the required information
75 |
76 | column_index = [c.lower() for c in cmp_df.columns]
77 |
78 | if "elpd" not in column_index:
79 | raise ValueError(
80 | "cmp_df must have been created using the `compare` function from ArviZ-Stats."
81 | )
82 |
83 | # Set default backend
84 | if backend is None:
85 | backend = rcParams["plot.backend"]
86 |
87 | if visuals is None:
88 | visuals = {}
89 |
90 | # Get plotting backend
91 | p_be = import_module(f"arviz_plots.backend.{backend}")
92 |
93 | # Get figure params and create figure and axis
94 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
95 | figsize = pc_kwargs.get("figure_kwargs", {}).get("figsize", None)
96 | figsize_units = pc_kwargs["figure_kwargs"].get("figsize_units", "inches")
97 |
98 | figsize = p_be.scale_fig_size(
99 | figsize,
100 | rows=int(len(cmp_df) ** 0.5),
101 | cols=2,
102 | figsize_units=figsize_units,
103 | )
104 | figsize_units = "dots"
105 |
106 | figure, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units)
107 |
108 | # Create plot collection
109 | plot_collection = PlotCollection(
110 | Dataset({}),
111 | viz_dt=DataTree.from_dict(
112 | {"/": Dataset({"figure": np.array(figure, dtype=object), "plot": target})}
113 | ),
114 | backend=backend,
115 | **pc_kwargs,
116 | )
117 |
118 | if isinstance(target, np.ndarray):
119 | target = target.tolist()
120 |
121 | # Set scale relative to the best model
122 | if relative_scale:
123 | cmp_df = cmp_df.copy()
124 | cmp_df["elpd"] = cmp_df["elpd"] - cmp_df["elpd"].iloc[0]
125 |
126 | # Compute positions of yticks
127 | yticks_pos = list(range(len(cmp_df), 0, -1))
128 |
129 | # Plot ELPD standard error bars
130 | if (error_kwargs := visuals.get("error_bar", {})) is not False:
131 | error_kwargs.setdefault("color", "black")
132 |
133 | # Compute values for standard error bars
134 | se_list = list(zip((cmp_df["elpd"] - cmp_df["se"]), (cmp_df["elpd"] + cmp_df["se"])))
135 |
136 | for se_vals, ytick in zip(se_list, yticks_pos):
137 | p_be.line(se_vals, (ytick, ytick), target, **error_kwargs)
138 |
139 | # Add reference line for the best model
140 | if (ref_kwargs := visuals.get("ref_line", {})) is not False:
141 | ref_kwargs.setdefault("color", "gray")
142 | ref_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 2, {})[-1])
143 | p_be.line(
144 | (cmp_df["elpd"].iloc[0], cmp_df["elpd"].iloc[0]),
145 | (yticks_pos[0], yticks_pos[-1]),
146 | target,
147 | **ref_kwargs,
148 | )
149 |
150 | # Plot ELPD point estimates
151 | if (pe_kwargs := visuals.get("point_estimate", {})) is not False:
152 | pe_kwargs.setdefault("color", "black")
153 | p_be.scatter(cmp_df["elpd"], yticks_pos, target, **pe_kwargs)
154 |
155 | # Add shade for statistically undistinguishable models
156 | if similar_shade and (shade_kwargs := visuals.get("shade", {})) is not False:
157 | shade_kwargs.setdefault("color", "black")
158 | shade_kwargs.setdefault("alpha", 0.1)
159 |
160 | x_0, x_1 = cmp_df["elpd"].iloc[0] - 4, cmp_df["elpd"].iloc[0]
161 |
162 | padding = (yticks_pos[0] - yticks_pos[-1]) * 0.05
163 | p_be.fill_between_y(
164 | x=[x_0, x_1],
165 | y_bottom=yticks_pos[-1] - padding,
166 | y_top=yticks_pos[0] + padding,
167 | target=target,
168 | **shade_kwargs,
169 | )
170 |
171 | # Add title and labels
172 | if (title_kwargs := visuals.get("title", {})) is not False:
173 | p_be.title(
174 | "Model comparison\nhigher is better",
175 | target,
176 | **title_kwargs,
177 | )
178 |
179 | if (labels_kwargs := visuals.get("labels", {})) is not False:
180 | p_be.ylabel("ranked models", target, **labels_kwargs)
181 | p_be.xlabel("ELPD", target, **labels_kwargs)
182 |
183 | if (ticklabels_kwargs := visuals.get("ticklabels", {})) is not False:
184 | p_be.yticks(yticks_pos, cmp_df.index, target, **ticklabels_kwargs)
185 |
186 | return plot_collection
187 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/energy_plot.py:
--------------------------------------------------------------------------------
1 | """Energy plot code."""
2 | import numpy as np
3 | from arviz_base import convert_to_dataset, rcParams
4 |
5 | from arviz_plots.plots.dist_plot import plot_dist
6 |
7 |
8 | def plot_energy(
9 | dt,
10 | bfmi=False,
11 | kind=None,
12 | plot_collection=None,
13 | backend=None,
14 | labeller=None,
15 | aes_by_visuals=None,
16 | visuals=None,
17 | stats=None,
18 | **pc_kwargs,
19 | ):
20 | r"""Plot transition distribution and marginal energy distribution in HMC algorithms.
21 |
22 | This may help to diagnose poor exploration by gradient-based algorithms like HMC or NUTS.
23 | The energy function in HMC can identify posteriors with heavy tailed distributions, that
24 | in practice are challenging for sampling.
25 |
26 | This plot is in the style of the one used in [1]_.
27 |
28 | Parameters
29 | ----------
30 | dt : DataTree
31 | ``sample_stats`` group with an ``energy`` variable is mandatory.
32 | bfmi : bool
33 | Whether to the plot the value of the estimated Bayesian fraction of missing
34 | information. Defaults to False. Not implemented yet.
35 | kind : {"kde", "hist", "dot", "ecdf"}, optional
36 | How to represent the marginal density.
37 | Defaults to ``rcParams["plot.density_kind"]``
38 | plot_collection : PlotCollection, optional
39 | backend : {"matplotlib", "bokeh", "plotly"}, optional
40 | labeller : labeller, optional
41 | aes_by_visuals : mapping of {str : sequence of str}, optional
42 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
43 | when plotted. Valid keys are the same as for `visuals`.
44 |
45 | visuals : mapping of {str : mapping or False}, optional
46 | Valid keys are:
47 |
48 | * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument.
49 |
50 | * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
51 | * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
52 | * "hist" -> passed to :func: `~arviz_plots.visuals.hist`
53 |
54 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title`
55 | * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function
56 |
57 | stats : mapping, optional
58 | Valid keys are:
59 | * density -> passed to kde, ecdf, ...
60 |
61 | pc_kwargs : mapping
62 | Passed to :class:`arviz_plots.PlotCollection.wrap`
63 |
64 | Returns
65 | -------
66 | PlotCollection
67 |
68 | Examples
69 | --------
70 | Plot a default energy plot
71 |
72 | .. plot::
73 | :context: close-figs
74 |
75 | >>> from arviz_plots import plot_energy, style
76 | >>> style.use("arviz-variat")
77 | >>> from arviz_base import load_arviz_data
78 | >>> schools = load_arviz_data('centered_eight')
79 | >>> plot_energy(schools)
80 |
81 |
82 | .. minigallery:: plot_energy
83 |
84 | References
85 | ----------
86 | .. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in
87 | Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695
88 | """
89 | if kind is None:
90 | kind = rcParams["plot.density_kind"]
91 | if visuals is None:
92 | visuals = {}
93 | else:
94 | visuals = visuals.copy()
95 |
96 | new_ds = _get_energy_ds(dt)
97 |
98 | sample_dims = ["chain", "draw"]
99 | if not all(dim in new_ds.dims for dim in sample_dims):
100 | raise ValueError("Both 'chain' and 'draw' dimensions must be present in the dataset")
101 |
102 | pc_kwargs.setdefault("cols", None)
103 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
104 | pc_kwargs["aes"].setdefault("color", ["energy"])
105 | visuals.setdefault("credible_interval", False)
106 | visuals.setdefault("point_estimate", False)
107 | visuals.setdefault("point_estimate_text", False)
108 | visuals.setdefault("title", False)
109 |
110 | plot_collection = plot_dist(
111 | new_ds,
112 | var_names=None,
113 | filter_vars=None,
114 | group=None,
115 | coords=None,
116 | sample_dims=sample_dims,
117 | kind=kind,
118 | point_estimate=None,
119 | ci_kind=None,
120 | ci_prob=None,
121 | plot_collection=plot_collection,
122 | backend=backend,
123 | labeller=labeller,
124 | aes_by_visuals=aes_by_visuals,
125 | visuals=visuals,
126 | stats=stats,
127 | **pc_kwargs,
128 | )
129 |
130 | plot_collection.add_legend("energy")
131 |
132 | if bfmi:
133 | raise NotImplementedError("BFMI is not implemented yet")
134 |
135 | return plot_collection
136 |
137 |
138 | def _get_energy_ds(dt):
139 | energy = dt["sample_stats"].energy.values
140 | return convert_to_dataset(
141 | {"energy_": np.dstack([energy - energy.mean(), np.diff(energy, append=np.nan)])},
142 | coords={"energy__dim_0": ["marginal", "transition"]},
143 | ).rename({"energy__dim_0": "energy"})
144 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/loo_pit_plot.py:
--------------------------------------------------------------------------------
1 | """Plot loo pit."""
2 | from arviz_base import convert_to_datatree
3 | from arviz_stats.loo import loo_pit
4 |
5 | from arviz_plots.plots.ecdf_plot import plot_ecdf_pit
6 |
7 |
8 | def plot_loo_pit(
9 | dt,
10 | ci_prob=None,
11 | coverage=False,
12 | method="simulation",
13 | n_simulations=1000,
14 | var_names=None,
15 | filter_vars=None, # pylint: disable=unused-argument
16 | group="posterior_predictive",
17 | coords=None, # pylint: disable=unused-argument
18 | sample_dims=None,
19 | plot_collection=None,
20 | backend=None,
21 | labeller=None,
22 | aes_by_visuals=None,
23 | visuals=None,
24 | **pc_kwargs,
25 | ):
26 | r"""LOO-PIT Δ-ECDF values with simultaneous confidence envelope.
27 |
28 | For a calibrated model the LOO Probability Integral Transform (PIT) values,
29 | $p(\tilde{y}_i \le y_i \mid y_{-i})$, should be uniformly distributed.
30 | Where $y_i$ represents the observed data for index $i$ and $\tilde y_i$ represents
31 | the posterior predictive sample at index $i$. $y_{-i}$ indicates we have left out the
32 | $i$-th observation. LOO-PIT values are computed using the PSIS-LOO-CV method described
33 | in [1]_ and [2]_.
34 |
35 | This plot shows the empirical cumulative distribution function (ECDF) of the LOO-PIT values.
36 | To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between the
37 | observed ECDF and the expected CDF. Simultaneous confidence bands are computed using the method
38 | described in described in [3]_.
39 |
40 | Alternatively, we can visualize the coverage of the central posterior credible intervals by
41 | setting ``coverage=True``. This allows us to assess whether the credible intervals includes
42 | the observed values. We can obtain the coverage of the central intervals from the LOO-PIT by
43 | replacing the LOO-PIT with two times the absolute difference between the LOO-PIT values and 0.5.
44 |
45 | For more details on how to interpret this plot,
46 | see https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#pit-ecdfs.
47 |
48 | Parameters
49 | ----------
50 | dt : DataTree
51 | Input data
52 | ci_prob : float, optional
53 | Indicates the probability that should be contained within the plotted credible interval.
54 | Defaults to ``rcParams["stats.ci_prob"]``
55 | coverage : bool, optional
56 | If True, plot the coverage of the central posterior credible intervals. Defaults to False.
57 | n_simulations : int, optional
58 | Number of simulations to use to compute simultaneous confidence intervals when using the
59 | `method="simulation"` ignored if method is "optimized". Defaults to 1000.
60 | method : str, optional
61 | Method to compute the confidence intervals. Either "simulation" or "optimized".
62 | Defaults to "simulation".
63 | var_names : str or list of str, optional
64 | One or more variables to be plotted. Currently only one variable is supported.
65 | Prefix the variables by ~ when you want to exclude them from the plot.
66 | filter_vars : {None, “like”, “regex”}, optional, default=None
67 | If None (default), interpret var_names as the real variables names.
68 | If “like”, interpret var_names as substrings of the real variables names.
69 | If “regex”, interpret var_names as regular expressions on the real variables names.
70 | coords : dict, optional
71 | Coordinates to plot.
72 | sample_dims : str or sequence of hashable, optional
73 | Dimensions to reduce unless mapped to an aesthetic.
74 | Defaults to ``rcParams["data.sample_dims"]``
75 | plot_collection : PlotCollection, optional
76 | backend : {"matplotlib", "bokeh", "plotly"}, optional
77 | labeller : labeller, optional
78 | aes_by_visuals : mapping of {str : sequence of str}, optional
79 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
80 | when plotted. Valid keys are the same as for `visuals`.
81 |
82 | visuals : mapping of {str : mapping or False}, optional
83 | Valid keys are:
84 |
85 | * ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
86 | * ci -> passed to :func:`~arviz_plots.visuals.ci_line_y`
87 | * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x`
88 | * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y`
89 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title`
90 |
91 | pc_kwargs : mapping
92 | Passed to :class:`arviz_plots.PlotCollection.grid`
93 |
94 | Returns
95 | -------
96 | PlotCollection
97 |
98 | Examples
99 | --------
100 | Plot the ecdf-PIT for the crabs hurdle-negative-binomial dataset.
101 |
102 | .. plot::
103 | :context: close-figs
104 |
105 | >>> from arviz_plots import plot_loo_pit, style
106 | >>> style.use("arviz-variat")
107 | >>> from arviz_base import load_arviz_data
108 | >>> dt = load_arviz_data('radon')
109 | >>> plot_loo_pit(dt)
110 |
111 |
112 | Plot the coverage for the crabs hurdle-negative-binomial dataset.
113 |
114 | .. plot::
115 | :context: close-figs
116 |
117 | >>> plot_loo_pit(dt, coverage=True)
118 |
119 |
120 | .. minigallery:: plot_loo_pit
121 |
122 | References
123 | ----------
124 | .. [1] Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation
125 | and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
126 |
127 | .. [2] Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning
128 | Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
129 |
130 | .. [3] Säilynoja et al. *Graphical test for discrete uniformity and
131 | its applications in goodness-of-fit evaluation and multiple sample comparison*.
132 | Statistics and Computing 32(32). (2022) https://doi.org/10.1007/s11222-022-10090-6
133 | """
134 | if visuals is None:
135 | visuals = {}
136 | else:
137 | visuals = visuals.copy()
138 | if isinstance(sample_dims, str):
139 | sample_dims = [sample_dims]
140 |
141 | if group != "posterior_predictive":
142 | raise ValueError(f"Group {group} not supported. Only 'posterior_predictive' is supported.")
143 |
144 | lpv = loo_pit(dt)
145 | new_dt = convert_to_datatree(lpv, group="loo_pit")
146 |
147 | visuals.setdefault("ylabel", {})
148 | visuals.setdefault("remove_axis", False)
149 | visuals.setdefault("xlabel", {"text": "LOO-PIT"})
150 |
151 | plot_collection = plot_ecdf_pit(
152 | new_dt,
153 | var_names=var_names,
154 | filter_vars=filter_vars,
155 | group="loo_pit",
156 | coords=coords,
157 | sample_dims=lpv.dims,
158 | ci_prob=ci_prob,
159 | coverage=coverage,
160 | n_simulations=n_simulations,
161 | method=method,
162 | plot_collection=plot_collection,
163 | backend=backend,
164 | labeller=labeller,
165 | aes_by_visuals=aes_by_visuals,
166 | visuals=visuals,
167 | **pc_kwargs,
168 | )
169 |
170 | return plot_collection
171 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/pairs_focus_plot.py:
--------------------------------------------------------------------------------
1 | """Pair focus plot code."""
2 | from copy import copy
3 | from importlib import import_module
4 |
5 | import numpy as np
6 | from arviz_base import rcParams
7 | from arviz_base.labels import BaseLabeller
8 |
9 | from arviz_plots.plot_collection import PlotCollection
10 | from arviz_plots.plots.utils import (
11 | filter_aes,
12 | get_group,
13 | process_group_variables_coords,
14 | set_wrap_layout,
15 | )
16 | from arviz_plots.visuals import divergence_scatter, labelled_x, labelled_y, scatter_x
17 |
18 |
19 | def plot_pairs_focus(
20 | dt,
21 | var_names=None,
22 | filter_vars=None,
23 | group="posterior",
24 | coords=None,
25 | sample_dims=None,
26 | focus_var=None,
27 | focus_var_coords=None,
28 | plot_collection=None,
29 | backend=None,
30 | labeller=None,
31 | aes_by_visuals=None,
32 | visuals=None,
33 | **pc_kwargs,
34 | ):
35 | """Plot a fixed variable against other variables in the dataset.
36 |
37 | Parameters
38 | ----------
39 | dt : DataTree
40 | Input data
41 | var_names: str or list of str, optional
42 | One or more variables to be plotted.
43 | Prefix the variables by ~ when you want to exclude them from the plot.
44 | filter_vars: {None, “like”, “regex”}, optional, default=None
45 | If None (default), interpret var_names as the real variables names.
46 | If “like”, interpret var_names as substrings of the real variables names.
47 | If “regex”, interpret var_names as regular expressions on the real variables names.
48 | group : str, optional
49 | Group to use for plotting. Defaults to "posterior".
50 | coords : mapping, optional
51 | Coordinates to use for plotting. Defaults to None.
52 | sample_dims : iterable, optional
53 | Dimensions to reduce unless mapped to an aesthetic.
54 | Defaults to ``rcParams["data.sample_dims"]``
55 | focus_var: str
56 | Name of the variable to be plotted against all other variables.
57 | focus_var_coords : mapping, optional
58 | Coordinates to use for the target variable. Defaults to None.
59 | plot_collection : PlotCollection, optional
60 | backend : {"matplotlib", "bokeh","plotly"}, optional
61 | labeller : labeller, optional
62 | aes_by_visuals : mapping, optional
63 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
64 | when plotted. Valid keys are the same as for `visuals`.
65 | visuals : mapping of {str : mapping or False}, optional
66 | Valid keys are:
67 |
68 | * scatter -> passed to :func:`~.visuals.scatter_x`
69 | * divergence -> passed to :func:`~.visuals.divergence_scatter`. Defaults to False.
70 | * title -> :func:`~.visuals.labelled_title`
71 |
72 | pc_kwargs : mapping
73 | Passed to :class:`arviz_plots.PlotCollection`
74 |
75 | Returns
76 | -------
77 | PlotCollection
78 |
79 | Examples
80 | --------
81 | Default plot_pair_focus
82 |
83 | .. plot::
84 | :context: close-figs
85 |
86 | >>> from arviz_plots import plot_pairs_focus, style
87 | >>> style.use("arviz-variat")
88 | >>> from arviz_base import load_arviz_data
89 | >>> dt = load_arviz_data('centered_eight')
90 | >>> plot_pairs_focus(
91 | >>> dt,
92 | >>> var_names=["theta", "tau"],
93 | >>> focus_var="mu",
94 | >>> )
95 |
96 | """
97 | if sample_dims is None:
98 | sample_dims = rcParams["data.sample_dims"]
99 | if isinstance(sample_dims, str):
100 | sample_dims = [sample_dims]
101 | if visuals is None:
102 | visuals = {}
103 | if pc_kwargs is None:
104 | pc_kwargs = {}
105 | else:
106 | pc_kwargs = pc_kwargs.copy()
107 |
108 | if aes_by_visuals is None:
109 | aes_by_visuals = {}
110 | else:
111 | aes_by_visuals = aes_by_visuals.copy()
112 |
113 | if backend is None:
114 | if plot_collection is None:
115 | backend = rcParams["plot.backend"]
116 | else:
117 | backend = plot_collection.backend
118 |
119 | if var_names is None:
120 | var_names = "~" + focus_var
121 |
122 | distribution = process_group_variables_coords(
123 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
124 | )
125 |
126 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
127 |
128 | if plot_collection is None:
129 | pc_kwargs.setdefault(
130 | "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims]
131 | )
132 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
133 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
134 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
135 | if "chain" in distribution:
136 | pc_kwargs["aes"].setdefault("overlay", ["chain"])
137 | pc_kwargs["figure_kwargs"].setdefault("sharey", True)
138 | plot_collection = PlotCollection.wrap(
139 | distribution,
140 | backend=backend,
141 | **pc_kwargs,
142 | )
143 |
144 | # scatter
145 | y = (
146 | dt.posterior[focus_var].sel(focus_var_coords)
147 | if focus_var_coords is not None
148 | else dt.posterior[focus_var]
149 | )
150 | aes_by_visuals["scatter"] = {"overlay"}.union(aes_by_visuals.get("scatter", {}))
151 | scatter_kwargs = copy(visuals.get("scatter", {}))
152 | scatter_kwargs.setdefault("alpha", 0.5)
153 | colors = plot_bknd.get_default_aes("color", 1, {})
154 | scatter_kwargs.setdefault("color", colors[0])
155 | scatter_kwargs.setdefault("width", 0)
156 | _, _, scatter_ignore = filter_aes(plot_collection, aes_by_visuals, "scatter", sample_dims)
157 |
158 | plot_collection.map(
159 | scatter_x,
160 | "scatter",
161 | ignore_aes=scatter_ignore,
162 | y=y,
163 | **scatter_kwargs,
164 | )
165 |
166 | # divergence
167 |
168 | aes_by_visuals["divergence"] = {"overlay"}.union(aes_by_visuals.get("divergence", {}))
169 | div_kwargs = copy(visuals.get("divergence", False))
170 | if div_kwargs is True:
171 | div_kwargs = {}
172 | sample_stats = get_group(dt, "sample_stats", allow_missing=True)
173 | if (
174 | div_kwargs is not False
175 | and sample_stats is not None
176 | and "diverging" in sample_stats.data_vars
177 | and np.any(sample_stats.diverging)
178 | ):
179 | divergence_mask = dt.sample_stats.diverging
180 | _, div_aes, div_ignore = filter_aes(
181 | plot_collection, aes_by_visuals, "divergence", sample_dims
182 | )
183 | if "color" not in div_aes:
184 | div_kwargs.setdefault("color", "black")
185 | div_kwargs.setdefault("alpha", 0.4)
186 | plot_collection.map(
187 | divergence_scatter,
188 | "divergence",
189 | ignore_aes=div_ignore,
190 | y=y,
191 | mask=divergence_mask,
192 | **div_kwargs,
193 | )
194 |
195 | if labeller is None:
196 | labeller = BaseLabeller()
197 |
198 | # xlabel of plots
199 |
200 | xlabel_kwargs = copy(visuals.get("xlabel", {}))
201 | _, _, xlabel_ignore = filter_aes(plot_collection, aes_by_visuals, "xlabel", sample_dims)
202 | plot_collection.map(
203 | labelled_x,
204 | "xlabel",
205 | subset_info=True,
206 | ignore_aes=xlabel_ignore,
207 | labeller=labeller,
208 | **xlabel_kwargs,
209 | )
210 |
211 | # ylabel of plots
212 | ylabel_kwargs = copy(visuals.get("ylabel", {}))
213 | _, _, ylabel_ignore = filter_aes(plot_collection, aes_by_visuals, "ylabel", sample_dims)
214 | plot_collection.map(
215 | labelled_y,
216 | "ylabel",
217 | ignore_aes=ylabel_ignore,
218 | text=focus_var,
219 | **ylabel_kwargs,
220 | )
221 |
222 | return plot_collection
223 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/prior_posterior_plot.py:
--------------------------------------------------------------------------------
1 | """Contain functions for Bayes Factor plotting."""
2 |
3 | from importlib import import_module
4 |
5 | import numpy as np
6 | from arviz_base import extract, rcParams
7 | from xarray import concat
8 |
9 | from arviz_plots.plot_collection import PlotCollection
10 | from arviz_plots.plots.dist_plot import plot_dist
11 | from arviz_plots.plots.utils import process_group_variables_coords, set_wrap_layout
12 |
13 |
14 | def plot_prior_posterior(
15 | dt,
16 | var_names=None,
17 | filter_vars=None,
18 | group=None, # pylint: disable=unused-argument
19 | coords=None,
20 | sample_dims=None,
21 | kind=None,
22 | plot_collection=None,
23 | backend=None,
24 | labeller=None,
25 | aes_by_visuals=None,
26 | visuals=None,
27 | stats=None,
28 | **pc_kwargs,
29 | ):
30 | r"""Plot 1D marginal densities for prior and posterior.
31 |
32 | The Bayes factor is estimated by comparing a model (H1) against a model
33 | in which the parameter of interest has been restricted to be a point-null (H0)
34 | This computation assumes the models are nested and thus H0 is a special case of H1.
35 |
36 | Parameters
37 | ----------
38 | dt : DataTree or dict of {str : DataTree}
39 | Input data. In case of dictionary input, the keys are taken to be model names.
40 | In such cases, a dimension "model" is generated and can be used to map to aesthetics.
41 | var_names : str or list of str, optional
42 | One or more variables to be plotted.
43 | Prefix the variables by ~ when you want to exclude them from the plot.
44 | filter_vars : {None, “like”, “regex”}, default=None
45 | If None, interpret var_names as the real variables names.
46 | If “like”, interpret var_names as substrings of the real variables names.
47 | If “regex”, interpret var_names as regular expressions on the real variables names.
48 | group : None
49 | This argument is ignored. Have it here for compatibility with other plotting functions.
50 | coords : dict, optional
51 | sample_dims : str or sequence of hashable, optional
52 | Dimensions to reduce unless mapped to an aesthetic.
53 | Defaults to ``rcParams["data.sample_dims"]``
54 | kind : {"kde", "hist", "dot", "ecdf"}, optional
55 | How to represent the marginal density.
56 | Defaults to ``rcParams["plot.density_kind"]``
57 | plot_collection : PlotCollection, optional
58 | backend : {"matplotlib", "bokeh"}, optional
59 | labeller : labeller, optional
60 | aes_by_visuals : mapping of {str : sequence of str}, optional
61 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
62 | when plotted. Valid keys are the same as for `visuals`.
63 | visuals : mapping of {str : mapping or False}, optional
64 | Valid keys are:
65 |
66 | * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument.
67 |
68 | * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
69 | * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
70 | * "hist" -> passed to :func: `~arviz_plots.visuals.hist`
71 |
72 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title`
73 |
74 | stats : mapping, optional
75 | Valid keys are:
76 |
77 | * density -> passed to kde, ecdf, ...
78 |
79 | pc_kwargs : mapping
80 | Passed to :class:`arviz_plots.PlotCollection.wrap`
81 |
82 | Returns
83 | -------
84 | PlotCollection
85 |
86 | Examples
87 | --------
88 | Select two variables and plot them with a ecdf.
89 |
90 | .. plot::
91 | :context: close-figs
92 |
93 | >>> from arviz_plots import plot_prior_posterior, style
94 | >>> style.use("arviz-variat")
95 | >>> from arviz_base import load_arviz_data
96 | >>> dt = load_arviz_data('centered_eight')
97 | >>> plot_prior_posterior(dt, var_names=["mu", "tau"], kind="ecdf")
98 |
99 |
100 | .. minigallery:: plot_prior_posterior
101 | """
102 | if sample_dims is None:
103 | sample_dims = rcParams["data.sample_dims"]
104 | if isinstance(sample_dims, str):
105 | sample_dims = [sample_dims]
106 | sample_dims = list(sample_dims)
107 | if kind is None:
108 | kind = rcParams["plot.density_kind"]
109 | if stats is None:
110 | stats = {}
111 | else:
112 | stats = stats.copy()
113 | if visuals is None:
114 | visuals = {}
115 | else:
116 | visuals = visuals.copy()
117 | if sample_dims is None:
118 | sample_dims = rcParams["data.sample_dims"]
119 | if isinstance(sample_dims, str):
120 | sample_dims = [sample_dims]
121 | sample_dims = list(sample_dims)
122 | if not isinstance(visuals, dict):
123 | visuals = {}
124 |
125 | if backend is None:
126 | if plot_collection is None:
127 | backend = rcParams["plot.backend"]
128 | else:
129 | backend = plot_collection.backend
130 |
131 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
132 |
133 | prior_size = np.prod([dt.prior.sizes[dim] for dim in sample_dims])
134 | posterior_size = np.prod([dt.posterior.sizes[dim] for dim in sample_dims])
135 | num_samples = min(prior_size, posterior_size)
136 |
137 | ds_prior = (
138 | extract(dt, group="prior", num_samples=num_samples, random_seed=0, keep_dataset=True)
139 | .drop_vars(sample_dims + ["sample"])
140 | .assign_coords(sample=("sample", np.arange(num_samples)))
141 | )
142 | ds_posterior = (
143 | extract(dt, group="posterior", num_samples=num_samples, random_seed=0, keep_dataset=True)
144 | .drop_vars(sample_dims + ["sample"])
145 | .assign_coords(sample=("sample", np.arange(num_samples)))
146 | )
147 |
148 | distribution = concat([ds_prior, ds_posterior], dim="group").assign_coords(
149 | {"group": ["prior", "posterior"]}
150 | )
151 |
152 | distribution = process_group_variables_coords(
153 | distribution,
154 | group=None,
155 | var_names=var_names,
156 | filter_vars=filter_vars,
157 | coords=coords,
158 | )
159 |
160 | if len(sample_dims) > 1:
161 | # sample dims will have been stacked and renamed by `extract`
162 | sample_dims = ["sample"]
163 |
164 | if plot_collection is None:
165 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
166 |
167 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
168 | pc_kwargs["aes"].setdefault("color", ["group"])
169 | pc_kwargs.setdefault("col_wrap", 4)
170 | pc_kwargs.setdefault(
171 | "cols",
172 | ["__variable__"]
173 | + [dim for dim in distribution.dims if dim not in sample_dims + ["group"]],
174 | )
175 |
176 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
177 |
178 | plot_collection = PlotCollection.wrap(
179 | distribution,
180 | backend=backend,
181 | **pc_kwargs,
182 | )
183 |
184 | visuals.setdefault("credible_interval", False)
185 | visuals.setdefault("point_estimate", False)
186 | visuals.setdefault("point_estimate_text", False)
187 |
188 | if aes_by_visuals is None:
189 | aes_by_visuals = {}
190 | else:
191 | aes_by_visuals = aes_by_visuals.copy()
192 |
193 | if kind == "hist":
194 | visuals.setdefault("hist", {})
195 | visuals.setdefault("remove_axis", True)
196 | if visuals["hist"] is not False:
197 | visuals["hist"].setdefault("step", True)
198 | stats.setdefault("density", {"density": True})
199 |
200 | plot_collection = plot_dist(
201 | distribution,
202 | var_names=None,
203 | group=None,
204 | coords=None,
205 | sample_dims=sample_dims,
206 | kind=kind,
207 | point_estimate=None,
208 | ci_kind=None,
209 | ci_prob=None,
210 | plot_collection=plot_collection,
211 | backend=backend,
212 | labeller=labeller,
213 | visuals=visuals,
214 | stats=stats,
215 | **pc_kwargs,
216 | )
217 |
218 | plot_collection.add_legend("group")
219 |
220 | return plot_collection
221 |
--------------------------------------------------------------------------------
/src/arviz_plots/plots/trace_plot.py:
--------------------------------------------------------------------------------
1 | """Trace plot code."""
2 | from copy import copy
3 | from importlib import import_module
4 |
5 | import numpy as np
6 | from arviz_base import rcParams
7 | from arviz_base.labels import BaseLabeller
8 |
9 | from arviz_plots.plot_collection import PlotCollection
10 | from arviz_plots.plots.utils import (
11 | filter_aes,
12 | get_group,
13 | process_group_variables_coords,
14 | set_wrap_layout,
15 | )
16 | from arviz_plots.visuals import labelled_title, labelled_x, line, ticklabel_props, trace_rug
17 |
18 |
19 | def plot_trace(
20 | dt,
21 | var_names=None,
22 | filter_vars=None,
23 | group="posterior",
24 | coords=None,
25 | sample_dims=None,
26 | plot_collection=None,
27 | backend=None,
28 | labeller=None,
29 | aes_by_visuals=None,
30 | visuals=None,
31 | **pc_kwargs,
32 | ):
33 | """Plot iteration versus sampled values.
34 |
35 | Parameters
36 | ----------
37 | dt : DataTree
38 | Input data
39 | var_names: str or list of str, optional
40 | One or more variables to be plotted.
41 | Prefix the variables by ~ when you want to exclude them from the plot.
42 | filter_vars: {None, “like”, “regex”}, optional, default=None
43 | If None (default), interpret var_names as the real variables names.
44 | If “like”, interpret var_names as substrings of the real variables names.
45 | If “regex”, interpret var_names as regular expressions on the real variables names.
46 | sample_dims : iterable, optional
47 | Dimensions to reduce unless mapped to an aesthetic.
48 | Defaults to ``rcParams["data.sample_dims"]``
49 | plot_collection : PlotCollection, optional
50 | backend : {"matplotlib", "bokeh"}, optional
51 | labeller : labeller, optional
52 | aes_by_visuals : mapping, optional
53 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
54 | when plotted. Defaults to only mapping properties to the trace lines.
55 | visuals : mapping of {str : mapping or False}, optional
56 | Valid keys are:
57 |
58 | * trace -> passed to :func:`~.visuals.line`
59 | * divergence -> passed to :func:`~.visuals.trace_rug`
60 | * title -> :func:`~.visuals.labelled_title`
61 | * xlabel -> :func:`~.visuals.labelled_x`
62 | * ticklabels -> :func:`~.visuals.ticklabel_props`
63 |
64 | pc_kwargs : mapping
65 | Passed to :class:`arviz_plots.PlotCollection`
66 |
67 | Returns
68 | -------
69 | PlotCollection
70 |
71 | Examples
72 | --------
73 | The following examples focus on behaviour specific to ``plot_trace``.
74 | For a general introduction to batteries-included functions like this one and common
75 | usage examples see :ref:`plots_intro`
76 |
77 | Default plot_trace
78 |
79 | .. plot::
80 | :context: close-figs
81 |
82 | >>> from arviz_plots import plot_trace, style
83 | >>> style.use("arviz-variat")
84 | >>> from arviz_base import load_arviz_data
85 | >>> centered = load_arviz_data('centered_eight')
86 | >>> plot_trace(centered)
87 |
88 | """
89 | if sample_dims is None:
90 | sample_dims = rcParams["data.sample_dims"]
91 | if isinstance(sample_dims, str):
92 | sample_dims = [sample_dims]
93 | if visuals is None:
94 | visuals = {}
95 |
96 | distribution = process_group_variables_coords(
97 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
98 | )
99 |
100 | if backend is None:
101 | if plot_collection is None:
102 | backend = rcParams["plot.backend"]
103 | else:
104 | backend = plot_collection.backend
105 |
106 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
107 |
108 | if plot_collection is None:
109 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
110 | if "chain" in distribution:
111 | pc_kwargs["aes"].setdefault("color", ["chain"])
112 | pc_kwargs["aes"].setdefault("overlay", ["chain"])
113 | pc_kwargs.setdefault(
114 | "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims]
115 | )
116 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
117 | aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"]
118 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
119 | pc_kwargs["figure_kwargs"].setdefault("sharex", True)
120 | plot_collection = PlotCollection.wrap(
121 | distribution,
122 | backend=backend,
123 | **pc_kwargs,
124 | )
125 | else:
126 | aux_dim_list = list(plot_collection.viz["plot"].dims)
127 |
128 | if aes_by_visuals is None:
129 | aes_by_visuals = {}
130 | else:
131 | aes_by_visuals = aes_by_visuals.copy()
132 | aes_by_visuals.setdefault("trace", plot_collection.aes_set)
133 | aes_by_visuals.setdefault("divergence", {"overlay"})
134 |
135 | if labeller is None:
136 | labeller = BaseLabeller()
137 |
138 | # trace
139 | trace_kwargs = copy(visuals.get("trace", {}))
140 | if trace_kwargs is False:
141 | xname = None
142 | else:
143 | default_xname = sample_dims[0] if len(sample_dims) == 1 else "draw"
144 | if (default_xname not in distribution.dims) or (
145 | not np.issubdtype(distribution[default_xname].dtype, np.number)
146 | ):
147 | default_xname = None
148 | xname = trace_kwargs.get("xname", default_xname)
149 | trace_kwargs["xname"] = xname
150 | _, _, trace_ignore = filter_aes(plot_collection, aes_by_visuals, "trace", sample_dims)
151 | plot_collection.map(
152 | line,
153 | "trace",
154 | data=distribution,
155 | ignore_aes=trace_ignore,
156 | **trace_kwargs,
157 | )
158 |
159 | # divergences
160 | sample_stats = get_group(dt, "sample_stats", allow_missing=True)
161 | divergence_kwargs = copy(visuals.get("divergence", {}))
162 | if (
163 | sample_stats is not None
164 | and "diverging" in sample_stats.data_vars
165 | and np.any(sample_stats.diverging)
166 | and divergence_kwargs is not False
167 | ):
168 | divergence_mask = dt.sample_stats.diverging
169 | _, div_aes, div_ignore = filter_aes(
170 | plot_collection, aes_by_visuals, "divergence", sample_dims
171 | )
172 | if "color" not in div_aes:
173 | divergence_kwargs.setdefault("color", "black")
174 | if "marker" not in div_aes:
175 | divergence_kwargs.setdefault("marker", "|")
176 | if "size" not in div_aes:
177 | divergence_kwargs.setdefault("size", 30)
178 | div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list]
179 |
180 | plot_collection.map(
181 | trace_rug,
182 | "divergence",
183 | data=distribution,
184 | ignore_aes=div_ignore,
185 | xname=xname,
186 | y=distribution.min(div_reduce_dims),
187 | mask=divergence_mask,
188 | **divergence_kwargs,
189 | )
190 |
191 | # aesthetics
192 | title_kwargs = copy(visuals.get("title", {}))
193 | if title_kwargs is not False:
194 | _, title_aes, title_ignore = filter_aes(
195 | plot_collection, aes_by_visuals, "title", sample_dims
196 | )
197 | if "color" not in title_aes:
198 | title_kwargs.setdefault("color", "black")
199 | plot_collection.map(
200 | labelled_title,
201 | "title",
202 | ignore_aes=title_ignore,
203 | subset_info=True,
204 | labeller=labeller,
205 | **title_kwargs,
206 | )
207 |
208 | # Add "Steps" as x_label for trace
209 | xlabel_kwargs = copy(visuals.get("xlabel", {}))
210 | if xlabel_kwargs is not False:
211 | _, xlabel_aes, xlabel_ignore = filter_aes(
212 | plot_collection, aes_by_visuals, "xlabel", sample_dims
213 | )
214 |
215 | if "color" not in xlabel_aes:
216 | xlabel_kwargs.setdefault("color", "black")
217 |
218 | plot_collection.map(
219 | labelled_x,
220 | "xlabel",
221 | ignore_aes=xlabel_ignore,
222 | text="Steps" if xname is None else xname.capitalize(),
223 | **xlabel_kwargs,
224 | )
225 |
226 | # Adjust tick labels
227 | ticklabels_kwargs = copy(visuals.get("ticklabels", {}))
228 | if ticklabels_kwargs is not False:
229 | _, _, ticklabels_ignore = filter_aes(
230 | plot_collection, aes_by_visuals, "ticklabels", sample_dims
231 | )
232 | plot_collection.map(
233 | ticklabel_props,
234 | "ticklabels",
235 | ignore_aes=ticklabels_ignore,
236 | axis="both",
237 | store_artist=backend == "none",
238 | **ticklabels_kwargs,
239 | )
240 |
241 | return plot_collection
242 |
--------------------------------------------------------------------------------
/src/arviz_plots/py.typed:
--------------------------------------------------------------------------------
1 | # Marker file for PEP 561
2 |
--------------------------------------------------------------------------------
/src/arviz_plots/style.py:
--------------------------------------------------------------------------------
1 | """Style/templating helpers."""
2 | import os
3 |
4 | from arviz_base import rcParams
5 |
6 |
7 | def use(name):
8 | """Set an arviz style as the default style/template for all available backends.
9 |
10 | Parameters
11 | ----------
12 | name : str
13 | Name of the style to be set as default.
14 | """
15 | ok = False
16 |
17 | try:
18 | import matplotlib.pyplot as plt
19 |
20 | if name in plt.style.available:
21 | plt.style.use(name)
22 | ok = True
23 | except ImportError:
24 | pass
25 |
26 | try:
27 | import plotly.io as pio
28 |
29 | if name in pio.templates:
30 | pio.templates.default = name
31 | ok = True
32 | except ImportError:
33 | pass
34 |
35 | try:
36 | if name in ["arviz-cetrino", "arviz-variat", "arviz-vibrant"]:
37 | from bokeh.io import curdoc
38 | from bokeh.themes import Theme
39 |
40 | path = os.path.dirname(os.path.abspath(__file__))
41 | curdoc().theme = Theme(filename=f"{path}/styles/{name}.yml")
42 | ok = True
43 | except (ImportError, FileNotFoundError):
44 | pass
45 |
46 | if not ok:
47 | raise ValueError(f"Style {name} not found.")
48 |
49 |
50 | def available():
51 | """List available styles."""
52 | styles = {}
53 |
54 | try:
55 | import matplotlib.pyplot as plt
56 |
57 | styles["matplotlib"] = plt.style.available
58 | except ImportError:
59 | pass
60 |
61 | try:
62 | import plotly.io as pio
63 |
64 | styles["plotly"] = list(pio.templates)
65 | except ImportError:
66 | pass
67 |
68 | return styles
69 |
70 |
71 | def get(name, backend=None):
72 | """Get the style/template with the given name.
73 |
74 | Parameters
75 | ----------
76 | name : str
77 | Name of the style/template to get.
78 | backend : str
79 | Name of the backend to use. Options are 'matplotlib' and 'plotly'.
80 | Defaults to ``rcParams["plot.backend"]``.
81 | """
82 | if backend is None:
83 | backend = rcParams["plot.backend"]
84 | if backend not in ["matplotlib", "plotly"]:
85 | raise ValueError(f"Default styles/templates are not supported for Backend {backend}")
86 |
87 | if backend == "matplotlib":
88 | import matplotlib.pyplot as plt
89 |
90 | if name in plt.style.available:
91 | return plt.style.library[name]
92 |
93 | elif backend == "plotly":
94 | import plotly.io as pio
95 |
96 | if name in pio.templates:
97 | return pio.templates[name]
98 |
99 | raise ValueError(f"Style {name} not found.")
100 |
--------------------------------------------------------------------------------
/src/arviz_plots/styles/arviz-cetrino.mplstyle:
--------------------------------------------------------------------------------
1 | ## ***************************************************************************
2 | ## * FIGURE *
3 | ## ***************************************************************************
4 |
5 | figure.facecolor: white # broken white outside box
6 | figure.edgecolor: None # broken white outside box
7 | figure.titleweight: bold # weight of the figure title
8 | figure.titlesize: x-large
9 |
10 | figure.figsize: 6, 5
11 | figure.dpi: 200.0
12 | figure.constrained_layout.use: True
13 |
14 | ## ***************************************************************************
15 | ## * FONT *
16 | ## ***************************************************************************
17 |
18 | font.size: 12
19 | font.style: normal
20 | font.variant: normal
21 | font.weight: normal
22 | font.stretch: normal
23 |
24 | text.color: .15
25 |
26 | ## ***************************************************************************
27 | ## * AXES *
28 | ## ***************************************************************************
29 |
30 | axes.facecolor: white
31 | axes.edgecolor: .33 # axes edge color
32 | axes.linewidth: 0.8 # edge line width
33 |
34 | axes.grid: False # do not show grid
35 | # axes.grid.axis: y # which axis the grid should apply to
36 | axes.grid.which: major # grid lines at {major, minor, both} ticks
37 | axes.axisbelow: True # keep grid layer in the back
38 |
39 | grid.color: .8 # grid color
40 | grid.linestyle: - # solid
41 | grid.linewidth: 0.8 # in points
42 | grid.alpha: 1.0 # transparency, between 0.0 and 1.0
43 |
44 | lines.solid_capstyle: round
45 |
46 | axes.spines.right: False # do not show right spine
47 | axes.spines.top: False # do not show top spine
48 |
49 | axes.titlesize: 16
50 | axes.titleweight: bold # font weight of title
51 |
52 | axes.labelsize: large
53 | axes.labelcolor: .15
54 | axes.labelweight: normal # weight of the x and y labels
55 |
56 | # color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
57 | # see preview and check for colorblindness here https://coolors.co/009988-9238b2-d2225f-ec8f26-fcd026-3cd186-a57119-2f5e14-f225f4-8f9fbf
58 | axes.prop_cycle: cycler(color=["009988", "9238b2", "d2225f", "ec8f26", "fcd026", "3cd186", "a57119", "2f5e14", "f225f4", "8f9fbf"])
59 |
60 | image.cmap: viridis
61 |
62 | ## ***************************************************************************
63 | ## * TICKS *
64 | ## ***************************************************************************
65 |
66 | xtick.labelsize: large
67 | xtick.color: .15
68 | xtick.top: False
69 | xtick.bottom: True
70 | xtick.direction: out
71 |
72 | ytick.labelsize: large
73 | ytick.color: .15
74 | ytick.left: True
75 | ytick.right: False
76 | ytick.direction: out
77 |
78 | ## ***************************************************************************
79 | ## * LEGEND *
80 | ## ***************************************************************************
81 |
82 | legend.framealpha: 0.5
83 | legend.frameon: False # do not draw on background patch
84 | legend.fancybox: False # do not round corners
85 |
86 | legend.numpoints: 1
87 | legend.scatterpoints: 1
88 |
89 | legend.fontsize: large
--------------------------------------------------------------------------------
/src/arviz_plots/styles/arviz-cetrino.yml:
--------------------------------------------------------------------------------
1 | attrs:
2 | Plot:
3 | background_fill_color: white
4 | border_fill_color: white
5 | outline_line_width: 0
6 | outline_line_color: null
7 | Axis:
8 | major_tick_line_color: '#262626'
9 | minor_tick_line_alpha: 0
10 | Grid:
11 | grid_line_color: null
12 | Title:
13 | text_color: 'black'
14 | text_font_style: 'bold'
15 | align: 'center'
16 | Text:
17 | text_color: '#262626'
18 | text_font_style: 'bold'
19 | Cycler:
20 | colors : [
21 | '#009988', '#9238b2', '#d2225f', '#ec8f26', '#fcd026',
22 | '#3cd186', '#a57119', '#2f5e14', '#f225f4', '#8f9fbf',
23 | ]
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/arviz_plots/styles/arviz-variat.mplstyle:
--------------------------------------------------------------------------------
1 | ## ***************************************************************************
2 | ## * FIGURE *
3 | ## ***************************************************************************
4 |
5 | figure.facecolor: white # broken white outside box
6 | figure.edgecolor: None # broken white outside box
7 | figure.titleweight: bold # weight of the figure title
8 | figure.titlesize: x-large
9 |
10 | figure.figsize: 6, 5
11 | figure.dpi: 200.0
12 | figure.constrained_layout.use: True
13 |
14 | ## ***************************************************************************
15 | ## * FONT *
16 | ## ***************************************************************************
17 |
18 | font.size: 12
19 | font.style: normal
20 | font.variant: normal
21 | font.weight: normal
22 | font.stretch: normal
23 |
24 | text.color: .15
25 |
26 | ## ***************************************************************************
27 | ## * AXES *
28 | ## ***************************************************************************
29 |
30 | axes.facecolor: white
31 | axes.edgecolor: .33 # axes edge color
32 | axes.linewidth: 0.8 # edge line width
33 |
34 | axes.grid: False # do not show grid
35 | # axes.grid.axis: y # which axis the grid should apply to
36 | axes.grid.which: major # grid lines at {major, minor, both} ticks
37 | axes.axisbelow: True # keep grid layer in the back
38 |
39 | grid.color: .8 # grid color
40 | grid.linestyle: - # solid
41 | grid.linewidth: 0.8 # in points
42 | grid.alpha: 1.0 # transparency, between 0.0 and 1.0
43 |
44 | lines.solid_capstyle: round
45 |
46 | axes.spines.right: False # do not show right spine
47 | axes.spines.top: False # do not show top spine
48 |
49 | axes.titlesize: 16
50 | axes.titleweight: bold # font weight of title
51 |
52 | axes.labelsize: large
53 | axes.labelcolor: .15
54 | axes.labelweight: normal # weight of the x and y labels
55 |
56 | # color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
57 | # see preview and check for colorblindness here https://coolors.co/36acc6-f66d7f-fac364-7c2695-228306-a252f4-63f0ea-000000-6f6f6f-b7b7b7
58 | axes.prop_cycle: cycler(color=["36acc6", "f66d7f", "fac364", "7c2695", "228306", "a252f4", "63f0ea", "000000", "6f6f6f", "b7b7b7"])
59 |
60 | image.cmap: viridis
61 |
62 | ## ***************************************************************************
63 | ## * TICKS *
64 | ## ***************************************************************************
65 |
66 | xtick.labelsize: large
67 | xtick.color: .15
68 | xtick.top: False
69 | xtick.bottom: True
70 | xtick.direction: out
71 |
72 | ytick.labelsize: large
73 | ytick.color: .15
74 | ytick.left: True
75 | ytick.right: False
76 | ytick.direction: out
77 |
78 | ## ***************************************************************************
79 | ## * LEGEND *
80 | ## ***************************************************************************
81 |
82 | legend.framealpha: 0.5
83 | legend.frameon: False # do not draw on background patch
84 | legend.fancybox: False # do not round corners
85 |
86 | legend.numpoints: 1
87 | legend.scatterpoints: 1
88 |
89 | legend.fontsize: large
90 |
91 |
--------------------------------------------------------------------------------
/src/arviz_plots/styles/arviz-variat.yml:
--------------------------------------------------------------------------------
1 | attrs:
2 | Plot:
3 | background_fill_color: white
4 | border_fill_color: white
5 | outline_line_width: 0
6 | outline_line_color: null
7 | Axis:
8 | major_tick_line_color: '#262626'
9 | minor_tick_line_alpha: 0
10 | Grid:
11 | grid_line_color: null
12 | Title:
13 | text_color: 'black'
14 | text_font_style: 'bold'
15 | align: 'center'
16 | Text:
17 | text_color: '#262626'
18 | text_font_style: 'bold'
19 | Cycler:
20 | colors : [
21 | '#36acc6', '#f66d7f', '#fac364', '#7c2695', '#228306',
22 | '#a252f4', '#63f0ea', '#000000', '#6f6f6f', '#b7b7b7'
23 | ]
24 |
--------------------------------------------------------------------------------
/src/arviz_plots/styles/arviz-vibrant.mplstyle:
--------------------------------------------------------------------------------
1 | ## ***************************************************************************
2 | ## * FIGURE *
3 | ## ***************************************************************************
4 |
5 | figure.facecolor: white # broken white outside box
6 | figure.edgecolor: None # broken white outside box
7 | figure.titleweight: bold # weight of the figure title
8 | figure.titlesize: x-large
9 |
10 | figure.figsize: 6, 5
11 | figure.dpi: 200.0
12 | figure.constrained_layout.use: True
13 |
14 | ## ***************************************************************************
15 | ## * FONT *
16 | ## ***************************************************************************
17 |
18 | font.size: 12
19 | font.style: normal
20 | font.variant: normal
21 | font.weight: normal
22 | font.stretch: normal
23 |
24 | text.color: .15
25 |
26 | ## ***************************************************************************
27 | ## * AXES *
28 | ## ***************************************************************************
29 |
30 | axes.facecolor: white
31 | axes.edgecolor: .33 # axes edge color
32 | axes.linewidth: 0.8 # edge line width
33 |
34 | axes.grid: False # do not show grid
35 | # axes.grid.axis: y # which axis the grid should apply to
36 | axes.grid.which: major # grid lines at {major, minor, both} ticks
37 | axes.axisbelow: True # keep grid layer in the back
38 |
39 | grid.color: .8 # grid color
40 | grid.linestyle: - # solid
41 | grid.linewidth: 0.8 # in points
42 | grid.alpha: 1.0 # transparency, between 0.0 and 1.0
43 |
44 | lines.solid_capstyle: round
45 |
46 | axes.spines.right: False # do not show right spine
47 | axes.spines.top: False # do not show top spine
48 |
49 | axes.titlesize: 16
50 | axes.titleweight: bold # font weight of title
51 |
52 | axes.labelsize: large
53 | axes.labelcolor: .15
54 | axes.labelweight: normal # weight of the x and y labels
55 |
56 | # color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
57 | # see preview and check for colorblindness here https://coolors.co/008b92-f15c58-48cdef-98d81a-997ee5-f5dc9d-c90a4e-145393-323232-616161
58 | axes.prop_cycle: cycler(color=["008b92", "f15c58", "48cdef", "98d81a", "997ee5", "f5dc9d", "c90a4e", "145393", "323232", "616161"])
59 | image.cmap: viridis
60 |
61 | ## ***************************************************************************
62 | ## * TICKS *
63 | ## ***************************************************************************
64 |
65 | xtick.labelsize: large
66 | xtick.color: .15
67 | xtick.top: False
68 | xtick.bottom: True
69 | xtick.direction: out
70 |
71 | ytick.labelsize: large
72 | ytick.color: .15
73 | ytick.left: True
74 | ytick.right: False
75 | ytick.direction: out
76 |
77 | ## ***************************************************************************
78 | ## * LEGEND *
79 | ## ***************************************************************************
80 |
81 | legend.framealpha: 0.5
82 | legend.frameon: False # do not draw on background patch
83 | legend.fancybox: False # do not round corners
84 |
85 | legend.numpoints: 1
86 | legend.scatterpoints: 1
87 |
88 | legend.fontsize: large
89 |
--------------------------------------------------------------------------------
/src/arviz_plots/styles/arviz-vibrant.yml:
--------------------------------------------------------------------------------
1 | attrs:
2 | Plot:
3 | background_fill_color: white
4 | border_fill_color: white
5 | outline_line_width: 0
6 | outline_line_color: null
7 | Axis:
8 | major_tick_line_color: '#262626'
9 | minor_tick_line_alpha: 0
10 | Grid:
11 | grid_line_color: null
12 | Title:
13 | text_color: 'black'
14 | text_font_style: 'bold'
15 | align: 'center'
16 | Text:
17 | text_color: '#262626'
18 | text_font_style: 'bold'
19 | Cycler:
20 | colors : [
21 | '#008b92', '#f15c58', '#48cdef', '#98d81a', '#997ee5',
22 | '#f5dc9d', '#c90a4e', '#145393', '#323232', '#616161',
23 | ]
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """ArviZ plots test module."""
2 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=redefined-outer-name
2 | """Configuration for test suite."""
3 | import logging
4 | import os
5 |
6 | import pytest
7 | from arviz_base.testing import cmp as _cmp
8 | from arviz_base.testing import datatree as _datatree
9 | from arviz_base.testing import datatree2 as _datatree2
10 | from arviz_base.testing import datatree3 as _datatree3
11 | from arviz_base.testing import datatree_4d as _datatree_4d
12 | from arviz_base.testing import datatree_binary as _datatree_binary
13 | from arviz_base.testing import datatree_sample as _datatree_sample
14 | from hypothesis import settings
15 |
16 | _log = logging.getLogger("arviz_plots")
17 |
18 | settings.register_profile("fast", deadline=3000, max_examples=20)
19 | settings.register_profile("chron", deadline=3000, max_examples=500)
20 | settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "fast"))
21 |
22 |
23 | def pytest_addoption(parser):
24 | """Definition for command line option to save figures from tests or skip backends."""
25 | parser.addoption("--save", nargs="?", const="test_images", help="Save images rendered by plot")
26 | parser.addoption("--skip-mpl", action="store_const", const=True, help="Skip matplotlib tests")
27 | parser.addoption("--skip-bokeh", action="store_const", const=True, help="Skip bokeh tests")
28 | parser.addoption("--skip-plotly", action="store_const", const=True, help="Skip plotly tests")
29 |
30 |
31 | @pytest.fixture(scope="session")
32 | def save_figs(request):
33 | """Enable command line switch for saving generation figures upon testing."""
34 | fig_dir = request.config.getoption("--save")
35 |
36 | if fig_dir is not None:
37 | # Try creating directory if it doesn't exist
38 | _log.info("Saving generated images in %s", fig_dir)
39 |
40 | os.makedirs(fig_dir, exist_ok=True)
41 | _log.info("Directory %s created", fig_dir)
42 |
43 | # Clear all files from the directory
44 | # Does not alter or delete directories
45 | for file in os.listdir(fig_dir):
46 | full_path = os.path.join(fig_dir, file)
47 |
48 | try:
49 | os.remove(full_path)
50 |
51 | except OSError:
52 | _log.info("Failed to remove %s", full_path)
53 |
54 | return fig_dir
55 |
56 |
57 | @pytest.fixture(scope="function")
58 | def clean_plots(request, save_figs):
59 | """Close plots after each test, saving too if --save is specified during test invocation."""
60 |
61 | def fin():
62 | if ("backend" in request.fixturenames) and any(
63 | "matplotlib" in key for key in request.keywords.keys()
64 | ):
65 | import matplotlib.pyplot as plt
66 |
67 | if save_figs is not None:
68 | plt.savefig(f"{os.path.join(save_figs, request.node.name)}.png")
69 | plt.close("all")
70 |
71 | request.addfinalizer(fin)
72 |
73 |
74 | @pytest.fixture(scope="function")
75 | def check_skips(request):
76 | """Skip bokeh or matplotlib tests if requested via command line."""
77 | skip_mpl = request.config.getoption("--skip-mpl")
78 | skip_bokeh = request.config.getoption("--skip-bokeh")
79 | skip_plotly = request.config.getoption("--skip-plotly")
80 |
81 | if "backend" in request.fixturenames:
82 | if skip_mpl and any("matplotlib" in key for key in request.keywords.keys()):
83 | pytest.skip(reason="Requested skipping matplolib tests via command line argument")
84 | if skip_bokeh and any("bokeh" in key for key in request.keywords.keys()):
85 | pytest.skip(reason="Requested skipping bokeh tests via command line argument")
86 | if skip_plotly and any("plotly" in key for key in request.keywords.keys()):
87 | pytest.skip(reason="Requested skipping plotly tests via command line argument")
88 |
89 |
90 | @pytest.fixture(scope="function")
91 | def no_artist_kwargs(monkeypatch):
92 | """Raise an error if visual kwargs are present when using 'none' backend."""
93 | monkeypatch.setattr("arviz_plots.backend.none.ALLOW_KWARGS", False)
94 |
95 |
96 | @pytest.fixture(scope="session")
97 | def datatree():
98 | """Fixture for a general DataTree."""
99 | return _datatree()
100 |
101 |
102 | @pytest.fixture(scope="session")
103 | def datatree2():
104 | """Fixture for a DataTree with a posterior and sample stats."""
105 | return _datatree2()
106 |
107 |
108 | @pytest.fixture(scope="session")
109 | def datatree3():
110 | """Fixture for a DataTree with discrete data."""
111 | return _datatree3()
112 |
113 |
114 | @pytest.fixture(scope="session")
115 | def datatree_4d():
116 | """Fixture for a DataTree with a 4D posterior."""
117 | return _datatree_4d()
118 |
119 |
120 | @pytest.fixture(scope="session")
121 | def datatree_binary():
122 | """Fixture for a DataTree with binary data."""
123 | return _datatree_binary()
124 |
125 |
126 | @pytest.fixture(scope="session")
127 | def datatree_sample():
128 | """Fixture for a DataTree with sample stats."""
129 | return _datatree_sample()
130 |
131 |
132 | @pytest.fixture(scope="session")
133 | def cmp():
134 | """Fixture for the cmp function."""
135 | return _cmp()
136 |
--------------------------------------------------------------------------------
/tests/test_fixtures.py:
--------------------------------------------------------------------------------
1 | """Test fixtures."""
2 | from importlib import import_module
3 |
4 | import pytest
5 |
6 |
7 | @pytest.mark.usefixtures("no_artist_kwargs")
8 | def test_no_artist_kwargs_fixture():
9 | none_backend = import_module("arviz_plots.backend.none")
10 | with pytest.raises(ValueError):
11 | none_backend.line([1, 2], [0, 1], [], extra_kwarg="yes")
12 |
--------------------------------------------------------------------------------
/tests/test_plot_matrix.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-self-use, redefined-outer-name
2 | """Test PlotMatrix."""
3 | import numpy as np
4 | import pytest
5 | import xarray as xr
6 | from arviz_base import dict_to_dataset
7 |
8 | from arviz_plots import PlotMatrix
9 | from arviz_plots.plot_matrix import subset_matrix_da
10 |
11 |
12 | @pytest.fixture(scope="module")
13 | def dataset(seed=31):
14 | rng = np.random.default_rng(seed)
15 | mu = rng.normal(size=(2, 7))
16 | theta = rng.normal(size=(2, 7, 2))
17 | eta = rng.normal(size=(2, 7, 3, 2))
18 |
19 | return dict_to_dataset(
20 | {"mu": mu, "theta": theta, "eta": eta},
21 | dims={"theta": ["hierarchy"], "eta": ["group", "hierarchy"]},
22 | )
23 |
24 |
25 | @pytest.fixture(scope="module")
26 | def matrix_da(seed=31):
27 | rng = np.random.default_rng(seed)
28 | var_coord = ["mu", "theta", "theta", "theta", "tau"]
29 | hierarchy_coord = [None, "a", "b", "c", None]
30 | return xr.DataArray(
31 | rng.normal(size=(5, 5)),
32 | dims=["row_index", "col_index"],
33 | coords={
34 | "var_name_x": (("col_index",), var_coord),
35 | "var_name_y": (("row_index",), var_coord),
36 | "hierarchy_x": (("col_index",), hierarchy_coord),
37 | "hierarchy_y": (("row_index",), hierarchy_coord),
38 | },
39 | )
40 |
41 |
42 | @pytest.mark.parametrize("subset", (["mu", {}], ["theta", {"hierarchy": "b"}]))
43 | def test_subset_matrix_da_diag(matrix_da, subset):
44 | da_subset = subset_matrix_da(matrix_da, subset[0], subset[1])
45 | assert not isinstance(da_subset, xr.DataArray)
46 |
47 |
48 | @pytest.mark.parametrize("subset_x", (["mu", {}], ["theta", {"hierarchy": "b"}]))
49 | @pytest.mark.parametrize("subset_y", (["tau", {}], ["theta", {"hierarchy": "c"}]))
50 | def test_subset_matrix_da_offdiag(matrix_da, subset_x, subset_y):
51 | da_subset = subset_matrix_da(
52 | matrix_da, subset_x[0], subset_x[1], var_name_y=subset_y[0], selection_y=subset_y[1]
53 | )
54 | assert not isinstance(da_subset, xr.DataArray)
55 |
56 |
57 | def test_plot_matrix_init(dataset):
58 | pc = PlotMatrix(dataset, ["__variable__", "hierarchy", "group"], backend="none")
59 | assert "plot" in pc.viz.data_vars
60 | coord_names = ("var_name_x", "var_name_y", "hierarchy_x", "hierarchy_y", "group_x", "group_y")
61 | missing_coord_names = [name for name in coord_names if name not in pc.viz["plot"].coords]
62 | assert not missing_coord_names, list(pc.viz["plot"].coords)
63 | assert pc.viz["plot"].sizes == {"row_index": 9, "col_index": 9}
64 |
65 |
66 | def test_plot_matrix_aes(dataset):
67 | pc = PlotMatrix(
68 | dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]}
69 | )
70 | assert "/color" in pc.aes.groups
71 | assert "mapping" in pc.aes["color"].data_vars
72 | assert "neutral_element" not in pc.aes["color"].data_vars
73 |
74 |
75 | # pylint: disable=unused-argument
76 | def map_auxiliar(da, target, target_list, kwarg_list, **kwargs):
77 | target_list.append(target)
78 | kwarg_list.append(kwargs)
79 | return 1
80 |
81 |
82 | # pylint: disable=unused-argument
83 | def map_auxiliar_couple(da_x, da_y, target, target_list, kwarg_list, **kwargs):
84 | target_list.append(target)
85 | kwarg_list.append(kwargs)
86 | return 1
87 |
88 |
89 | def test_plot_matrix_map(dataset):
90 | pc = PlotMatrix(
91 | dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]}
92 | )
93 | target_list = []
94 | kwarg_list = []
95 | pc.map(
96 | map_auxiliar,
97 | "aux",
98 | target_list=target_list,
99 | kwarg_list=kwarg_list,
100 | )
101 | assert all(len(aux_list) == 9 * 2 for aux_list in (target_list, kwarg_list))
102 | assert pc.viz["aux"].dims == ("row_index", "col_index", "chain")
103 | for i in range(9):
104 | for j in range(9):
105 | if i == j:
106 | assert all(
107 | elem is not None for elem in pc.viz["aux"].sel(row_index=i, col_index=j).values
108 | )
109 | else:
110 | assert all(
111 | elem is None for elem in pc.viz["aux"].sel(row_index=i, col_index=j).values
112 | )
113 |
114 |
115 | @pytest.mark.parametrize("triangle", ("both", "lower", "upper"))
116 | def test_plot_matrix_map_triangle(dataset, triangle):
117 | pc = PlotMatrix(
118 | dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]}
119 | )
120 | target_list = []
121 | kwarg_list = []
122 | pc.map_triangle(
123 | map_auxiliar_couple,
124 | "aux",
125 | target_list=target_list,
126 | kwarg_list=kwarg_list,
127 | triangle=triangle,
128 | )
129 | aux_len = sum(range(9)) * 2
130 | if triangle == "both":
131 | aux_len *= 2
132 | assert all(len(aux_list) == aux_len for aux_list in (target_list, kwarg_list))
133 | assert pc.viz["aux"].dims == ("row_index", "col_index", "chain")
134 | for i in range(9):
135 | for j in range(9):
136 | is_none = (elem is None for elem in pc.viz["aux"].sel(row_index=i, col_index=j).values)
137 | if i == j:
138 | assert all(is_none)
139 | elif i > j:
140 | if triangle in ("both", "lower"):
141 | assert not any(is_none)
142 | else:
143 | assert all(is_none)
144 | else:
145 | if triangle in ("both", "upper"):
146 | assert not any(is_none)
147 | else:
148 | assert all(is_none)
149 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist =
3 | check
4 | docs
5 | {py311,py312,py313}{,-coverage}
6 | # See https://tox.readthedocs.io/en/latest/example/package.html#flit
7 | isolated_build = True
8 | isolated_build_env = build
9 |
10 | [gh-actions]
11 | python =
12 | 3.11: py311-coverage
13 | 3.12: check, py312-coverage
14 | 3.13: py313-coverage
15 |
16 | [testenv]
17 | basepython =
18 | py311: python3.11
19 | py312: python3.12
20 | py313: python3.13
21 | # See https://github.com/tox-dev/tox/issues/1548
22 | {check,docs,cleandocs,viewdocs,build}: python3
23 | setenv =
24 | PYTHONUNBUFFERED = yes
25 | PYTEST_EXTRA_ARGS = -s
26 | coverage: PYTEST_EXTRA_ARGS = --cov --cov-report xml --cov-report term
27 | passenv =
28 | *
29 | extras =
30 | test
31 | matplotlib
32 | bokeh
33 | plotly
34 | commands =
35 | pytest {env:PYTEST_MARKERS:} {env:PYTEST_EXTRA_ARGS:} {posargs:-vv}
36 |
37 | [testenv:check]
38 | description = perform style checks
39 | deps =
40 | build
41 | pre-commit
42 | pylint
43 | skip_install = true
44 | commands =
45 | pre-commit install
46 | pre-commit run --all-files --show-diff-on-failure
47 | python -m build
48 |
49 | [testenv:docs]
50 | description = build HTML docs
51 | setenv =
52 | READTHEDOCS_PROJECT = arviz_plots
53 | READTHEDOCS_VERSION = latest
54 | extras =
55 | doc
56 | matplotlib
57 | bokeh
58 | plotly
59 | commands =
60 | sphinx-build -d "{toxworkdir}/docs_doctree" docs/source "{toxworkdir}/docs_out" --color -v -bhtml
61 |
62 | [testenv:nogallerydocs]
63 | description = build HTML docs
64 | setenv =
65 | READTHEDOCS_PROJECT = arviz_plots
66 | READTHEDOCS_VERSION = latest
67 | ARVIZDOCS_NOGALLERY = true
68 | extras =
69 | doc
70 | matplotlib
71 | bokeh
72 | plotly
73 | commands =
74 | sphinx-build -d "{toxworkdir}/docs_doctree" docs/source "{toxworkdir}/docs_out" --color -v -bhtml
75 |
76 | [testenv:cleandocs]
77 | description = clean HTML outputs docs
78 | skip_install = true
79 | allowlist_externals =
80 | rm
81 | find
82 | commands =
83 | find docs/source/gallery -maxdepth 1 -type f -name '*.md' -delete
84 | find docs/source/api/backend -maxdepth 1 -type f -name '*.rst' ! -name '*.part.rst' ! -name 'index.rst' ! -name '*.template.rst' -delete
85 | rm -r "{toxworkdir}/docs_out" "{toxworkdir}/docs_doctree" "{toxworkdir}/jupyter_execute" "{toxworkdir}/plot_directive"
86 | rm -r docs/source/api/generated docs/source/api/backend/generated docs/source/gallery/_images docs/source/gallery/_scripts
87 | rm docs/source/gallery/backreferences.json
88 |
89 | [testenv:viewdocs]
90 | description = open HTML docs
91 | skip_install = true
92 | commands =
93 | python -m webbrowser "{toxworkdir}/docs_out/index.html"
94 |
--------------------------------------------------------------------------------