├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ ├── hypothesis.yml │ ├── post-release.yml │ ├── publish.yml │ ├── rtd-preview.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── .readthedocs.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── References.md ├── docs ├── source │ ├── _static │ │ ├── ArviZ.png │ │ ├── ArviZ_white.png │ │ ├── bokeh-logo-dark.svg │ │ ├── bokeh-logo-light.svg │ │ ├── custom.css │ │ ├── favicon.ico │ │ ├── matplotlib-logo-dark.svg │ │ ├── matplotlib-logo-light.svg │ │ ├── none-logo-light.png │ │ ├── plotly-logo-dark.png │ │ └── plotly-logo-light.png │ ├── _templates │ │ └── name.html │ ├── api │ │ ├── backend │ │ │ ├── bokeh.part.rst │ │ │ ├── index.rst │ │ │ ├── interface.template.rst │ │ │ ├── matplotlib.part.rst │ │ │ ├── none.part.rst │ │ │ └── plotly.part.rst │ │ ├── helpers.rst │ │ ├── index.md │ │ ├── managers.rst │ │ ├── plots.rst │ │ └── visuals.rst │ ├── conf.py │ ├── contributing │ │ ├── docs.md │ │ ├── new_plot.md │ │ └── testing.md │ ├── gallery │ │ ├── distribution │ │ │ ├── 00_plot_dist_ecdf.py │ │ │ ├── 01_plot_dist_hist.py │ │ │ ├── 02_plot_dist_kde.py │ │ │ ├── 04_plot_forest.py │ │ │ ├── 05_plot_forest_shade.py │ │ │ ├── 06_plot_prior_posterior.py │ │ │ └── 07_plot_pairs_focus_distribution.py │ │ ├── inference_diagnostics │ │ │ ├── 00_plot_rank.py │ │ │ ├── 01_plot_trace.py │ │ │ ├── 02_plot_ess_evolution.py │ │ │ ├── 03_plot_ess_local.py │ │ │ ├── 04_plot_ess_quantile.py │ │ │ ├── 05_plot_ess_models.py │ │ │ ├── 05_plot_mcse.py │ │ │ ├── 06_plot_convergence_dist.py │ │ │ ├── 07_plot_autocorr.py │ │ │ ├── 08_plot_energy.py │ │ │ └── 09_plot_pairs_focus.py │ │ ├── mixed │ │ │ ├── 00_plot_rank_dist.py │ │ │ ├── 01_plot_trace_dist.py │ │ │ ├── 02_plot_forest_ess.py │ │ │ └── 03_combine_plots.py │ │ ├── model_comparison │ │ │ ├── 00_plot_compare.py │ │ │ └── 99_plot_bf.py │ │ ├── posterior_comparison │ │ │ ├── 00_plot_dist_models.py │ │ │ └── 01_plot_forest_models.py │ │ ├── predictive_checks │ │ │ ├── 00_plot_ppc_dist.py │ │ │ ├── 01_plot_ppc_rootogram.py │ │ │ ├── 03_plot_pava_calibration.py │ │ │ ├── 04_plot_ppc_pit.py │ │ │ ├── 05_plot_ppc_coverage.py │ │ │ ├── 06_plot_loo_pit.py │ │ │ ├── 07_plot_ppc_tstat.py │ │ │ └── 99_plot_forest_pp_obs.py │ │ ├── prior_and_likelihood_sensitivity_checks │ │ │ ├── 00_plot_psense.py │ │ │ └── 01_plot_psense_quantities.py │ │ ├── sbc │ │ │ ├── 00_plot_ecdf_pit.py │ │ │ └── 01_plot_ecdf_coverage.py │ │ └── utils │ │ │ ├── 00_add_reference_lines.py │ │ │ └── 01_add_reference_bands.py │ ├── glossary.md │ ├── index.md │ └── tutorials │ │ ├── compose_own_plot.ipynb │ │ ├── intro_to_plotcollection.ipynb │ │ ├── overview.ipynb │ │ └── plots_intro.ipynb └── sphinxext │ └── gallery_generator.py ├── pyproject.toml ├── src └── arviz_plots │ ├── __init__.py │ ├── _version.py │ ├── backend │ ├── __init__.py │ ├── bokeh │ │ ├── __init__.py │ │ └── legend.py │ ├── matplotlib │ │ ├── __init__.py │ │ └── legend.py │ ├── none │ │ ├── __init__.py │ │ └── legend.py │ └── plotly │ │ ├── __init__.py │ │ ├── legend.py │ │ └── templates.py │ ├── plot_collection.py │ ├── plot_matrix.py │ ├── plots │ ├── __init__.py │ ├── autocorr_plot.py │ ├── bf_plot.py │ ├── combine.py │ ├── compare_plot.py │ ├── convergence_dist_plot.py │ ├── dist_plot.py │ ├── ecdf_plot.py │ ├── energy_plot.py │ ├── ess_plot.py │ ├── evolution_plot.py │ ├── forest_plot.py │ ├── loo_pit_plot.py │ ├── mcse_plot.py │ ├── pairs_focus_plot.py │ ├── pava_calibration_plot.py │ ├── ppc_dist_plot.py │ ├── ppc_pit_plot.py │ ├── ppc_rootogram_plot.py │ ├── ppc_tstat.py │ ├── prior_posterior_plot.py │ ├── psense_dist_plot.py │ ├── psense_quantities_plot.py │ ├── rank_dist_plot.py │ ├── rank_plot.py │ ├── ridge_plot.py │ ├── trace_dist_plot.py │ ├── trace_plot.py │ └── utils.py │ ├── py.typed │ ├── style.py │ ├── styles │ ├── arviz-cetrino.mplstyle │ ├── arviz-cetrino.yml │ ├── arviz-variat.mplstyle │ ├── arviz-variat.yml │ ├── arviz-vibrant.mplstyle │ └── arviz-vibrant.yml │ └── visuals │ └── __init__.py ├── tests ├── __init__.py ├── conftest.py ├── test_fixtures.py ├── test_hypothesis_plots.py ├── test_plot_collection.py ├── test_plot_matrix.py └── test_plots.py └── tox.ini /.gitattributes: -------------------------------------------------------------------------------- 1 | # SCM syntax highlighting & preventing 3-way merges 2 | pixi.lock merge=binary linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/hypothesis.yml: -------------------------------------------------------------------------------- 1 | name: Run extended tests with hypothesis 2 | on: 3 | schedule: 4 | - cron: '17 5 * * 1' 5 | workflow_dispatch: 6 | 7 | permissions: 8 | issues: write 9 | 10 | jobs: 11 | hypothesis_testing: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.11' 19 | - name: Install arviz-plots 20 | run: | 21 | python -m pip install ".[test]" 22 | - name: Execute tests 23 | run: | 24 | pytest --hypothesis-profile chron -k hypothesis 25 | echo "DATE=$(date +'%Y-%m-%d %H:%M %z')" >> ${GITHUB_ENV} 26 | - name: Comment on issue if failed 27 | if: failure() 28 | uses: peter-evans/create-or-update-comment@v4 29 | with: 30 | issue-number: 43 31 | body: | 32 | The extended tests with hypothesis failed. 33 | 34 | * Branch: ${{ github.ref_name }} 35 | * Date: ${{ env.DATE }} 36 | 37 | See [workflow logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details on which tests failed and why. 38 | -------------------------------------------------------------------------------- /.github/workflows/post-release.yml: -------------------------------------------------------------------------------- 1 | name: Post-release 2 | on: 3 | release: 4 | types: [published, released] 5 | workflow_dispatch: 6 | 7 | jobs: 8 | changelog: 9 | name: Update changelog 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | ref: main 15 | - uses: rhysd/changelog-from-release/action@v3 16 | with: 17 | file: CHANGELOG.md 18 | github_token: ${{ secrets.GITHUB_TOKEN }} 19 | commit_summary_template: 'update changelog for %s changes' 20 | pull_request: true 21 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish library 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | # Don't try to be smart about PEP 440 compliance, 9 | # see https://www.python.org/dev/peps/pep-0440/#appendix-b-parsing-version-strings-with-regular-expressions 10 | - v* 11 | 12 | jobs: 13 | build-package: 14 | runs-on: ubuntu-latest 15 | permissions: 16 | # write attestations and id-token are necessary for attest-build-provenance-github 17 | attestations: write 18 | id-token: write 19 | steps: 20 | - uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | persist-credentials: false 24 | - uses: hynek/build-and-inspect-python-package@v2 25 | with: 26 | # Prove that the packages were built in the context of this workflow. 27 | attest-build-provenance-github: true 28 | publish: 29 | runs-on: ubuntu-latest 30 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 31 | # Use the `release` GitHub environment to protect the Trusted Publishing (OIDC) 32 | # workflow by requiring signoff from a maintainer. 33 | environment: 34 | name: publish 35 | url: https://pypi.org/p/arviz-plots 36 | needs: build-package 37 | permissions: 38 | # write id-token is necessary for trusted publishing (OIDC) 39 | id-token: write 40 | steps: 41 | - name: Download Distribution Artifacts 42 | uses: actions/download-artifact@v4 43 | with: 44 | # The build-and-inspect-python-package action invokes upload-artifact. 45 | # These are the correct arguments from that action. 46 | name: Packages 47 | path: dist 48 | - name: Publish to PyPI 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | -------------------------------------------------------------------------------- /.github/workflows/rtd-preview.yml: -------------------------------------------------------------------------------- 1 | name: Read the Docs Pull Request Preview 2 | on: 3 | pull_request_target: 4 | types: 5 | - opened 6 | 7 | permissions: 8 | pull-requests: write 9 | 10 | jobs: 11 | documentation-links: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: readthedocs/actions/preview@v1 15 | with: 16 | project-slug: "arviz-plots" 17 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | on: 3 | pull_request: 4 | push: 5 | branches: [main] 6 | paths-ignore: 7 | - "docs/" 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.11", "3.12"] 15 | fail-fast: false 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install tox tox-gh-actions 26 | - name: Test with tox 27 | run: tox 28 | - name: Upload coverage to Codecov 29 | uses: codecov/codecov-action@v5 30 | with: 31 | name: Python ${{ matrix.python-version }} 32 | fail_ci_if_error: false 33 | token: ${{ secrets.CODECOV_TOKEN }} 34 | verbose: true 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | .virtual_documents 12 | */.virtual_documents/* 13 | 14 | # IPython 15 | profile_default/ 16 | ipython_config.py 17 | 18 | # Remove previous ipynb_checkpoints 19 | # git rm -r .ipynb_checkpoints/ 20 | 21 | ### Python ### 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | cover/ 74 | 75 | # default test image folder 76 | test_images/ 77 | 78 | # Translations 79 | *.mo 80 | *.pot 81 | 82 | # Django stuff: 83 | *.log 84 | local_settings.py 85 | db.sqlite3 86 | db.sqlite3-journal 87 | 88 | # Flask stuff: 89 | instance/ 90 | .webassets-cache 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | docs/build 98 | docs/source/api/**/generated 99 | docs/source/gallery/_images 100 | docs/source/gallery/_scripts 101 | docs/source/gallery/*.md 102 | docs/source/gallery/backreferences.json 103 | docs/jupyter_execute 104 | docs/source/api/backend/*.rst 105 | !docs/source/api/backend/*.part.rst 106 | !docs/source/api/backend/index.rst 107 | 108 | # PyBuilder 109 | .pybuilder/ 110 | target/ 111 | 112 | # Jupyter Notebook 113 | 114 | # IPython 115 | 116 | # pyenv 117 | # For a library or package, you might want to ignore these files since the code is 118 | # intended to run in multiple environments; otherwise, check them in: 119 | # .python-version 120 | 121 | # pipenv 122 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 123 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 124 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 125 | # install all needed dependencies. 126 | #Pipfile.lock 127 | 128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 129 | __pypackages__/ 130 | 131 | # Celery stuff 132 | celerybeat-schedule 133 | celerybeat.pid 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # macos 172 | .DS_Store 173 | 174 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 175 | 176 | # pixi environments 177 | .pixi 178 | *.egg-info 179 | ## TO REMOVE ONCE WE HAVE THE MONOREPO 180 | pixi.lock 181 | pixi.toml 182 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-added-large-files 6 | args: ['--maxkb=1500'] 7 | - id: check-merge-conflict 8 | 9 | - repo: https://github.com/PyCQA/isort 10 | rev: 5.12.0 11 | hooks: 12 | - id: isort 13 | exclude: ^src/arviz_base/example_data/ 14 | 15 | - repo: https://github.com/psf/black 16 | rev: 23.3.0 17 | hooks: 18 | - id: black 19 | exclude: ^docs/source/gallery/ 20 | 21 | - repo: https://github.com/pycqa/pydocstyle 22 | rev: 6.3.0 23 | hooks: 24 | - id: pydocstyle 25 | additional_dependencies: [tomli] 26 | files: ^src/arviz_plots/.+\.py$ 27 | 28 | - repo: local 29 | hooks: 30 | - id: pylint 31 | name: pylint 32 | entry: pylint 33 | language: system 34 | types: [python] 35 | args: 36 | [ 37 | "-rn", # Only display messages 38 | "-sn", # Don't display the score 39 | ] 40 | exclude: ^docs/source/gallery/ 41 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | version: 2 4 | 5 | build: 6 | os: ubuntu-24.04 7 | tools: 8 | python: "3.12" 9 | 10 | sphinx: 11 | fail_on_warning: True 12 | configuration: docs/source/conf.py 13 | 14 | 15 | python: 16 | install: 17 | - method: pip 18 | path: . 19 | extra_requirements: 20 | - doc 21 | - matplotlib 22 | - bokeh 23 | - plotly 24 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | # [v0.5.0](https://github.com/arviz-devs/arviz-plots/releases/tag/v0.5.0) - 2025-03-21 3 | 4 | ## What's Changed 5 | * Randomized pit by [@aloctavodia](https://github.com/aloctavodia) in [#167](https://github.com/arviz-devs/arviz-plots/pull/167) 6 | * Add missing visuals to docs by [@aloctavodia](https://github.com/aloctavodia) in [#168](https://github.com/arviz-devs/arviz-plots/pull/168) 7 | * Make ecdf_line a step line by [@aloctavodia](https://github.com/aloctavodia) in [#169](https://github.com/arviz-devs/arviz-plots/pull/169) 8 | * Properly wrap columns and reduce duplicated code by [@aloctavodia](https://github.com/aloctavodia) in [#170](https://github.com/arviz-devs/arviz-plots/pull/170) 9 | * Added plot_bf() function for bayes_factor in arviz-plots by [@PiyushPanwarFST](https://github.com/PiyushPanwarFST) in [#158](https://github.com/arviz-devs/arviz-plots/pull/158) 10 | * Update utils.py by [@aloctavodia](https://github.com/aloctavodia) in [#171](https://github.com/arviz-devs/arviz-plots/pull/171) 11 | * Add plot_rank by [@aloctavodia](https://github.com/aloctavodia) in [#172](https://github.com/arviz-devs/arviz-plots/pull/172) 12 | * Add plot ecdf pit plot by [@aloctavodia](https://github.com/aloctavodia) in [#173](https://github.com/arviz-devs/arviz-plots/pull/173) 13 | * fix regression bug plot_psense_quantities by [@aloctavodia](https://github.com/aloctavodia) in [#175](https://github.com/arviz-devs/arviz-plots/pull/175) 14 | * Improve Data Generation & Plot Tests: by [@PiyushPanwarFST](https://github.com/PiyushPanwarFST) in [#177](https://github.com/arviz-devs/arviz-plots/pull/177) 15 | * Adds square root scale for bokeh and plotly by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#178](https://github.com/arviz-devs/arviz-plots/pull/178) 16 | * Update plot_compare by [@aloctavodia](https://github.com/aloctavodia) in [#181](https://github.com/arviz-devs/arviz-plots/pull/181) 17 | * plot_converge_dist: Add grouped argument by [@aloctavodia](https://github.com/aloctavodia) in [#182](https://github.com/arviz-devs/arviz-plots/pull/182) 18 | * adds sqrt scale for yaxis in plotly by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#183](https://github.com/arviz-devs/arviz-plots/pull/183) 19 | * Add coverage argument to pit plots by [@aloctavodia](https://github.com/aloctavodia) in [#185](https://github.com/arviz-devs/arviz-plots/pull/185) 20 | 21 | 22 | ## New Contributors 23 | * [@github-actions](https://github.com/github-actions) made their first contribution in [#165](https://github.com/arviz-devs/arviz-plots/pull/165) 24 | * [@PiyushPanwarFST](https://github.com/PiyushPanwarFST) made their first contribution in [#158](https://github.com/arviz-devs/arviz-plots/pull/158) 25 | 26 | **Full Changelog**: https://github.com/arviz-devs/arviz-plots/compare/v0.4.0...v0.5.0 27 | 28 | [Changes][v0.5.0] 29 | 30 | 31 | 32 | # [v0.4.0](https://github.com/arviz-devs/arviz-plots/releases/tag/v0.4.0) - 2025-03-05 33 | 34 | ## What's Changed 35 | * move out new_ds to arviz-stats by [@aloctavodia](https://github.com/aloctavodia) in [#102](https://github.com/arviz-devs/arviz-plots/pull/102) 36 | * update version, dependencies and CI by [@OriolAbril](https://github.com/OriolAbril) in [#110](https://github.com/arviz-devs/arviz-plots/pull/110) 37 | * Use DataTree class from xarray by [@OriolAbril](https://github.com/OriolAbril) in [#111](https://github.com/arviz-devs/arviz-plots/pull/111) 38 | * Update pyproject.toml by [@OriolAbril](https://github.com/OriolAbril) in [#113](https://github.com/arviz-devs/arviz-plots/pull/113) 39 | * Add energy plot by [@aloctavodia](https://github.com/aloctavodia) in [#108](https://github.com/arviz-devs/arviz-plots/pull/108) 40 | * Add plot for distribution of convergence diagnostics by [@aloctavodia](https://github.com/aloctavodia) in [#105](https://github.com/arviz-devs/arviz-plots/pull/105) 41 | * Add separated prior and likelihood groups by [@aloctavodia](https://github.com/aloctavodia) in [#117](https://github.com/arviz-devs/arviz-plots/pull/117) 42 | * Add psense_quantities plot by [@aloctavodia](https://github.com/aloctavodia) in [#119](https://github.com/arviz-devs/arviz-plots/pull/119) 43 | * Rename arviz-clean to arviz-variat by [@aloctavodia](https://github.com/aloctavodia) in [#120](https://github.com/arviz-devs/arviz-plots/pull/120) 44 | * add cetrino and vibrant styles for plotly by [@aloctavodia](https://github.com/aloctavodia) in [#121](https://github.com/arviz-devs/arviz-plots/pull/121) 45 | * psense: fix facetting and add xlabel by [@aloctavodia](https://github.com/aloctavodia) in [#123](https://github.com/arviz-devs/arviz-plots/pull/123) 46 | * Add summary dictionary arguments by [@aloctavodia](https://github.com/aloctavodia) in [#125](https://github.com/arviz-devs/arviz-plots/pull/125) 47 | * plotly: change format of title update in backend by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#124](https://github.com/arviz-devs/arviz-plots/pull/124) 48 | * Update glossary.md by [@aloctavodia](https://github.com/aloctavodia) in [#126](https://github.com/arviz-devs/arviz-plots/pull/126) 49 | * Add PAV-adjusted calibration plot by [@aloctavodia](https://github.com/aloctavodia) in [#127](https://github.com/arviz-devs/arviz-plots/pull/127) 50 | * Upper bound plotly by [@aloctavodia](https://github.com/aloctavodia) in [#128](https://github.com/arviz-devs/arviz-plots/pull/128) 51 | * Use isotonic function that works with datatrees by [@aloctavodia](https://github.com/aloctavodia) in [#131](https://github.com/arviz-devs/arviz-plots/pull/131) 52 | * plot_pava_calibrarion: Add reference, fix xlabel by [@aloctavodia](https://github.com/aloctavodia) in [#132](https://github.com/arviz-devs/arviz-plots/pull/132) 53 | * Fix bug when setting some plot_kwargs to false by [@aloctavodia](https://github.com/aloctavodia) in [#134](https://github.com/arviz-devs/arviz-plots/pull/134) 54 | * Add citations by [@aloctavodia](https://github.com/aloctavodia) in [#135](https://github.com/arviz-devs/arviz-plots/pull/135) 55 | * use <6 version of plotly for documentation and use latest for other purposes by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#136](https://github.com/arviz-devs/arviz-plots/pull/136) 56 | * fix see also pava gallery by [@aloctavodia](https://github.com/aloctavodia) in [#137](https://github.com/arviz-devs/arviz-plots/pull/137) 57 | * Add plot_ppc_dist by [@aloctavodia](https://github.com/aloctavodia) in [#138](https://github.com/arviz-devs/arviz-plots/pull/138) 58 | * Add warning message for discrete data by [@aloctavodia](https://github.com/aloctavodia) in [#139](https://github.com/arviz-devs/arviz-plots/pull/139) 59 | * rename plot_pava and minor fixes by [@aloctavodia](https://github.com/aloctavodia) in [#140](https://github.com/arviz-devs/arviz-plots/pull/140) 60 | * Plotly: Fix excesive margins by [@aloctavodia](https://github.com/aloctavodia) in [#141](https://github.com/arviz-devs/arviz-plots/pull/141) 61 | * add arviz-style for bokeh by [@aloctavodia](https://github.com/aloctavodia) in [#122](https://github.com/arviz-devs/arviz-plots/pull/122) 62 | * Add Plot ppc rootogram by [@aloctavodia](https://github.com/aloctavodia) in [#142](https://github.com/arviz-devs/arviz-plots/pull/142) 63 | * plot_ppc_rootogram: fix examples by [@aloctavodia](https://github.com/aloctavodia) in [#144](https://github.com/arviz-devs/arviz-plots/pull/144) 64 | * Reorganize categories in the gallery by [@aloctavodia](https://github.com/aloctavodia) in [#145](https://github.com/arviz-devs/arviz-plots/pull/145) 65 | * remove plots from titles by [@aloctavodia](https://github.com/aloctavodia) in [#146](https://github.com/arviz-devs/arviz-plots/pull/146) 66 | * Consistence use of data_pairs, remove default markers from pava by [@aloctavodia](https://github.com/aloctavodia) in [#152](https://github.com/arviz-devs/arviz-plots/pull/152) 67 | * added functionality of step histogram for all three backends by [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) in [#147](https://github.com/arviz-devs/arviz-plots/pull/147) 68 | * Use continuous outcome for plot_ppc_dist example by [@aloctavodia](https://github.com/aloctavodia) in [#154](https://github.com/arviz-devs/arviz-plots/pull/154) 69 | * add grid visual by [@aloctavodia](https://github.com/aloctavodia) in [#155](https://github.com/arviz-devs/arviz-plots/pull/155) 70 | * Add plot_ppc_pit by [@aloctavodia](https://github.com/aloctavodia) in [#159](https://github.com/arviz-devs/arviz-plots/pull/159) 71 | * plot_ppc_pava: Change default xlabel by [@aloctavodia](https://github.com/aloctavodia) in [#161](https://github.com/arviz-devs/arviz-plots/pull/161) 72 | 73 | ## New Contributors 74 | * [@The-Broken-Keyboard](https://github.com/The-Broken-Keyboard) made their first contribution in [#124](https://github.com/arviz-devs/arviz-plots/pull/124) 75 | 76 | **Full Changelog**: https://github.com/arviz-devs/arviz-plots/compare/v0.3.0...v0.4.0 77 | 78 | [Changes][v0.4.0] 79 | 80 | 81 | [v0.5.0]: https://github.com/arviz-devs/arviz-plots/compare/v0.4.0...v0.5.0 82 | [v0.4.0]: https://github.com/arviz-devs/arviz-plots/tree/v0.4.0 83 | 84 | 85 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # ArviZ Community Code of Conduct 2 | 3 | ArviZ adopts the NumFOCUS Code of Conduct directly. In other words, we 4 | expect our community to treat others with kindness and understanding. 5 | 6 | 7 | # THE SHORT VERSION 8 | Be kind to others. Do not insult or put down others. 9 | Behave professionally. Remember that harassment and sexist, racist, 10 | or exclusionary jokes are not appropriate. 11 | 12 | All communication should be appropriate for a professional audience 13 | including people of many different backgrounds. Sexual language and 14 | imagery are not appropriate. 15 | 16 | ArviZ is dedicated to providing a harassment-free community for everyone, 17 | regardless of gender, sexual orientation, gender identity, and 18 | expression, disability, physical appearance, body size, race, 19 | or religion. We do not tolerate harassment of community members 20 | in any form. 21 | 22 | Thank you for helping make this a welcoming, friendly community for all. 23 | 24 | 25 | # How to Submit a Report 26 | If you feel that there has been a Code of Conduct violation an anonymous 27 | reporting form is available. 28 | **If you feel your safety is in jeopardy or the situation is an 29 | emergency, we urge you to contact local law enforcement before making 30 | a report. (In the U.S., dial 911.)** 31 | 32 | We are committed to promptly addressing any reported issues. 33 | If you have experienced or witnessed behavior that violates this 34 | Code of Conduct, please complete the form below to 35 | make a report. 36 | 37 | **REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT 38 | 39 | Reports are sent to the NumFOCUS Code of Conduct Enforcement Team 40 | (see below). 41 | 42 | You can view the Privacy Policy and Terms of Service for TypeForm here. 43 | The NumFOCUS Privacy Policy is here: 44 | https://www.numfocus.org/privacy-policy 45 | 46 | 47 | # Full Code of Conduct 48 | The full text of the NumFOCUS/ArviZ Code of Conduct can be found on 49 | NumFOCUS's website 50 | https://numfocus.org/code-of-conduct 51 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## Before contributing 4 | 5 | Welcome to arviz-plots! Before contributing to the project, 6 | make sure that you **read our code of conduct** (`CODE_OF_CONDUCT.md`). 7 | 8 | ## Contributing code 9 | 10 | 1. Set up a Python development environment 11 | (advice: use [venv](https://docs.python.org/3/library/venv.html), 12 | [virtualenv](https://virtualenv.pypa.io/), or [miniconda](https://docs.conda.io/en/latest/miniconda.html)) 13 | 2. Install tox: `python -m pip install tox` 14 | 3. Clone the repository 15 | 4. Start a new branch off main: `git switch -c new-branch main` 16 | 5. Make your code changes 17 | 6. Check that your code follows the style guidelines of the project: `tox -e check` 18 | 7. (optional) Build the documentation: `tox -e docs` 19 | 8. (optional) Run the tests: `tox -e py310` 20 | (change the version number according to the Python you are using) 21 | 9. Commit, push, and open a pull request! 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # arviz-plots 2 | 3 | [![Run tests](https://github.com/arviz-devs/arviz-plots/actions/workflows/test.yml/badge.svg)](https://github.com/arviz-devs/arviz-plots/actions/workflows/test.yml) 4 | [![codecov](https://codecov.io/gh/arviz-devs/arviz-plots/graph/badge.svg?token=1VIPLXCOJQ)](https://codecov.io/gh/arviz-devs/arviz-plots) 5 | [![Powered by NumFOCUS](https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org) 6 | 7 | ArviZ plotting elements and static battery included plots 8 | 9 | We are currently working on splitting ArviZ into independent modules. 10 | See https://github.com/arviz-devs/arviz/issues/2088 for more details. 11 | -------------------------------------------------------------------------------- /References.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | This is a list of references for the methods implemented in ArviZ (including base/stats/plots). 4 | The references are organized by diagnosis. The format is the one used for docstrings. 5 | We could have references in BibTeX format, and use https://sphinxcontrib-bibtex.readthedocs.io/en/latest/ 6 | to include them in the docstring of each function and in a single webpage. 7 | See https://github.com/arviz-devs/arviz-stats/issues/56 8 | 9 | 10 | ## Psense 11 | 12 | References 13 | ---------- 14 | .. [1] Kallioinen et al. *Detecting and diagnosing prior and likelihood sensitivity with 15 | power-scaling*, Stat Comput 34(57) (2024), https://doi.org/10.1007/s11222-023-10366-5 16 | 17 | 18 | ## LOO 19 | 20 | References 21 | ---------- 22 | .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation 23 | and WAIC*. Statistics and Computing. 27(5) (2017). https://doi.org/10.1007/s11222-016-9696-4. 24 | arXiv preprint https://arxiv.org/abs/1507.04544. 25 | 26 | .. [2] Yao et al. *Using stacking to average Bayesian predictive distributions* 27 | Bayesian Analysis, 13, 3 (2018). https://doi.org/10.1214/17-BA1091 28 | arXiv preprint https://arxiv.org/abs/1704.02030. 29 | 30 | .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. 31 | Journal of Machine Learning Research, 25(72) (2024). https://jmlr.org/papers/v25/19-556.html. 32 | arXiv preprint https://arxiv.org/abs/1507.02646 33 | 34 | .. [4] Magnusson et al. *Bayesian Leave-One-Out Cross-Validation for Large Data.* 35 | Proceedings of the 36th International Conference on Machine Learning, 97 (2019) 36 | https://proceedings.mlr.press/v97/magnusson19a.html 37 | arXiv preprint https://arxiv.org/abs/1904.10679 38 | 39 | 40 | ## R-hat, ESS, MCSE 41 | 42 | References 43 | ---------- 44 | .. [1] Vehtari et al. *Rank-normalization, folding, and localization: An improved Rhat for 45 | assessing convergence of MCMC*. Bayesian Analysis. 16(2) (2021) 46 | https://doi.org/10.1214/20-BA1221. arXiv preprint https://arxiv.org/abs/1903.08008 47 | 48 | 49 | ## PAV-adjusted calibration plot 50 | 51 | References 52 | ---------- 53 | .. [1] Dimitriadis et al. *Stable reliability diagrams for probabilistic classifiers*. 54 | PNAS, 118(8) (2021). https://doi.org/10.1073/pnas.2016191118 55 | 56 | .. [2] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*. 57 | (2025) arXiv preprint https://arxiv.org/abs/2503.01509 58 | 59 | 60 | ## Energy plot and divergences 61 | 62 | References 63 | ---------- 64 | .. [1] Betancourt. *Diagnosing Suboptimal Cotangent Disintegrations in 65 | Hamiltonian Monte Carlo*. (2016) https://arxiv.org/abs/1604.00695 66 | 67 | ## Rootograms 68 | 69 | References 70 | ---------- 71 | .. [1] Kleiber et al. *Visualizing Count Data Regressions Using Rootograms*. 72 | The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590 73 | 74 | .. [2] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*. 75 | (2025) arXiv preprint https://arxiv.org/abs/2503.01509 76 | 77 | ## ECDF-pit 78 | 79 | References 80 | ---------- 81 | .. [1] Säilynoja et al. *Graphical test for discrete uniformity and 82 | its applications in goodness-of-fit evaluation and multiple sample comparison*. 83 | Statistics and Computing 32(32). (2022) https://doi.org/10.1007/s11222-022-10090-6 84 | 85 | ## Bayesian R2 86 | 87 | References 88 | ---------- 89 | .. [1] Gelman et al. *R-squared for Bayesian regression models*. 90 | The American Statistician, 73(3). (2019) https://doi:10.1080/00031305.2018.1549100 91 | preprint http://www.stat.columbia.edu/~gelman/research/unpublished/bayes_R2_v3.pdf 92 | 93 | -------------------------------------------------------------------------------- /docs/source/_static/ArviZ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/ArviZ.png -------------------------------------------------------------------------------- /docs/source/_static/ArviZ_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/ArviZ_white.png -------------------------------------------------------------------------------- /docs/source/_static/bokeh-logo-dark.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_static/bokeh-logo-light.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* -------------------- THEME OVERRIDES -------------------- */ 2 | html[data-theme="light"] { 3 | --pst-color-primary: rgb(11 117 145); 4 | --pst-color-secondary: rgb(238 144 64); 5 | --pst-color-plot-background: rgb(255, 255, 255); 6 | } 7 | 8 | html[data-theme="dark"] { 9 | --pst-color-primary: rgb(0 192 191); 10 | --pst-color-secondary: rgb(238 144 64); 11 | --pst-color-plot-background: rgb(218, 219, 220); 12 | } 13 | 14 | /* Override for example gallery - remove border around card */ 15 | .bd-content div.sd-card.example-gallery { 16 | border: none; 17 | } 18 | 19 | /* -------------------- EXAMPLE GALLERY + (homepage) -------------------- */ 20 | 21 | /* Homepage - grid layout */ 22 | .home-flex-grid { 23 | display: flex; 24 | flex-flow: row wrap; 25 | justify-content: center; 26 | gap: 10px; 27 | padding: 20px 0px 40px; 28 | } 29 | 30 | /* Homepage + Example Gallery Body - Set dimensions */ 31 | .home-img-plot, 32 | .bd-content div.sd-card.example-gallery .sd-card-body, 33 | .home-img-plot-overlay, 34 | .example-img-plot-overlay { 35 | display: flex; 36 | justify-content: center; 37 | align-items: center; 38 | overflow: hidden; 39 | padding: 10px; 40 | } 41 | .home-img-plot, 42 | .home-img-plot-overlay { 43 | width: 235px; 44 | height: 130px; 45 | } 46 | .bd-content div.sd-card.example-gallery .sd-card-body, 47 | .example-img-plot-overlay { 48 | width: 100%; 49 | height: 150px; 50 | } 51 | .home-img-plot img, 52 | .bd-content div.sd-card.example-gallery .sd-card-body img { 53 | /* Images keep aspect ratio and fit in container */ 54 | /* To make images stretch/fill container, change to min-width */ 55 | max-width: 100%; 56 | max-height: 100%; 57 | } 58 | 59 | /* Homepage + Example Gallery Body - Set color and hover */ 60 | .home-img-plot.img-thumbnail, 61 | .bd-content div.sd-card.example-gallery .sd-card-body { 62 | background-color: var(--pst-color-plot-background); /* Same as img-thumbnail from pydata css, adjusted for dark mode */ 63 | } 64 | .home-img-plot-overlay, 65 | .example-img-plot-overlay, 66 | .bd-content div.sd-card.example-gallery .sd-card-body { 67 | border: 1px solid #dee2e6; /* Same as img-thumbnail from pydata css */ 68 | border-radius: 0.25rem; /* Same as img-thumbnail from pydata css */ 69 | } 70 | .home-img-plot-overlay, 71 | .example-img-plot-overlay, 72 | .example-img-plot-overlay p.sd-card-text { 73 | background: var(--pst-color-primary); 74 | position: absolute; 75 | color: var(--pst-color-background); 76 | opacity: 0; 77 | transition: .2s ease; 78 | text-align: center; 79 | padding: 10px; 80 | z-index: 998; /* Make sure overlay is above image...this is here to handle dark mode */ 81 | } 82 | .home-img-plot-overlay:hover, 83 | .bd-content div.sd-card.example-gallery:hover .example-img-plot-overlay, 84 | .example-img-plot-overlay p.sd-card-text { 85 | opacity: 90%; 86 | } 87 | 88 | /* Example Gallery Body - Set syntax highlighting for code on hover */ 89 | .example-img-plot-overlay .sd-card-text code.code { 90 | background-color: var(--pst-color-background); 91 | } 92 | .example-img-plot-overlay .sd-card-text .code span.pre { 93 | color: var(--pst-color-primary); 94 | font-weight: 700; 95 | } 96 | 97 | /* Example Gallery Footer - Plot titles goes here */ 98 | .example-gallery .sd-card-footer { 99 | height: 40px; 100 | padding: 5px; 101 | border-top: none !important; 102 | } 103 | .example-gallery .sd-card-footer p.sd-card-text { 104 | color: var(--pst-color-text-muted); 105 | font-size: 1rem; /* This is font size for plot titles (below the figure) */ 106 | font-weight: 700; 107 | } 108 | .sd-card.example-gallery:hover .sd-card-footer p.sd-card-text { 109 | color: var(--pst-color-primary); /* Change text color on hover over card */ 110 | } 111 | 112 | /* Overlapping */ 113 | .example-gallery a.sd-stretched-link.reference { 114 | z-index: 999; /* Countermeasure where z-index = 998 */ 115 | } 116 | 117 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/none-logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/none-logo-light.png -------------------------------------------------------------------------------- /docs/source/_static/plotly-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/plotly-logo-dark.png -------------------------------------------------------------------------------- /docs/source/_static/plotly-logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arviz-devs/arviz-plots/1856806a25750e264c42403ec6eb4c45fcc61c78/docs/source/_static/plotly-logo-light.png -------------------------------------------------------------------------------- /docs/source/_templates/name.html: -------------------------------------------------------------------------------- 1 |
2 |
BASE
3 |
STATS
4 |
PLOTS
5 |
6 | -------------------------------------------------------------------------------- /docs/source/api/backend/bokeh.part.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Bokeh backend 3 | ============= 4 | 5 | .. automodule:: arviz_plots.backend.bokeh 6 | -------------------------------------------------------------------------------- /docs/source/api/backend/index.rst: -------------------------------------------------------------------------------- 1 | ============================== 2 | Interface to plotting backends 3 | ============================== 4 | 5 | ------------------ 6 | Available backends 7 | ------------------ 8 | 9 | .. grid:: 1 1 2 2 10 | 11 | .. grid-item-card:: 12 | :link: matplotlib 13 | :link-type: doc 14 | :link-alt: Matplotlib 15 | :img-background: ../../_static/matplotlib-logo-light.svg 16 | :class-img-bottom: dark-light 17 | 18 | .. grid-item-card:: 19 | :link: bokeh 20 | :link-type: doc 21 | :link-alt: Bokeh 22 | :img-background: ../../_static/bokeh-logo-light.svg 23 | :class-img-bottom: dark-light 24 | 25 | .. grid-item-card:: 26 | :link: plotly 27 | :link-type: doc 28 | :link-alt: Plotly 29 | :img-background: ../../_static/plotly-logo-light.png 30 | :class-img-bottom: dark-light 31 | 32 | .. grid-item-card:: 33 | :link: none 34 | :link-type: doc 35 | :link-alt: None (no plotting, only data processing) 36 | :img-background: ../../_static/none-logo-light.png 37 | :class-img-bottom: dark-light 38 | 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | :hidden: 43 | 44 | Matplotlib 45 | Bokeh 46 | Plotly 47 | None (only processing, no plotting) 48 | 49 | --------------------------- 50 | Common interface definition 51 | --------------------------- 52 | 53 | .. automodule:: arviz_plots.backend 54 | 55 | .. dropdown:: Keyword arguments 56 | :name: backend_interface_arguments 57 | :open: 58 | 59 | The argument names are defined here to have a comprehensive list of all possibilities. 60 | If relevant, a keyword argument present here should be present in the function, 61 | and converted in each backend to its corresponding argument in that backend. 62 | 63 | This set of arguments doesn't aim to be complete, only to cover basic properties 64 | so the plotting functions can work on multiple backends without duplication. 65 | Advanced customization will be backend specific through ``**kwargs``. 66 | 67 | target 68 | This module is designed mainly in a functional way. Thus, all functions 69 | should take a ``target`` argument which indicates on which object should 70 | the function be applied to. 71 | 72 | color 73 | Color of the visual element. Should also be present whenever ``facecolor`` 74 | and ``edgecolor`` are present, setting the default value for both. 75 | 76 | facecolor 77 | Color for filling the visual element. 78 | 79 | edgecolor 80 | Color for the edges of the visual element. 81 | 82 | alpha 83 | Transparency of the visual element. 84 | 85 | width 86 | Width of the visual element itself or of its edges, whichever applies. 87 | 88 | size 89 | Size of the visual element. 90 | 91 | linestyle 92 | Style of the line plotted. 93 | 94 | marker 95 | Marker to be added to the plot. 96 | 97 | vertical_align 98 | Vertical alignment between the visual element and the data coordinates provided. 99 | 100 | horizontal_align 101 | Horizontal alignment between the visual element and the data coordinates provided. 102 | 103 | axis 104 | Data axis (x, y or both) on which to apply the function. 105 | -------------------------------------------------------------------------------- /docs/source/api/backend/interface.template.rst: -------------------------------------------------------------------------------- 1 | Object creation and I/O 2 | ....................... 3 | 4 | .. autosummary:: 5 | :toctree: generated/ 6 | 7 | create_plotting_grid 8 | show 9 | 10 | Geoms 11 | ..... 12 | 13 | .. autosummary:: 14 | :toctree: generated/ 15 | 16 | line 17 | scatter 18 | text 19 | 20 | Plot appeareance 21 | ................ 22 | 23 | .. autosummary:: 24 | :toctree: generated/ 25 | 26 | title 27 | ylabel 28 | xlabel 29 | xticks 30 | yticks 31 | ticklabel_props 32 | remove_ticks 33 | remove_axis 34 | 35 | Legend 36 | ...... 37 | 38 | .. autosummary:: 39 | :toctree: generated/ 40 | 41 | legend 42 | -------------------------------------------------------------------------------- /docs/source/api/backend/matplotlib.part.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Matplotlib backend 3 | ================== 4 | 5 | .. automodule:: arviz_plots.backend.matplotlib 6 | -------------------------------------------------------------------------------- /docs/source/api/backend/none.part.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | None backend 3 | ============ 4 | 5 | .. automodule:: arviz_plots.backend.none 6 | -------------------------------------------------------------------------------- /docs/source/api/backend/plotly.part.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | Plotly backend 3 | ============== 4 | 5 | .. automodule:: arviz_plots.backend.plotly 6 | -------------------------------------------------------------------------------- /docs/source/api/helpers.rst: -------------------------------------------------------------------------------- 1 | ========================= 2 | Helper plotting functions 3 | ========================= 4 | 5 | Helper plotting functions are available at the ``arviz_plots`` top level namespace and 6 | provide ways to customize plots by adding elements or modifying existing ones. 7 | 8 | An introductory guide to the ``plot_...`` functions is also available at :ref:`plots_intro`. 9 | 10 | .. currentmodule:: arviz_plots 11 | 12 | .. autosummary:: 13 | :toctree: generated/ 14 | 15 | add_bands 16 | add_lines 17 | 18 | Style 19 | ............... 20 | 21 | .. autosummary:: 22 | :toctree: generated/ 23 | 24 | style.available 25 | style.get 26 | style.use 27 | 28 | -------------------------------------------------------------------------------- /docs/source/api/index.md: -------------------------------------------------------------------------------- 1 | # API reference 2 | 3 | ```{eval-rst} 4 | .. automodule:: arviz_plots 5 | ``` 6 | 7 | ```{toctree} 8 | :maxdepth: 1 9 | 10 | plots 11 | helpers 12 | managers 13 | visuals 14 | backend/index 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/source/api/managers.rst: -------------------------------------------------------------------------------- 1 | ============================================= 2 | Managers for faceting and aesthetics mapping 3 | ============================================= 4 | The classes in this module lay at the core of the library, 5 | and are consequently available at the ``arviz_plots`` top level namespace. 6 | 7 | They abstract all information regarding :term:`faceting` and :term:`aesthetic mapping` 8 | in our :term:`figure` to prevent duplication and ensure coherence between 9 | the different functions. 10 | 11 | .. currentmodule:: arviz_plots 12 | 13 | PlotCollection 14 | ============== 15 | 16 | Object creation 17 | ............... 18 | 19 | .. autosummary:: 20 | :toctree: generated/ 21 | 22 | PlotCollection 23 | PlotCollection.grid 24 | PlotCollection.wrap 25 | 26 | Plotting 27 | ........ 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | 32 | PlotCollection.add_legend 33 | PlotCollection.map 34 | 35 | Attributes 36 | .......... 37 | 38 | .. autosummary:: 39 | :toctree: generated/ 40 | 41 | PlotCollection.aes 42 | PlotCollection.viz 43 | PlotCollection.aes_set 44 | PlotCollection.facet_dims 45 | PlotCollection.data 46 | 47 | Faceting and aesthetics mapping 48 | ................................ 49 | 50 | .. autosummary:: 51 | :toctree: generated/ 52 | 53 | PlotCollection.generate_aes_dt 54 | PlotCollection.get_aes_as_dataset 55 | PlotCollection.get_aes_kwargs 56 | PlotCollection.update_aes 57 | PlotCollection.update_aes_from_dataset 58 | 59 | Other 60 | ..... 61 | 62 | .. autosummary:: 63 | :toctree: generated/ 64 | 65 | PlotCollection.allocate_artist 66 | PlotCollection.get_viz 67 | PlotCollection.get_target 68 | PlotCollection.show 69 | PlotCollection.savefig 70 | 71 | PlotMatrix 72 | ========== 73 | 74 | Object creation 75 | ............... 76 | 77 | .. autosummary:: 78 | :toctree: generated/ 79 | 80 | PlotMatrix 81 | 82 | Plotting 83 | ........ 84 | 85 | .. autosummary:: 86 | :toctree: generated/ 87 | 88 | PlotMatrix.map 89 | PlotMatrix.map_triangle 90 | PlotMatrix.map_lower 91 | PlotMatrix.map_upper 92 | 93 | Attributes 94 | .......... 95 | 96 | .. autosummary:: 97 | :toctree: generated/ 98 | 99 | PlotMatrix.aes 100 | PlotMatrix.viz 101 | PlotMatrix.aes_set 102 | PlotMatrix.facet_dims 103 | PlotMatrix.data 104 | -------------------------------------------------------------------------------- /docs/source/api/plots.rst: -------------------------------------------------------------------------------- 1 | .. _plots_api: 2 | 3 | ======================== 4 | Batteries-included plots 5 | ======================== 6 | 7 | Batteries-included plotting functions are available at the ``arviz_plots`` 8 | top level namespace and provide plug and play opinionated solutions 9 | to common tasks within the Bayesian workflow. 10 | 11 | Each of the entries below describe the behaviour of each function, all its arguments 12 | and include a handful of examples of each. 13 | A complementary introduction and guide to ``plot_...`` functions is available at :ref:`plots_intro`. 14 | 15 | .. currentmodule:: arviz_plots 16 | 17 | .. autosummary:: 18 | :toctree: generated/ 19 | 20 | combine_plots 21 | plot_autocorr 22 | plot_bf 23 | plot_compare 24 | plot_convergence_dist 25 | plot_dist 26 | plot_energy 27 | plot_ecdf_pit 28 | plot_ess 29 | plot_ess_evolution 30 | plot_forest 31 | plot_loo_pit 32 | plot_pairs_focus 33 | plot_ppc_dist 34 | plot_ppc_pava 35 | plot_ppc_pit 36 | plot_ppc_rootogram 37 | plot_prior_posterior 38 | plot_psense_dist 39 | plot_psense_quantities 40 | plot_rank 41 | plot_rank_dist 42 | plot_ridge 43 | plot_trace 44 | plot_trace_dist -------------------------------------------------------------------------------- /docs/source/api/visuals.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Visual elements 3 | =============== 4 | 5 | .. automodule:: arviz_plots.visuals 6 | 7 | Data plotting elements 8 | ---------------------- 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | ci_line_y 14 | ecdf_line 15 | fill_between_y 16 | hline 17 | hist 18 | hspan 19 | line 20 | line_xy 21 | line_x 22 | scatter_xy 23 | scatter_x 24 | scatter_xy 25 | trace_rug 26 | vline 27 | vspan 28 | 29 | 30 | Data and axis annotating elements 31 | --------------------------------- 32 | 33 | .. autosummary:: 34 | :toctree: generated/ 35 | 36 | annotate_label 37 | annotate_xy 38 | labelled_title 39 | labelled_x 40 | labelled_y 41 | point_estimate_text 42 | 43 | Plot customization elements 44 | --------------------------- 45 | 46 | .. autosummary:: 47 | :toctree: generated/ 48 | 49 | remove_axis 50 | remove_ticks 51 | set_xticks 52 | ticklabel_props 53 | grid 54 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=redefined-builtin,invalid-name 2 | import os 3 | import sys 4 | from importlib.metadata import metadata 5 | from pathlib import Path 6 | 7 | # -- Project information 8 | 9 | _metadata = metadata("arviz-plots") 10 | 11 | project = _metadata["Name"] 12 | author = _metadata["Author-email"].split("<", 1)[0].strip() 13 | copyright = f"2022, {author}" 14 | 15 | version = _metadata["Version"] 16 | if os.environ.get("READTHEDOCS", False): 17 | rtd_version = os.environ.get("READTHEDOCS_VERSION", "") 18 | if "." not in rtd_version and rtd_version.lower() != "stable": 19 | version = "dev" 20 | else: 21 | branch_name = os.environ.get("BUILD_SOURCEBRANCHNAME", "") 22 | if branch_name == "main": 23 | version = "dev" 24 | release = version 25 | 26 | 27 | # -- General configuration 28 | 29 | sys.path.insert(0, os.path.abspath("../sphinxext")) 30 | 31 | templates_path = ["_templates"] 32 | exclude_patterns = [ 33 | "Thumbs.db", 34 | ".DS_Store", 35 | ".ipynb_checkpoints", 36 | "**/*.template.rst", 37 | "**/*.part.rst", 38 | "**/*.part.md", 39 | ] 40 | skip_gallery = os.environ.get("ARVIZDOCS_NOGALLERY", False) 41 | 42 | extensions = [ 43 | "sphinx.ext.intersphinx", 44 | "sphinx.ext.mathjax", 45 | "sphinx.ext.viewcode", 46 | "sphinx.ext.autosummary", 47 | "sphinx.ext.extlinks", 48 | "numpydoc", 49 | "myst_nb", 50 | "sphinx_copybutton", 51 | "sphinx_design", 52 | "jupyter_sphinx", 53 | "matplotlib.sphinxext.plot_directive", 54 | "bokeh.sphinxext.bokeh_plot", 55 | ] 56 | if skip_gallery: 57 | exclude_patterns.append("gallery/*") 58 | else: 59 | extensions.append("gallery_generator") 60 | 61 | suppress_warnings = ["mystnb.unknown_mime_type"] 62 | 63 | backend_modules = ("none", "matplotlib", "bokeh", "plotly") 64 | api_backend_dir = Path(__file__).parent.resolve() / "api" / "backend" 65 | with open(api_backend_dir / "interface.template.rst", "r", encoding="utf-8") as f: 66 | interface_template = f.read() 67 | for file in backend_modules: 68 | with open(api_backend_dir / f"{file}.part.rst", "r", encoding="utf-8") as f: 69 | intro = f.read() 70 | with open(api_backend_dir / f"{file}.rst", "w", encoding="utf-8") as f: 71 | f.write(f"{intro}\n\n{interface_template}") 72 | 73 | # The reST default role (used for this markup: `text`) to use for all documents. 74 | default_role = "autolink" 75 | 76 | # If true, '()' will be appended to :func: etc. cross-reference text. 77 | add_function_parentheses = False 78 | 79 | # -- Options for extensions 80 | 81 | plot_include_source = True 82 | plot_formats = [("png", 90)] 83 | plot_html_show_formats = False 84 | plot_html_show_source_link = False 85 | 86 | extlinks = { 87 | "issue": ("https://github.com/arviz-devs/arviz-plots/issues/%s", "GH#%s"), 88 | "pull": ("https://github.com/arviz-devs/arviz-plots/pull/%s", "PR#%s"), 89 | } 90 | 91 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 92 | copybutton_prompt_is_regexp = True 93 | 94 | nb_execution_mode = "auto" 95 | nb_execution_excludepatterns = ["*.ipynb"] 96 | nb_kernel_rgx_aliases = {".*": "python3"} 97 | myst_enable_extensions = ["colon_fence", "deflist", "dollarmath", "amsmath", "linkify"] 98 | 99 | autosummary_generate = True 100 | autodoc_typehints = "none" 101 | autodoc_default_options = { 102 | "members": False, 103 | } 104 | 105 | numpydoc_show_class_members = False 106 | numpydoc_xref_param_type = True 107 | numpydoc_xref_ignore = {"of", "or", "optional", "scalar", "default"} 108 | singulars = ("int", "list", "dict", "float") 109 | numpydoc_xref_aliases = { 110 | "DataArray": ":class:`xarray.DataArray`", 111 | "Dataset": ":class:`xarray.Dataset`", 112 | "DataTree": ":class:`xarray.DataTree`", 113 | "mapping": ":term:`python:mapping`", 114 | "hashable": ":term:`python:hashable`", 115 | **{f"{singular}s": f":any:`{singular}s <{singular}>`" for singular in singulars}, 116 | } 117 | 118 | intersphinx_mapping = { 119 | "arviz_org": ("https://www.arviz.org/en/latest/", None), 120 | "arviz_base": ("https://arviz-base.readthedocs.io/en/latest/", None), 121 | "arviz_stats": ("https://arviz-stats.readthedocs.io/en/latest/", None), 122 | "einstats": ("https://einstats.python.arviz.org/en/latest/", None), 123 | "numpy": ("https://numpy.org/doc/stable/", None), 124 | "python": ("https://docs.python.org/3/", None), 125 | "xarray": ("https://docs.xarray.dev/en/stable/", None), 126 | "matplotlib": ("https://matplotlib.org/stable/", None), 127 | "bokeh": ("https://docs.bokeh.org/en/latest", None), 128 | } 129 | 130 | # -- Options for HTML output 131 | html_theme = "sphinx_book_theme" 132 | html_context = {"default_mode": "light"} 133 | html_theme_options = { 134 | "logo": { 135 | "image_light": "_static/ArviZ.png", 136 | "image_dark": "_static/ArviZ_white.png", 137 | } 138 | } 139 | html_favicon = "_static/favicon.ico" 140 | html_static_path = ["_static"] 141 | html_css_files = ["custom.css"] 142 | html_sidebars = { 143 | "**": [ 144 | "navbar-logo.html", 145 | "name.html", 146 | "icon-links.html", 147 | "search-button-field.html", 148 | "sbt-sidebar-nav.html", 149 | ] 150 | } 151 | -------------------------------------------------------------------------------- /docs/source/contributing/docs.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | ## How to build the documentation locally 4 | Similarly to testing, there are also tox jobs that take care of building the right environment 5 | and running the required commands to build the documentation. 6 | 7 | In general the process should follow these three steps in this order: 8 | 9 | ```console 10 | tox -e cleandocs 11 | tox -e docs # or tox -e nogallerydocs 12 | tox -e viewdocs 13 | ``` 14 | 15 | These commands will respectively: 16 | 17 | * Delete all intermediate sphinx files that were generated during the build process 18 | * Run `sphinx-build` command to parse and render the library documentation 19 | * Open the documentation homepage on the default browser with `python -m webbrowser` 20 | 21 | The only required step however is the middle one. In general sphinx uses the intermediate 22 | files only if it detects it hasn't been modified, so when iterating quickly it is recommended 23 | to skip the clean step in order to achieve faster builds. Moreover, if the documentation 24 | page is already open on the browser, there is no need for the viewdocs job because 25 | the documentation is always rendered on the same path; refreshing the page from the browser 26 | is enough. 27 | 28 | The example gallery requires processing the python scripts in order to execute each 29 | once per backend in order to generate the png or html+javascript preview. 30 | Therefore, it is the most time consuming step of generating the documentation. 31 | As very often we'll work on the part of the docs not related to the example gallery, 32 | the command `tox -e nogallerydocs` will generate the documentation without the example gallery, 33 | which allows for much faster iteration when writing documentation. 34 | This also means for example the `minigallery` directive in the docstrings won't work, 35 | and sphinx will output warnings about it when using this option. 36 | 37 | 38 | ## How to add examples to the gallery 39 | Examples in the gallery are written in the form of python scripts. 40 | They are divided between multiple categories, 41 | with each category being a folder within `/docs/source/gallery/`. 42 | Therefore, anything matching this glob `/docs/source/gallery/**/*.py` 43 | will be rendered into the example gallery. 44 | To control the order in which examples appear in the gallery, 45 | all filenames should start with two digits, underscore and then 46 | the unique name of the script. 47 | 48 | The script is divided in two parts, the first is a file level docstring, 49 | the second is the code example itself. 50 | 51 | The docstring part should contain a markdown top level title, 52 | a short description of the example and a seealso directive using MyST syntax. 53 | 54 | The code part should import arviz-plots as `azp`. Later on, it set `backend="none"` 55 | explicitly when calling the plotting functions and 56 | store the generated {class}`~arviz_plots.PlotCollection` as the `pc` variable 57 | so the example can finish with `pc.show()`. 58 | 59 | Here is an example that can be used as template: 60 | 61 | ```python 62 | """ 63 | # Posterior ECDFs 64 | 65 | Faceted ECDF plots for 1D marginals of the distribution 66 | 67 | --- 68 | 69 | :::{seealso} 70 | API Documentation: {func}`~arviz_plots.plot_dist` 71 | 72 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz) 73 | ::: 74 | """ 75 | from arviz_base import load_arviz_data 76 | 77 | import arviz_plots as azp 78 | 79 | azp.style.use("arviz-variat") 80 | 81 | data = load_arviz_data("centered_eight") 82 | pc = azp.plot_dist( 83 | data, 84 | kind="ecdf", 85 | col_wrap=4, 86 | backend="none" # change to preferred backend 87 | ) 88 | pc.show() 89 | ``` 90 | 91 | ## About arviz-plots documentation 92 | Documentation for arviz-plots is written in both rST and MyST (which can be used from jupyter 93 | notebooks too) and rendered with Sphinx. Docstrings follow the numpydoc style guide. 94 | 95 | ### The gallery generator sphinxext 96 | We have a custom sphinx extension to generate the example gallery, located at 97 | `/docs/sphinxext/gallery_generator.py`. 98 | 99 | This sphinx extension reads the example scripts within `/docs/source/gallery` 100 | and takes care of the following tasks: 101 | 102 | 1. Process the script contents into a MyST source page with proper syntax, tabs, code block, 103 | and relevant links. 104 | 1. Execute the code for all plotting backends to generate the respective previews. 105 | In addition, when executing the matplotlib version, it is stored as a png to use as the 106 | miniature in the gallery page. 107 | 1. Generate the index page for the gallery, with the grid view 108 | 1. Generate a json with references of all the functions used in the different examples. 109 | This supports the `minigallery` directive that allows adding plot specific galleries 110 | in the examples section of the docstring. 111 | -------------------------------------------------------------------------------- /docs/source/contributing/testing.md: -------------------------------------------------------------------------------- 1 | # Testing 2 | 3 | ## How to run the test suite 4 | 5 | To run the test suite `tox` should be installed. 6 | 7 | ### Run the whole test suite 8 | Tox creates an independent env where all testing dependencies are installed 9 | and then runs pytest to execute all tests on the `/test`: 10 | 11 | ```console 12 | tox -e py311 13 | ``` 14 | 15 | The `-e` flag stands for "execute" we want to execute a previously defined job that 16 | takes care of the steps above. The job name is "py" followed by your local python 17 | version without decimal point. 18 | 19 | :::{note} 20 | It is also possible to run `pytest tests/` directly instead of `tox -e py311`, 21 | and all commands covered in this page work either way. However, it is recommended 22 | to use tox to isolate the testing environment and have local testing be as similar 23 | as possible as testing in CI jobs. 24 | ::: 25 | 26 | ### Pass arguments to pytest 27 | We can also pass arguments through tox to pytest. With this we can for example 28 | select specific subsets of tests to be executed with the `-k` flag: 29 | 30 | ```console 31 | tox -e py311 -- -k plot_trace_dist 32 | ``` 33 | 34 | Would run all tests whose name contains `plot_trace_dist`. 35 | The [pytest documentation](https://docs.pytest.org/en/stable/reference/reference.html#command-line-flags) 36 | lists and describes all available options. 37 | 38 | ### Custom pytest arguments 39 | In addition to built-in pytest arguments, we have also defined a couple extra flags 40 | in `tests/conftest.py` to handle arviz-plots specific situations. 41 | 42 | #### Skip flags 43 | One of the drivers of arviz-plots design is ensuring parity between the different 44 | plotting backends with minimal duplication. 45 | 46 | Therefore, all tests that depend on the plotting backend should be parametrized 47 | with `@pytest.mark.parametrize("backend", backend_list)` to make sure all backends 48 | pass all relevant tests. 49 | 50 | At the same time however, arviz-plots considers all backends optional dependencies, 51 | so not all backends might be installed and consequently, not all tests can be executed. 52 | By default, testing works under the assumption that all backends are installed, 53 | but backend specific tests can be skipped when running the test suite: 54 | 55 | ```console 56 | tox -e py311 -- --skip-bokeh 57 | tox -e py311 -- --skip-mpl 58 | tox -e py311 -- --skip-plotly 59 | ``` 60 | 61 | :::{note} It is also possible to use both flags, in which case, only tests 62 | independent of the plotting backend like asethetic generation and the like 63 | will be executed. 64 | ::: 65 | 66 | #### Saving matplotlib figures generated by tests 67 | Testing checks plotting functions can be executed and return objects with the 68 | right properties, but there are no checks against the actual generated images. 69 | 70 | As we might often want to check the generated images, there is also a `--save` 71 | flag to indicate pytest to save all figures generated by matplotlib while testing. 72 | 73 | This flag takes one optional argument in case we want to specify the folder where 74 | the images will be saved, otherwise it defaults to `test_images` in the project 75 | home folder. 76 | 77 | ```console 78 | tox -e py311 -- --save 79 | ``` 80 | 81 | Generates basically the same output as any test job: 82 | 83 | ``` 84 | build: _optional_hooks> python ... 85 | [...] 86 | py311: OK (42.26=setup[1.79]+cmd[40.47] seconds) 87 | congratulations :) (42.31 seconds) 88 | ``` 89 | 90 | But if you inspect the project home folder, you should see a `/test_images` folder 91 | with contents similar to: 92 | 93 | ``` 94 | 'test_grid[matplotlib].png' 'test_plot_forest_extendable[matplotlib].png' 'test_plot_trace_dist[True-False-matplotlib].png' 95 | 'test_grid_rows_cols[cols-matplotlib].png' 'test_plot_forest[False-matplotlib].png' 'test_plot_trace_dist[True-True-matplotlib].png' 96 | 'test_grid_rows_cols[rows-matplotlib].png' 'test_plot_forest_models[matplotlib].png' 'test_plot_trace[matplotlib].png' 97 | 'test_grid_scalar[matplotlib].png' 'test_plot_forest_sample[matplotlib].png' 'test_plot_trace_sample[matplotlib].png' 98 | 'test_grid_variable[matplotlib].png' 'test_plot_forest[True-matplotlib].png' 'test_wrap[matplotlib].png' 99 | 'test_plot_dist[matplotlib].png' 'test_plot_trace_dist[False-False-matplotlib].png' 'test_wrap_only_variable[matplotlib].png' 100 | 'test_plot_dist_models[matplotlib].png' 'test_plot_trace_dist[False-True-matplotlib].png' 'test_wrap_variable[matplotlib].png' 101 | ``` 102 | 103 | ## About arviz-plots testing 104 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/00_plot_dist_ecdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Posterior ECDFs 3 | 4 | Faceted ECDF plots for 1D marginals of the distribution 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_dist` 10 | 11 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | data = load_arviz_data("centered_eight") 21 | pc = azp.plot_dist( 22 | data, 23 | kind="ecdf", 24 | col_wrap=4, 25 | backend="none" # change to preferred backend 26 | ) 27 | pc.show() 28 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/01_plot_dist_hist.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Posterior Histograms 3 | 4 | Faceted histogram plots for 1D marginals of the distribution. 5 | The `point_estimate_text` option is set to False to omit that visual from the plot. 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_dist` 10 | 11 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | data = load_arviz_data("centered_eight") 21 | pc = azp.plot_dist( 22 | data, 23 | kind="hist", 24 | visuals={"point_estimate_text": False}, 25 | backend="none" # change to preferred backend 26 | ) 27 | pc.show() 28 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/02_plot_dist_kde.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Posterior KDEs 3 | 4 | KDE plot of the variable `mu` from the centered eight model. The `sample_dims` parameter is 5 | used to restrict the KDE computation along the `draw` dimension only." 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_dist` 10 | 11 | EABM chapter on [Visualization of Random Variables with ArviZ](https://arviz-devs.github.io/EABM/Chapters/Distributions.html#distributions-in-arviz) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | data = load_arviz_data("centered_eight") 21 | pc = azp.plot_dist( 22 | data, 23 | kind="kde", 24 | var_names=["mu"], 25 | sample_dims=["draw"], 26 | backend="none" # change to preferred backend 27 | ) 28 | pc.show() 29 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/04_plot_forest.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Forest plot 3 | 4 | Default forest plot with marginal distribution summaries 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_forest` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("rugby") 19 | pc = azp.plot_forest( 20 | data, 21 | var_names=["home", "atts", "defs"], 22 | backend="none" # change to preferred backend 23 | ) 24 | pc.show() 25 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/05_plot_forest_shade.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Forest plot with shading 3 | 4 | Forest plot marginal summaries with row shading to enhance reading 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_forest` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("rugby") 19 | pc = azp.plot_forest( 20 | data, 21 | var_names=["home", "atts", "defs"], 22 | shade_label="team", 23 | backend="none", # change to preferred backend 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/06_plot_prior_posterior.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Plot prior and posterior 3 | 4 | Plot prior and posterior marginal distributions. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_prior_posterior` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("centered_eight") 19 | pc = azp.plot_prior_posterior( 20 | data, 21 | var_names="mu", 22 | kind="hist", 23 | backend="none" # change to preferred backend 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/distribution/07_plot_pairs_focus_distribution.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Scatterplot one variable against all others 3 | 4 | Plot one variable against other variables in the dataset. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_pairs_focus` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("centered_eight") 19 | pc = azp.plot_pairs_focus( 20 | data, 21 | var_names=["theta","tau"], 22 | focus_var="mu", 23 | figure_kwargs={"sharex": True}, 24 | backend="none", # change to preferred backend 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/00_plot_rank.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Rank plot 3 | 4 | faceted plot with fractional ranks for each variable 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_rank` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("centered_eight") 19 | pc = azp.plot_rank( 20 | data, 21 | backend="none" # change to preferred backend 22 | ) 23 | pc.show() 24 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/01_plot_trace.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Trace plot 3 | 4 | faceted plot with MCMC traces for each variable 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_trace` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("centered_eight") 19 | pc = azp.plot_trace( 20 | data, 21 | backend="none" # change to preferred backend 22 | ) 23 | pc.show() 24 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/02_plot_ess_evolution.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ESS evolution 3 | 4 | faceted plot with ESS 'bulk' and 'tail' for each variable 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ess_evolution` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("centered_eight") 20 | pc = azp.plot_ess_evolution(data, backend="none") # change to preferred backend 21 | pc.show() 22 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/03_plot_ess_local.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ESS local 3 | 4 | faceted local ESS plot 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ess` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("centered_eight") 20 | pc = azp.plot_ess( 21 | data, 22 | kind="local", 23 | backend="none", # change to preferred backend 24 | rug=True, 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/04_plot_ess_quantile.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ESS quantile 3 | 4 | faceted quantile ESS plot 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ess` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("centered_eight") 20 | pc = azp.plot_ess( 21 | data, 22 | kind="quantile", 23 | backend="none", # change to preferred backend 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/05_plot_ess_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ESS comparison 3 | 4 | Full ESS (Either local or quantile) comparison between different models 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ess` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | c = load_arviz_data("centered_eight") 20 | n = load_arviz_data("non_centered_eight") 21 | pc = azp.plot_ess( 22 | {"Centered": c, "Non Centered": n}, 23 | backend="none", # change to preferred backend 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/05_plot_mcse.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Monte Carlo standard error 3 | 4 | faceted quantile MCSE plot 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ess` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("centered_eight") 20 | pc = azp.plot_mcse( 21 | data, 22 | extra_methods=True, 23 | var_names=["mu"], 24 | backend="none", # change to preferred backend 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/06_plot_convergence_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Convergence diagnostics distribution 3 | 4 | Plot the distribution of ESS and R-hat. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_convergence_dist` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("radon") 20 | pc = azp.plot_convergence_dist( 21 | data, 22 | var_names=["za_county"], 23 | backend="none", # change to preferred backend 24 | ) 25 | 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/07_plot_autocorr.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Autocorrelation Plot 3 | 4 | faceted plot with autocorrelation for each variable 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_autocorr` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("centered_eight") 19 | pc = azp.plot_autocorr( 20 | data, 21 | backend="none" # change to preferred backend 22 | ) 23 | pc.show() 24 | -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/08_plot_energy.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Energy 3 | 4 | Plot transition and marginal energy distributions 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_energy` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("centered_eight") 19 | pc = azp.plot_energy( 20 | data, 21 | backend="none" # change to preferred backend 22 | ) 23 | pc.show() -------------------------------------------------------------------------------- /docs/source/gallery/inference_diagnostics/09_plot_pairs_focus.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Scatter plot with divergences 3 | 4 | Plot one variable against other variables in the dataset. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_pairs_focus` 10 | ::: 11 | """ 12 | import numpy as np 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | dt = load_arviz_data("centered_eight") 20 | dt.posterior["log_tau"] = np.log(dt.posterior["tau"]) 21 | 22 | pc = azp.plot_pairs_focus( 23 | dt, 24 | var_names=["theta"], 25 | focus_var="log_tau", 26 | visuals={"divergence":True}, 27 | backend="none", # change to preferred backend 28 | ) 29 | pc.show() 30 | -------------------------------------------------------------------------------- /docs/source/gallery/mixed/00_plot_rank_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Rank and distribution plot 3 | 4 | Two column layout with marginal distributions on the left and fractional ranks on the right 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_rank_dist` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("non_centered_eight") 19 | pc = azp.plot_rank_dist( 20 | data, 21 | var_names=["mu", "tau"], 22 | backend="none" # change to preferred backend 23 | ) 24 | pc.show() 25 | -------------------------------------------------------------------------------- /docs/source/gallery/mixed/01_plot_trace_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Trace and distribution plot 3 | 4 | Two column layout with marginal distributions on the left and MCMC traces on the right 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_trace_dist` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | data = load_arviz_data("non_centered_eight") 19 | pc = azp.plot_trace_dist( 20 | data, 21 | backend="none" # change to preferred backend 22 | ) 23 | pc.show() 24 | -------------------------------------------------------------------------------- /docs/source/gallery/mixed/02_plot_forest_ess.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Forest plot with ESS 3 | 4 | Multiple panel visualization with a forest plot and ESS information 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_forest` 10 | ::: 11 | """ 12 | import arviz_stats # make azstats accessor available 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | centered = load_arviz_data("centered_eight") 20 | 21 | c_aux = ( 22 | centered["posterior"] 23 | .dataset.expand_dims(column=3) 24 | .assign_coords(column=["labels", "forest", "ess"]) 25 | ) 26 | 27 | pc = azp.plot_forest( 28 | c_aux, 29 | combined=True, 30 | backend="none", # change to preferred backend 31 | ) 32 | 33 | pc.map( 34 | azp.visuals.scatter_x, 35 | "ess", 36 | data=centered.posterior.ds.azstats.ess(), 37 | coords={"column": "ess"}, 38 | color="gray", 39 | ) 40 | pc.show() 41 | -------------------------------------------------------------------------------- /docs/source/gallery/mixed/03_combine_plots.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Custom diagnostic plots combination 3 | 4 | Arrange three diagnostic plots (ESS evolution plot, rank plot and autocorrelation plot) 5 | in a custom column layout. 6 | 7 | --- 8 | 9 | :::{seealso} 10 | API Documentation: {func}`~arviz_plots.combine_plots` 11 | ::: 12 | """ 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("non_centered_eight") 20 | pc = azp.combine_plots( 21 | data, 22 | [ 23 | (azp.plot_ess_evolution, {}), 24 | (azp.plot_rank, {}), 25 | (azp.plot_autocorr, {}), 26 | ], 27 | var_names=["theta", "mu", "tau"], 28 | coords={"school": ["Hotchkiss", "St. Paul's"]}, 29 | backend="none", # change to preferred backend 30 | ) 31 | pc.show() 32 | -------------------------------------------------------------------------------- /docs/source/gallery/model_comparison/00_plot_compare.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Predictive model comparison 3 | 4 | Compare multiple models using predictive accuracy estimated using PSIS-LOO-CV. 5 | Usually the DataFrame ``cmp_df`` is generated using ArviZ's ```compare` function. 6 | 7 | --- 8 | 9 | :::{seealso} 10 | API Documentation: {func}`~arviz_plots.plot_compare` 11 | ::: 12 | """ 13 | import pandas as pd 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | 20 | cmp_df = pd.DataFrame({"elpd": [-4.5, -14.3, -16.2], 21 | "p": [2.6, 2.3, 2.1], 22 | "elpd_diff": [0, 9.7, 11.3], 23 | "weight": [0.9, 0.1, 0], 24 | "se": [2.3, 2.7, 2.3], 25 | "dse": [0, 2.7, 2.3], 26 | "warning": [False, False, False], 27 | }, 28 | index=["Model B", "Model A", "Model C"]) 29 | 30 | pc = azp.plot_compare(cmp_df, 31 | backend="none", # change to preferred backend 32 | ) 33 | pc.show() -------------------------------------------------------------------------------- /docs/source/gallery/model_comparison/99_plot_bf.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Bayes_factor 3 | 4 | Compute Bayes factor using Savage–Dickey ratio. 5 | 6 | We can apply this function when the null model is nested within the alternative. 7 | In other words when the null (``ref_val``) is a particular value of the model we are 8 | building (see [here](https://statproofbook.github.io/P/bf-sddr.html)). 9 | 10 | For others cases computing Bayes factor is not straightforward and requires more complex 11 | methods. Instead, of Bayes factors, we usually recommend Pareto smoothed 12 | importance sampling leave one out cross validation (PSIS-LOO-CV). In ArviZ, you will find 13 | them as functions with ``loo`` in their names. 14 | 15 | --- 16 | :::{seealso} 17 | API Documentation: {func}`~arviz_plots.plot_bf` 18 | ::: 19 | """ 20 | from arviz_base import load_arviz_data 21 | 22 | import arviz_plots as azp 23 | 24 | azp.style.use("arviz-variat") 25 | 26 | data = load_arviz_data("centered_eight") 27 | 28 | pc = azp.plot_bf( 29 | data, 30 | backend="none", # change to preferred backend 31 | var_names="mu" 32 | ) 33 | 34 | pc.show() -------------------------------------------------------------------------------- /docs/source/gallery/posterior_comparison/00_plot_dist_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Posterior KDEs for two models 3 | 4 | Full marginal distribution comparison between different models 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_dist` 10 | 11 | Other examples comparing marginal distributions: {ref}`gallery_forest_models` 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | c = load_arviz_data("centered_eight") 21 | n = load_arviz_data("non_centered_eight") 22 | pc = azp.plot_dist( 23 | {"Centered": c, "Non Centered": n}, 24 | backend="none" # change to preferred backend 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/posterior_comparison/01_plot_forest_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Posterior forest for two models 3 | 4 | Forest plot summaries for 1D marginal distributions 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_forest` 10 | 11 | Other examples comparing marginal distributions: {ref}`gallery_dist_models` 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | c = load_arviz_data("centered_eight") 21 | n = load_arviz_data("non_centered_eight") 22 | pc = azp.plot_forest( 23 | {"Centered": c, "Non Centered": n}, 24 | backend="none" # change to preferred backend 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/00_plot_ppc_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Predictive check with KDEs 3 | 4 | Plot of samples from the posterior predictive and observed data. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ppc_dist` 10 | 11 | EABM chapter on [Posterior predictive checks](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#posterior-predictive-checks) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | dt = load_arviz_data("radon") 21 | pc = azp.plot_ppc_dist( 22 | dt, 23 | backend="none", 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/01_plot_ppc_rootogram.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Rootogram 3 | 4 | Rootogram for the posterior predictive and observed data. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ppc_rootogram` 10 | 11 | EABM chapter on [Posterior predictive checks for count data](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#posterior-predictive-checks-for-count-data) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | dt = load_arviz_data("rugby") 21 | pc = azp.plot_ppc_rootogram( 22 | dt, 23 | aes={"color": ["__variable__"]}, # map variable to color 24 | aes_by_visuals={"title": ["color"]}, # change title's color per variable 25 | backend="none", 26 | ) 27 | pc.show() 28 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/03_plot_pava_calibration.py: -------------------------------------------------------------------------------- 1 | """ 2 | # PAV-adjusted calibration 3 | 4 | PAV-adjusted calibration plot for binary predictions. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ppc_pava` 10 | 11 | EABM chapter on [Posterior predictive checks for binary data](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#posterior-predictive-checks-for-binary-data) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | dt = load_arviz_data("anes") 21 | pc = azp.plot_ppc_pava( 22 | dt, 23 | backend="none", 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/04_plot_ppc_pit.py: -------------------------------------------------------------------------------- 1 | """ 2 | # PIT ECDF 3 | 4 | Plot of the probability integral transform of the posterior predictive distribution with respect to the observed data. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ppc_pit` 10 | 11 | EABM chapter on [Posterior predictive checks with PIT-ECDFs](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#pit-ecdfs) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | dt = load_arviz_data("radon") 21 | pc = azp.plot_ppc_pit( 22 | dt, 23 | backend="none", 24 | ) 25 | pc.show() 26 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/05_plot_ppc_coverage.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Coverage ECDF 3 | 4 | Proportion of true values that fall within a given prediction interval. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ppc_pit` 10 | 11 | EABM chapter on [Posterior predictive checks and coverage](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#coverage) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | dt = load_arviz_data("radon") 21 | pc = azp.plot_ppc_pit( 22 | dt, 23 | coverage=True, 24 | backend="none", 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/06_plot_loo_pit.py: -------------------------------------------------------------------------------- 1 | """ 2 | # LOO-PIT ECDF 3 | 4 | Plot of the probability integral transform of the posterior predictive distribution with 5 | respect to the observed data using the leave-one-out (LOO) method. 6 | 7 | 8 | --- 9 | 10 | :::{seealso} 11 | API Documentation: {func}`~arviz_plots.plot_ppc_pit` 12 | 13 | EABM chapter on [Posterior predictive checks with LOO-PIT ECDFs](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#sec-avoid-double-dipping) 14 | ::: 15 | """ 16 | from arviz_base import load_arviz_data 17 | 18 | import arviz_plots as azp 19 | 20 | azp.style.use("arviz-variat") 21 | 22 | dt = load_arviz_data("radon") 23 | pc = azp.plot_loo_pit( 24 | dt, 25 | backend="none", 26 | ) 27 | pc.show() 28 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/07_plot_ppc_tstat.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Test statistics 3 | 4 | T-statistic for the observed data and posterior predictive data. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_ppc_tstat` 10 | 11 | EABM chapter on [Posterior predictive checks with summary statistics](https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#using-summary-statistics) 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | dt = load_arviz_data("regression1d") 21 | pc = azp.plot_ppc_tstat( 22 | dt, 23 | t_stat="median", 24 | backend="none" 25 | ) 26 | pc.show() 27 | -------------------------------------------------------------------------------- /docs/source/gallery/predictive_checks/99_plot_forest_pp_obs.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Posterior predictive forest and observations 3 | 4 | Overlay of forest plot for the posterior predictive samples and the actual observations 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.plot_forest` 10 | ::: 11 | """ 12 | from arviz_base import load_arviz_data 13 | 14 | import arviz_plots as azp 15 | 16 | azp.style.use("arviz-variat") 17 | 18 | idata = load_arviz_data("non_centered_eight") 19 | pc = azp.plot_forest( 20 | idata, 21 | group="posterior_predictive", 22 | combined=True, 23 | labels=["obs_dim_0"], 24 | backend="none", # change to preferred backend 25 | ) 26 | 27 | pc.map( 28 | azp.visuals.scatter_x, 29 | "observations", 30 | data=idata.observed_data.ds, 31 | coords={"column": "forest"}, 32 | color="gray", 33 | ) 34 | 35 | pc.map( 36 | azp.visuals.labelled_x, 37 | "xlabel", 38 | coords={"column": "forest"}, 39 | text="Observations", 40 | ignore_aes="y", 41 | ) 42 | pc.show() 43 | -------------------------------------------------------------------------------- /docs/source/gallery/prior_and_likelihood_sensitivity_checks/00_plot_psense.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Sensitivity posterior marginals 3 | 4 | The posterior sensitivity is assessed by power-scaling the prior or likelihood and 5 | visualizing the resulting changes. Sensitivity can then be quantified by considering 6 | how much the perturbed posteriors differ from the base posterior. 7 | 8 | --- 9 | 10 | :::{seealso} 11 | API Documentation: {func}`~arviz_plots.plot_psense_dist` 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | idata = load_arviz_data("rugby") 21 | pc = azp.plot_psense_dist( 22 | idata, 23 | var_names=["defs", "sd_att", "sd_def"], 24 | coords={"team": ["Scotland", "Wales"]}, 25 | y=[-2, -1, 0], 26 | backend="none", 27 | ) 28 | pc.show() 29 | -------------------------------------------------------------------------------- /docs/source/gallery/prior_and_likelihood_sensitivity_checks/01_plot_psense_quantities.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Sensitivity posterior quantities 3 | 4 | The posterior quantities are computed by power-scaling the prior or likelihood and 5 | visualizing the resulting changes. Sensitivity can then be quantified by considering 6 | how much the perturbed quantities differ from the base quantities. 7 | 8 | --- 9 | 10 | :::{seealso} 11 | API Documentation: {func}`~arviz_plots.plot_psense_quantities` 12 | ::: 13 | """ 14 | from arviz_base import load_arviz_data 15 | 16 | import arviz_plots as azp 17 | 18 | azp.style.use("arviz-variat") 19 | 20 | idata = load_arviz_data("rugby") 21 | pc = azp.plot_psense_quantities( 22 | idata, 23 | var_names=["sd_att", "sd_def"], 24 | quantities=["mean", "sd", "0.25", "0.75"], 25 | col_wrap=2, 26 | backend="none", 27 | ) 28 | pc.show() 29 | -------------------------------------------------------------------------------- /docs/source/gallery/sbc/00_plot_ecdf_pit.py: -------------------------------------------------------------------------------- 1 | """ 2 | # PIT-ECDF 3 | 4 | Faceted plot with PIT Δ-ECDF values for each variable 5 | 6 | The ``plot_ecdf_pit`` function assumes the values passed to it has already been transformed 7 | to PIT values, as in the case of SBC analysis or values from ``arviz_base.loo_pit``. 8 | 9 | The distribution should be uniform if the model is well-calibrated. 10 | 11 | To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between 12 | the expected CDF from the observed ECDF. As small deviations from uniformity are expected, 13 | the plot also shows the credible envelope. 14 | 15 | --- 16 | 17 | :::{seealso} 18 | API Documentation: {func}`~arviz_plots.plot_ecdf_pit` 19 | ::: 20 | """ 21 | from arviz_base import load_arviz_data 22 | 23 | import arviz_plots as azp 24 | 25 | azp.style.use("arviz-variat") 26 | 27 | data = load_arviz_data("sbc") 28 | pc = azp.plot_ecdf_pit( 29 | data, 30 | backend="none" # change to preferred backend 31 | ) 32 | pc.show() 33 | -------------------------------------------------------------------------------- /docs/source/gallery/sbc/01_plot_ecdf_coverage.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Coverage ECDF 3 | 4 | Coverage refers to the proportion of true values that fall within a given prediction interval. 5 | For a well-calibrated model, the coverage should match the intended interval width. For example, 6 | a 95% credible interval should contain the true value 95% of the time. 7 | 8 | The distribution should be uniform if the model is well-calibrated. 9 | 10 | To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between 11 | the expected CDF from the observed ECDF. As small deviations from uniformity are expected, 12 | the plot also shows the credible envelope. 13 | 14 | We can compute the coverage for equal-tailed intervals (ETI) by passing `coverage=True` to the 15 | `plot_ecdf_pit` function. This works because ETI coverage can be obtained by transforming the PIT 16 | values. However, for other interval types, such as HDI, coverage must be computed explicitly and 17 | is not supported by this function. 18 | 19 | --- 20 | 21 | :::{seealso} 22 | API Documentation: {func}`~arviz_plots.plot_ecdf_pit` 23 | ::: 24 | """ 25 | from arviz_base import load_arviz_data 26 | 27 | import arviz_plots as azp 28 | 29 | azp.style.use("arviz-variat") 30 | 31 | data = load_arviz_data("sbc") 32 | pc = azp.plot_ecdf_pit( 33 | data, 34 | coverage=True, 35 | backend="none" # change to preferred backend 36 | ) 37 | pc.show() 38 | -------------------------------------------------------------------------------- /docs/source/gallery/utils/00_add_reference_lines.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Add Lines 3 | 4 | Draw lines on plots to highlight specific thresholds, targets, or important values. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.add_lines` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("centered_eight") 20 | ref_ds = data.posterior.dataset.quantile([0.5, 0.1, 0.9], dim=["chain", "draw"]) 21 | pc = azp.plot_dist( 22 | data, 23 | kind="ecdf", 24 | backend="none", # change to preferred backend 25 | ) 26 | pc = azp.add_lines( 27 | pc, 28 | values=ref_ds, 29 | ref_dim="quantile", 30 | aes_by_visuals={"ref_line": ["color"]}, 31 | color=["black", "gray", "gray"] 32 | ) 33 | pc.show() 34 | -------------------------------------------------------------------------------- /docs/source/gallery/utils/01_add_reference_bands.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Add Reference Bands 3 | 4 | Draw reference bands to highlight specific regions. 5 | 6 | --- 7 | 8 | :::{seealso} 9 | API Documentation: {func}`~arviz_plots.add_bands` 10 | ::: 11 | """ 12 | 13 | from arviz_base import load_arviz_data 14 | 15 | import arviz_plots as azp 16 | 17 | azp.style.use("arviz-variat") 18 | 19 | data = load_arviz_data("centered_eight") 20 | rope = [(-1, 1)] 21 | pc = azp.plot_forest( 22 | data, 23 | backend="none", # change to preferred backend 24 | ) 25 | pc.coords = {"column": "forest"} 26 | pc = azp.add_bands( 27 | pc, 28 | values=rope, 29 | visuals={"ref_band":{"color": "#f66d7f"}}, 30 | ) 31 | 32 | pc.show() 33 | -------------------------------------------------------------------------------- /docs/source/glossary.md: -------------------------------------------------------------------------------- 1 | # Glossary 2 | 3 | 4 | :::{glossary} 5 | aesthetic 6 | aesthetics 7 | When used as a noun, we use _an aesthetic_ as a graphical property that is 8 | being used to encode data. 9 | 10 | Moreover, within `arviz_plots` _aesthetics_ can actually be any arbitrary 11 | keyword argument accepted by the plotting function being used. 12 | 13 | aesthetic mapping 14 | aesthetic mappings 15 | We use _aesthetic mapping_ to indicate the relation between the {term}`aesthetics` 16 | in our plot and properties in our dataset. 17 | 18 | figure 19 | Highest level data visualization structure. All plotted elements 20 | are contained within a figure or its children. 21 | 22 | EABM 23 | Acronym for Exploratory Analysis of Bayesian Models. We use this concept to 24 | reference all the tasks within a Bayesian workflow outside of 25 | building and fitting or sampling a model. For more details, see the 26 | [EABM virtual book](https://arviz-devs.github.io/EABM/) 27 | 28 | plot 29 | plots 30 | Area (or areas) where the data will be plotted into. A {term}`figure` 31 | can contain multiple {term}`faceted` plots. 32 | 33 | visual 34 | visuals 35 | Graphical component or element added by `arviz-plots` 36 | 37 | faceting 38 | faceted 39 | Generate multiple similar {term}`plot` elements with each of them 40 | referring to a specific property or value of the data. 41 | 42 | ::: 43 | 44 | ## Equivalences with library specific objects 45 | 46 | | arviz-plots name | matplotlib | bokeh | plotly | 47 | |------------------|--------------|---------|------------------| 48 | | figure | figure | layout | Figure | 49 | | plot | axes/subplot | figure | -[^plotly_plot] | 50 | | visual | artist | glyph | trace | 51 | 52 | [^plotly_plot]: In plotly there is no specific object to represent a {term}`plot`. 53 | 54 | Instead, when adding {term}`visuals` one can choose to add a visual to all {term}`plots` 55 | in the {term}`figure`, or give the row/col indexes, or specify a subset of {term}`plots` 56 | on which to add the {term}`visual`. 57 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | # ArviZ-plots 2 | 3 | Welcome to the ArviZ-plots documentation! This library focuses on visual summaries and diagnostics for exploratory analysis of Bayesian models. It is one of the 3 components of the ArviZ library, the other two being: 4 | 5 | * [arviz-base](https://arviz-base.readthedocs.io/en/latest/) data related functionality, including converters from different PPLs. 6 | * [arviz-stats](https://arviz-stats.readthedocs.io/en/latest/) for statistical functions and diagnostics. 7 | 8 | We recommend most users install and use all three ArviZ components together through the main ArviZ package, as they're designed to work seamlessly as one toolkit. Advanced users may choose to install components individually for finer control over dependencies. 9 | 10 | Note: All plotting functions - whether accessed through the full ArviZ package or directly via ArviZ-plots - are documented here. 11 | 12 | 13 | ## Exploratory Analysis of Bayesian Models 14 | 15 | In Modern Bayesian statistics models are usually build and solve using probabilistic programming languages (PPLs) such as PyMC, Stan, NumPyro, etc. These languages allow users to specify models in a high-level language and perform inference using state-of-the-art algorithms like Markov Chain Monte Carlo (MCMC) or Variational Inference (VI). As a result we usually get a posterior distribution, in the form of samples. The posterior distribution has a central role in Bayesian statistics, but other distributions like the posterior and prior predictive distribution are also of interest. And other quantities may be relevant too. 16 | 17 | The correct visualization, analysis, and interpretation of these computed data is key to properly answer the questions that motivate our analysis. 18 | 19 | When working with Bayesian models there are a series of related tasks that need to be addressed besides inference itself: 20 | 21 | * Diagnoses of the quality of the inference 22 | 23 | * Model criticism, including evaluations of both model assumptions and model predictions 24 | 25 | * Comparison of models, including model selection or model averaging 26 | 27 | * Preparation of the results for a particular audience. 28 | 29 | We call these tasks exploratory analysis of Bayesian models (EABM). Successfully performing such tasks are central to the iterative and interactive modelling process (See Bayesian Workflow). In the words of Persi Diaconis. 30 | 31 | > Exploratory data analysis seeks to reveal structure, or simple descriptions in data. We look at numbers or graphs and try to find patterns. We pursue leads suggested by background information, imagination, patterns perceived, and experience with other data analyses. 32 | 33 | The goal of ArviZ is to provide a unified interface for performing exploratory analysis of Bayesian models in Python, regardless of the PPL used to perform inference. This allows users to focus on the analysis and interpretation of the results, rather than on the details of the implementation. 34 | 35 | 36 | 37 | ## Installation 38 | 39 | For instructions on how to install the full ArviZ package (including `arviz-base`, `arviz-stats` and `arviz-plots`), please refer to the [installation guide](https://python.arviz.org/en/latest/getting_started/Installation.html). 40 | 41 | However, if you are only interested in the plotting functions provided by ArviZ-plots, please follow the instructions below: 42 | 43 | ::::{tab-set} 44 | :::{tab-item} PyPI 45 | :sync: stable 46 | 47 | ```bash 48 | pip install "arviz-plots[]" 49 | ``` 50 | ::: 51 | :::{tab-item} GitHub 52 | :sync: dev 53 | 54 | ```bash 55 | pip install "arviz-plots[] @ git+https://github.com/arviz-devs/arviz-plots" 56 | ``` 57 | ::: 58 | :::: 59 | 60 | Note that `arviz-plots` is a minimal package, which only depends on 61 | xarray, numpy, arviz-base and arviz-stats. 62 | None of the possible backends: matplotlib, bokeh or plotly are installed 63 | by default. 64 | 65 | Consequently, it is not recommended to install `arviz-plots` but 66 | instead to choose which backend to use. For example `arviz-plots[bokeh]` 67 | or `arviz-plots[matplotlib, plotly]`, multiple comma separated values are valid too. 68 | 69 | This will ensure all relevant dependencies are installed. For example, to use the plotly backend, 70 | both `plotly>5` and `webcolors` are required. 71 | 72 | ```{toctree} 73 | :hidden: 74 | :caption: User guide 75 | 76 | tutorials/overview 77 | tutorials/plots_intro 78 | tutorials/intro_to_plotcollection 79 | tutorials/compose_own_plot 80 | ``` 81 | 82 | ```{toctree} 83 | :hidden: 84 | :caption: Reference 85 | 86 | gallery/index 87 | api/index 88 | glossary 89 | ``` 90 | ```{toctree} 91 | :hidden: 92 | :caption: Tutorials 93 | 94 | ArviZ in Context 95 | ``` 96 | 97 | ```{toctree} 98 | :hidden: 99 | :caption: Contributing 100 | 101 | contributing/testing 102 | contributing/new_plot 103 | contributing/docs 104 | ``` 105 | 106 | ```{toctree} 107 | :caption: About 108 | :hidden: 109 | 110 | BlueSky 111 | Mastodon 112 | GitHub repository 113 | ``` 114 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.4,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "arviz-plots" 7 | readme = "README.md" 8 | requires-python = ">=3.11" 9 | license = {file = "LICENSE"} 10 | authors = [ 11 | {name = "ArviZ team", email = "arvizdevs@gmail.com"} 12 | ] 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "Intended Audience :: Developers", 16 | "Intended Audience :: Science/Research", 17 | "Intended Audience :: Education", 18 | "Framework :: Matplotlib", 19 | "License :: OSI Approved :: Apache Software License", 20 | "Operating System :: OS Independent", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | ] 27 | dynamic = ["version", "description"] 28 | dependencies = [ 29 | "arviz-base @ git+https://github.com/arviz-devs/arviz-base", 30 | "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats", 31 | ] 32 | 33 | [tool.flit.module] 34 | name = "arviz_plots" 35 | 36 | [project.urls] 37 | source = "https://github.com/arviz-devs/arviz-plots" 38 | tracker = "https://github.com/arviz-devs/arviz-plots/issues" 39 | documentation = "https://arviz-plots.readthedocs.io" 40 | funding = "https://opencollective.com/arviz" 41 | 42 | [project.optional-dependencies] 43 | matplotlib = ["matplotlib"] 44 | bokeh = ["bokeh"] 45 | plotly = ["plotly", "webcolors"] 46 | test = [ 47 | "hypothesis", 48 | "pytest", 49 | "pytest-cov", 50 | "h5netcdf", 51 | "kaleido", 52 | ] 53 | doc = [ 54 | "sphinx-book-theme", 55 | "myst-parser[linkify]", 56 | "myst-nb", 57 | "sphinx-copybutton", 58 | "numpydoc", 59 | "sphinx>=6", 60 | "sphinx-design", 61 | "jupyter-sphinx", 62 | "h5netcdf", 63 | "plotly<6", 64 | ] 65 | 66 | 67 | [tool.black] 68 | line-length = 100 69 | 70 | [tool.isort] 71 | profile = "black" 72 | include_trailing_comma = true 73 | use_parentheses = true 74 | multi_line_output = 3 75 | line_length = 100 76 | skip = [ 77 | "src/arviz_plots/__init__.py" 78 | ] 79 | 80 | [tool.pydocstyle] 81 | convention = "numpy" 82 | match_dir = "^(?!docs|.tox).*" 83 | 84 | [tool.mypy] 85 | python_version = "3.10" 86 | warn_redundant_casts = true 87 | warn_unused_configs = true 88 | pretty = true 89 | show_error_codes = true 90 | show_error_context = true 91 | show_column_numbers = true 92 | 93 | disallow_any_generics = true 94 | disallow_subclassing_any = true 95 | disallow_untyped_calls = true 96 | disallow_incomplete_defs = true 97 | check_untyped_defs = true 98 | disallow_untyped_decorators = true 99 | no_implicit_optional = true 100 | warn_unused_ignores = true 101 | warn_return_any = true 102 | no_implicit_reexport = true 103 | 104 | # More strict checks for library code 105 | [[tool.mypy.overrides]] 106 | module = "arviz_plots" 107 | disallow_untyped_defs = true 108 | 109 | # Ignore certain missing imports 110 | # [[tool.mypy.overrides]] 111 | # module = "thirdparty.*" 112 | # ignore_missing_imports = true 113 | 114 | [tool.pytest.ini_options] 115 | filterwarnings = ["error"] 116 | addopts = "--durations=10" 117 | testpaths = [ 118 | "tests", 119 | ] 120 | 121 | [tool.coverage.run] 122 | source = ["arviz_plots"] 123 | -------------------------------------------------------------------------------- /src/arviz_plots/_version.py: -------------------------------------------------------------------------------- 1 | """Base ArviZ version.""" 2 | __version__ = "0.6.0.dev0" 3 | -------------------------------------------------------------------------------- /src/arviz_plots/backend/__init__.py: -------------------------------------------------------------------------------- 1 | """Common interface to plotting backends. 2 | 3 | Each submodule within this module defines a common interface layer to different plotting libraries. 4 | 5 | Outside ``arviz_plots.backend`` the corresponding backend module is imported, 6 | but only the common interface layer is used, making no distinctions between plotting backends. 7 | Each submodule inside ``arviz_plots.backend`` is expected to implement the same functions 8 | with the same call signature. Thus, adding a new backend requires only 9 | implementing this common interface for it, with no changes to any of the other modules. 10 | """ 11 | -------------------------------------------------------------------------------- /src/arviz_plots/backend/bokeh/legend.py: -------------------------------------------------------------------------------- 1 | """Bokeh manual legend generation.""" 2 | import warnings 3 | 4 | import numpy as np 5 | from bokeh.models import Legend 6 | 7 | 8 | def dealiase_line_kwargs(kwargs): 9 | """Convert arviz common interface properties to bokeh ones.""" 10 | prop_map = {"width": "line_width", "linestyle": "line_dash"} 11 | return {prop_map.get(key, key): value for key, value in kwargs.items()} 12 | 13 | 14 | def legend( 15 | target, 16 | kwarg_list, 17 | label_list, 18 | title=None, 19 | artist_type="line", 20 | artist_kwargs=None, 21 | legend_target=None, 22 | side="auto", 23 | legend_placement_threshold=600, # Magic number 24 | **kwargs, 25 | ): 26 | """Generate a legend on a figure given lists of labels and property kwargs. 27 | 28 | Parameters 29 | ---------- 30 | legend_target : (int, int), default (0, -1) 31 | Row and colum indicators of the :term:`plot` where the legend will be placed. 32 | Bokeh does not support :term:`figure` level legend. 33 | side : str, optional 34 | Side of the plot on which to place the legend. Use "center" to put the legend 35 | inside the plotting area. 36 | """ 37 | if artist_kwargs is None: 38 | artist_kwargs = {} 39 | if legend_target is None: 40 | legend_target = (0, -1) 41 | # TODO: improve selection of Figure object from what is stored as "figure" 42 | children = target.children 43 | if not isinstance(children[0], tuple): 44 | children = children[1].children 45 | plots = [child[0] for child in children] 46 | row_id = np.array([child[1] for child in children], dtype=int) 47 | col_id = np.array([child[2] for child in children], dtype=int) 48 | legend_id = np.argmax( 49 | (row_id == np.unique(row_id)[legend_target[0]]) 50 | & (col_id == np.unique(col_id)[legend_target[1]]) 51 | ) 52 | target_plot = plots[legend_id] 53 | if target_plot.legend: 54 | warnings.warn("This target plot already contains a legend") 55 | glyph_list = [] 56 | if artist_type == "line": 57 | artist_fun = target_plot.line 58 | kwarg_list = [dealiase_line_kwargs(kws) for kws in kwarg_list] 59 | else: 60 | raise NotImplementedError("Only line type legends supported for now") 61 | 62 | for kws in kwarg_list: 63 | glyph = artist_fun(**{**artist_kwargs, **kws}) 64 | glyph_list.append(glyph) 65 | 66 | if side == "auto": 67 | plot_width = target_plot.width 68 | if plot_width >= legend_placement_threshold: 69 | side = "right" 70 | else: 71 | side = "center" 72 | 73 | leg = Legend( 74 | items=[(str(label), [glyph]) for label, glyph in zip(label_list, glyph_list)], 75 | title=title, 76 | **kwargs, 77 | ) 78 | target_plot.add_layout(leg, side) 79 | return leg 80 | -------------------------------------------------------------------------------- /src/arviz_plots/backend/matplotlib/legend.py: -------------------------------------------------------------------------------- 1 | """Matplotlib manual legend generation.""" 2 | from matplotlib.lines import Line2D 3 | 4 | 5 | def dealiase_line_kwargs(kwargs): 6 | """Convert arviz common interface properties to matplotlib ones.""" 7 | prop_map = {"width": "linewidth"} 8 | return {prop_map.get(key, key): value for key, value in kwargs.items()} 9 | 10 | 11 | def legend( 12 | target, kwarg_list, label_list, title=None, artist_type="line", artist_kwargs=None, **kwargs 13 | ): 14 | """Generate a legend on a figure given lists of labels and property kwargs.""" 15 | if artist_kwargs is None: 16 | artist_kwargs = {} 17 | kwargs.setdefault("loc", "outside right upper") 18 | if artist_type == "line": 19 | artist_fun = Line2D 20 | kwarg_list = [dealiase_line_kwargs(kws) for kws in kwarg_list] 21 | else: 22 | raise NotImplementedError("Only line type legends supported for now") 23 | handles = [artist_fun([], [], **{**artist_kwargs, **kws}) for kws in kwarg_list] 24 | return target.legend(handles, label_list, title=title, **kwargs) 25 | -------------------------------------------------------------------------------- /src/arviz_plots/backend/none/legend.py: -------------------------------------------------------------------------------- 1 | """None backend manual legend generation. 2 | 3 | For now only used for documentation purposes. 4 | """ 5 | 6 | 7 | # pylint: disable=unused-argument 8 | def legend( 9 | target, kwarg_list, label_list, title=None, artist_type="line", artist_kwargs=None, **kwargs 10 | ): 11 | """Generate a legend on a figure given lists of labels and property kwargs. 12 | 13 | Parameters 14 | ---------- 15 | target : plot object 16 | kwarg_list : sequence of mapping 17 | label_list : sequence of str 18 | title : str, optional 19 | artist_type : {"line", "scatter", "rectangle"}, default "line" 20 | artist_kwargs : mapping, optional 21 | Passed to all visuals when generating legend miniatures. 22 | **kwargs 23 | Passed to backend legend generating function. 24 | """ 25 | return None 26 | -------------------------------------------------------------------------------- /src/arviz_plots/backend/plotly/legend.py: -------------------------------------------------------------------------------- 1 | """Plotly legend generation.""" 2 | 3 | 4 | def dealiase_line_kwargs(kwargs): 5 | """Convert arviz common interface properties to plotly ones.""" 6 | prop_map = {"linewidth": "width", "linestyle": "dash"} 7 | return {prop_map.get(key, key): value for key, value in kwargs.items()} 8 | 9 | 10 | def legend( 11 | target, kwarg_list, label_list, title=None, artist_type="line", artist_kwargs=None, **kwargs 12 | ): 13 | """Generate a legend with plotly. 14 | 15 | Parameters 16 | ---------- 17 | target : plotly.graph_objects.Figure 18 | The figure to add the legend to 19 | kwarg_list : list 20 | List of style dictionaries for each legend entry 21 | label_list : list 22 | List of labels for each legend entry 23 | title : str, optional 24 | Title of the legend 25 | artist_type : str, optional 26 | Type of visual to use for legend entries. Currently only "line" is supported. 27 | artist_kwargs : dict, optional 28 | Additional kwargs passed to all visuals 29 | **kwargs : dict 30 | Additional kwargs passed to legend configuration 31 | 32 | Returns 33 | ------- 34 | None 35 | The legend is added to the target figure inplace 36 | """ 37 | if artist_kwargs is None: 38 | artist_kwargs = {} 39 | 40 | if artist_type == "line": 41 | artist_fun = target.add_scatter 42 | kwarg_list = [dealiase_line_kwargs(kws) for kws in kwarg_list] 43 | mode = "lines" 44 | else: 45 | raise NotImplementedError("Only line type legends supported for now") 46 | 47 | for kws, label in zip(kwarg_list, label_list): 48 | artist_fun( 49 | x=[None], 50 | y=[None], 51 | name=str(label), 52 | mode=mode, 53 | line=kws, 54 | showlegend=True, 55 | **artist_kwargs, 56 | ) 57 | 58 | target.update_layout(showlegend=True, legend_title_text=title, **kwargs) 59 | -------------------------------------------------------------------------------- /src/arviz_plots/backend/plotly/templates.py: -------------------------------------------------------------------------------- 1 | """Plotly templates for ArviZ styles.""" 2 | import plotly.graph_objects as go 3 | 4 | arviz_variat_template = go.layout.Template() 5 | 6 | arviz_variat_template.layout.paper_bgcolor = "white" 7 | arviz_variat_template.layout.plot_bgcolor = "white" 8 | arviz_variat_template.layout.polar.bgcolor = "white" 9 | arviz_variat_template.layout.ternary.bgcolor = "white" 10 | arviz_variat_template.layout.margin = {"l": 50, "r": 10, "t": 40, "b": 45} 11 | axis_common = {"showgrid": False, "ticks": "outside", "showline": True, "zeroline": False} 12 | arviz_variat_template.layout.xaxis = axis_common 13 | arviz_variat_template.layout.yaxis = axis_common 14 | 15 | arviz_variat_template.layout.colorway = [ 16 | "#36acc6", 17 | "#f66d7f", 18 | "#fac364", 19 | "#7c2695", 20 | "#228306", 21 | "#a252f4", 22 | "#63f0ea", 23 | "#000000", 24 | "#6f6f6f", 25 | "#b7b7b7", 26 | ] 27 | 28 | 29 | arviz_cetrino_template = go.layout.Template() 30 | 31 | arviz_cetrino_template.layout.paper_bgcolor = "white" 32 | arviz_cetrino_template.layout.plot_bgcolor = "white" 33 | arviz_cetrino_template.layout.polar.bgcolor = "white" 34 | arviz_cetrino_template.layout.ternary.bgcolor = "white" 35 | arviz_cetrino_template.layout.margin = {"l": 50, "r": 10, "t": 40, "b": 45} 36 | axis_common = {"showgrid": False, "ticks": "outside", "showline": True, "zeroline": False} 37 | arviz_cetrino_template.layout.xaxis = axis_common 38 | arviz_cetrino_template.layout.yaxis = axis_common 39 | 40 | arviz_cetrino_template.layout.colorway = [ 41 | "#009988", 42 | "#9238b2", 43 | "#d2225f", 44 | "#ec8f26", 45 | "#fcd026", 46 | "#3cd186", 47 | "#a57119", 48 | "#2f5e14", 49 | "#f225f4", 50 | "#8f9fbf", 51 | ] 52 | 53 | 54 | arviz_vibrant_template = go.layout.Template() 55 | 56 | arviz_vibrant_template.layout.paper_bgcolor = "white" 57 | arviz_vibrant_template.layout.plot_bgcolor = "white" 58 | arviz_vibrant_template.layout.polar.bgcolor = "white" 59 | arviz_vibrant_template.layout.ternary.bgcolor = "white" 60 | arviz_vibrant_template.layout.margin = {"l": 50, "r": 10, "t": 40, "b": 45} 61 | axis_common = {"showgrid": False, "ticks": "outside", "showline": True, "zeroline": False} 62 | arviz_vibrant_template.layout.xaxis = axis_common 63 | arviz_vibrant_template.layout.yaxis = axis_common 64 | 65 | arviz_vibrant_template.layout.colorway = [ 66 | "#008b92", 67 | "#f15c58", 68 | "#48cdef", 69 | "#98d81a", 70 | "#997ee5", 71 | "#f5dc9d", 72 | "#c90a4e", 73 | "#145393", 74 | "#323232", 75 | "#616161", 76 | ] 77 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/__init__.py: -------------------------------------------------------------------------------- 1 | """Batteries-included ArviZ plots.""" 2 | 3 | from .autocorr_plot import plot_autocorr 4 | from .bf_plot import plot_bf 5 | from .combine import combine_plots 6 | from .compare_plot import plot_compare 7 | from .convergence_dist_plot import plot_convergence_dist 8 | from .dist_plot import plot_dist 9 | from .ecdf_plot import plot_ecdf_pit 10 | from .energy_plot import plot_energy 11 | from .ess_plot import plot_ess 12 | from .evolution_plot import plot_ess_evolution 13 | from .forest_plot import plot_forest 14 | from .loo_pit_plot import plot_loo_pit 15 | from .mcse_plot import plot_mcse 16 | from .pairs_focus_plot import plot_pairs_focus 17 | from .pava_calibration_plot import plot_ppc_pava 18 | from .ppc_dist_plot import plot_ppc_dist 19 | from .ppc_pit_plot import plot_ppc_pit 20 | from .ppc_rootogram_plot import plot_ppc_rootogram 21 | from .ppc_tstat import plot_ppc_tstat 22 | from .prior_posterior_plot import plot_prior_posterior 23 | from .psense_dist_plot import plot_psense_dist 24 | from .psense_quantities_plot import plot_psense_quantities 25 | from .rank_dist_plot import plot_rank_dist 26 | from .rank_plot import plot_rank 27 | from .ridge_plot import plot_ridge 28 | from .trace_dist_plot import plot_trace_dist 29 | from .trace_plot import plot_trace 30 | from .utils import add_bands, add_lines 31 | 32 | __all__ = [ 33 | "combine_plots", 34 | "plot_autocorr", 35 | "plot_bf", 36 | "plot_compare", 37 | "plot_convergence_dist", 38 | "plot_dist", 39 | "plot_forest", 40 | "plot_trace", 41 | "plot_trace_dist", 42 | "plot_ecdf_pit", 43 | "plot_energy", 44 | "plot_ess", 45 | "plot_ess_evolution", 46 | "plot_loo_pit", 47 | "plot_mcse", 48 | "plot_ppc_dist", 49 | "plot_ppc_rootogram", 50 | "plot_prior_posterior", 51 | "plot_rank", 52 | "plot_rank_dist", 53 | "plot_ridge", 54 | "plot_ppc_pava", 55 | "plot_ppc_pit", 56 | "plot_ppc_tstat", 57 | "plot_psense_dist", 58 | "plot_psense_quantities", 59 | "plot_pairs_focus", 60 | "add_lines", 61 | "add_bands", 62 | ] 63 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/autocorr_plot.py: -------------------------------------------------------------------------------- 1 | """Autocorrelation plot code.""" 2 | 3 | from copy import copy 4 | from importlib import import_module 5 | 6 | import numpy as np 7 | from arviz_base import rcParams 8 | from arviz_base.labels import BaseLabeller 9 | 10 | from arviz_plots.plot_collection import PlotCollection 11 | from arviz_plots.plots.utils import filter_aes, process_group_variables_coords, set_wrap_layout 12 | from arviz_plots.visuals import fill_between_y, labelled_title, labelled_x, line, line_xy 13 | 14 | 15 | def plot_autocorr( 16 | dt, 17 | var_names=None, 18 | filter_vars=None, 19 | group="posterior", 20 | coords=None, 21 | sample_dims=None, 22 | max_lag=None, 23 | plot_collection=None, 24 | backend=None, 25 | labeller=None, 26 | aes_by_visuals=None, 27 | visuals=None, 28 | **pc_kwargs, 29 | ): 30 | """Autocorrelation plots for the given dataset. 31 | 32 | Line plot of the autocorrelation function (ACF) 33 | 34 | The ACF plots can be used as a convergence diagnostic for posteriors from MCMC 35 | samples. 36 | 37 | Parameters 38 | ---------- 39 | dt : DataTree 40 | Input data 41 | var_names : str or list of str, optional 42 | One or more variables to be plotted. Currently only one variable is supported. 43 | Prefix the variables by ~ when you want to exclude them from the plot. 44 | filter_vars : {None, “like”, “regex”}, optional, default=None 45 | If None (default), interpret var_names as the real variables names. 46 | If “like”, interpret var_names as substrings of the real variables names. 47 | If “regex”, interpret var_names as regular expressions on the real variables names. 48 | group : str, optional 49 | Which group to use. Defaults to "posterior". 50 | coords : dict, optional 51 | Coordinates to plot. 52 | sample_dims : str or sequence of hashable, optional 53 | Dimensions to reduce unless mapped to an aesthetic. 54 | Defaults to ``rcParams["data.sample_dims"]`` 55 | max_lag : int, optional 56 | Maximum lag to compute the ACF. Defaults to 100. 57 | plot_collection : PlotCollection, optional 58 | backend : {"matplotlib", "bokeh", "plotly"}, optional 59 | labeller : labeller, optional 60 | aes_by_visuals : mapping of {str : sequence of str}, optional 61 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 62 | when plotted. Valid keys are the same as for `visuals`. 63 | 64 | visuals : mapping of {str : mapping or False}, optional 65 | Valid keys are: 66 | 67 | * lines -> passed to :func:`~arviz_plots.visuals.ecdf_line` 68 | * ref_line -> passed to :func:`~arviz_plots.visuals.line_xy` 69 | * ci -> passed to :func:`~arviz_plots.visuals.fill_between_y` 70 | * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` 71 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title` 72 | 73 | pc_kwargs : mapping 74 | Passed to :class:`arviz_plots.PlotCollection.grid` 75 | 76 | Returns 77 | ------- 78 | PlotCollection 79 | 80 | Examples 81 | -------- 82 | Autocorrelation plot for mu variable in the centered eight dataset. 83 | 84 | .. plot:: 85 | :context: close-figs 86 | 87 | >>> from arviz_plots import plot_autocorr, style 88 | >>> style.use("arviz-variat") 89 | >>> from arviz_base import load_arviz_data 90 | >>> dt = load_arviz_data('centered_eight') 91 | >>> plot_autocorr(dt, var_names=["mu"]) 92 | 93 | 94 | .. minigallery:: plot_autocorr 95 | 96 | """ 97 | if sample_dims is None: 98 | sample_dims = rcParams["data.sample_dims"] 99 | if isinstance(sample_dims, str): 100 | sample_dims = [sample_dims] 101 | sample_dims = list(sample_dims) 102 | if visuals is None: 103 | visuals = {} 104 | else: 105 | visuals = visuals.copy() 106 | 107 | if backend is None: 108 | if plot_collection is None: 109 | backend = rcParams["plot.backend"] 110 | else: 111 | backend = plot_collection.backend 112 | 113 | labeller = BaseLabeller() 114 | 115 | # Default max lag to 100 116 | if max_lag is None: 117 | max_lag = 100 118 | 119 | distribution = process_group_variables_coords( 120 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords 121 | ) 122 | 123 | acf_dataset = distribution.azstats.autocorr(dim=sample_dims).sel(draw=slice(0, max_lag - 1)) 124 | c_i = 1.96 / acf_dataset.sizes["draw"] ** 0.5 125 | x_ci = np.arange(0, max_lag).astype(float) 126 | 127 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") 128 | default_linestyle = plot_bknd.get_default_aes("linestyle", 2, {})[1] 129 | 130 | if plot_collection is None: 131 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() 132 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() 133 | pc_kwargs.setdefault("col_wrap", 4) 134 | pc_kwargs.setdefault( 135 | "cols", ["__variable__"] + [dim for dim in acf_dataset.dims if dim not in sample_dims] 136 | ) 137 | 138 | if "chain" in distribution: 139 | pc_kwargs["aes"].setdefault("color", ["chain"]) 140 | pc_kwargs["aes"].setdefault("overlay", ["chain"]) 141 | 142 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) 143 | pc_kwargs["figure_kwargs"].setdefault("sharex", True) 144 | pc_kwargs["figure_kwargs"].setdefault("sharey", True) 145 | 146 | plot_collection = PlotCollection.wrap( 147 | distribution, 148 | backend=backend, 149 | **pc_kwargs, 150 | ) 151 | 152 | if aes_by_visuals is None: 153 | aes_by_visuals = {} 154 | else: 155 | aes_by_visuals = aes_by_visuals.copy() 156 | aes_by_visuals.setdefault("lines", plot_collection.aes_set) 157 | 158 | ## reference line 159 | ref_ls_kwargs = copy(visuals.get("ref_line", {})) 160 | 161 | if ref_ls_kwargs is not False: 162 | _, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "ref_line", sample_dims) 163 | ref_ls_kwargs.setdefault("color", "gray") 164 | ref_ls_kwargs.setdefault("linestyle", default_linestyle) 165 | 166 | plot_collection.map( 167 | line_xy, 168 | "ref_line", 169 | data=acf_dataset, 170 | x=x_ci, 171 | y=0, 172 | ignore_aes=ac_ls_ignore, 173 | **ref_ls_kwargs, 174 | ) 175 | 176 | ## autocorrelation line 177 | acf_ls_kwargs = copy(visuals.get("lines", {})) 178 | 179 | if acf_ls_kwargs is not False: 180 | _, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "lines", sample_dims) 181 | 182 | plot_collection.map( 183 | line, 184 | "lines", 185 | data=acf_dataset, 186 | ignore_aes=ac_ls_ignore, 187 | **acf_ls_kwargs, 188 | ) 189 | 190 | # Plot confidence intervals 191 | ci_kwargs = copy(visuals.get("ci", {})) 192 | _, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "ci", "draw") 193 | if ci_kwargs is not False: 194 | ci_kwargs.setdefault("color", "black") 195 | ci_kwargs.setdefault("alpha", 0.1) 196 | 197 | plot_collection.map( 198 | fill_between_y, 199 | "ci", 200 | data=acf_dataset, 201 | x=x_ci, 202 | y=0, 203 | y_bottom=-c_i, 204 | y_top=c_i, 205 | ignore_aes=ci_ignore, 206 | **ci_kwargs, 207 | ) 208 | 209 | # set xlabel 210 | _, xlabels_aes, xlabels_ignore = filter_aes( 211 | plot_collection, aes_by_visuals, "xlabel", sample_dims 212 | ) 213 | xlabel_kwargs = copy(visuals.get("xlabel", {})) 214 | if xlabel_kwargs is not False: 215 | if "color" not in xlabels_aes: 216 | xlabel_kwargs.setdefault("color", "black") 217 | 218 | xlabel_kwargs.setdefault("text", "Lag") 219 | plot_collection.map( 220 | labelled_x, 221 | "xlabel", 222 | ignore_aes=xlabels_ignore, 223 | subset_info=True, 224 | **xlabel_kwargs, 225 | ) 226 | 227 | # title 228 | title_kwargs = copy(visuals.get("title", {})) 229 | _, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims) 230 | 231 | if title_kwargs is not False: 232 | plot_collection.map( 233 | labelled_title, 234 | "title", 235 | ignore_aes=title_ignore, 236 | subset_info=True, 237 | labeller=labeller, 238 | **title_kwargs, 239 | ) 240 | 241 | return plot_collection 242 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/bf_plot.py: -------------------------------------------------------------------------------- 1 | """Contain functions for Bayes Factor plotting.""" 2 | 3 | from copy import copy 4 | 5 | import xarray as xr 6 | from arviz_stats.bayes_factor import bayes_factor 7 | 8 | from arviz_plots.plots.prior_posterior_plot import plot_prior_posterior 9 | from arviz_plots.plots.utils import add_lines, filter_aes 10 | 11 | 12 | def plot_bf( 13 | dt, 14 | var_names, 15 | ref_val=0, 16 | kind=None, 17 | sample_dims=None, 18 | plot_collection=None, 19 | backend=None, 20 | labeller=None, 21 | aes_by_visuals=None, 22 | visuals=None, 23 | stats=None, 24 | **pc_kwargs, 25 | ): 26 | r"""Bayes Factor for comparing hypothesis of two nested models. 27 | 28 | The Bayes factor is estimated by comparing a model (H1) against a model 29 | in which the parameter of interest has been restricted to be a point-null (H0) 30 | This computation assumes H0 is a special case of H1. For more details see here 31 | https://arviz-devs.github.io/EABM/Chapters/Model_comparison.html#savagedickey-ratio 32 | 33 | Parameters 34 | ---------- 35 | dt : DataTree or dict of {str : DataTree} 36 | Input data. In case of dictionary input, the keys are taken to be model names. 37 | In such cases, a dimension "model" is generated and can be used to map to aesthetics. 38 | var_names : str, optional 39 | Variables for which the bayes factor will be computed and the prior and 40 | posterior will be plotted. 41 | ref_val : int or float, default 0 42 | Reference (point-null) value for Bayes factor estimation. 43 | kind : {"kde", "hist", "dot", "ecdf"}, optional 44 | How to represent the marginal density. 45 | Defaults to ``rcParams["plot.density_kind"]`` 46 | sample_dims : str or sequence of hashable, optional 47 | Dimensions to reduce unless mapped to an aesthetic. 48 | Defaults to ``rcParams["data.sample_dims"]`` 49 | plot_collection : PlotCollection, optional 50 | backend : {"matplotlib", "bokeh", "plotly"}, optional 51 | labeller : labeller, optional 52 | aes_by_visuals : mapping of {str : sequence of str}, optional 53 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 54 | when plotted. Valid keys are the same as for `visuals`. 55 | visuals : mapping of {str : mapping or False}, optional 56 | Valid keys are: 57 | 58 | * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument. 59 | 60 | * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` 61 | * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` 62 | * "hist" -> passed to :func: `~arviz_plots.visuals.hist` 63 | 64 | * "ref_line -> passed to :func: `~arviz_plots.visuals.vline` 65 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title` 66 | 67 | stats : mapping, optional 68 | Valid keys are: 69 | 70 | * density -> passed to kde, ecdf, ... 71 | 72 | pc_kwargs : mapping 73 | Passed to :class:`arviz_plots.PlotCollection.wrap` 74 | 75 | Returns 76 | ------- 77 | PlotCollection 78 | 79 | Examples 80 | -------- 81 | Select one variable. 82 | 83 | .. plot:: 84 | :context: close-figs 85 | 86 | >>> from arviz_plots import plot_bf, style 87 | >>> style.use("arviz-variat") 88 | >>> from arviz_base import load_arviz_data 89 | >>> dt = load_arviz_data('centered_eight') 90 | >>> plot_bf(dt, var_names="mu", kind="hist") 91 | 92 | .. minigallery:: plot_bf 93 | """ 94 | if visuals is None: 95 | visuals = {} 96 | else: 97 | visuals = visuals.copy() 98 | if aes_by_visuals is None: 99 | aes_by_visuals = {} 100 | else: 101 | aes_by_visuals = aes_by_visuals.copy() 102 | 103 | bf, _ = bayes_factor(dt, var_names, ref_val, return_ref_vals=True) 104 | 105 | if isinstance(var_names, str): 106 | var_names = [var_names] 107 | bf_aes_ds = xr.Dataset( 108 | { 109 | var: xr.DataArray( 110 | None, 111 | coords={"BF_type": [f"BF01:{bf[var]['BF01']:.2f}"]}, 112 | dims=["BF_type"], 113 | ) 114 | for var in var_names 115 | } 116 | ) 117 | 118 | plot_collection = plot_prior_posterior( 119 | dt, 120 | var_names=var_names, 121 | coords=None, 122 | sample_dims=sample_dims, 123 | kind=kind, 124 | plot_collection=plot_collection, 125 | backend=backend, 126 | labeller=labeller, 127 | visuals=visuals, 128 | stats=stats, 129 | **pc_kwargs, 130 | ) 131 | 132 | plot_collection.update_aes_from_dataset("bf_aes", bf_aes_ds) 133 | 134 | ref_line_kwargs = copy(visuals.get("ref_line", {})) 135 | if ref_line_kwargs is False: 136 | raise ValueError( 137 | "visuals['ref_line'] can't be False, use ref_val=False to remove this element" 138 | ) 139 | 140 | if ref_val is not False: 141 | _, ref_aes, _ = filter_aes(plot_collection, aes_by_visuals, "ref_line", "sample") 142 | if "color" not in ref_aes: 143 | ref_line_kwargs.setdefault("color", "black") 144 | if "alpha" not in ref_aes: 145 | ref_line_kwargs.setdefault("alpha", 0.5) 146 | add_lines( 147 | plot_collection, 148 | ref_val, 149 | aes_by_visuals=aes_by_visuals, 150 | visuals={"ref_line": ref_line_kwargs}, 151 | ) 152 | 153 | if backend == "matplotlib": ## remove this when we have a better way to handle legends 154 | plot_collection.add_legend( 155 | ["__variable__", "BF_type"], loc="upper left", fontsize=10, text_only=True 156 | ) 157 | 158 | return plot_collection 159 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/combine.py: -------------------------------------------------------------------------------- 1 | """Elements to combine multiple batteries-included plots into a single figure.""" 2 | import re 3 | from importlib import import_module 4 | 5 | from arviz_base import rcParams 6 | 7 | from arviz_plots import PlotCollection 8 | from arviz_plots.plot_collection import backend_from_object 9 | from arviz_plots.plots.utils import process_group_variables_coords, set_grid_layout 10 | 11 | 12 | def get_valid_arg(key, value, backend): 13 | """Convert none backend aesthetic argument indicator to a valid value for the given backend. 14 | 15 | Parameters 16 | ---------- 17 | key : str 18 | The keyword part of the :ref:`backend-interface-arguments` for which `value` should 19 | be valid. 20 | value : any 21 | The current value for `key`. It might be an indicator from the none backend such as 22 | "color_0" or "linestyle_3" which gets processed or something else in which case 23 | it is assumed to be a valid argument already and returned as is. 24 | backend : str 25 | The backend for which `value` should be valid. 26 | 27 | Returns 28 | ------- 29 | valid_value : any 30 | """ 31 | plot_backend = import_module(f"arviz_plots.backend.{backend}") 32 | key_matcher = "color" if key in {"facecolor", "edgecolor"} else key 33 | if isinstance(value, str): 34 | match = re.match(key_matcher + "_([0-9]+)", value) 35 | if match: 36 | index = int(match.groups()[0]) 37 | return plot_backend.get_default_aes(key, index + 1)[index] 38 | return value 39 | 40 | 41 | def backendize_kwargs(kwargs, backend): 42 | """Process the visual description dictionary from the none backend to valid kwargs.""" 43 | return { 44 | key: get_valid_arg(key, value, backend) 45 | for key, value in kwargs.items() 46 | if key != "function" 47 | } 48 | 49 | 50 | def render(da, target, **kwargs): 51 | """Render visual descriptions from the none backend with a plotting backend.""" 52 | backend = backend_from_object(target, return_module=False) 53 | plot_backend = import_module(f"arviz_plots.backend.{backend}") 54 | visuals = da.item() 55 | plot_fun_name = visuals["function"] 56 | visuals = backendize_kwargs(visuals, backend) 57 | kwargs = backendize_kwargs(kwargs, backend) 58 | return getattr(plot_backend, plot_fun_name)(target=target, **{**visuals, **kwargs}) 59 | 60 | 61 | def combine_plots( 62 | dt, 63 | plots, 64 | var_names=None, 65 | filter_vars=None, 66 | group="posterior", 67 | coords=None, 68 | sample_dims=None, 69 | expand="column", 70 | plot_names=None, 71 | backend=None, 72 | **pc_kwargs, 73 | ): 74 | """Arrange multiple batteries-included plots in a customizable column or row layout. 75 | 76 | Parameters 77 | ---------- 78 | dt : DataTree of dict of {str : DataTree} 79 | Input data. In case of dictionary input, the keys are taken to be model names. 80 | In such cases, a dimension "model" is generated and can be used to map to aesthetics. 81 | 82 | Note that not all batteries included functions accept dictionary input, so it will 83 | only work when all plotting functions requested in `plots` are compatible with it. 84 | plots : list of tuple of (callable, mapping) 85 | List of all the plotting functions to be combined. Each element in this list 86 | is a tuple with two elements. The first is the function to be called, the second 87 | is a dictionary with any keyword arguments that should be used when calling that function. 88 | var_names : str or sequence of str, optional 89 | One or more variables to be plotted. 90 | Prefix the variables by ~ when you want to exclude them from the plot. 91 | filter_vars : {None, “like”, “regex”}, default None 92 | If None, interpret `var_names` as the real variables names. 93 | If “like”, interpret `var_names` as substrings of the real variables names. 94 | If “regex”, interpret `var_names` as regular expressions on the real variables names. 95 | group : str, default "posterior" 96 | Group to be plotted. 97 | coords : dict, optional 98 | sample_dims : str or sequence of hashable, optional 99 | Dimensions to reduce unless mapped to an aesthetic. 100 | Defaults to ``rcParams["data.sample_dims"]`` 101 | expand : {"column", "row"}, default "column" 102 | How to combine the different plotting functions. If "column", each plotting function 103 | will be added as a new column, if "row" it will be a new row instead. 104 | plot_names : list of str, optional 105 | List of the same length as `plots` with the plot names to use as coordinate values 106 | in the returned :class:`~arviz_plots.PlotCollection`. 107 | backend : {"matplotlib", "bokeh", "plotly"}, optional 108 | Plotting backend to use. Defaults to ``rcParams["plot.backend"]``. 109 | pc_kwargs : mapping, optional 110 | Passed to :class:`arviz_plots.PlotCollection.grid` 111 | 112 | Returns 113 | ------- 114 | PlotCollection 115 | 116 | Examples 117 | -------- 118 | Customize the names of the plots in the returned :class:`PlotCollection` 119 | 120 | .. plot:: 121 | :context: close-figs 122 | 123 | >>> import arviz_plots as azp 124 | >>> azp.style.use("arviz-variat") 125 | >>> from arviz_base import load_arviz_data 126 | >>> rugby = load_arviz_data('rugby') 127 | >>> pc = azp.combine_plots( 128 | >>> rugby, 129 | >>> plots=[ 130 | >>> (azp.plot_ppc_pit, {}), 131 | >>> (azp.plot_ppc_rootogram, {}), 132 | >>> ], 133 | >>> group="posterior_predictive", 134 | >>> plot_names=["pit", "rootogram"], 135 | >>> ) 136 | 137 | Now if we inspect the ``pc.viz`` attribute, we can see it has a ``column`` dimension 138 | with the requested coordinate values: 139 | 140 | .. plot:: 141 | :context: close-figs 142 | 143 | >>> pc.viz 144 | 145 | .. minigallery:: combine_plots 146 | """ 147 | if plot_names is None: 148 | plot_names = [ 149 | getattr(elem[0], "__name__") + f"_{idx:02d}" for idx, elem in enumerate(plots) 150 | ] 151 | if sample_dims is None: 152 | sample_dims = rcParams["data.sample_dims"] 153 | if isinstance(sample_dims, str): 154 | sample_dims = [sample_dims] 155 | if backend is None: 156 | backend = rcParams["plot.backend"] 157 | 158 | distribution = process_group_variables_coords( 159 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords 160 | ) 161 | facet_dims = ["__variable__"] + ( 162 | [] 163 | if "predictive" in group 164 | else [dim for dim in distribution.dims if dim not in sample_dims] 165 | ) 166 | 167 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() 168 | if expand == "column": 169 | pc_kwargs.setdefault("cols", ["column"]) 170 | pc_kwargs.setdefault("rows", facet_dims) 171 | expand_kwargs = {"column": len(plots)} 172 | elif expand == "row": 173 | pc_kwargs.setdefault("cols", facet_dims) 174 | pc_kwargs.setdefault("rows", ["row"]) 175 | expand_kwargs = {"row": len(plots)} 176 | else: 177 | raise ValueError(f"`expand` must be 'row' or 'column' but got '{expand}'") 178 | distribution = distribution.expand_dims(**expand_kwargs).assign_coords({expand: plot_names}) 179 | 180 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") 181 | pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, distribution) 182 | 183 | pc = PlotCollection.grid( 184 | distribution, 185 | backend=backend, 186 | **pc_kwargs, 187 | ) 188 | 189 | for name, (plot, kwargs) in zip(plot_names, plots): 190 | pc_i = plot( 191 | dt, 192 | backend="none", 193 | group=group, 194 | var_names=var_names, 195 | filter_vars=filter_vars, 196 | coords=coords, 197 | sample_dims=sample_dims, 198 | **kwargs, 199 | ) 200 | pc.coords = None 201 | pc.aes = pc_i.aes 202 | pc.coords = {expand: name} 203 | for viz_group, ds in pc_i.viz.children.items(): 204 | if viz_group in {"plot", "row_index", "col_index"}: 205 | continue 206 | attrs = ds.attrs 207 | pc.map( 208 | render, 209 | fun_label=f"{viz_group}_{name}", 210 | data=ds.dataset, 211 | ignore_aes=attrs.get("ignore_aes", frozenset()), 212 | ) 213 | pc.coords = None 214 | # TODO: at some point all `pc_i.aes` objects should be merged 215 | # and stored into the `pc.aes` attribute 216 | 217 | return pc 218 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/compare_plot.py: -------------------------------------------------------------------------------- 1 | """Compare plot code.""" 2 | from importlib import import_module 3 | 4 | import numpy as np 5 | from arviz_base import rcParams 6 | from xarray import Dataset, DataTree 7 | 8 | from arviz_plots.plot_collection import PlotCollection 9 | 10 | 11 | def plot_compare( 12 | cmp_df, 13 | similar_shade=True, 14 | relative_scale=False, 15 | backend=None, 16 | visuals=None, 17 | **pc_kwargs, 18 | ): 19 | r"""Summary plot for model comparison. 20 | 21 | Models are compared based on their expected log pointwise predictive density (ELPD). 22 | 23 | The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out 24 | cross-validation (LOO). Details are presented in [1]_ and [2]_. 25 | 26 | Parameters 27 | ---------- 28 | comp_df : pandas.DataFrame 29 | Result of the :func:`arviz_stats.compare` method. 30 | similar_shade : bool, optional 31 | If True, a shade is drawn to indicate models with similar 32 | predictive performance to the best model. Defaults to True. 33 | relative_scale : bool, optional. 34 | If True scale the ELPD values relative to the best model. 35 | Defaults to False. 36 | backend : {"bokeh", "matplotlib", "plotly"} 37 | Select plotting backend. Defaults to rcParams["plot.backend"]. 38 | figsize : (float, float), optional 39 | If `None`, size is (10, num of models) inches. 40 | visuals : mapping of {str : mapping or False}, optional 41 | Valid keys are: 42 | 43 | * point_estimate -> passed to :func:`~.backend.scatter` 44 | * error_bar -> passed to :func:`~.backend.line` 45 | * ref_line -> passed to :func:`~.backend.line` 46 | * shade -> passed to :func:`~.backend.fill_between_y` 47 | * labels -> passed to :func:`~.backend.xticks` and :func:`~.backend.yticks` 48 | * title -> passed to :func:`~.backend.title` 49 | * ticklabels -> passed to :func:`~.backend.yticks` 50 | 51 | pc_kwargs : mapping 52 | Passed to :class:`arviz_plots.PlotCollection` 53 | 54 | Returns 55 | ------- 56 | axes :bokeh figure, matplotlib axes or plotly figure 57 | 58 | See Also 59 | -------- 60 | :func:`arviz_stats.compare`: Summary plot for model comparison. 61 | :func:`arviz_stats.loo` : Compute the ELPD using Pareto smoothed importance sampling 62 | Leave-one-out cross-validation method. 63 | 64 | References 65 | ---------- 66 | .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation 67 | and WAIC*. Statistics and Computing. 27(5) (2017). 68 | https://doi.org/10.1007/s11222-016-9696-4. arXiv preprint https://arxiv.org/abs/1507.04544. 69 | 70 | .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. 71 | Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html 72 | arXiv preprint https://arxiv.org/abs/1507.02646 73 | """ 74 | # Check if cmp_df contains the required information 75 | 76 | column_index = [c.lower() for c in cmp_df.columns] 77 | 78 | if "elpd" not in column_index: 79 | raise ValueError( 80 | "cmp_df must have been created using the `compare` function from ArviZ-Stats." 81 | ) 82 | 83 | # Set default backend 84 | if backend is None: 85 | backend = rcParams["plot.backend"] 86 | 87 | if visuals is None: 88 | visuals = {} 89 | 90 | # Get plotting backend 91 | p_be = import_module(f"arviz_plots.backend.{backend}") 92 | 93 | # Get figure params and create figure and axis 94 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() 95 | figsize = pc_kwargs.get("figure_kwargs", {}).get("figsize", None) 96 | figsize_units = pc_kwargs["figure_kwargs"].get("figsize_units", "inches") 97 | 98 | figsize = p_be.scale_fig_size( 99 | figsize, 100 | rows=int(len(cmp_df) ** 0.5), 101 | cols=2, 102 | figsize_units=figsize_units, 103 | ) 104 | figsize_units = "dots" 105 | 106 | figure, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units) 107 | 108 | # Create plot collection 109 | plot_collection = PlotCollection( 110 | Dataset({}), 111 | viz_dt=DataTree.from_dict( 112 | {"/": Dataset({"figure": np.array(figure, dtype=object), "plot": target})} 113 | ), 114 | backend=backend, 115 | **pc_kwargs, 116 | ) 117 | 118 | if isinstance(target, np.ndarray): 119 | target = target.tolist() 120 | 121 | # Set scale relative to the best model 122 | if relative_scale: 123 | cmp_df = cmp_df.copy() 124 | cmp_df["elpd"] = cmp_df["elpd"] - cmp_df["elpd"].iloc[0] 125 | 126 | # Compute positions of yticks 127 | yticks_pos = list(range(len(cmp_df), 0, -1)) 128 | 129 | # Plot ELPD standard error bars 130 | if (error_kwargs := visuals.get("error_bar", {})) is not False: 131 | error_kwargs.setdefault("color", "black") 132 | 133 | # Compute values for standard error bars 134 | se_list = list(zip((cmp_df["elpd"] - cmp_df["se"]), (cmp_df["elpd"] + cmp_df["se"]))) 135 | 136 | for se_vals, ytick in zip(se_list, yticks_pos): 137 | p_be.line(se_vals, (ytick, ytick), target, **error_kwargs) 138 | 139 | # Add reference line for the best model 140 | if (ref_kwargs := visuals.get("ref_line", {})) is not False: 141 | ref_kwargs.setdefault("color", "gray") 142 | ref_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 2, {})[-1]) 143 | p_be.line( 144 | (cmp_df["elpd"].iloc[0], cmp_df["elpd"].iloc[0]), 145 | (yticks_pos[0], yticks_pos[-1]), 146 | target, 147 | **ref_kwargs, 148 | ) 149 | 150 | # Plot ELPD point estimates 151 | if (pe_kwargs := visuals.get("point_estimate", {})) is not False: 152 | pe_kwargs.setdefault("color", "black") 153 | p_be.scatter(cmp_df["elpd"], yticks_pos, target, **pe_kwargs) 154 | 155 | # Add shade for statistically undistinguishable models 156 | if similar_shade and (shade_kwargs := visuals.get("shade", {})) is not False: 157 | shade_kwargs.setdefault("color", "black") 158 | shade_kwargs.setdefault("alpha", 0.1) 159 | 160 | x_0, x_1 = cmp_df["elpd"].iloc[0] - 4, cmp_df["elpd"].iloc[0] 161 | 162 | padding = (yticks_pos[0] - yticks_pos[-1]) * 0.05 163 | p_be.fill_between_y( 164 | x=[x_0, x_1], 165 | y_bottom=yticks_pos[-1] - padding, 166 | y_top=yticks_pos[0] + padding, 167 | target=target, 168 | **shade_kwargs, 169 | ) 170 | 171 | # Add title and labels 172 | if (title_kwargs := visuals.get("title", {})) is not False: 173 | p_be.title( 174 | "Model comparison\nhigher is better", 175 | target, 176 | **title_kwargs, 177 | ) 178 | 179 | if (labels_kwargs := visuals.get("labels", {})) is not False: 180 | p_be.ylabel("ranked models", target, **labels_kwargs) 181 | p_be.xlabel("ELPD", target, **labels_kwargs) 182 | 183 | if (ticklabels_kwargs := visuals.get("ticklabels", {})) is not False: 184 | p_be.yticks(yticks_pos, cmp_df.index, target, **ticklabels_kwargs) 185 | 186 | return plot_collection 187 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/energy_plot.py: -------------------------------------------------------------------------------- 1 | """Energy plot code.""" 2 | import numpy as np 3 | from arviz_base import convert_to_dataset, rcParams 4 | 5 | from arviz_plots.plots.dist_plot import plot_dist 6 | 7 | 8 | def plot_energy( 9 | dt, 10 | bfmi=False, 11 | kind=None, 12 | plot_collection=None, 13 | backend=None, 14 | labeller=None, 15 | aes_by_visuals=None, 16 | visuals=None, 17 | stats=None, 18 | **pc_kwargs, 19 | ): 20 | r"""Plot transition distribution and marginal energy distribution in HMC algorithms. 21 | 22 | This may help to diagnose poor exploration by gradient-based algorithms like HMC or NUTS. 23 | The energy function in HMC can identify posteriors with heavy tailed distributions, that 24 | in practice are challenging for sampling. 25 | 26 | This plot is in the style of the one used in [1]_. 27 | 28 | Parameters 29 | ---------- 30 | dt : DataTree 31 | ``sample_stats`` group with an ``energy`` variable is mandatory. 32 | bfmi : bool 33 | Whether to the plot the value of the estimated Bayesian fraction of missing 34 | information. Defaults to False. Not implemented yet. 35 | kind : {"kde", "hist", "dot", "ecdf"}, optional 36 | How to represent the marginal density. 37 | Defaults to ``rcParams["plot.density_kind"]`` 38 | plot_collection : PlotCollection, optional 39 | backend : {"matplotlib", "bokeh", "plotly"}, optional 40 | labeller : labeller, optional 41 | aes_by_visuals : mapping of {str : sequence of str}, optional 42 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 43 | when plotted. Valid keys are the same as for `visuals`. 44 | 45 | visuals : mapping of {str : mapping or False}, optional 46 | Valid keys are: 47 | 48 | * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument. 49 | 50 | * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` 51 | * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` 52 | * "hist" -> passed to :func: `~arviz_plots.visuals.hist` 53 | 54 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title` 55 | * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function 56 | 57 | stats : mapping, optional 58 | Valid keys are: 59 | * density -> passed to kde, ecdf, ... 60 | 61 | pc_kwargs : mapping 62 | Passed to :class:`arviz_plots.PlotCollection.wrap` 63 | 64 | Returns 65 | ------- 66 | PlotCollection 67 | 68 | Examples 69 | -------- 70 | Plot a default energy plot 71 | 72 | .. plot:: 73 | :context: close-figs 74 | 75 | >>> from arviz_plots import plot_energy, style 76 | >>> style.use("arviz-variat") 77 | >>> from arviz_base import load_arviz_data 78 | >>> schools = load_arviz_data('centered_eight') 79 | >>> plot_energy(schools) 80 | 81 | 82 | .. minigallery:: plot_energy 83 | 84 | References 85 | ---------- 86 | .. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in 87 | Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695 88 | """ 89 | if kind is None: 90 | kind = rcParams["plot.density_kind"] 91 | if visuals is None: 92 | visuals = {} 93 | else: 94 | visuals = visuals.copy() 95 | 96 | new_ds = _get_energy_ds(dt) 97 | 98 | sample_dims = ["chain", "draw"] 99 | if not all(dim in new_ds.dims for dim in sample_dims): 100 | raise ValueError("Both 'chain' and 'draw' dimensions must be present in the dataset") 101 | 102 | pc_kwargs.setdefault("cols", None) 103 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() 104 | pc_kwargs["aes"].setdefault("color", ["energy"]) 105 | visuals.setdefault("credible_interval", False) 106 | visuals.setdefault("point_estimate", False) 107 | visuals.setdefault("point_estimate_text", False) 108 | visuals.setdefault("title", False) 109 | 110 | plot_collection = plot_dist( 111 | new_ds, 112 | var_names=None, 113 | filter_vars=None, 114 | group=None, 115 | coords=None, 116 | sample_dims=sample_dims, 117 | kind=kind, 118 | point_estimate=None, 119 | ci_kind=None, 120 | ci_prob=None, 121 | plot_collection=plot_collection, 122 | backend=backend, 123 | labeller=labeller, 124 | aes_by_visuals=aes_by_visuals, 125 | visuals=visuals, 126 | stats=stats, 127 | **pc_kwargs, 128 | ) 129 | 130 | plot_collection.add_legend("energy") 131 | 132 | if bfmi: 133 | raise NotImplementedError("BFMI is not implemented yet") 134 | 135 | return plot_collection 136 | 137 | 138 | def _get_energy_ds(dt): 139 | energy = dt["sample_stats"].energy.values 140 | return convert_to_dataset( 141 | {"energy_": np.dstack([energy - energy.mean(), np.diff(energy, append=np.nan)])}, 142 | coords={"energy__dim_0": ["marginal", "transition"]}, 143 | ).rename({"energy__dim_0": "energy"}) 144 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/loo_pit_plot.py: -------------------------------------------------------------------------------- 1 | """Plot loo pit.""" 2 | from arviz_base import convert_to_datatree 3 | from arviz_stats.loo import loo_pit 4 | 5 | from arviz_plots.plots.ecdf_plot import plot_ecdf_pit 6 | 7 | 8 | def plot_loo_pit( 9 | dt, 10 | ci_prob=None, 11 | coverage=False, 12 | method="simulation", 13 | n_simulations=1000, 14 | var_names=None, 15 | filter_vars=None, # pylint: disable=unused-argument 16 | group="posterior_predictive", 17 | coords=None, # pylint: disable=unused-argument 18 | sample_dims=None, 19 | plot_collection=None, 20 | backend=None, 21 | labeller=None, 22 | aes_by_visuals=None, 23 | visuals=None, 24 | **pc_kwargs, 25 | ): 26 | r"""LOO-PIT Δ-ECDF values with simultaneous confidence envelope. 27 | 28 | For a calibrated model the LOO Probability Integral Transform (PIT) values, 29 | $p(\tilde{y}_i \le y_i \mid y_{-i})$, should be uniformly distributed. 30 | Where $y_i$ represents the observed data for index $i$ and $\tilde y_i$ represents 31 | the posterior predictive sample at index $i$. $y_{-i}$ indicates we have left out the 32 | $i$-th observation. LOO-PIT values are computed using the PSIS-LOO-CV method described 33 | in [1]_ and [2]_. 34 | 35 | This plot shows the empirical cumulative distribution function (ECDF) of the LOO-PIT values. 36 | To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between the 37 | observed ECDF and the expected CDF. Simultaneous confidence bands are computed using the method 38 | described in described in [3]_. 39 | 40 | Alternatively, we can visualize the coverage of the central posterior credible intervals by 41 | setting ``coverage=True``. This allows us to assess whether the credible intervals includes 42 | the observed values. We can obtain the coverage of the central intervals from the LOO-PIT by 43 | replacing the LOO-PIT with two times the absolute difference between the LOO-PIT values and 0.5. 44 | 45 | For more details on how to interpret this plot, 46 | see https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#pit-ecdfs. 47 | 48 | Parameters 49 | ---------- 50 | dt : DataTree 51 | Input data 52 | ci_prob : float, optional 53 | Indicates the probability that should be contained within the plotted credible interval. 54 | Defaults to ``rcParams["stats.ci_prob"]`` 55 | coverage : bool, optional 56 | If True, plot the coverage of the central posterior credible intervals. Defaults to False. 57 | n_simulations : int, optional 58 | Number of simulations to use to compute simultaneous confidence intervals when using the 59 | `method="simulation"` ignored if method is "optimized". Defaults to 1000. 60 | method : str, optional 61 | Method to compute the confidence intervals. Either "simulation" or "optimized". 62 | Defaults to "simulation". 63 | var_names : str or list of str, optional 64 | One or more variables to be plotted. Currently only one variable is supported. 65 | Prefix the variables by ~ when you want to exclude them from the plot. 66 | filter_vars : {None, “like”, “regex”}, optional, default=None 67 | If None (default), interpret var_names as the real variables names. 68 | If “like”, interpret var_names as substrings of the real variables names. 69 | If “regex”, interpret var_names as regular expressions on the real variables names. 70 | coords : dict, optional 71 | Coordinates to plot. 72 | sample_dims : str or sequence of hashable, optional 73 | Dimensions to reduce unless mapped to an aesthetic. 74 | Defaults to ``rcParams["data.sample_dims"]`` 75 | plot_collection : PlotCollection, optional 76 | backend : {"matplotlib", "bokeh", "plotly"}, optional 77 | labeller : labeller, optional 78 | aes_by_visuals : mapping of {str : sequence of str}, optional 79 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 80 | when plotted. Valid keys are the same as for `visuals`. 81 | 82 | visuals : mapping of {str : mapping or False}, optional 83 | Valid keys are: 84 | 85 | * ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line` 86 | * ci -> passed to :func:`~arviz_plots.visuals.ci_line_y` 87 | * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` 88 | * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` 89 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title` 90 | 91 | pc_kwargs : mapping 92 | Passed to :class:`arviz_plots.PlotCollection.grid` 93 | 94 | Returns 95 | ------- 96 | PlotCollection 97 | 98 | Examples 99 | -------- 100 | Plot the ecdf-PIT for the crabs hurdle-negative-binomial dataset. 101 | 102 | .. plot:: 103 | :context: close-figs 104 | 105 | >>> from arviz_plots import plot_loo_pit, style 106 | >>> style.use("arviz-variat") 107 | >>> from arviz_base import load_arviz_data 108 | >>> dt = load_arviz_data('radon') 109 | >>> plot_loo_pit(dt) 110 | 111 | 112 | Plot the coverage for the crabs hurdle-negative-binomial dataset. 113 | 114 | .. plot:: 115 | :context: close-figs 116 | 117 | >>> plot_loo_pit(dt, coverage=True) 118 | 119 | 120 | .. minigallery:: plot_loo_pit 121 | 122 | References 123 | ---------- 124 | .. [1] Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation 125 | and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 126 | 127 | .. [2] Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning 128 | Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html 129 | 130 | .. [3] Säilynoja et al. *Graphical test for discrete uniformity and 131 | its applications in goodness-of-fit evaluation and multiple sample comparison*. 132 | Statistics and Computing 32(32). (2022) https://doi.org/10.1007/s11222-022-10090-6 133 | """ 134 | if visuals is None: 135 | visuals = {} 136 | else: 137 | visuals = visuals.copy() 138 | if isinstance(sample_dims, str): 139 | sample_dims = [sample_dims] 140 | 141 | if group != "posterior_predictive": 142 | raise ValueError(f"Group {group} not supported. Only 'posterior_predictive' is supported.") 143 | 144 | lpv = loo_pit(dt) 145 | new_dt = convert_to_datatree(lpv, group="loo_pit") 146 | 147 | visuals.setdefault("ylabel", {}) 148 | visuals.setdefault("remove_axis", False) 149 | visuals.setdefault("xlabel", {"text": "LOO-PIT"}) 150 | 151 | plot_collection = plot_ecdf_pit( 152 | new_dt, 153 | var_names=var_names, 154 | filter_vars=filter_vars, 155 | group="loo_pit", 156 | coords=coords, 157 | sample_dims=lpv.dims, 158 | ci_prob=ci_prob, 159 | coverage=coverage, 160 | n_simulations=n_simulations, 161 | method=method, 162 | plot_collection=plot_collection, 163 | backend=backend, 164 | labeller=labeller, 165 | aes_by_visuals=aes_by_visuals, 166 | visuals=visuals, 167 | **pc_kwargs, 168 | ) 169 | 170 | return plot_collection 171 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/pairs_focus_plot.py: -------------------------------------------------------------------------------- 1 | """Pair focus plot code.""" 2 | from copy import copy 3 | from importlib import import_module 4 | 5 | import numpy as np 6 | from arviz_base import rcParams 7 | from arviz_base.labels import BaseLabeller 8 | 9 | from arviz_plots.plot_collection import PlotCollection 10 | from arviz_plots.plots.utils import ( 11 | filter_aes, 12 | get_group, 13 | process_group_variables_coords, 14 | set_wrap_layout, 15 | ) 16 | from arviz_plots.visuals import divergence_scatter, labelled_x, labelled_y, scatter_x 17 | 18 | 19 | def plot_pairs_focus( 20 | dt, 21 | var_names=None, 22 | filter_vars=None, 23 | group="posterior", 24 | coords=None, 25 | sample_dims=None, 26 | focus_var=None, 27 | focus_var_coords=None, 28 | plot_collection=None, 29 | backend=None, 30 | labeller=None, 31 | aes_by_visuals=None, 32 | visuals=None, 33 | **pc_kwargs, 34 | ): 35 | """Plot a fixed variable against other variables in the dataset. 36 | 37 | Parameters 38 | ---------- 39 | dt : DataTree 40 | Input data 41 | var_names: str or list of str, optional 42 | One or more variables to be plotted. 43 | Prefix the variables by ~ when you want to exclude them from the plot. 44 | filter_vars: {None, “like”, “regex”}, optional, default=None 45 | If None (default), interpret var_names as the real variables names. 46 | If “like”, interpret var_names as substrings of the real variables names. 47 | If “regex”, interpret var_names as regular expressions on the real variables names. 48 | group : str, optional 49 | Group to use for plotting. Defaults to "posterior". 50 | coords : mapping, optional 51 | Coordinates to use for plotting. Defaults to None. 52 | sample_dims : iterable, optional 53 | Dimensions to reduce unless mapped to an aesthetic. 54 | Defaults to ``rcParams["data.sample_dims"]`` 55 | focus_var: str 56 | Name of the variable to be plotted against all other variables. 57 | focus_var_coords : mapping, optional 58 | Coordinates to use for the target variable. Defaults to None. 59 | plot_collection : PlotCollection, optional 60 | backend : {"matplotlib", "bokeh","plotly"}, optional 61 | labeller : labeller, optional 62 | aes_by_visuals : mapping, optional 63 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 64 | when plotted. Valid keys are the same as for `visuals`. 65 | visuals : mapping of {str : mapping or False}, optional 66 | Valid keys are: 67 | 68 | * scatter -> passed to :func:`~.visuals.scatter_x` 69 | * divergence -> passed to :func:`~.visuals.divergence_scatter`. Defaults to False. 70 | * title -> :func:`~.visuals.labelled_title` 71 | 72 | pc_kwargs : mapping 73 | Passed to :class:`arviz_plots.PlotCollection` 74 | 75 | Returns 76 | ------- 77 | PlotCollection 78 | 79 | Examples 80 | -------- 81 | Default plot_pair_focus 82 | 83 | .. plot:: 84 | :context: close-figs 85 | 86 | >>> from arviz_plots import plot_pairs_focus, style 87 | >>> style.use("arviz-variat") 88 | >>> from arviz_base import load_arviz_data 89 | >>> dt = load_arviz_data('centered_eight') 90 | >>> plot_pairs_focus( 91 | >>> dt, 92 | >>> var_names=["theta", "tau"], 93 | >>> focus_var="mu", 94 | >>> ) 95 | 96 | """ 97 | if sample_dims is None: 98 | sample_dims = rcParams["data.sample_dims"] 99 | if isinstance(sample_dims, str): 100 | sample_dims = [sample_dims] 101 | if visuals is None: 102 | visuals = {} 103 | if pc_kwargs is None: 104 | pc_kwargs = {} 105 | else: 106 | pc_kwargs = pc_kwargs.copy() 107 | 108 | if aes_by_visuals is None: 109 | aes_by_visuals = {} 110 | else: 111 | aes_by_visuals = aes_by_visuals.copy() 112 | 113 | if backend is None: 114 | if plot_collection is None: 115 | backend = rcParams["plot.backend"] 116 | else: 117 | backend = plot_collection.backend 118 | 119 | if var_names is None: 120 | var_names = "~" + focus_var 121 | 122 | distribution = process_group_variables_coords( 123 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords 124 | ) 125 | 126 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") 127 | 128 | if plot_collection is None: 129 | pc_kwargs.setdefault( 130 | "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims] 131 | ) 132 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() 133 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) 134 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() 135 | if "chain" in distribution: 136 | pc_kwargs["aes"].setdefault("overlay", ["chain"]) 137 | pc_kwargs["figure_kwargs"].setdefault("sharey", True) 138 | plot_collection = PlotCollection.wrap( 139 | distribution, 140 | backend=backend, 141 | **pc_kwargs, 142 | ) 143 | 144 | # scatter 145 | y = ( 146 | dt.posterior[focus_var].sel(focus_var_coords) 147 | if focus_var_coords is not None 148 | else dt.posterior[focus_var] 149 | ) 150 | aes_by_visuals["scatter"] = {"overlay"}.union(aes_by_visuals.get("scatter", {})) 151 | scatter_kwargs = copy(visuals.get("scatter", {})) 152 | scatter_kwargs.setdefault("alpha", 0.5) 153 | colors = plot_bknd.get_default_aes("color", 1, {}) 154 | scatter_kwargs.setdefault("color", colors[0]) 155 | scatter_kwargs.setdefault("width", 0) 156 | _, _, scatter_ignore = filter_aes(plot_collection, aes_by_visuals, "scatter", sample_dims) 157 | 158 | plot_collection.map( 159 | scatter_x, 160 | "scatter", 161 | ignore_aes=scatter_ignore, 162 | y=y, 163 | **scatter_kwargs, 164 | ) 165 | 166 | # divergence 167 | 168 | aes_by_visuals["divergence"] = {"overlay"}.union(aes_by_visuals.get("divergence", {})) 169 | div_kwargs = copy(visuals.get("divergence", False)) 170 | if div_kwargs is True: 171 | div_kwargs = {} 172 | sample_stats = get_group(dt, "sample_stats", allow_missing=True) 173 | if ( 174 | div_kwargs is not False 175 | and sample_stats is not None 176 | and "diverging" in sample_stats.data_vars 177 | and np.any(sample_stats.diverging) 178 | ): 179 | divergence_mask = dt.sample_stats.diverging 180 | _, div_aes, div_ignore = filter_aes( 181 | plot_collection, aes_by_visuals, "divergence", sample_dims 182 | ) 183 | if "color" not in div_aes: 184 | div_kwargs.setdefault("color", "black") 185 | div_kwargs.setdefault("alpha", 0.4) 186 | plot_collection.map( 187 | divergence_scatter, 188 | "divergence", 189 | ignore_aes=div_ignore, 190 | y=y, 191 | mask=divergence_mask, 192 | **div_kwargs, 193 | ) 194 | 195 | if labeller is None: 196 | labeller = BaseLabeller() 197 | 198 | # xlabel of plots 199 | 200 | xlabel_kwargs = copy(visuals.get("xlabel", {})) 201 | _, _, xlabel_ignore = filter_aes(plot_collection, aes_by_visuals, "xlabel", sample_dims) 202 | plot_collection.map( 203 | labelled_x, 204 | "xlabel", 205 | subset_info=True, 206 | ignore_aes=xlabel_ignore, 207 | labeller=labeller, 208 | **xlabel_kwargs, 209 | ) 210 | 211 | # ylabel of plots 212 | ylabel_kwargs = copy(visuals.get("ylabel", {})) 213 | _, _, ylabel_ignore = filter_aes(plot_collection, aes_by_visuals, "ylabel", sample_dims) 214 | plot_collection.map( 215 | labelled_y, 216 | "ylabel", 217 | ignore_aes=ylabel_ignore, 218 | text=focus_var, 219 | **ylabel_kwargs, 220 | ) 221 | 222 | return plot_collection 223 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/prior_posterior_plot.py: -------------------------------------------------------------------------------- 1 | """Contain functions for Bayes Factor plotting.""" 2 | 3 | from importlib import import_module 4 | 5 | import numpy as np 6 | from arviz_base import extract, rcParams 7 | from xarray import concat 8 | 9 | from arviz_plots.plot_collection import PlotCollection 10 | from arviz_plots.plots.dist_plot import plot_dist 11 | from arviz_plots.plots.utils import process_group_variables_coords, set_wrap_layout 12 | 13 | 14 | def plot_prior_posterior( 15 | dt, 16 | var_names=None, 17 | filter_vars=None, 18 | group=None, # pylint: disable=unused-argument 19 | coords=None, 20 | sample_dims=None, 21 | kind=None, 22 | plot_collection=None, 23 | backend=None, 24 | labeller=None, 25 | aes_by_visuals=None, 26 | visuals=None, 27 | stats=None, 28 | **pc_kwargs, 29 | ): 30 | r"""Plot 1D marginal densities for prior and posterior. 31 | 32 | The Bayes factor is estimated by comparing a model (H1) against a model 33 | in which the parameter of interest has been restricted to be a point-null (H0) 34 | This computation assumes the models are nested and thus H0 is a special case of H1. 35 | 36 | Parameters 37 | ---------- 38 | dt : DataTree or dict of {str : DataTree} 39 | Input data. In case of dictionary input, the keys are taken to be model names. 40 | In such cases, a dimension "model" is generated and can be used to map to aesthetics. 41 | var_names : str or list of str, optional 42 | One or more variables to be plotted. 43 | Prefix the variables by ~ when you want to exclude them from the plot. 44 | filter_vars : {None, “like”, “regex”}, default=None 45 | If None, interpret var_names as the real variables names. 46 | If “like”, interpret var_names as substrings of the real variables names. 47 | If “regex”, interpret var_names as regular expressions on the real variables names. 48 | group : None 49 | This argument is ignored. Have it here for compatibility with other plotting functions. 50 | coords : dict, optional 51 | sample_dims : str or sequence of hashable, optional 52 | Dimensions to reduce unless mapped to an aesthetic. 53 | Defaults to ``rcParams["data.sample_dims"]`` 54 | kind : {"kde", "hist", "dot", "ecdf"}, optional 55 | How to represent the marginal density. 56 | Defaults to ``rcParams["plot.density_kind"]`` 57 | plot_collection : PlotCollection, optional 58 | backend : {"matplotlib", "bokeh"}, optional 59 | labeller : labeller, optional 60 | aes_by_visuals : mapping of {str : sequence of str}, optional 61 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 62 | when plotted. Valid keys are the same as for `visuals`. 63 | visuals : mapping of {str : mapping or False}, optional 64 | Valid keys are: 65 | 66 | * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument. 67 | 68 | * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` 69 | * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` 70 | * "hist" -> passed to :func: `~arviz_plots.visuals.hist` 71 | 72 | * title -> passed to :func:`~arviz_plots.visuals.labelled_title` 73 | 74 | stats : mapping, optional 75 | Valid keys are: 76 | 77 | * density -> passed to kde, ecdf, ... 78 | 79 | pc_kwargs : mapping 80 | Passed to :class:`arviz_plots.PlotCollection.wrap` 81 | 82 | Returns 83 | ------- 84 | PlotCollection 85 | 86 | Examples 87 | -------- 88 | Select two variables and plot them with a ecdf. 89 | 90 | .. plot:: 91 | :context: close-figs 92 | 93 | >>> from arviz_plots import plot_prior_posterior, style 94 | >>> style.use("arviz-variat") 95 | >>> from arviz_base import load_arviz_data 96 | >>> dt = load_arviz_data('centered_eight') 97 | >>> plot_prior_posterior(dt, var_names=["mu", "tau"], kind="ecdf") 98 | 99 | 100 | .. minigallery:: plot_prior_posterior 101 | """ 102 | if sample_dims is None: 103 | sample_dims = rcParams["data.sample_dims"] 104 | if isinstance(sample_dims, str): 105 | sample_dims = [sample_dims] 106 | sample_dims = list(sample_dims) 107 | if kind is None: 108 | kind = rcParams["plot.density_kind"] 109 | if stats is None: 110 | stats = {} 111 | else: 112 | stats = stats.copy() 113 | if visuals is None: 114 | visuals = {} 115 | else: 116 | visuals = visuals.copy() 117 | if sample_dims is None: 118 | sample_dims = rcParams["data.sample_dims"] 119 | if isinstance(sample_dims, str): 120 | sample_dims = [sample_dims] 121 | sample_dims = list(sample_dims) 122 | if not isinstance(visuals, dict): 123 | visuals = {} 124 | 125 | if backend is None: 126 | if plot_collection is None: 127 | backend = rcParams["plot.backend"] 128 | else: 129 | backend = plot_collection.backend 130 | 131 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") 132 | 133 | prior_size = np.prod([dt.prior.sizes[dim] for dim in sample_dims]) 134 | posterior_size = np.prod([dt.posterior.sizes[dim] for dim in sample_dims]) 135 | num_samples = min(prior_size, posterior_size) 136 | 137 | ds_prior = ( 138 | extract(dt, group="prior", num_samples=num_samples, random_seed=0, keep_dataset=True) 139 | .drop_vars(sample_dims + ["sample"]) 140 | .assign_coords(sample=("sample", np.arange(num_samples))) 141 | ) 142 | ds_posterior = ( 143 | extract(dt, group="posterior", num_samples=num_samples, random_seed=0, keep_dataset=True) 144 | .drop_vars(sample_dims + ["sample"]) 145 | .assign_coords(sample=("sample", np.arange(num_samples))) 146 | ) 147 | 148 | distribution = concat([ds_prior, ds_posterior], dim="group").assign_coords( 149 | {"group": ["prior", "posterior"]} 150 | ) 151 | 152 | distribution = process_group_variables_coords( 153 | distribution, 154 | group=None, 155 | var_names=var_names, 156 | filter_vars=filter_vars, 157 | coords=coords, 158 | ) 159 | 160 | if len(sample_dims) > 1: 161 | # sample dims will have been stacked and renamed by `extract` 162 | sample_dims = ["sample"] 163 | 164 | if plot_collection is None: 165 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() 166 | 167 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() 168 | pc_kwargs["aes"].setdefault("color", ["group"]) 169 | pc_kwargs.setdefault("col_wrap", 4) 170 | pc_kwargs.setdefault( 171 | "cols", 172 | ["__variable__"] 173 | + [dim for dim in distribution.dims if dim not in sample_dims + ["group"]], 174 | ) 175 | 176 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) 177 | 178 | plot_collection = PlotCollection.wrap( 179 | distribution, 180 | backend=backend, 181 | **pc_kwargs, 182 | ) 183 | 184 | visuals.setdefault("credible_interval", False) 185 | visuals.setdefault("point_estimate", False) 186 | visuals.setdefault("point_estimate_text", False) 187 | 188 | if aes_by_visuals is None: 189 | aes_by_visuals = {} 190 | else: 191 | aes_by_visuals = aes_by_visuals.copy() 192 | 193 | if kind == "hist": 194 | visuals.setdefault("hist", {}) 195 | visuals.setdefault("remove_axis", True) 196 | if visuals["hist"] is not False: 197 | visuals["hist"].setdefault("step", True) 198 | stats.setdefault("density", {"density": True}) 199 | 200 | plot_collection = plot_dist( 201 | distribution, 202 | var_names=None, 203 | group=None, 204 | coords=None, 205 | sample_dims=sample_dims, 206 | kind=kind, 207 | point_estimate=None, 208 | ci_kind=None, 209 | ci_prob=None, 210 | plot_collection=plot_collection, 211 | backend=backend, 212 | labeller=labeller, 213 | visuals=visuals, 214 | stats=stats, 215 | **pc_kwargs, 216 | ) 217 | 218 | plot_collection.add_legend("group") 219 | 220 | return plot_collection 221 | -------------------------------------------------------------------------------- /src/arviz_plots/plots/trace_plot.py: -------------------------------------------------------------------------------- 1 | """Trace plot code.""" 2 | from copy import copy 3 | from importlib import import_module 4 | 5 | import numpy as np 6 | from arviz_base import rcParams 7 | from arviz_base.labels import BaseLabeller 8 | 9 | from arviz_plots.plot_collection import PlotCollection 10 | from arviz_plots.plots.utils import ( 11 | filter_aes, 12 | get_group, 13 | process_group_variables_coords, 14 | set_wrap_layout, 15 | ) 16 | from arviz_plots.visuals import labelled_title, labelled_x, line, ticklabel_props, trace_rug 17 | 18 | 19 | def plot_trace( 20 | dt, 21 | var_names=None, 22 | filter_vars=None, 23 | group="posterior", 24 | coords=None, 25 | sample_dims=None, 26 | plot_collection=None, 27 | backend=None, 28 | labeller=None, 29 | aes_by_visuals=None, 30 | visuals=None, 31 | **pc_kwargs, 32 | ): 33 | """Plot iteration versus sampled values. 34 | 35 | Parameters 36 | ---------- 37 | dt : DataTree 38 | Input data 39 | var_names: str or list of str, optional 40 | One or more variables to be plotted. 41 | Prefix the variables by ~ when you want to exclude them from the plot. 42 | filter_vars: {None, “like”, “regex”}, optional, default=None 43 | If None (default), interpret var_names as the real variables names. 44 | If “like”, interpret var_names as substrings of the real variables names. 45 | If “regex”, interpret var_names as regular expressions on the real variables names. 46 | sample_dims : iterable, optional 47 | Dimensions to reduce unless mapped to an aesthetic. 48 | Defaults to ``rcParams["data.sample_dims"]`` 49 | plot_collection : PlotCollection, optional 50 | backend : {"matplotlib", "bokeh"}, optional 51 | labeller : labeller, optional 52 | aes_by_visuals : mapping, optional 53 | Mapping of visuals to aesthetics that should use their mapping in `plot_collection` 54 | when plotted. Defaults to only mapping properties to the trace lines. 55 | visuals : mapping of {str : mapping or False}, optional 56 | Valid keys are: 57 | 58 | * trace -> passed to :func:`~.visuals.line` 59 | * divergence -> passed to :func:`~.visuals.trace_rug` 60 | * title -> :func:`~.visuals.labelled_title` 61 | * xlabel -> :func:`~.visuals.labelled_x` 62 | * ticklabels -> :func:`~.visuals.ticklabel_props` 63 | 64 | pc_kwargs : mapping 65 | Passed to :class:`arviz_plots.PlotCollection` 66 | 67 | Returns 68 | ------- 69 | PlotCollection 70 | 71 | Examples 72 | -------- 73 | The following examples focus on behaviour specific to ``plot_trace``. 74 | For a general introduction to batteries-included functions like this one and common 75 | usage examples see :ref:`plots_intro` 76 | 77 | Default plot_trace 78 | 79 | .. plot:: 80 | :context: close-figs 81 | 82 | >>> from arviz_plots import plot_trace, style 83 | >>> style.use("arviz-variat") 84 | >>> from arviz_base import load_arviz_data 85 | >>> centered = load_arviz_data('centered_eight') 86 | >>> plot_trace(centered) 87 | 88 | """ 89 | if sample_dims is None: 90 | sample_dims = rcParams["data.sample_dims"] 91 | if isinstance(sample_dims, str): 92 | sample_dims = [sample_dims] 93 | if visuals is None: 94 | visuals = {} 95 | 96 | distribution = process_group_variables_coords( 97 | dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords 98 | ) 99 | 100 | if backend is None: 101 | if plot_collection is None: 102 | backend = rcParams["plot.backend"] 103 | else: 104 | backend = plot_collection.backend 105 | 106 | plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") 107 | 108 | if plot_collection is None: 109 | pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() 110 | if "chain" in distribution: 111 | pc_kwargs["aes"].setdefault("color", ["chain"]) 112 | pc_kwargs["aes"].setdefault("overlay", ["chain"]) 113 | pc_kwargs.setdefault( 114 | "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims] 115 | ) 116 | pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() 117 | aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"] 118 | pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) 119 | pc_kwargs["figure_kwargs"].setdefault("sharex", True) 120 | plot_collection = PlotCollection.wrap( 121 | distribution, 122 | backend=backend, 123 | **pc_kwargs, 124 | ) 125 | else: 126 | aux_dim_list = list(plot_collection.viz["plot"].dims) 127 | 128 | if aes_by_visuals is None: 129 | aes_by_visuals = {} 130 | else: 131 | aes_by_visuals = aes_by_visuals.copy() 132 | aes_by_visuals.setdefault("trace", plot_collection.aes_set) 133 | aes_by_visuals.setdefault("divergence", {"overlay"}) 134 | 135 | if labeller is None: 136 | labeller = BaseLabeller() 137 | 138 | # trace 139 | trace_kwargs = copy(visuals.get("trace", {})) 140 | if trace_kwargs is False: 141 | xname = None 142 | else: 143 | default_xname = sample_dims[0] if len(sample_dims) == 1 else "draw" 144 | if (default_xname not in distribution.dims) or ( 145 | not np.issubdtype(distribution[default_xname].dtype, np.number) 146 | ): 147 | default_xname = None 148 | xname = trace_kwargs.get("xname", default_xname) 149 | trace_kwargs["xname"] = xname 150 | _, _, trace_ignore = filter_aes(plot_collection, aes_by_visuals, "trace", sample_dims) 151 | plot_collection.map( 152 | line, 153 | "trace", 154 | data=distribution, 155 | ignore_aes=trace_ignore, 156 | **trace_kwargs, 157 | ) 158 | 159 | # divergences 160 | sample_stats = get_group(dt, "sample_stats", allow_missing=True) 161 | divergence_kwargs = copy(visuals.get("divergence", {})) 162 | if ( 163 | sample_stats is not None 164 | and "diverging" in sample_stats.data_vars 165 | and np.any(sample_stats.diverging) 166 | and divergence_kwargs is not False 167 | ): 168 | divergence_mask = dt.sample_stats.diverging 169 | _, div_aes, div_ignore = filter_aes( 170 | plot_collection, aes_by_visuals, "divergence", sample_dims 171 | ) 172 | if "color" not in div_aes: 173 | divergence_kwargs.setdefault("color", "black") 174 | if "marker" not in div_aes: 175 | divergence_kwargs.setdefault("marker", "|") 176 | if "size" not in div_aes: 177 | divergence_kwargs.setdefault("size", 30) 178 | div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] 179 | 180 | plot_collection.map( 181 | trace_rug, 182 | "divergence", 183 | data=distribution, 184 | ignore_aes=div_ignore, 185 | xname=xname, 186 | y=distribution.min(div_reduce_dims), 187 | mask=divergence_mask, 188 | **divergence_kwargs, 189 | ) 190 | 191 | # aesthetics 192 | title_kwargs = copy(visuals.get("title", {})) 193 | if title_kwargs is not False: 194 | _, title_aes, title_ignore = filter_aes( 195 | plot_collection, aes_by_visuals, "title", sample_dims 196 | ) 197 | if "color" not in title_aes: 198 | title_kwargs.setdefault("color", "black") 199 | plot_collection.map( 200 | labelled_title, 201 | "title", 202 | ignore_aes=title_ignore, 203 | subset_info=True, 204 | labeller=labeller, 205 | **title_kwargs, 206 | ) 207 | 208 | # Add "Steps" as x_label for trace 209 | xlabel_kwargs = copy(visuals.get("xlabel", {})) 210 | if xlabel_kwargs is not False: 211 | _, xlabel_aes, xlabel_ignore = filter_aes( 212 | plot_collection, aes_by_visuals, "xlabel", sample_dims 213 | ) 214 | 215 | if "color" not in xlabel_aes: 216 | xlabel_kwargs.setdefault("color", "black") 217 | 218 | plot_collection.map( 219 | labelled_x, 220 | "xlabel", 221 | ignore_aes=xlabel_ignore, 222 | text="Steps" if xname is None else xname.capitalize(), 223 | **xlabel_kwargs, 224 | ) 225 | 226 | # Adjust tick labels 227 | ticklabels_kwargs = copy(visuals.get("ticklabels", {})) 228 | if ticklabels_kwargs is not False: 229 | _, _, ticklabels_ignore = filter_aes( 230 | plot_collection, aes_by_visuals, "ticklabels", sample_dims 231 | ) 232 | plot_collection.map( 233 | ticklabel_props, 234 | "ticklabels", 235 | ignore_aes=ticklabels_ignore, 236 | axis="both", 237 | store_artist=backend == "none", 238 | **ticklabels_kwargs, 239 | ) 240 | 241 | return plot_collection 242 | -------------------------------------------------------------------------------- /src/arviz_plots/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for PEP 561 2 | -------------------------------------------------------------------------------- /src/arviz_plots/style.py: -------------------------------------------------------------------------------- 1 | """Style/templating helpers.""" 2 | import os 3 | 4 | from arviz_base import rcParams 5 | 6 | 7 | def use(name): 8 | """Set an arviz style as the default style/template for all available backends. 9 | 10 | Parameters 11 | ---------- 12 | name : str 13 | Name of the style to be set as default. 14 | """ 15 | ok = False 16 | 17 | try: 18 | import matplotlib.pyplot as plt 19 | 20 | if name in plt.style.available: 21 | plt.style.use(name) 22 | ok = True 23 | except ImportError: 24 | pass 25 | 26 | try: 27 | import plotly.io as pio 28 | 29 | if name in pio.templates: 30 | pio.templates.default = name 31 | ok = True 32 | except ImportError: 33 | pass 34 | 35 | try: 36 | if name in ["arviz-cetrino", "arviz-variat", "arviz-vibrant"]: 37 | from bokeh.io import curdoc 38 | from bokeh.themes import Theme 39 | 40 | path = os.path.dirname(os.path.abspath(__file__)) 41 | curdoc().theme = Theme(filename=f"{path}/styles/{name}.yml") 42 | ok = True 43 | except (ImportError, FileNotFoundError): 44 | pass 45 | 46 | if not ok: 47 | raise ValueError(f"Style {name} not found.") 48 | 49 | 50 | def available(): 51 | """List available styles.""" 52 | styles = {} 53 | 54 | try: 55 | import matplotlib.pyplot as plt 56 | 57 | styles["matplotlib"] = plt.style.available 58 | except ImportError: 59 | pass 60 | 61 | try: 62 | import plotly.io as pio 63 | 64 | styles["plotly"] = list(pio.templates) 65 | except ImportError: 66 | pass 67 | 68 | return styles 69 | 70 | 71 | def get(name, backend=None): 72 | """Get the style/template with the given name. 73 | 74 | Parameters 75 | ---------- 76 | name : str 77 | Name of the style/template to get. 78 | backend : str 79 | Name of the backend to use. Options are 'matplotlib' and 'plotly'. 80 | Defaults to ``rcParams["plot.backend"]``. 81 | """ 82 | if backend is None: 83 | backend = rcParams["plot.backend"] 84 | if backend not in ["matplotlib", "plotly"]: 85 | raise ValueError(f"Default styles/templates are not supported for Backend {backend}") 86 | 87 | if backend == "matplotlib": 88 | import matplotlib.pyplot as plt 89 | 90 | if name in plt.style.available: 91 | return plt.style.library[name] 92 | 93 | elif backend == "plotly": 94 | import plotly.io as pio 95 | 96 | if name in pio.templates: 97 | return pio.templates[name] 98 | 99 | raise ValueError(f"Style {name} not found.") 100 | -------------------------------------------------------------------------------- /src/arviz_plots/styles/arviz-cetrino.mplstyle: -------------------------------------------------------------------------------- 1 | ## *************************************************************************** 2 | ## * FIGURE * 3 | ## *************************************************************************** 4 | 5 | figure.facecolor: white # broken white outside box 6 | figure.edgecolor: None # broken white outside box 7 | figure.titleweight: bold # weight of the figure title 8 | figure.titlesize: x-large 9 | 10 | figure.figsize: 6, 5 11 | figure.dpi: 200.0 12 | figure.constrained_layout.use: True 13 | 14 | ## *************************************************************************** 15 | ## * FONT * 16 | ## *************************************************************************** 17 | 18 | font.size: 12 19 | font.style: normal 20 | font.variant: normal 21 | font.weight: normal 22 | font.stretch: normal 23 | 24 | text.color: .15 25 | 26 | ## *************************************************************************** 27 | ## * AXES * 28 | ## *************************************************************************** 29 | 30 | axes.facecolor: white 31 | axes.edgecolor: .33 # axes edge color 32 | axes.linewidth: 0.8 # edge line width 33 | 34 | axes.grid: False # do not show grid 35 | # axes.grid.axis: y # which axis the grid should apply to 36 | axes.grid.which: major # grid lines at {major, minor, both} ticks 37 | axes.axisbelow: True # keep grid layer in the back 38 | 39 | grid.color: .8 # grid color 40 | grid.linestyle: - # solid 41 | grid.linewidth: 0.8 # in points 42 | grid.alpha: 1.0 # transparency, between 0.0 and 1.0 43 | 44 | lines.solid_capstyle: round 45 | 46 | axes.spines.right: False # do not show right spine 47 | axes.spines.top: False # do not show top spine 48 | 49 | axes.titlesize: 16 50 | axes.titleweight: bold # font weight of title 51 | 52 | axes.labelsize: large 53 | axes.labelcolor: .15 54 | axes.labelweight: normal # weight of the x and y labels 55 | 56 | # color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/ 57 | # see preview and check for colorblindness here https://coolors.co/009988-9238b2-d2225f-ec8f26-fcd026-3cd186-a57119-2f5e14-f225f4-8f9fbf 58 | axes.prop_cycle: cycler(color=["009988", "9238b2", "d2225f", "ec8f26", "fcd026", "3cd186", "a57119", "2f5e14", "f225f4", "8f9fbf"]) 59 | 60 | image.cmap: viridis 61 | 62 | ## *************************************************************************** 63 | ## * TICKS * 64 | ## *************************************************************************** 65 | 66 | xtick.labelsize: large 67 | xtick.color: .15 68 | xtick.top: False 69 | xtick.bottom: True 70 | xtick.direction: out 71 | 72 | ytick.labelsize: large 73 | ytick.color: .15 74 | ytick.left: True 75 | ytick.right: False 76 | ytick.direction: out 77 | 78 | ## *************************************************************************** 79 | ## * LEGEND * 80 | ## *************************************************************************** 81 | 82 | legend.framealpha: 0.5 83 | legend.frameon: False # do not draw on background patch 84 | legend.fancybox: False # do not round corners 85 | 86 | legend.numpoints: 1 87 | legend.scatterpoints: 1 88 | 89 | legend.fontsize: large -------------------------------------------------------------------------------- /src/arviz_plots/styles/arviz-cetrino.yml: -------------------------------------------------------------------------------- 1 | attrs: 2 | Plot: 3 | background_fill_color: white 4 | border_fill_color: white 5 | outline_line_width: 0 6 | outline_line_color: null 7 | Axis: 8 | major_tick_line_color: '#262626' 9 | minor_tick_line_alpha: 0 10 | Grid: 11 | grid_line_color: null 12 | Title: 13 | text_color: 'black' 14 | text_font_style: 'bold' 15 | align: 'center' 16 | Text: 17 | text_color: '#262626' 18 | text_font_style: 'bold' 19 | Cycler: 20 | colors : [ 21 | '#009988', '#9238b2', '#d2225f', '#ec8f26', '#fcd026', 22 | '#3cd186', '#a57119', '#2f5e14', '#f225f4', '#8f9fbf', 23 | ] 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/arviz_plots/styles/arviz-variat.mplstyle: -------------------------------------------------------------------------------- 1 | ## *************************************************************************** 2 | ## * FIGURE * 3 | ## *************************************************************************** 4 | 5 | figure.facecolor: white # broken white outside box 6 | figure.edgecolor: None # broken white outside box 7 | figure.titleweight: bold # weight of the figure title 8 | figure.titlesize: x-large 9 | 10 | figure.figsize: 6, 5 11 | figure.dpi: 200.0 12 | figure.constrained_layout.use: True 13 | 14 | ## *************************************************************************** 15 | ## * FONT * 16 | ## *************************************************************************** 17 | 18 | font.size: 12 19 | font.style: normal 20 | font.variant: normal 21 | font.weight: normal 22 | font.stretch: normal 23 | 24 | text.color: .15 25 | 26 | ## *************************************************************************** 27 | ## * AXES * 28 | ## *************************************************************************** 29 | 30 | axes.facecolor: white 31 | axes.edgecolor: .33 # axes edge color 32 | axes.linewidth: 0.8 # edge line width 33 | 34 | axes.grid: False # do not show grid 35 | # axes.grid.axis: y # which axis the grid should apply to 36 | axes.grid.which: major # grid lines at {major, minor, both} ticks 37 | axes.axisbelow: True # keep grid layer in the back 38 | 39 | grid.color: .8 # grid color 40 | grid.linestyle: - # solid 41 | grid.linewidth: 0.8 # in points 42 | grid.alpha: 1.0 # transparency, between 0.0 and 1.0 43 | 44 | lines.solid_capstyle: round 45 | 46 | axes.spines.right: False # do not show right spine 47 | axes.spines.top: False # do not show top spine 48 | 49 | axes.titlesize: 16 50 | axes.titleweight: bold # font weight of title 51 | 52 | axes.labelsize: large 53 | axes.labelcolor: .15 54 | axes.labelweight: normal # weight of the x and y labels 55 | 56 | # color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/ 57 | # see preview and check for colorblindness here https://coolors.co/36acc6-f66d7f-fac364-7c2695-228306-a252f4-63f0ea-000000-6f6f6f-b7b7b7 58 | axes.prop_cycle: cycler(color=["36acc6", "f66d7f", "fac364", "7c2695", "228306", "a252f4", "63f0ea", "000000", "6f6f6f", "b7b7b7"]) 59 | 60 | image.cmap: viridis 61 | 62 | ## *************************************************************************** 63 | ## * TICKS * 64 | ## *************************************************************************** 65 | 66 | xtick.labelsize: large 67 | xtick.color: .15 68 | xtick.top: False 69 | xtick.bottom: True 70 | xtick.direction: out 71 | 72 | ytick.labelsize: large 73 | ytick.color: .15 74 | ytick.left: True 75 | ytick.right: False 76 | ytick.direction: out 77 | 78 | ## *************************************************************************** 79 | ## * LEGEND * 80 | ## *************************************************************************** 81 | 82 | legend.framealpha: 0.5 83 | legend.frameon: False # do not draw on background patch 84 | legend.fancybox: False # do not round corners 85 | 86 | legend.numpoints: 1 87 | legend.scatterpoints: 1 88 | 89 | legend.fontsize: large 90 | 91 | -------------------------------------------------------------------------------- /src/arviz_plots/styles/arviz-variat.yml: -------------------------------------------------------------------------------- 1 | attrs: 2 | Plot: 3 | background_fill_color: white 4 | border_fill_color: white 5 | outline_line_width: 0 6 | outline_line_color: null 7 | Axis: 8 | major_tick_line_color: '#262626' 9 | minor_tick_line_alpha: 0 10 | Grid: 11 | grid_line_color: null 12 | Title: 13 | text_color: 'black' 14 | text_font_style: 'bold' 15 | align: 'center' 16 | Text: 17 | text_color: '#262626' 18 | text_font_style: 'bold' 19 | Cycler: 20 | colors : [ 21 | '#36acc6', '#f66d7f', '#fac364', '#7c2695', '#228306', 22 | '#a252f4', '#63f0ea', '#000000', '#6f6f6f', '#b7b7b7' 23 | ] 24 | -------------------------------------------------------------------------------- /src/arviz_plots/styles/arviz-vibrant.mplstyle: -------------------------------------------------------------------------------- 1 | ## *************************************************************************** 2 | ## * FIGURE * 3 | ## *************************************************************************** 4 | 5 | figure.facecolor: white # broken white outside box 6 | figure.edgecolor: None # broken white outside box 7 | figure.titleweight: bold # weight of the figure title 8 | figure.titlesize: x-large 9 | 10 | figure.figsize: 6, 5 11 | figure.dpi: 200.0 12 | figure.constrained_layout.use: True 13 | 14 | ## *************************************************************************** 15 | ## * FONT * 16 | ## *************************************************************************** 17 | 18 | font.size: 12 19 | font.style: normal 20 | font.variant: normal 21 | font.weight: normal 22 | font.stretch: normal 23 | 24 | text.color: .15 25 | 26 | ## *************************************************************************** 27 | ## * AXES * 28 | ## *************************************************************************** 29 | 30 | axes.facecolor: white 31 | axes.edgecolor: .33 # axes edge color 32 | axes.linewidth: 0.8 # edge line width 33 | 34 | axes.grid: False # do not show grid 35 | # axes.grid.axis: y # which axis the grid should apply to 36 | axes.grid.which: major # grid lines at {major, minor, both} ticks 37 | axes.axisbelow: True # keep grid layer in the back 38 | 39 | grid.color: .8 # grid color 40 | grid.linestyle: - # solid 41 | grid.linewidth: 0.8 # in points 42 | grid.alpha: 1.0 # transparency, between 0.0 and 1.0 43 | 44 | lines.solid_capstyle: round 45 | 46 | axes.spines.right: False # do not show right spine 47 | axes.spines.top: False # do not show top spine 48 | 49 | axes.titlesize: 16 50 | axes.titleweight: bold # font weight of title 51 | 52 | axes.labelsize: large 53 | axes.labelcolor: .15 54 | axes.labelweight: normal # weight of the x and y labels 55 | 56 | # color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/ 57 | # see preview and check for colorblindness here https://coolors.co/008b92-f15c58-48cdef-98d81a-997ee5-f5dc9d-c90a4e-145393-323232-616161 58 | axes.prop_cycle: cycler(color=["008b92", "f15c58", "48cdef", "98d81a", "997ee5", "f5dc9d", "c90a4e", "145393", "323232", "616161"]) 59 | image.cmap: viridis 60 | 61 | ## *************************************************************************** 62 | ## * TICKS * 63 | ## *************************************************************************** 64 | 65 | xtick.labelsize: large 66 | xtick.color: .15 67 | xtick.top: False 68 | xtick.bottom: True 69 | xtick.direction: out 70 | 71 | ytick.labelsize: large 72 | ytick.color: .15 73 | ytick.left: True 74 | ytick.right: False 75 | ytick.direction: out 76 | 77 | ## *************************************************************************** 78 | ## * LEGEND * 79 | ## *************************************************************************** 80 | 81 | legend.framealpha: 0.5 82 | legend.frameon: False # do not draw on background patch 83 | legend.fancybox: False # do not round corners 84 | 85 | legend.numpoints: 1 86 | legend.scatterpoints: 1 87 | 88 | legend.fontsize: large 89 | -------------------------------------------------------------------------------- /src/arviz_plots/styles/arviz-vibrant.yml: -------------------------------------------------------------------------------- 1 | attrs: 2 | Plot: 3 | background_fill_color: white 4 | border_fill_color: white 5 | outline_line_width: 0 6 | outline_line_color: null 7 | Axis: 8 | major_tick_line_color: '#262626' 9 | minor_tick_line_alpha: 0 10 | Grid: 11 | grid_line_color: null 12 | Title: 13 | text_color: 'black' 14 | text_font_style: 'bold' 15 | align: 'center' 16 | Text: 17 | text_color: '#262626' 18 | text_font_style: 'bold' 19 | Cycler: 20 | colors : [ 21 | '#008b92', '#f15c58', '#48cdef', '#98d81a', '#997ee5', 22 | '#f5dc9d', '#c90a4e', '#145393', '#323232', '#616161', 23 | ] -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ArviZ plots test module.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=redefined-outer-name 2 | """Configuration for test suite.""" 3 | import logging 4 | import os 5 | 6 | import pytest 7 | from arviz_base.testing import cmp as _cmp 8 | from arviz_base.testing import datatree as _datatree 9 | from arviz_base.testing import datatree2 as _datatree2 10 | from arviz_base.testing import datatree3 as _datatree3 11 | from arviz_base.testing import datatree_4d as _datatree_4d 12 | from arviz_base.testing import datatree_binary as _datatree_binary 13 | from arviz_base.testing import datatree_sample as _datatree_sample 14 | from hypothesis import settings 15 | 16 | _log = logging.getLogger("arviz_plots") 17 | 18 | settings.register_profile("fast", deadline=3000, max_examples=20) 19 | settings.register_profile("chron", deadline=3000, max_examples=500) 20 | settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "fast")) 21 | 22 | 23 | def pytest_addoption(parser): 24 | """Definition for command line option to save figures from tests or skip backends.""" 25 | parser.addoption("--save", nargs="?", const="test_images", help="Save images rendered by plot") 26 | parser.addoption("--skip-mpl", action="store_const", const=True, help="Skip matplotlib tests") 27 | parser.addoption("--skip-bokeh", action="store_const", const=True, help="Skip bokeh tests") 28 | parser.addoption("--skip-plotly", action="store_const", const=True, help="Skip plotly tests") 29 | 30 | 31 | @pytest.fixture(scope="session") 32 | def save_figs(request): 33 | """Enable command line switch for saving generation figures upon testing.""" 34 | fig_dir = request.config.getoption("--save") 35 | 36 | if fig_dir is not None: 37 | # Try creating directory if it doesn't exist 38 | _log.info("Saving generated images in %s", fig_dir) 39 | 40 | os.makedirs(fig_dir, exist_ok=True) 41 | _log.info("Directory %s created", fig_dir) 42 | 43 | # Clear all files from the directory 44 | # Does not alter or delete directories 45 | for file in os.listdir(fig_dir): 46 | full_path = os.path.join(fig_dir, file) 47 | 48 | try: 49 | os.remove(full_path) 50 | 51 | except OSError: 52 | _log.info("Failed to remove %s", full_path) 53 | 54 | return fig_dir 55 | 56 | 57 | @pytest.fixture(scope="function") 58 | def clean_plots(request, save_figs): 59 | """Close plots after each test, saving too if --save is specified during test invocation.""" 60 | 61 | def fin(): 62 | if ("backend" in request.fixturenames) and any( 63 | "matplotlib" in key for key in request.keywords.keys() 64 | ): 65 | import matplotlib.pyplot as plt 66 | 67 | if save_figs is not None: 68 | plt.savefig(f"{os.path.join(save_figs, request.node.name)}.png") 69 | plt.close("all") 70 | 71 | request.addfinalizer(fin) 72 | 73 | 74 | @pytest.fixture(scope="function") 75 | def check_skips(request): 76 | """Skip bokeh or matplotlib tests if requested via command line.""" 77 | skip_mpl = request.config.getoption("--skip-mpl") 78 | skip_bokeh = request.config.getoption("--skip-bokeh") 79 | skip_plotly = request.config.getoption("--skip-plotly") 80 | 81 | if "backend" in request.fixturenames: 82 | if skip_mpl and any("matplotlib" in key for key in request.keywords.keys()): 83 | pytest.skip(reason="Requested skipping matplolib tests via command line argument") 84 | if skip_bokeh and any("bokeh" in key for key in request.keywords.keys()): 85 | pytest.skip(reason="Requested skipping bokeh tests via command line argument") 86 | if skip_plotly and any("plotly" in key for key in request.keywords.keys()): 87 | pytest.skip(reason="Requested skipping plotly tests via command line argument") 88 | 89 | 90 | @pytest.fixture(scope="function") 91 | def no_artist_kwargs(monkeypatch): 92 | """Raise an error if visual kwargs are present when using 'none' backend.""" 93 | monkeypatch.setattr("arviz_plots.backend.none.ALLOW_KWARGS", False) 94 | 95 | 96 | @pytest.fixture(scope="session") 97 | def datatree(): 98 | """Fixture for a general DataTree.""" 99 | return _datatree() 100 | 101 | 102 | @pytest.fixture(scope="session") 103 | def datatree2(): 104 | """Fixture for a DataTree with a posterior and sample stats.""" 105 | return _datatree2() 106 | 107 | 108 | @pytest.fixture(scope="session") 109 | def datatree3(): 110 | """Fixture for a DataTree with discrete data.""" 111 | return _datatree3() 112 | 113 | 114 | @pytest.fixture(scope="session") 115 | def datatree_4d(): 116 | """Fixture for a DataTree with a 4D posterior.""" 117 | return _datatree_4d() 118 | 119 | 120 | @pytest.fixture(scope="session") 121 | def datatree_binary(): 122 | """Fixture for a DataTree with binary data.""" 123 | return _datatree_binary() 124 | 125 | 126 | @pytest.fixture(scope="session") 127 | def datatree_sample(): 128 | """Fixture for a DataTree with sample stats.""" 129 | return _datatree_sample() 130 | 131 | 132 | @pytest.fixture(scope="session") 133 | def cmp(): 134 | """Fixture for the cmp function.""" 135 | return _cmp() 136 | -------------------------------------------------------------------------------- /tests/test_fixtures.py: -------------------------------------------------------------------------------- 1 | """Test fixtures.""" 2 | from importlib import import_module 3 | 4 | import pytest 5 | 6 | 7 | @pytest.mark.usefixtures("no_artist_kwargs") 8 | def test_no_artist_kwargs_fixture(): 9 | none_backend = import_module("arviz_plots.backend.none") 10 | with pytest.raises(ValueError): 11 | none_backend.line([1, 2], [0, 1], [], extra_kwarg="yes") 12 | -------------------------------------------------------------------------------- /tests/test_plot_matrix.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use, redefined-outer-name 2 | """Test PlotMatrix.""" 3 | import numpy as np 4 | import pytest 5 | import xarray as xr 6 | from arviz_base import dict_to_dataset 7 | 8 | from arviz_plots import PlotMatrix 9 | from arviz_plots.plot_matrix import subset_matrix_da 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def dataset(seed=31): 14 | rng = np.random.default_rng(seed) 15 | mu = rng.normal(size=(2, 7)) 16 | theta = rng.normal(size=(2, 7, 2)) 17 | eta = rng.normal(size=(2, 7, 3, 2)) 18 | 19 | return dict_to_dataset( 20 | {"mu": mu, "theta": theta, "eta": eta}, 21 | dims={"theta": ["hierarchy"], "eta": ["group", "hierarchy"]}, 22 | ) 23 | 24 | 25 | @pytest.fixture(scope="module") 26 | def matrix_da(seed=31): 27 | rng = np.random.default_rng(seed) 28 | var_coord = ["mu", "theta", "theta", "theta", "tau"] 29 | hierarchy_coord = [None, "a", "b", "c", None] 30 | return xr.DataArray( 31 | rng.normal(size=(5, 5)), 32 | dims=["row_index", "col_index"], 33 | coords={ 34 | "var_name_x": (("col_index",), var_coord), 35 | "var_name_y": (("row_index",), var_coord), 36 | "hierarchy_x": (("col_index",), hierarchy_coord), 37 | "hierarchy_y": (("row_index",), hierarchy_coord), 38 | }, 39 | ) 40 | 41 | 42 | @pytest.mark.parametrize("subset", (["mu", {}], ["theta", {"hierarchy": "b"}])) 43 | def test_subset_matrix_da_diag(matrix_da, subset): 44 | da_subset = subset_matrix_da(matrix_da, subset[0], subset[1]) 45 | assert not isinstance(da_subset, xr.DataArray) 46 | 47 | 48 | @pytest.mark.parametrize("subset_x", (["mu", {}], ["theta", {"hierarchy": "b"}])) 49 | @pytest.mark.parametrize("subset_y", (["tau", {}], ["theta", {"hierarchy": "c"}])) 50 | def test_subset_matrix_da_offdiag(matrix_da, subset_x, subset_y): 51 | da_subset = subset_matrix_da( 52 | matrix_da, subset_x[0], subset_x[1], var_name_y=subset_y[0], selection_y=subset_y[1] 53 | ) 54 | assert not isinstance(da_subset, xr.DataArray) 55 | 56 | 57 | def test_plot_matrix_init(dataset): 58 | pc = PlotMatrix(dataset, ["__variable__", "hierarchy", "group"], backend="none") 59 | assert "plot" in pc.viz.data_vars 60 | coord_names = ("var_name_x", "var_name_y", "hierarchy_x", "hierarchy_y", "group_x", "group_y") 61 | missing_coord_names = [name for name in coord_names if name not in pc.viz["plot"].coords] 62 | assert not missing_coord_names, list(pc.viz["plot"].coords) 63 | assert pc.viz["plot"].sizes == {"row_index": 9, "col_index": 9} 64 | 65 | 66 | def test_plot_matrix_aes(dataset): 67 | pc = PlotMatrix( 68 | dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]} 69 | ) 70 | assert "/color" in pc.aes.groups 71 | assert "mapping" in pc.aes["color"].data_vars 72 | assert "neutral_element" not in pc.aes["color"].data_vars 73 | 74 | 75 | # pylint: disable=unused-argument 76 | def map_auxiliar(da, target, target_list, kwarg_list, **kwargs): 77 | target_list.append(target) 78 | kwarg_list.append(kwargs) 79 | return 1 80 | 81 | 82 | # pylint: disable=unused-argument 83 | def map_auxiliar_couple(da_x, da_y, target, target_list, kwarg_list, **kwargs): 84 | target_list.append(target) 85 | kwarg_list.append(kwargs) 86 | return 1 87 | 88 | 89 | def test_plot_matrix_map(dataset): 90 | pc = PlotMatrix( 91 | dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]} 92 | ) 93 | target_list = [] 94 | kwarg_list = [] 95 | pc.map( 96 | map_auxiliar, 97 | "aux", 98 | target_list=target_list, 99 | kwarg_list=kwarg_list, 100 | ) 101 | assert all(len(aux_list) == 9 * 2 for aux_list in (target_list, kwarg_list)) 102 | assert pc.viz["aux"].dims == ("row_index", "col_index", "chain") 103 | for i in range(9): 104 | for j in range(9): 105 | if i == j: 106 | assert all( 107 | elem is not None for elem in pc.viz["aux"].sel(row_index=i, col_index=j).values 108 | ) 109 | else: 110 | assert all( 111 | elem is None for elem in pc.viz["aux"].sel(row_index=i, col_index=j).values 112 | ) 113 | 114 | 115 | @pytest.mark.parametrize("triangle", ("both", "lower", "upper")) 116 | def test_plot_matrix_map_triangle(dataset, triangle): 117 | pc = PlotMatrix( 118 | dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]} 119 | ) 120 | target_list = [] 121 | kwarg_list = [] 122 | pc.map_triangle( 123 | map_auxiliar_couple, 124 | "aux", 125 | target_list=target_list, 126 | kwarg_list=kwarg_list, 127 | triangle=triangle, 128 | ) 129 | aux_len = sum(range(9)) * 2 130 | if triangle == "both": 131 | aux_len *= 2 132 | assert all(len(aux_list) == aux_len for aux_list in (target_list, kwarg_list)) 133 | assert pc.viz["aux"].dims == ("row_index", "col_index", "chain") 134 | for i in range(9): 135 | for j in range(9): 136 | is_none = (elem is None for elem in pc.viz["aux"].sel(row_index=i, col_index=j).values) 137 | if i == j: 138 | assert all(is_none) 139 | elif i > j: 140 | if triangle in ("both", "lower"): 141 | assert not any(is_none) 142 | else: 143 | assert all(is_none) 144 | else: 145 | if triangle in ("both", "upper"): 146 | assert not any(is_none) 147 | else: 148 | assert all(is_none) 149 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | check 4 | docs 5 | {py311,py312,py313}{,-coverage} 6 | # See https://tox.readthedocs.io/en/latest/example/package.html#flit 7 | isolated_build = True 8 | isolated_build_env = build 9 | 10 | [gh-actions] 11 | python = 12 | 3.11: py311-coverage 13 | 3.12: check, py312-coverage 14 | 3.13: py313-coverage 15 | 16 | [testenv] 17 | basepython = 18 | py311: python3.11 19 | py312: python3.12 20 | py313: python3.13 21 | # See https://github.com/tox-dev/tox/issues/1548 22 | {check,docs,cleandocs,viewdocs,build}: python3 23 | setenv = 24 | PYTHONUNBUFFERED = yes 25 | PYTEST_EXTRA_ARGS = -s 26 | coverage: PYTEST_EXTRA_ARGS = --cov --cov-report xml --cov-report term 27 | passenv = 28 | * 29 | extras = 30 | test 31 | matplotlib 32 | bokeh 33 | plotly 34 | commands = 35 | pytest {env:PYTEST_MARKERS:} {env:PYTEST_EXTRA_ARGS:} {posargs:-vv} 36 | 37 | [testenv:check] 38 | description = perform style checks 39 | deps = 40 | build 41 | pre-commit 42 | pylint 43 | skip_install = true 44 | commands = 45 | pre-commit install 46 | pre-commit run --all-files --show-diff-on-failure 47 | python -m build 48 | 49 | [testenv:docs] 50 | description = build HTML docs 51 | setenv = 52 | READTHEDOCS_PROJECT = arviz_plots 53 | READTHEDOCS_VERSION = latest 54 | extras = 55 | doc 56 | matplotlib 57 | bokeh 58 | plotly 59 | commands = 60 | sphinx-build -d "{toxworkdir}/docs_doctree" docs/source "{toxworkdir}/docs_out" --color -v -bhtml 61 | 62 | [testenv:nogallerydocs] 63 | description = build HTML docs 64 | setenv = 65 | READTHEDOCS_PROJECT = arviz_plots 66 | READTHEDOCS_VERSION = latest 67 | ARVIZDOCS_NOGALLERY = true 68 | extras = 69 | doc 70 | matplotlib 71 | bokeh 72 | plotly 73 | commands = 74 | sphinx-build -d "{toxworkdir}/docs_doctree" docs/source "{toxworkdir}/docs_out" --color -v -bhtml 75 | 76 | [testenv:cleandocs] 77 | description = clean HTML outputs docs 78 | skip_install = true 79 | allowlist_externals = 80 | rm 81 | find 82 | commands = 83 | find docs/source/gallery -maxdepth 1 -type f -name '*.md' -delete 84 | find docs/source/api/backend -maxdepth 1 -type f -name '*.rst' ! -name '*.part.rst' ! -name 'index.rst' ! -name '*.template.rst' -delete 85 | rm -r "{toxworkdir}/docs_out" "{toxworkdir}/docs_doctree" "{toxworkdir}/jupyter_execute" "{toxworkdir}/plot_directive" 86 | rm -r docs/source/api/generated docs/source/api/backend/generated docs/source/gallery/_images docs/source/gallery/_scripts 87 | rm docs/source/gallery/backreferences.json 88 | 89 | [testenv:viewdocs] 90 | description = open HTML docs 91 | skip_install = true 92 | commands = 93 | python -m webbrowser "{toxworkdir}/docs_out/index.html" 94 | --------------------------------------------------------------------------------