├── .editorconfig ├── .gitignore ├── .idea ├── .gitignore ├── ipme.iml ├── modules.xml └── vcs.xml ├── .travis.yml ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── examples ├── coal_mining_disasters │ ├── ipme.py │ └── model.py ├── data │ └── evaluation_sleepstudy.csv ├── drivers_reaction_times │ ├── ipme.py │ ├── model_hierarchical.py │ └── model_pooled.py ├── eight_schools_problem │ ├── ipme.py │ ├── model_centered.py │ └── model_non_centered.py ├── golf_putting │ ├── ipme.py │ ├── model_geometry.py │ └── model_simple.py ├── radon_basement │ ├── ipme.py │ └── model.py ├── stochastic_volatility │ ├── ipme.py │ └── model.py └── user_study │ ├── min_temperature │ ├── ipme.py │ └── model.py │ ├── random_number_generator │ ├── ipme.py │ └── model.py │ └── reaction_times │ ├── ipme.py │ └── model.py ├── ipme ├── __init__.py ├── classes │ ├── __init__.py │ ├── cell │ │ ├── __init__.py │ │ ├── interactive_continuous_cell.py │ │ ├── interactive_discrete_cell.py │ │ ├── interactive_pred_ckeck_cell.py │ │ ├── interactive_scatter_cell.py │ │ ├── static_continuous_cell.py │ │ ├── static_discrete_cell.py │ │ ├── static_pred_check_cell.py │ │ ├── static_scatter_cell.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── cell_clear_selection.py │ │ │ ├── cell_continuous_handler.py │ │ │ ├── cell_discrete_handler.py │ │ │ ├── cell_pred_check_handler.py │ │ │ ├── cell_scatter_handler.py │ │ │ ├── cell_widgets.py │ │ │ └── global_reset.py │ ├── data │ │ ├── __init__.py │ │ ├── data.py │ │ └── dimension.py │ ├── graph.py │ ├── grid │ │ ├── __init__.py │ │ ├── graph_grid.py │ │ ├── predictive_ckecks_grid.py │ │ └── scatter_matrix_grid.py │ ├── interaction_control │ │ ├── __init__.py │ │ └── interaction_control.py │ └── scatter_matrix.py ├── cli.py ├── interfaces │ ├── __init__.py │ ├── cell.py │ ├── data_interface.py │ ├── grid.py │ ├── predictive_check_cell.py │ ├── scatter_cell.py │ └── variable_cell.py ├── methods.py └── utils │ ├── __init__.py │ ├── constants.py │ ├── functions.py │ ├── js_code.py │ └── stats.py ├── requirements_dev.txt ├── setup.cfg ├── setup.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.idea/ipme.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Config file for automatic testing at travis-ci.com 2 | 3 | language: python 4 | python: 5 | - 3.8 6 | - 3.7 7 | - 3.6 8 | - 3.5 9 | 10 | # Command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors 11 | install: pip install -U tox-travis 12 | 13 | # Command to run tests, e.g. python setup.py test 14 | script: tox 15 | 16 | 17 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Evdoxia Taka 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/evdoxiataka/ipme/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | ipme could always use more documentation, whether as part of the 42 | official ipme docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/evdoxiataka/ipme/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `ipme` for local development. 61 | 62 | 1. Fork the `ipme` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/ipme.git 66 | 67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 68 | 69 | $ mkvirtualenv ipme 70 | $ cd ipme/ 71 | $ python setup.py develop 72 | 73 | 4. Create a branch for local development:: 74 | 75 | $ git checkout -b name-of-your-bugfix-or-feature 76 | 77 | Now you can make your changes locally. 78 | 79 | 5. When you're done making changes, check that your changes pass flake8 and the 80 | tests, including testing other Python versions with tox:: 81 | 82 | $ flake8 ipme tests 83 | $ python setup.py test or pytest 84 | $ tox 85 | 86 | To get flake8 and tox, just pip install them into your virtualenv. 87 | 88 | 6. Commit your changes and push your branch to GitHub:: 89 | 90 | $ git add . 91 | $ git commit -m "Your detailed description of your changes." 92 | $ git push origin name-of-your-bugfix-or-feature 93 | 94 | 7. Submit a pull request through the GitHub website. 95 | 96 | Pull Request Guidelines 97 | ----------------------- 98 | 99 | Before you submit a pull request, check that it meets these guidelines: 100 | 101 | 1. The pull request should include tests. 102 | 2. If the pull request adds functionality, the docs should be updated. Put 103 | your new functionality into a function with a docstring, and add the 104 | feature to the list in README.rst. 105 | 3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check 106 | https://travis-ci.com/evdoxiataka/ipme/pull_requests 107 | and make sure that the tests pass for all supported Python versions. 108 | 109 | Tips 110 | ---- 111 | 112 | To run a subset of tests:: 113 | 114 | 115 | $ python -m unittest tests.test_imd 116 | 117 | Deploying 118 | --------- 119 | 120 | A reminder for the maintainers on how to deploy. 121 | Make sure all your changes are committed (including an entry in HISTORY.rst). 122 | Then run:: 123 | 124 | $ bump2version patch # possible: major / minor / patch 125 | $ git push 126 | $ git push --tags 127 | 128 | Travis will then deploy to PyPI if tests pass. 129 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2020-05-31) 6 | ------------------ 7 | 8 | * First release on PyPI. 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020, Evdoxia Taka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | 7 | recursive-include tests * 8 | recursive-exclude * __pycache__ 9 | recursive-exclude * *.py[co] 10 | 11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -fr .tox/ 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | rm -fr .pytest_cache 49 | 50 | lint: ## check style with flake8 51 | flake8 ipme tests 52 | 53 | test: ## run tests quickly with the default Python 54 | python setup.py test 55 | 56 | test-all: ## run tests on every Python version with tox 57 | tox 58 | 59 | coverage: ## check code coverage quickly with the default Python 60 | coverage run --source ipme setup.py test 61 | coverage report -m 62 | coverage html 63 | $(BROWSER) htmlcov/index.html 64 | 65 | docs: ## generate Sphinx HTML documentation, including API docs 66 | rm -f docs/ipme.rst 67 | rm -f docs/modules.rst 68 | sphinx-apidoc -o docs/ ipme 69 | $(MAKE) -C docs clean 70 | $(MAKE) -C docs html 71 | $(BROWSER) docs/_build/html/index.html 72 | 73 | servedocs: docs ## compile the docs watching for changes 74 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 75 | 76 | release: dist ## package and upload a release 77 | twine upload dist/* 78 | 79 | dist: clean ## builds source and wheel package 80 | python setup.py sdist 81 | python setup.py bdist_wheel 82 | ls -l dist 83 | 84 | install: clean ## install the package to the active Python's site-packages 85 | python setup.py install 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interactive Visualizations of Probabilistic Models 2 | This package provides interactive visualizations of probabilistc models' inherent uncertainty in the prior and posterior sample space. 3 | 4 | These visualizations are intented for interactive exploration of the *uncertainty* in Bayesian probabilistic models and enhancement of their interpretability. 5 | 6 | ## Requirements 7 | * Probabilistic models are expressed in a Probabilistic Programming Language (PPL), and 8 | * A sample-based inference algorithm is used for the inference (e.g. MCMC). 9 | 10 | ## Input 11 | The visualizations provided by this package take as an input a zip file in the *.npz* format. This file contains a description of the model's structure and the inference results (MCMC samples) 12 | in a standardized form of a collection of npy arrays and metadata (in a json format). 13 | The [arviz_json](https://github.com/johnhw/arviz_json) package creates this standardized output of probabilistic models expressed in *PyMC3*. For more details about the standardization of the IPME input see *Taka et al. 2020*. 14 | 15 | The following figure presents the pipeline for the automatic transformation of a probabilistic model expressed in a PPL and its sample-based inference results into the standardized .npz file. 16 | 17 | ![method](https://user-images.githubusercontent.com/37831445/97790524-20ed6900-1bc1-11eb-950c-838ea67b4163.jpg) 18 | 19 | ## Example of a Probabilistic Model 20 | The following probabilistic statements describe the drivers' reaction times model under sleep-deprivation conditions for 10 consecutive days and 18 lorry drivers. This is a hierarchical linear regression probabilistic model. The problem and data for this model retrieved from *Belenky et al. 2003*. 21 | 22 | ![image](https://user-images.githubusercontent.com/37831445/120328046-56d30700-c2e2-11eb-8e4f-05d891e4b2d4.png) 23 | 24 | ## Interactive Probabilistic Models Explorer (IPME) 25 | The *IPME* representation of the model is presented in the following figure. See more details about this representation in *Taka et al. 2020*. A demo of this visualization can be found in my [talk](https://www.youtube.com/watch?v=2hadiSJRAJI&feature=youtu.be) to PyMCON 2020. 26 | ```python 27 | import ipme 28 | """ 29 | mode: String in {'i','s'} for interactive or static 30 | vars: 'all' or List of variable names e.g. ['a','b'] 31 | spaces: String in {'all','prior','posterior'} or List of spaces e.g. ['prior','posterior'] 32 | predictive_checks: List of observed variables names 33 | """ 34 | ipme.graph("reaction_times_hierarchical.npz", mode = "i", vars = 'all', spaces = 'all', predictive_checks = ['y_pred']) 35 | ``` 36 | 37 | 38 | 39 | 40 | https://user-images.githubusercontent.com/37831445/205636036-ec1a6820-f368-4332-9fcf-0a508c8dd63f.mp4 41 | 42 | 43 | ## Interactive Pair Plot (IPP) 44 | The *IPP* representation of the model is presented in the following figure. This visualization was introduced and evaluated in *Taka et al. 2022*. 45 | ```python 46 | import ipme 47 | """ 48 | mode: String in {'i','s'} for interactive or static 49 | vars: List of variable names e.g. ['a','b'] 50 | spaces: String in {'all','prior','posterior'} or List of spaces e.g. ['prior','posterior'] 51 | """ 52 | ipme.scatter_matrix('reaction_times_hierarchical.npz', mode = "i", vars = ['sigma_a','sigma_b','sigma_sigma','mu_a','mu_b','sigma','a','b','y_pred'], spaces = 'all') 53 | ``` 54 | 55 | 56 | 57 | https://user-images.githubusercontent.com/37831445/205637701-ac11ff87-c240-4882-8b69-8b0b15d8e344.mp4 58 | 59 | 60 | # Examples 61 | The folder `/examples` in this repository includes some examples of use. The examples illustrates the definition of Bayesian probabilistic models and running of sample-based inference in PyMC3. The examples are organized per problem. Each problem's directory includes the following Python scripts: 62 | * *`model.py`*: includes the definition of the model in PyMC3, and exports the inference data into a *.npz* file. 63 | * *`ipme.py`*: demonstrates the use of the ipme package for the visualization of the model. 64 | 65 | The folder `/examples/user_study` contains the models used in the user study presented in *Taka et al. 2022*. 66 | 67 | **Note:** To run these scripts, you need to install the following Python libraries: PyMC3, ArviZ, and the arviz_json and ipme packages (the last two can only be installed through github). 68 | 69 | # Please Cite: 70 | *E. Taka, S. Stein, and J. H. Williamson.* Increasing interpretability of Bayesian probabilistic programming models through interactive representations. Frontiers in Computer Science, 2:52, 2020. doi: 10.3389/fcomp.2020.567344. URL: https://www.frontiersin.org/article/10.3389/fcomp.2020.567344 71 | 72 | *E. Taka, S. Stein and J. H. Williamson,* "Does Interactive Conditioning Help Users Better Understand the Structure of Probabilistic Models?," in IEEE Transactions on Visualization and Computer Graphics, doi: 10.1109/TVCG.2022.3231967 73 | 74 | # References 75 | *E. Taka, S. Stein, and J. H. Williamson.* Increasing interpretability of Bayesian probabilistic programming models through interactive representations. Frontiers in Computer Science, 2:52, 2020. doi: 10.3389/fcomp.2020.567344. URL: https://www.frontiersin.org/article/10.3389/fcomp.2020.567344 76 | 77 | *E. Taka, S. Stein and J. H. Williamson,* "Does Interactive Conditioning Help Users Better Understand the Structure of Probabilistic Models?," in IEEE Transactions on Visualization and Computer Graphics, doi: 10.1109/TVCG.2022.3231967 78 | 79 | *G. Belenky, N. J. Wesensten, D. R. Thorne, M. L. Thomas, H. C. Sing, D. P. Redmond, M. B. Russo, and T. J.Balkin.* Patterns of performance degradation and restoration during sleep restriction and subsequent recovery: a sleep dose-response study. Journal of Sleep Research, vol. 12, no. 1, pp. 1–12, 2003. URL: https://onlinelibrary.wiley.com/doi/abs/10.1046/j.1365-2869.2003.00337.x 80 | 81 | 82 | *The “Closed-Loop Data Science for Complex, Computationally- and Data-Intensive Analytics” project*. URL: https://www.gla.ac.uk/schools/computing/research/researchsections/ida-section/closedloop/ 83 | 84 | *arviz_json* (Automatic Transformation of PyMC3 models into standardized output): https://github.com/johnhw/arviz_json 85 | -------------------------------------------------------------------------------- /examples/coal_mining_disasters/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='coal_mining_disasters_PyMC3.npz' 5 | ipme.graph(infer_datapath, predictive_checks = ['disasters']) -------------------------------------------------------------------------------- /examples/coal_mining_disasters/model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pymc3 as pm 3 | import numpy as np 4 | import arviz as az 5 | from arviz_json import get_dag, arviz_to_json 6 | 7 | ## Discrete Variables Model 8 | ## Reference: https://docs.pymc.io/notebooks/getting_started.html#Case-study-2:-Coal-mining-disasters 9 | #data 10 | disaster_data = pd.Series([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, 11 | 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, 12 | 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0, 13 | 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, 14 | 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, 15 | 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, 16 | 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]) 17 | years = np.arange(1851, 1962) 18 | years_missing=[1890,1935] 19 | 20 | #model-inference 21 | fileName='coal_mining_disasters_PyMC3' 22 | samples=10000 23 | tune=10000 24 | chains=2 25 | coords = {"year": years} 26 | with pm.Model(coords=coords) as disaster_model: 27 | switchpoint = pm.DiscreteUniform('switchpoint', lower=years.min(), upper=years.max(), testval=1900) 28 | early_rate = pm.Exponential('early_rate', 1) 29 | late_rate = pm.Exponential('late_rate', 1) 30 | rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) 31 | disasters = pm.Poisson('disasters', rate, observed=disaster_data, dims='year') 32 | #inference 33 | trace = pm.sample(samples, chains=chains, tune=tune) 34 | prior = pm.sample_prior_predictive(samples=samples) 35 | posterior_predictive = pm.sample_posterior_predictive(trace,samples=samples) 36 | 37 | ## STEP 1 38 | # will also capture all the sampler statistics 39 | data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive) 40 | 41 | ## STEP 2 42 | # extract dag 43 | dag = get_dag(disaster_model) 44 | # insert dag into sampler stat attributes 45 | data.sample_stats.attrs["graph"] = str(dag) 46 | 47 | ## STEP 3 48 | # save data 49 | arviz_to_json(data, fileName+'.npz') 50 | -------------------------------------------------------------------------------- /examples/data/evaluation_sleepstudy.csv: -------------------------------------------------------------------------------- 1 | "","Reaction","Days","Subject" 2 | "1",249.56,0,"308" 3 | "2",258.7047,1,"308" 4 | "3",250.8006,2,"308" 5 | "4",321.4398,3,"308" 6 | "5",356.8519,4,"308" 7 | "6",414.6901,5,"308" 8 | "7",382.2038,6,"308" 9 | "8",290.1486,7,"308" 10 | "9",430.5853,8,"308" 11 | "10",466.3535,9,"308" 12 | "11",222.7339,0,"309" 13 | "12",205.2658,1,"309" 14 | "13",202.9778,2,"309" 15 | "14",204.707,3,"309" 16 | "15",207.7161,4,"309" 17 | "16",215.9618,5,"309" 18 | "17",213.6303,6,"309" 19 | "18",217.7272,7,"309" 20 | "19",224.2957,8,"309" 21 | "20",237.3142,9,"309" 22 | "21",199.0539,0,"310" 23 | "22",194.3322,1,"310" 24 | "23",234.32,2,"310" 25 | "24",232.8416,3,"310" 26 | "25",229.3074,4,"310" 27 | "26",220.4579,5,"310" 28 | "27",235.4208,6,"310" 29 | "28",255.7511,7,"310" 30 | "29",261.0125,8,"310" 31 | "30",247.5153,9,"310" 32 | "31",321.5426,0,"330" 33 | "32",300.4002,1,"330" 34 | "33",283.8565,2,"330" 35 | "34",285.133,3,"330" 36 | "35",285.7973,4,"330" 37 | "36",297.5855,5,"330" 38 | "37",280.2396,6,"330" 39 | "38",318.2613,7,"330" 40 | "39",305.3495,8,"330" 41 | "40",354.0487,9,"330" 42 | "41",287.6079,0,"331" 43 | "42",285,1,"331" 44 | "43",301.8206,2,"331" 45 | "44",320.1153,3,"331" 46 | "45",316.2773,4,"331" 47 | "46",293.3187,5,"331" 48 | "47",290.075,6,"331" 49 | "48",334.8177,7,"331" 50 | "49",293.7469,8,"331" 51 | "50",371.5811,9,"331" 52 | "51",234.8606,0,"332" 53 | "52",242.8118,1,"332" 54 | "53",272.9613,2,"332" 55 | "54",309.7688,3,"332" 56 | "55",317.4629,4,"332" 57 | "56",309.9976,5,"332" 58 | "57",454.1619,6,"332" 59 | "58",346.8311,7,"332" 60 | "59",330.3003,8,"332" 61 | "60",253.8644,9,"332" 62 | "61",283.8424,0,"333" 63 | "62",289.555,1,"333" 64 | "63",276.7693,2,"333" 65 | "64",299.8097,3,"333" 66 | "65",297.171,4,"333" 67 | "66",338.1665,5,"333" 68 | "67",332.0265,6,"333" 69 | "68",348.8399,7,"333" 70 | "69",333.36,8,"333" 71 | "70",362.0428,9,"333" 72 | "71",265.4731,0,"334" 73 | "72",276.2012,1,"334" 74 | "73",243.3647,2,"334" 75 | "74",254.6723,3,"334" 76 | "75",279.0244,4,"334" 77 | "76",284.1912,5,"334" 78 | "77",305.5248,6,"334" 79 | "78",331.5229,7,"334" 80 | "79",335.7469,8,"334" 81 | "80",377.299,9,"334" 82 | "81",241.6083,0,"335" 83 | "82",273.9472,1,"335" 84 | "83",254.4907,2,"335" 85 | "84",270.8021,3,"335" 86 | "85",251.4519,4,"335" 87 | "86",254.6362,5,"335" 88 | "87",245.4523,6,"335" 89 | "88",235.311,7,"335" 90 | "89",235.7541,8,"335" 91 | "90",237.2466,9,"335" 92 | "91",312.3666,0,"337" 93 | "92",313.8058,1,"337" 94 | "93",291.6112,2,"337" 95 | "94",346.1222,3,"337" 96 | "95",365.7324,4,"337" 97 | "96",391.8385,5,"337" 98 | "97",404.2601,6,"337" 99 | "98",416.6923,7,"337" 100 | "99",455.8643,8,"337" 101 | "100",458.9167,9,"337" 102 | "101",236.1032,0,"349" 103 | "102",230.3167,1,"349" 104 | "103",238.9256,2,"349" 105 | "104",254.922,3,"349" 106 | "105",250.7103,4,"349" 107 | "106",269.7744,5,"349" 108 | "107",281.5648,6,"349" 109 | "108",308.102,7,"349" 110 | "109",336.2806,8,"349" 111 | "110",351.6451,9,"349" 112 | "111",256.2968,0,"350" 113 | "112",243.4543,1,"350" 114 | "113",256.2046,2,"350" 115 | "114",255.5271,3,"350" 116 | "115",268.9165,4,"350" 117 | "116",329.7247,5,"350" 118 | "117",379.4445,6,"350" 119 | "118",362.9184,7,"350" 120 | "119",394.4872,8,"350" 121 | "120",389.0527,9,"350" 122 | "121",250.5265,0,"351" 123 | "122",300.0576,1,"351" 124 | "123",269.8939,2,"351" 125 | "124",280.5891,3,"351" 126 | "125",271.8274,4,"351" 127 | "126",304.6336,5,"351" 128 | "127",287.7466,6,"351" 129 | "128",266.5955,7,"351" 130 | "129",321.5418,8,"351" 131 | "130",347.5655,9,"351" 132 | "131",221.6771,0,"352" 133 | "132",298.1939,1,"352" 134 | "133",326.8785,2,"352" 135 | "134",346.8555,3,"352" 136 | "135",348.7402,4,"352" 137 | "136",352.8287,5,"352" 138 | "137",354.4266,6,"352" 139 | "138",360.4326,7,"352" 140 | "139",375.6406,8,"352" 141 | "140",388.5417,9,"352" 142 | "141",271.9235,0,"369" 143 | "142",268.4369,1,"369" 144 | "143",257.2424,2,"369" 145 | "144",277.6566,3,"369" 146 | "145",314.8222,4,"369" 147 | "146",317.2135,5,"369" 148 | "147",298.1353,6,"369" 149 | "148",348.1229,7,"369" 150 | "149",340.28,8,"369" 151 | "150",366.5131,9,"369" 152 | "151",225.264,0,"370" 153 | "152",234.5235,1,"370" 154 | "153",238.9008,2,"370" 155 | "154",240.473,3,"370" 156 | "155",267.5373,4,"370" 157 | "156",344.1937,5,"370" 158 | "157",281.1481,6,"370" 159 | "158",347.5855,7,"370" 160 | "159",365.163,8,"370" 161 | "160",372.2288,9,"370" 162 | "161",269.8804,0,"371" 163 | "162",272.4428,1,"371" 164 | "163",277.8989,2,"371" 165 | "164",281.7895,3,"371" 166 | "165",279.1705,4,"371" 167 | "166",284.512,5,"371" 168 | "167",259.2658,6,"371" 169 | "168",304.6306,7,"371" 170 | "169",350.7807,8,"371" 171 | "170",369.4692,9,"371" 172 | "171",269.4117,0,"372" 173 | "172",273.474,1,"372" 174 | "173",297.5968,2,"372" 175 | "174",310.6316,3,"372" 176 | "175",287.1726,4,"372" 177 | "176",329.6076,5,"372" 178 | "177",334.4818,6,"372" 179 | "178",343.2199,7,"372" 180 | "179",369.1417,8,"372" 181 | "180",364.1236,9,"372" 182 | -------------------------------------------------------------------------------- /examples/drivers_reaction_times/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='reaction_times_pooled.npz' 5 | # infer_datapath='reaction_times_hierarchical.npz' 6 | ipme.graph(infer_datapath, predictive_checks = ['y_pred']) -------------------------------------------------------------------------------- /examples/drivers_reaction_times/model_hierarchical.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pymc3 as pm 3 | import arviz as az 4 | from arviz_json import get_dag, arviz_to_json 5 | 6 | ## data 7 | DATAPATH='../data/evaluation_sleepstudy.csv' ## data from Belenky et al. (2003) 8 | reactions = pd.read_csv(DATAPATH, usecols=['Reaction','Days','Subject']) 9 | 10 | ## Drivers Reaction Times Hierarchical Model 11 | samples = 2000 12 | chains = 2 13 | tune = 1000 14 | 15 | ## data 16 | driver_idx, drivers = pd.factorize(reactions["Subject"], sort=True) 17 | day_idx, days = pd.factorize(reactions["Days"], sort=True) 18 | 19 | ##dims 20 | coords_h = {"driver": drivers,"driver_idx_day":reactions.Subject} 21 | 22 | with pm.Model(coords=coords_h) as hierarchical_model: 23 | ## model 24 | #hyper-priors 25 | mu_a = pm.Normal('mu_a' ,mu=100, sd=250) 26 | sigma_a = pm.HalfNormal('sigma_a', sd=250) 27 | mu_b = pm.Normal('mu_b',mu=10, sd=250) 28 | sigma_b = pm.HalfNormal('sigma_b', sd=250) 29 | sigma_sigma = pm.HalfNormal('sigma_sigma', sd=200) 30 | #priors 31 | a = pm.Normal("a", mu=mu_a, sd=sigma_a, dims="driver") 32 | b = pm.Normal("b", mu=mu_b, sd=sigma_b, dims="driver") 33 | sigma = pm.HalfNormal("sigma", sd=sigma_sigma, dims="driver") 34 | y_pred = pm.Normal('y_pred', mu=a[driver_idx]+b[driver_idx]*day_idx, sd=sigma[driver_idx], observed=reactions.Reaction, dims="driver_idx_day") 35 | ## inference 36 | trace_hi = pm.sample(draws=samples, chains=chains, tune=tune) 37 | prior_hi = pm.sample_prior_predictive(samples=samples) 38 | posterior_predictive_hi = pm.sample_posterior_predictive(trace_hi, samples=samples) 39 | 40 | ## STEP 1 41 | ## export inference results in ArviZ InferenceData obj 42 | ## will also capture all the sampler statistics 43 | data_hi = az.from_pymc3(trace = trace_hi, prior = prior_hi, posterior_predictive = posterior_predictive_hi) 44 | 45 | ## STEP 2 46 | ## extract dag 47 | dag_hi = get_dag(hierarchical_model) 48 | ## insert dag into sampler stat attributes 49 | data_hi.sample_stats.attrs["graph"] = str(dag_hi) 50 | 51 | ## STEP 3 52 | ## save data 53 | fileName_hi = "reaction_times_hierarchical" 54 | arviz_to_json(data_hi, fileName_hi+'.npz') -------------------------------------------------------------------------------- /examples/drivers_reaction_times/model_pooled.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pymc3 as pm 3 | import arviz as az 4 | from arviz_json import get_dag, arviz_to_json 5 | 6 | ## data 7 | DATAPATH='../data/evaluation_sleepstudy.csv' ## data from Belenky et al. (2003) 8 | reactions = pd.read_csv(DATAPATH,usecols=['Reaction','Days','Subject']) 9 | 10 | ## Drivers Reaction Times Pooled Model 11 | samples=4000 12 | chains=2 13 | tune=2000 14 | 15 | ## data 16 | driver_idx, drivers = pd.factorize(reactions["Subject"], sort=True) 17 | day_idx, days = pd.factorize(reactions["Days"], sort=True) 18 | 19 | ## dims 20 | coords_p = {"driver": drivers,"driver_idx_day":reactions.Subject} 21 | 22 | with pm.Model(coords=coords_p) as fullyPooled_model: 23 | ## model 24 | a = pm.Normal("a", mu=100, sd=250) 25 | b = pm.Normal("b", mu=10, sd=250) 26 | sigma = pm.HalfNormal("sigma", sd=200) 27 | y_pred = pm.Normal('y_pred', mu = a + b*day_idx, sd = sigma, observed = reactions.Reaction, dims = "driver_idx_day") 28 | ## inference 29 | trace_p = pm.sample(samples, chains=chains, tune=tune) 30 | prior_p = pm.sample_prior_predictive(samples=samples) 31 | posterior_predictive_p = pm.sample_posterior_predictive(trace_p, samples=samples) 32 | 33 | ## STEP 1 34 | ## export inference results in ArviZ InferenceData obj 35 | ## will also capture all the sampler statistics 36 | data_p = az.from_pymc3(trace = trace_p, prior = prior_p, posterior_predictive = posterior_predictive_p) 37 | 38 | ## STEP 2 39 | ## extract dag 40 | dag_p = get_dag(fullyPooled_model) 41 | ## insert dag into sampler stat attributes 42 | data_p.sample_stats.attrs["graph"] = str(dag_p) 43 | 44 | ## STEP 3 45 | ## save data 46 | fileName_p = "reaction_times_pooled" 47 | arviz_to_json(data_p, fileName_p+'.npz') -------------------------------------------------------------------------------- /examples/eight_schools_problem/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='inference_8_schools_centered.npz' 5 | #infer_datapath='inference_8_schools_non_centered.npz' 6 | ipme.graph(infer_datapath, predictive_checks = ['y']) 7 | -------------------------------------------------------------------------------- /examples/eight_schools_problem/model_centered.py: -------------------------------------------------------------------------------- 1 | import pymc3 as pm 2 | import numpy as np 3 | import arviz as az 4 | from arviz_json import get_dag, arviz_to_json 5 | SEED = [20100420, 20134234] 6 | 7 | #Hierarchical Model 8 | #Reference: https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html#The-Eight-Schools-Model 9 | #data 10 | J = 8 11 | obs = np.array([28., 8., -3., 7., -1., 1., 18., 12.]) 12 | sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.]) 13 | 14 | #model-inference 15 | coords_c = {"school": ["A","B","C","D","E","F","G","H"]} 16 | fileName_c="eight_schools_centered" 17 | samples=4000 18 | chains=2 19 | tune=1000 20 | with pm.Model(coords=coords_c) as centered_eight: 21 | mu = pm.Normal('mu', mu=0, sigma=5) 22 | tau = pm.HalfCauchy('tau', beta=5) 23 | theta = pm.Normal('theta', mu=mu, sigma=tau, dims='school') 24 | y = pm.Normal('y', mu=theta, sigma=sigma, observed=obs, dims='school') 25 | #inference 26 | trace_c = pm.sample(samples, chains=chains, tune=tune, random_seed=SEED) 27 | prior_c= pm.sample_prior_predictive(samples=samples) 28 | posterior_predictive_c = pm.sample_posterior_predictive(trace_c, samples=samples) 29 | 30 | ## STEP 1 31 | # will also capture all the sampler statistics 32 | data_c = az.from_pymc3(trace = trace_c, prior = prior_c, posterior_predictive = posterior_predictive_c) 33 | 34 | ## STEP 2 35 | #dag 36 | dag_c = get_dag(centered_eight) 37 | # insert dag into sampler stat attributes 38 | data_c.sample_stats.attrs["graph"] = str(dag_c) 39 | 40 | ## STEP 3 41 | # save data 42 | arviz_to_json(data_c, fileName_c+'.npz') 43 | -------------------------------------------------------------------------------- /examples/eight_schools_problem/model_non_centered.py: -------------------------------------------------------------------------------- 1 | import pymc3 as pm 2 | import pandas as pd 3 | import numpy as np 4 | import arviz as az 5 | from arviz_json import get_dag, arviz_to_json 6 | SEED = [20100420, 20134234] 7 | 8 | #Hierarchical Model 9 | #Reference: https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html#A-Non-Centered-Eight-Schools-Implementation 10 | #data 11 | J = 8 12 | obs = np.array([28., 8., -3., 7., -1., 1., 18., 12.]) 13 | sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.]) 14 | 15 | #model-inference 16 | coords = {"school": ["A","B","C","D","E","F","G","H"]} 17 | samples=5000 18 | chains=2 19 | tune=1000 20 | fileName="eight_schools_non_centered" 21 | with pm.Model(coords=coords) as NonCentered_eight: 22 | mu = pm.Normal('mu', mu=0, sigma=5) 23 | tau = pm.HalfCauchy('tau', beta=5) 24 | theta_tilde = pm.Normal('theta_t', mu=0, sigma=1, dims='school') 25 | theta = pm.Deterministic('theta', mu + tau * theta_tilde, dims='school') 26 | y = pm.Normal('y', mu=theta, sigma=sigma, observed=obs, dims='school') 27 | #inference 28 | trace_nc = pm.sample(samples, chains=chains, tune=tune, random_seed=SEED, target_accept=.90) 29 | prior_nc= pm.sample_prior_predictive(samples=samples) 30 | posterior_predictive_nc = pm.sample_posterior_predictive(trace_nc,samples=samples) 31 | 32 | ## STEP 1 33 | # will also capture all the sampler statistics 34 | data_nc = az.from_pymc3(trace = trace_nc, prior = prior_nc, posterior_predictive = posterior_predictive_nc) 35 | 36 | ## STEP 2 37 | #dag 38 | dag_nc = get_dag(NonCentered_eight) 39 | # insert dag into sampler stat attributes 40 | data_nc.sample_stats.attrs["graph"] = str(dag_nc) 41 | 42 | ## STEP 3 43 | # save data 44 | arviz_to_json(data_nc, fileName+'.npz') 45 | -------------------------------------------------------------------------------- /examples/golf_putting/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='golf_simple_PyMC3.npz' 5 | #infer_datapath='golf_geometry_PyMC3.npz' 6 | ipme.graph(infer_datapath, predictive_checks = ['successes']) 7 | -------------------------------------------------------------------------------- /examples/golf_putting/model_geometry.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import io 3 | import pymc3 as pm 4 | import arviz as az 5 | from arviz_json import get_dag, arviz_to_json 6 | import theano.tensor as tt 7 | 8 | CUP_RADIUS = (4.25 / 2) / 12 9 | BALL_RADIUS = (1.68 / 2) / 12 10 | 11 | def Phi(x): 12 | """Calculates the standard normal cumulative distribution function.""" 13 | return 0.5 + 0.5 * tt.erf(x / tt.sqrt(2.)) 14 | 15 | #Binomial Model 16 | #Reference: https://docs.pymc.io/notebooks/putting_workflow.html#Geometry-based-model 17 | #data 18 | golf_data = """distance tries successes 19 | 2 1443 1346 20 | 3 694 577 21 | 4 455 337 22 | 5 353 208 23 | 6 272 149 24 | 7 256 136 25 | 8 240 111 26 | 9 217 69 27 | 10 200 67 28 | 11 237 75 29 | 12 202 52 30 | 13 192 46 31 | 14 174 54 32 | 15 167 28 33 | 16 201 27 34 | 17 195 31 35 | 18 191 33 36 | 19 147 20 37 | 20 152 24""" 38 | data = pd.read_csv(io.StringIO(golf_data), sep=" ") 39 | 40 | #model-inference 41 | coords = {"distance": data.distance} 42 | fileName='golf_geometry_PyMC3' 43 | samples=2000 44 | chains=2 45 | tune=1000 46 | geometry_model=pm.Model(coords=coords) 47 | with geometry_model: 48 | #to store the n-parameter of Binomial dist 49 | #in the constant group of ArviZ InferenceData 50 | #You should always call it n for imd to retrieve it 51 | n = pm.Data('n', data.tries) 52 | sigma_angle = pm.HalfNormal('sigma_angle') 53 | p_goes_in = pm.Deterministic('p_goes_in', 2 * Phi(tt.arcsin((CUP_RADIUS - BALL_RADIUS) / data.distance) / sigma_angle) - 1, dims='distance') 54 | successes = pm.Binomial('successes', n=n, p=p_goes_in, observed=data.successes, dims='distance') 55 | #inference 56 | trace_g = pm.sample(draws=samples, chains=chains, tune=tune) 57 | prior_g= pm.sample_prior_predictive(samples=samples) 58 | posterior_predictive_g = pm.sample_posterior_predictive(trace_g,samples=samples) 59 | 60 | ## STEP 1 61 | # will also capture all the sampler statistics 62 | data_g = az.from_pymc3(trace=trace_g, prior=prior_g, posterior_predictive=posterior_predictive_g) 63 | 64 | ## STEP 2 65 | #dag 66 | dag_g = get_dag(geometry_model) 67 | # insert dag into sampler stat attributes 68 | data_g.sample_stats.attrs["graph"] = str(dag_g) 69 | 70 | ## STEP 3 71 | # save data 72 | arviz_to_json(data_g, fileName+'.npz') -------------------------------------------------------------------------------- /examples/golf_putting/model_simple.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import io 3 | import pymc3 as pm 4 | import arviz as az 5 | from arviz_json import get_dag, arviz_to_json 6 | 7 | #Binomial Logistic Regression Model 8 | #Reference: https://docs.pymc.io/notebooks/putting_workflow.html#Logit-model 9 | #data 10 | golf_data = """distance tries successes 11 | 2 1443 1346 12 | 3 694 577 13 | 4 455 337 14 | 5 353 208 15 | 6 272 149 16 | 7 256 136 17 | 8 240 111 18 | 9 217 69 19 | 10 200 67 20 | 11 237 75 21 | 12 202 52 22 | 13 192 46 23 | 14 174 54 24 | 15 167 28 25 | 16 201 27 26 | 17 195 31 27 | 18 191 33 28 | 19 147 20 29 | 20 152 24""" 30 | data = pd.read_csv(io.StringIO(golf_data), sep=" ") 31 | 32 | #model-inference 33 | coords = {"distance": data.distance} 34 | fileName='golf_simple_PyMC3' 35 | samples=2000 36 | chains=2 37 | tune=1000 38 | simple_model=pm.Model(coords=coords) 39 | with simple_model: 40 | #to store the n-parameter of Binomial dist 41 | #in the constant group of ArviZ InferenceData 42 | #You should always call it n for imd to retrieve it 43 | n = pm.Data('n', data.tries) 44 | a = pm.Normal('a') 45 | b = pm.Normal('b') 46 | p_goes_in = pm.Deterministic('p_goes_in', pm.math.invlogit(a * data.distance + b), dims='distance') 47 | successes = pm.Binomial('successes', n=n, p=p_goes_in, observed=data.successes, dims='distance') 48 | #inference 49 | # Get posterior trace, prior trace, posterior predictive samples, and the DAG 50 | trace = pm.sample(draws=samples, chains=chains, tune=tune) 51 | prior= pm.sample_prior_predictive(samples=samples) 52 | posterior_predictive = pm.sample_posterior_predictive(trace,samples=samples) 53 | 54 | ## STEP 1 55 | # will also capture all the sampler statistics 56 | data_s = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive) 57 | 58 | ## STEP 2 59 | #dag 60 | dag = get_dag(simple_model) 61 | # insert dag into sampler stat attributes 62 | data_s.sample_stats.attrs["graph"] = str(dag) 63 | 64 | ## STEP 3 65 | # save data 66 | arviz_to_json(data_s, fileName+'.npz') -------------------------------------------------------------------------------- /examples/radon_basement/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='radon_basement_PyMC3.npz' 5 | ipme.graph(infer_datapath, predictive_checks = ['radon']) -------------------------------------------------------------------------------- /examples/radon_basement/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pymc3 as pm 3 | import pandas as pd 4 | import theano 5 | import arviz as az 6 | from arviz_json import get_dag, arviz_to_json 7 | 8 | #Hierarchical Linear Regression Model 9 | #Reference1: https://docs.pymc.io/notebooks/multilevel_modeling.html 10 | #Reference2: https://docs.pymc.io/notebooks/GLM-hierarchical.html#The-data-set 11 | #data 12 | data_r = pd.read_csv(pm.get_data('radon.csv')) 13 | data_r['log_radon'] = data_r['log_radon'].astype(theano.config.floatX) 14 | county_names = data_r.county.unique() 15 | county_idx = data_r.county_code.values 16 | 17 | n_counties = len(data_r.county.unique()) 18 | 19 | #model-inference 20 | fileName='radon_basement_PyMC3' 21 | samples=3000 22 | tune=10000 23 | chains=2 24 | coords = {"county":county_names, "county_idx_household":data_r["county"].tolist()} 25 | with pm.Model(coords=coords) as model: 26 | # Hyperpriors for group nodes 27 | mu_a = pm.Normal('mu_a', mu=0., sigma=100) 28 | sigma_a = pm.HalfNormal('sigma_a', 5.) 29 | mu_b = pm.Normal('mu_b', mu=0., sigma=100) 30 | sigma_b = pm.HalfNormal('sigma_b', 5.) 31 | 32 | # Intercept for each county, distributed around group mean mu_a 33 | # Above we just set mu and sd to a fixed value while here we 34 | # plug in a common group distribution for all a and b (which are 35 | # vectors of length n_counties). 36 | a = pm.Normal('a', mu=mu_a, sigma=sigma_a, dims='county') 37 | # Intercept for each county, distributed around group mean mu_a 38 | b = pm.Normal('b', mu=mu_b, sigma=sigma_b, dims='county') 39 | 40 | # Model error 41 | eps = pm.HalfCauchy('eps', 5.) 42 | 43 | radon_est = a[county_idx] + b[county_idx]*data_r.floor.values 44 | 45 | # Data likelihood 46 | radon = pm.Normal('radon', mu=radon_est, 47 | sigma=eps, observed=data_r.log_radon, dims='county_idx_household') 48 | 49 | #Inference 50 | trace = pm.sample(samples, chains=chains, tune=tune, target_accept=1.0) 51 | prior = pm.sample_prior_predictive(samples=samples) 52 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples) 53 | 54 | ## STEP 1 55 | # will also capture all the sampler statistics 56 | data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive) 57 | 58 | ## STEP 2 59 | #dag 60 | dag = get_dag(model) 61 | # insert dag into sampler stat attributes 62 | data.sample_stats.attrs["graph"] = str(dag) 63 | 64 | ## STEP 3 65 | # save data 66 | arviz_to_json(data, fileName+'.npz') -------------------------------------------------------------------------------- /examples/stochastic_volatility/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='stochastic_volatility__PyMC3.npz' 5 | ipme.graph(infer_datapath, predictive_checks=['returns']) -------------------------------------------------------------------------------- /examples/stochastic_volatility/model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pymc3 as pm 3 | import arviz as az 4 | import numpy as np 5 | from arviz_json import get_dag, arviz_to_json 6 | 7 | #StudentT Timeseries Model 8 | #Reference1: https://docs.pymc.io/notebooks/getting_started.html#Case-study-1:-Stochastic-volatility 9 | #Reference2: https://docs.pymc.io/notebooks/stochastic_volatility.html#Stochastic-Volatility-model 10 | #data 11 | returns = pd.read_csv(pm.get_data('SP500.csv'), parse_dates=True, index_col=0) 12 | dates = returns.index.strftime("%Y/%m/%d").tolist() 13 | 14 | #model-inference 15 | fileName='stochastic_volatility_PyMC3' 16 | samples=2000 17 | tune=2000 18 | chains=2 19 | coords = {"date": dates} 20 | with pm.Model(coords=coords) as model: 21 | step_size = pm.Exponential('step_size', 10) 22 | volatility = pm.GaussianRandomWalk('volatility', sigma = step_size, dims='date') 23 | nu = pm.Exponential('nu', 0.1) 24 | returns = pm.StudentT('returns', nu = nu, lam = np.exp(-2*volatility) , observed = returns["change"], dims='date') 25 | #inference 26 | trace = pm.sample(draws=samples, chains=chains, tune=tune) 27 | prior = pm.sample_prior_predictive(samples=samples) 28 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples) 29 | 30 | ## STEP 1 31 | # will also capture all the sampler statistics 32 | data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive) 33 | 34 | ## STEP 2 35 | #dag 36 | dag = get_dag(model) 37 | # insert dag into sampler stat attributes 38 | data.sample_stats.attrs["graph"] = str(dag) 39 | 40 | ## STEP 3 41 | # save data 42 | arviz_to_json(data, fileName+'.npz') 43 | -------------------------------------------------------------------------------- /examples/user_study/min_temperature/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='min_temperature.npz' 5 | ipme.scatter_matrix(infer_datapath,vars=['a','b','c','temperature']) -------------------------------------------------------------------------------- /examples/user_study/min_temperature/model.py: -------------------------------------------------------------------------------- 1 | import pymc3 as pm 2 | import numpy as np 3 | import arviz as az 4 | # !pip install git+https://github.com/johnhw/arviz_json.git 5 | from arviz_json import get_dag, arviz_to_json 6 | 7 | RANDOM_SEED = np.random.seed(1225) 8 | 9 | ## the average minimum temperature in Scotland in month November for the years 10 | ## 1884-2020: metoffice.gov.uk 11 | ## https://www.metoffice.gov.uk/pub/data/weather/uk/climate/datasets/Tmin/date/Scotland.txt 12 | l = [1.0,1.1,2.5,0.8,2.7,2.9,1.4,1.1,2.2,0.3,3.6,1.7,2.1,3.5,1.5,4.6,2.3, 13 | 1.2,3.5,1.8,1.3,1.2,3.7,1.4,2.9,-0.3,-1.1,1.1,1.6,3.2,1.9,-1.3,3.2, 14 | 3.1,0.8,-1.6,3.6,0.7,2.1,-0.5,3.5,-0.0,1.1,1.8,2.7,2.0,0.6,3.6,1.8, 15 | 1.7,1.5,1.7,1.3,1.5,3.8,3.0,1.9,1.7,1.0,2.2,0.4,3.4,2.7,1.6,3.0,2.4, 16 | 0.6,3.4,-0.3,4.0,1.6,3.5,2.6,2.8,2.2,2.9,1.7,1.1,1.5,2.2,2.3,-0.2,0.5, 17 | 1.5,1.7,-0.8,1.9,1.7,0.9,0.9,1.7,1.8,1.5,0.9,3.1,1.2,1.9,2.5,2.2,3.0, 18 | 3.3,-1.1,2.5,2.6,1.7,2.0,2.0,2.0,1.5,0.4,5.2,3.0,0.2,4.6,1.3,2.9,1.8, 19 | 3.0,3.8,3.4,3.3,1.5,3.4,3.4,2.0,2.8,0.0,4.8,1.7,1.1,3.8,3.6,0.3,1.5,3.6,1.1] 20 | years = np.arange(1884,2020) 21 | 22 | fileName='min_temperature' 23 | samples=4000 24 | chains=2 25 | tune=1000 26 | coords = {'years':years} 27 | temperature_model = pm.Model(coords=coords) 28 | with temperature_model: 29 | #priors 30 | a = pm.Uniform('a', 80, 100) 31 | b = pm.Normal('b', mu=2, sd=10) 32 | c = pm.HalfNormal('c', sd=10) 33 | 34 | #predictions 35 | temperature = pm.Normal('temperature', 36 | mu = b, 37 | sd = c, 38 | observed = l, 39 | dims = 'years') 40 | 41 | trace = pm.sample(draws=samples, chains=chains, tune=tune) 42 | prior = pm.sample_prior_predictive(samples=samples) 43 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples) 44 | dag = get_dag(temperature_model) 45 | 46 | # will also capture all the sampler statistics 47 | data = az.from_pymc3(trace=trace, 48 | prior=prior, 49 | posterior_predictive=posterior_predictive) 50 | 51 | # insert dag into sampler stat attributes 52 | data.sample_stats.attrs["graph"] = str(dag) 53 | 54 | # save data 55 | arviz_to_json(data, fileName+'.npz') -------------------------------------------------------------------------------- /examples/user_study/random_number_generator/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='transformation.npz' 5 | ipme.scatter_matrix(infer_datapath, vars=['a','b','c','random_number']) -------------------------------------------------------------------------------- /examples/user_study/random_number_generator/model.py: -------------------------------------------------------------------------------- 1 | import pymc3 as pm 2 | import numpy as np 3 | import arviz as az 4 | # !pip install git+https://github.com/johnhw/arviz_json.git 5 | from arviz_json import get_dag, arviz_to_json 6 | 7 | RANDOM_SEED = np.random.seed(1225) 8 | 9 | x_data = np.random.uniform(-3, 5, size=5) 10 | obs = np.arange(len(x_data)) 11 | 12 | fileName = 'transformation' 13 | samples = 4000 14 | # tune = 10000 15 | chains = 1 16 | coords = {"obs": obs} 17 | uniform_model = pm.Model(coords=coords) 18 | with uniform_model: 19 | # Priors for unknown model parameters 20 | a = pm.Normal('a', mu = 0, sd=10) 21 | b = pm.HalfNormal('b', sd = 10) 22 | c = pm.HalfNormal('c', sd=20) 23 | 24 | l = a - c 25 | u = a + c 26 | random_number = pm.Uniform('random_number', 27 | lower=l, 28 | upper=u, 29 | observed=x_data, 30 | dims="obs") 31 | 32 | trace = pm.sample(samples, chains = chains, target_accept=0.95) 33 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples) 34 | prior = pm.sample_prior_predictive(samples=samples, random_seed=RANDOM_SEED) 35 | 36 | # will also capture all the sampler statistics 37 | data = az.from_pymc3(trace = trace) 38 | 39 | # will also capture all the sampler statistics 40 | data = az.from_pymc3(trace = trace, 41 | prior = prior, 42 | posterior_predictive = posterior_predictive) 43 | 44 | dag = get_dag(uniform_model) 45 | data.sample_stats.attrs["graph"] = str(dag) 46 | 47 | # save data 48 | arviz_to_json(data, fileName+'.npz') -------------------------------------------------------------------------------- /examples/user_study/reaction_times/ipme.py: -------------------------------------------------------------------------------- 1 | import ipme 2 | 3 | if __name__=="__main__": 4 | infer_datapath='reaction_times_hierarchical.npz' 5 | ipme.scatter_matrix(infer_datapath,vars=['a','b','c','d','reaction_time']) -------------------------------------------------------------------------------- /examples/user_study/reaction_times/model.py: -------------------------------------------------------------------------------- 1 | import pymc3 as pm 2 | import pandas as pd 3 | import numpy as np 4 | import arviz as az 5 | # !pip install git+https://github.com/johnhw/arviz_json.git 6 | from arviz_json import get_dag, arviz_to_json 7 | 8 | RANDOM_SEED = np.random.seed(1225) 9 | 10 | DATAPATH='../../data/evaluation_sleepstudy.csv' ## data from Belenky et al. (2003) 11 | reactions = pd.read_csv(DATAPATH,usecols=['Reaction','Days','Subject']) 12 | reactions = reactions[0:3*10] 13 | 14 | driver_idx, drivers = pd.factorize(reactions["Subject"], sort=True) 15 | day_idx, days = pd.factorize(reactions["Days"], sort=True) 16 | coords_c = {"driver": drivers,"driver_idx_day":reactions.Subject} 17 | with pm.Model(coords=coords_c) as hierarchical_model: 18 | #hyper-priors 19 | c = pm.Normal('c', mu=100, sd=150) 20 | e = pm.HalfNormal('e', sd=150) 21 | f = pm.Normal('f', mu=10, sd=100) 22 | g = pm.HalfNormal('g', sd=100) 23 | h = pm.HalfNormal('h', sd=200) 24 | #priors 25 | a = pm.Normal("a", mu=c, sd=e, dims="driver") 26 | b = pm.Normal("b", mu=f, sd=g, dims="driver") 27 | sigma = pm.HalfNormal("sigma", sd=h, dims="driver") 28 | d = pm.Normal("d", mu = 0, sd=10.0) 29 | 30 | reaction_time = pm.Normal('reaction_time', 31 | mu = a[driver_idx]+b[driver_idx]*day_idx, 32 | sd=sigma[driver_idx], 33 | observed=reactions.Reaction, 34 | dims="driver_idx_day") 35 | 36 | samples=4000 37 | chains=3 38 | tune=3000 39 | fileName_hi="reaction_times_hierarchical" 40 | with hierarchical_model: 41 | # Get posterior trace, prior trace, posterior predictive samples, and the DAG 42 | trace_hi = pm.sample(draws=samples, chains=chains, tune=tune, target_accept=0.9) 43 | prior_hi = pm.sample_prior_predictive(samples=samples) 44 | posterior_predictive_hi = pm.sample_posterior_predictive(trace_hi, 45 | samples=samples) 46 | 47 | # export inference results in ArviZ InferenceData obj 48 | # will also capture all the sampler statistics 49 | data_hi = az.from_pymc3(trace = trace_hi, 50 | prior = prior_hi, 51 | posterior_predictive = posterior_predictive_hi) 52 | 53 | # insert dag into sampler stat attributes 54 | dag_hi = get_dag(hierarchical_model) 55 | data_hi.sample_stats.attrs["graph"] = str(dag_hi) 56 | 57 | # save data 58 | arviz_to_json(data_hi, fileName_hi+'.npz') -------------------------------------------------------------------------------- /ipme/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for imd.""" 2 | 3 | __author__ = """Evdoxia Taka""" 4 | __email__ = 'e.taka.1@research.gla.ac.uk' 5 | __version__ = '0.1.0' 6 | 7 | from .methods import * -------------------------------------------------------------------------------- /ipme/classes/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = """Evdoxia Taka""" 2 | __email__ = 'e.taka.1@research.gla.ac.uk' 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /ipme/classes/cell/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/cell/__init__.py -------------------------------------------------------------------------------- /ipme/classes/cell/interactive_continuous_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.variable_cell import VariableCell 2 | from .utils.cell_continuous_handler import CellContinuousHandler 3 | from .utils.cell_clear_selection import CellClearSelection 4 | from ..cell.utils.cell_widgets import CellWidgets 5 | 6 | class InteractiveContinuousCell(VariableCell): 7 | def __init__(self, name, control): 8 | """ 9 | Parameters: 10 | -------- 11 | name A String within the set {""}. 12 | control A Control object 13 | """ 14 | self.selection = {} 15 | self.sel_samples = {} 16 | self.non_sel_samples = {} 17 | self.reconstructed = {} 18 | self.clear_selection = {} 19 | VariableCell.__init__(self, name, control) 20 | 21 | def initialize_cds(self, space): 22 | CellContinuousHandler.initialize_cds_interactive(self, space) 23 | 24 | def initialize_fig(self, space): 25 | CellContinuousHandler.initialize_fig_interactive(self, space) 26 | 27 | def initialize_glyphs(self, space): 28 | CellContinuousHandler.initialize_glyphs_interactive(self, space) 29 | CellClearSelection.initialize_glyphs_x_button(self, space) 30 | 31 | def widget_callback(self, attr, old, new, w_title, space): 32 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space) 33 | 34 | def update_cds(self, space): 35 | CellContinuousHandler.update_cds_interactive(self, space) 36 | 37 | ## ONLY FOR INTERACTIVE CASE 38 | def update_source_cds(self, space): 39 | CellContinuousHandler.update_source_cds_interactive(self, space) 40 | 41 | def update_selection_cds(self, space, xmin, xmax): 42 | CellContinuousHandler.update_selection_cds_interactive(self, space, xmin, xmax) 43 | 44 | def update_reconstructed_cds(self, space): 45 | CellContinuousHandler.update_reconstructed_cds_interactive(self, space) 46 | -------------------------------------------------------------------------------- /ipme/classes/cell/interactive_discrete_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.variable_cell import VariableCell 2 | from .utils.cell_discrete_handler import CellDiscreteHandler 3 | from .utils.cell_clear_selection import CellClearSelection 4 | from ..cell.utils.cell_widgets import CellWidgets 5 | 6 | class InteractiveDiscreteCell(VariableCell): 7 | def __init__(self, name, control): 8 | """ 9 | Parameters: 10 | -------- 11 | name A String within the set {""}. 12 | control A Control object 13 | """ 14 | self.selection = {} 15 | self.reconstructed = {} 16 | self.clear_selection = {} 17 | VariableCell.__init__(self, name, control) 18 | 19 | def initialize_cds(self, space): 20 | CellDiscreteHandler.initialize_cds_interactive(self, space) 21 | 22 | def initialize_fig(self, space): 23 | CellDiscreteHandler.initialize_fig_interactive(self, space) 24 | 25 | def initialize_glyphs(self, space): 26 | CellDiscreteHandler.initialize_glyphs_interactive(self, space) 27 | CellClearSelection.initialize_glyphs_x_button(self, space) 28 | 29 | def widget_callback(self, attr, old, new, w_title, space): 30 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space) 31 | 32 | def update_cds(self, space): 33 | CellDiscreteHandler.update_cds_interactive(self, space) 34 | 35 | ## ONLY FOR INTERACTIVE CASE 36 | def update_source_cds(self, space): 37 | CellDiscreteHandler.update_source_cds_interactive(self, space) 38 | 39 | def update_selection_cds(self, space, xmin, xmax): 40 | CellDiscreteHandler.update_selection_cds_interactive(self, space, xmin, xmax) 41 | 42 | def update_reconstructed_cds(self, space): 43 | CellDiscreteHandler.update_reconstructed_cds_interactive(self, space) 44 | -------------------------------------------------------------------------------- /ipme/classes/cell/interactive_pred_ckeck_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.predictive_check_cell import PredictiveCheckCell 2 | from .utils.cell_pred_check_handler import CellPredCheckHandler 3 | from ..cell.utils.cell_widgets import CellWidgets 4 | 5 | # from ipme.utils.stats import hist 6 | # from ipme.utils.functions import get_finite_samples, get_samples_for_pred_check, get_hist_bins_range 7 | # from ipme.utils.constants import COLORS 8 | 9 | # import numpy as np 10 | from bokeh.models import LegendItem 11 | 12 | # import threading 13 | # from abc import abstractmethod 14 | 15 | class InteractivePredCheckCell(PredictiveCheckCell): 16 | def __init__(self, name, control, function = 'min'): 17 | """ 18 | Parameters: 19 | -------- 20 | name A String within the set {""}. 21 | function A String in {"min","max","mean","std"}. 22 | Sets: 23 | -------- 24 | func 25 | source 26 | reconstructed 27 | samples 28 | seg 29 | """ 30 | self.func = function 31 | PredictiveCheckCell.__init__(self, name, control) 32 | 33 | def initialize_cds(self, space): 34 | CellPredCheckHandler.initialize_cds_interactive(self, space) 35 | 36 | def initialize_fig(self, space): 37 | CellPredCheckHandler.initialize_fig_interactive(self, space) 38 | 39 | def initialize_glyphs(self, space): 40 | CellPredCheckHandler.initialize_glyphs_interactive(self, space) 41 | 42 | def widget_callback(self, attr, old, new, w_title, space): 43 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space) 44 | 45 | def update_cds(self, space): 46 | CellPredCheckHandler.update_cds_interactive(self, space) 47 | 48 | ## ONLY FOR INTERACTIVE CASE 49 | def update_source_cds(self, space): 50 | CellPredCheckHandler.update_source_cds_interactive(self, space) 51 | 52 | def update_sel_samples_cds(self, space): 53 | CellPredCheckHandler.update_sel_samples_cds_interactive(self, space) 54 | 55 | # def initialize_glyphs(self,space): 56 | # q = self.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=self.source[space], \ 57 | # fill_color=COLORS[0], line_color="white", fill_alpha=1.0, name="full") 58 | # seg = self.plot[space].segment(x0 ='x0', y0 ='y0', x1='x1', y1='y1', source=self.seg[space], \ 59 | # color="black", line_width=2, name="seg") 60 | # q_sel = self.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=self.reconstructed[space], \ 61 | # fill_color=COLORS[1], line_color="white", fill_alpha=0.7, name="sel") 62 | # ## Add Legends 63 | # data = self.seg[space].data['x0'] 64 | # pvalue = self.pvalue[space].data["pv"] 65 | # if len(data) and len(pvalue): 66 | # legend = Legend(items=[ (self.func + "(obs) = " + format(data[0], '.2f'), [seg]), 67 | # ("p-value = "+format(pvalue[0],'.4f'), [q]), 68 | # ], location="top_left") 69 | # self.plot[space].add_layout(legend, 'above') 70 | # ## Add Tooltips for hist 71 | # #####TODO:Correct overlap of tooltips##### 72 | # TOOLTIPS = [ 73 | # ("top", "@top"), 74 | # ("right","@right"), 75 | # ("left","@left"), 76 | # ] 77 | # hover = HoverTool( tooltips=TOOLTIPS,renderers=[q,q_sel]) 78 | # self.plot[space].tools.append(hover) 79 | 80 | ## Update legends when data in _pvalue_rec cds is updated 81 | def update_legends(self, space, attr, old, new): 82 | if len(self.plot[space].legend.items) == 3: 83 | self.plot[space].legend.items.pop() 84 | r = self.plot[space].select(name="sel") 85 | pvalue = self.pvalue_rec[space].data["pv"] 86 | if len(r) and len(pvalue): 87 | self.plot[space].legend.items.append(LegendItem(label="p-value = " + format(pvalue[0], '.4f'), \ 88 | renderers=[r[0]])) 89 | 90 | # def widget_callback(self, attr, old, new, w_title, space): 91 | # inds = self.ic.data.get_indx_for_idx_dim(self.name, w_title, new) 92 | # if inds==-1: 93 | # return 94 | # self.cur_idx_dims_values[self.name][w_title] = inds 95 | # self._update_plot(space) 96 | 97 | # def _update_plot(self, space): 98 | # pass 99 | 100 | # def initialize_cds(self, space): 101 | # CellPredCheckHandler.initialize_cds_interactive(self, space) 102 | # # ## ColumnDataSource for full sample set 103 | # # data, samples = self.get_data_for_cur_idx_dims_values(space) 104 | # # self.samples[space] = ColumnDataSource(data=dict(x=samples)) 105 | # # #data func 106 | # # if ~np.isfinite(data).all(): 107 | # # data = get_finite_samples(data) 108 | # # data_func = get_samples_for_pred_check(data, self.func) 109 | # # #samples func 110 | # # if ~np.isfinite(samples).all(): 111 | # # samples = get_finite_samples(samples) 112 | # # samples_func = get_samples_for_pred_check(samples, self.func) 113 | # # if samples_func.size: 114 | # # #pvalue 115 | # # pv = np.count_nonzero(samples_func>=data_func) / len(samples_func) 116 | # # #histogram 117 | # # type = self._data.get_var_dist_type(self.name) 118 | # # if type == "Continuous": 119 | # # bins, range = get_hist_bins_range(samples_func, self.func, type) 120 | # # else: 121 | # # bins, range = get_hist_bins_range(samples_func, self.func, type, ref_length = None, ref_values=np.unique(samples.flatten())) 122 | 123 | # # his, edges = hist(samples_func, bins=bins, range=range, density=True) 124 | # # #cds 125 | # # self.pvalue[space] = ColumnDataSource(data=dict(pv=[pv])) 126 | # # self.source[space] = ColumnDataSource(data=dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his)))) 127 | # # self.seg[space] = ColumnDataSource(data=dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()])) 128 | # # else: 129 | # # self.pvalue[space] = ColumnDataSource(data=dict(pv=[])) 130 | # # self.source[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[])) 131 | # # self.seg[space] = ColumnDataSource(data=dict(x0=[], x1=[], y0=[], y1=[])) 132 | 133 | # # ## ColumnDataSource for restricted sample set 134 | # # self.pvalue_rec[space] = ColumnDataSource(data=dict(pv=[])) 135 | # # self.pvalue_rec[space].on_change('data', partial(self._update_legends, space)) 136 | # # self.reconstructed[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[])) 137 | 138 | # ## Update plots when indices of selected samples are updated 139 | # def sample_inds_callback(self, space, attr, old, new): 140 | # # _, samples = self._get_data_for_cur_idx_dims_values(space) 141 | # samples = self.samples[space].data['x'] 142 | # max_full_hist = self.source[space].data['top'].max() 143 | # if samples.size: 144 | # inds=self.ic.sample_inds[space].data['inds'] 145 | # if len(inds): 146 | # sel_sample = samples[inds] 147 | # if ~np.isfinite(sel_sample).all(): 148 | # sel_sample = get_finite_samples(sel_sample) 149 | # sel_sample_func = get_samples_for_pred_check(sel_sample, self.func) 150 | # #data func 151 | # data_func = self.seg[space].data['x0'][0] 152 | # #pvalue in restricted space 153 | # sel_pv = np.count_nonzero(sel_sample_func >= data_func) / len(sel_sample_func) 154 | # #compute updated histogram 155 | # min_p = self.source[space].data['left'][0] 156 | # max_p = self.source[space].data['right'][-1] 157 | # min_c = sel_sample_func.min() 158 | # max_c = sel_sample_func.max() 159 | # if min_c < min_p or max_c > max_p: 160 | # ref_len = self.source[space].data['right'][0] - min_p 161 | # bins, range = get_hist_bins_range(sel_sample_func, self.func, self._type, ref_length=ref_len) 162 | # else: 163 | # range = (min_p,max_p) 164 | # bins = len(self.source[space].data['right']) 165 | # his, edges = hist(sel_sample_func, bins=bins, range=range) 166 | # ##max selected hist 167 | # max_sel_hist = his.max() 168 | # #update reconstructed cds 169 | # self.pvalue_rec[space].data = dict(pv=[sel_pv]) 170 | # self.reconstructed[space].data = dict(left=edges[:-1], top=his, right =edges[1:], bottom=np.zeros(len(his))) 171 | # self.seg[space].data['y1'] = [max_sel_hist + 0.1 * max_sel_hist] 172 | # else: 173 | # self.pvalue_rec[space].data = dict(pv=[]) 174 | # self.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[]) 175 | # self.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist] 176 | # else: 177 | # self.pvalue_rec[space].data = dict(pv=[]) 178 | # self.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[]) 179 | # self.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist] 180 | -------------------------------------------------------------------------------- /ipme/classes/cell/interactive_scatter_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.scatter_cell import ScatterCell 2 | from .utils.cell_scatter_handler import CellScatterHandler 3 | from ..cell.utils.cell_widgets import CellWidgets 4 | 5 | class InteractiveScatterCell(ScatterCell): 6 | def __init__(self, vars, control): 7 | """ 8 | Parameters: 9 | -------- 10 | name A String within the set {""}. 11 | control A Control object 12 | """ 13 | self.sel_samples = {} 14 | self.non_sel_samples = {} 15 | ScatterCell.__init__(self, vars, control) 16 | 17 | def initialize_cds(self, space): 18 | CellScatterHandler.initialize_cds_interactive(self, space) 19 | 20 | def initialize_fig(self, space): 21 | CellScatterHandler.initialize_fig_interactive(self, space) 22 | 23 | def initialize_glyphs(self, space): 24 | CellScatterHandler.initialize_glyphs_interactive(self, space) 25 | 26 | def widget_callback(self, attr, old, new, w_title, space): 27 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space) 28 | 29 | def update_cds(self, space): 30 | CellScatterHandler.update_cds_interactive(self, space) 31 | 32 | ## ONLY FOR INTERACTIVE CASE 33 | def update_source_cds(self, space): 34 | CellScatterHandler.update_source_cds_interactive(self, space) 35 | 36 | def update_sel_samples_cds(self, space): 37 | CellScatterHandler.update_sel_samples_cds_interactive(self, space) 38 | -------------------------------------------------------------------------------- /ipme/classes/cell/static_continuous_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.variable_cell import VariableCell 2 | from .utils.cell_continuous_handler import CellContinuousHandler 3 | from ..cell.utils.cell_widgets import CellWidgets 4 | 5 | from ipme.utils.functions import get_stratum_range, find_indices 6 | 7 | class StaticContinuousCell(VariableCell): 8 | def __init__(self, name, control): 9 | """ 10 | Parameters: 11 | -------- 12 | name A String within the set {""}. 13 | control A Control object 14 | """ 15 | VariableCell.__init__(self, name, control) 16 | 17 | def initialize_cds(self, space): 18 | CellContinuousHandler.initialize_cds_static(self, space) 19 | 20 | def initialize_fig(self, space): 21 | CellContinuousHandler.initialize_fig_static(self, space) 22 | 23 | def initialize_glyphs(self, space): 24 | CellContinuousHandler.initialize_glyphs_static(self, space) 25 | 26 | def widget_callback(self, attr, old, new, w_title, space): 27 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space) 28 | 29 | def update_cds(self, space): 30 | CellContinuousHandler.update_cds_static(self, space) 31 | 32 | ## ONLY FOR STATIC CASE 33 | def set_stratum(self, space, stratum = 0): 34 | """ 35 | Sets selection by spliting the ordered sample set 36 | in 4 equal-sized subsets. 37 | """ 38 | samples = self.get_samples_for_cur_idx_dims_values(space) 39 | xmin,xmax = get_stratum_range(samples, stratum) 40 | inds = find_indices(samples, lambda e: xmin <= e <= xmax, xmin, xmax) 41 | self.ic.set_sel_var_inds(space, self.name, inds) 42 | self.compute_intersection_of_samples(space) 43 | return (xmin,xmax) 44 | 45 | -------------------------------------------------------------------------------- /ipme/classes/cell/static_discrete_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.variable_cell import VariableCell 2 | from .utils.cell_discrete_handler import CellDiscreteHandler 3 | from ..cell.utils.cell_widgets import CellWidgets 4 | 5 | from ipme.utils.functions import get_stratum_range, find_indices 6 | 7 | class StaticDiscreteCell(VariableCell): 8 | def __init__(self, name, control): 9 | """ 10 | Parameters: 11 | -------- 12 | name A String within the set {""}. 13 | control A Control object 14 | """ 15 | VariableCell.__init__(self, name, control) 16 | 17 | def initialize_cds(self, space): 18 | CellDiscreteHandler.initialize_cds_static(self, space) 19 | 20 | def initialize_fig(self, space): 21 | CellDiscreteHandler.initialize_fig_static(self, space) 22 | 23 | def initialize_glyphs(self, space): 24 | CellDiscreteHandler.initialize_glyphs_static(self, space) 25 | 26 | def widget_callback(self, attr, old, new, w_title, space): 27 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space) 28 | 29 | def update_cds(self, space): 30 | CellDiscreteHandler.update_cds_static(self, space) 31 | 32 | ## ONLY FOR STATIC CASE 33 | def set_stratum(self, space, stratum = 0): 34 | """ 35 | Sets selection by spliting the ordered sample set 36 | in 4 equal-sized subsets. 37 | """ 38 | samples = self.get_samples_for_cur_idx_dims_values(space) 39 | xmin,xmax = get_stratum_range(samples, stratum) 40 | inds = find_indices(samples, lambda e: xmin <= e <= xmax, xmin, xmax) 41 | self.ic.set_sel_var_inds(space, self.name, inds) 42 | self.compute_intersection_of_samples(space) 43 | return (xmin,xmax) 44 | -------------------------------------------------------------------------------- /ipme/classes/cell/static_pred_check_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.predictive_check_cell import PredictiveCheckCell 2 | from .utils.cell_pred_check_handler import CellPredCheckHandler 3 | from ..cell.utils.cell_widgets import CellWidgets 4 | 5 | class StaticPredCheckCell(PredictiveCheckCell): 6 | def __init__(self, vars, control, function = 'min'): 7 | """ 8 | Parameters: 9 | -------- 10 | vars A List of variableNames of the model. 11 | control A Control object 12 | function A String in {"min","max","mean","std"}. 13 | """ 14 | self.func = function 15 | PredictiveCheckCell.__init__(self, vars, control) 16 | 17 | def initialize_cds(self, space): 18 | CellPredCheckHandler.initialize_cds_static(self, space) 19 | 20 | def initialize_fig(self, space): 21 | CellPredCheckHandler.initialize_fig_static(self, space) 22 | 23 | def initialize_glyphs(self, space): 24 | CellPredCheckHandler.initialize_glyphs_static(self, space) 25 | 26 | def widget_callback(self, attr, old, new, w_title, space): 27 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space) 28 | 29 | def update_cds(self, space): 30 | CellPredCheckHandler.update_cds_static(self, space) 31 | 32 | -------------------------------------------------------------------------------- /ipme/classes/cell/static_scatter_cell.py: -------------------------------------------------------------------------------- 1 | from ipme.interfaces.scatter_cell import ScatterCell 2 | from .utils.cell_scatter_handler import CellScatterHandler 3 | from ..cell.utils.cell_widgets import CellWidgets 4 | 5 | class StaticScatterCell(ScatterCell): 6 | def __init__(self, vars, control): 7 | """ 8 | Parameters: 9 | -------- 10 | vars A List of variableNames of the model. 11 | control A Control object 12 | """ 13 | ScatterCell.__init__(self, vars, control) 14 | 15 | def initialize_cds(self, space): 16 | CellScatterHandler.initialize_cds_static(self, space) 17 | 18 | def initialize_fig(self, space): 19 | CellScatterHandler.initialize_fig_static(self, space) 20 | 21 | def initialize_glyphs(self, space): 22 | CellScatterHandler.initialize_glyphs_static(self, space) 23 | 24 | def widget_callback(self, attr, old, new, w_title, space): 25 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space) 26 | 27 | def update_cds(self, space): 28 | CellScatterHandler.update_cds_static(self, space) 29 | 30 | -------------------------------------------------------------------------------- /ipme/classes/cell/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/cell/utils/__init__.py -------------------------------------------------------------------------------- /ipme/classes/cell/utils/cell_clear_selection.py: -------------------------------------------------------------------------------- 1 | from ipme.utils.functions import find_highest_point 2 | from ipme.utils.js_code import HOVER_CODE 3 | 4 | from bokeh.models import HoverTool, CustomJS 5 | 6 | class CellClearSelection: 7 | 8 | def __init__(self): 9 | pass 10 | 11 | @staticmethod 12 | def initialize_glyphs_x_button(variableCell, space): 13 | ## x-button to clear selection 14 | sq_x = variableCell.plot[space].scatter('x', 'y', marker = "square_x", size = 10, fill_color = "grey", hover_fill_color = "firebrick", \ 15 | fill_alpha = 0.5, hover_alpha = 1.0, line_color = "grey", hover_line_color = "white", \ 16 | source = variableCell.clear_selection[space], name = 'clear_selection') 17 | ## Add HoverTool for x-button 18 | variableCell.plot[space].add_tools(HoverTool(tooltips = "Clear Selection", renderers = [sq_x], mode = 'mouse', show_arrow = False, 19 | callback = CustomJS(args = dict(source = variableCell.clear_selection[space]), code = HOVER_CODE))) 20 | 21 | @staticmethod 22 | def update_clear_selection_cds(variableCell, space): 23 | """ 24 | Updates clear_selection ColumnDataSource (cds). 25 | """ 26 | sel_var_idx_dims_values = variableCell.ic.get_sel_var_idx_dims_values() 27 | sel_space = variableCell.ic.get_sel_space() 28 | var_x_range = variableCell.ic.get_var_x_range() 29 | cur_idx_dims_values = {} 30 | if variableCell.name in variableCell.cur_idx_dims_values: 31 | cur_idx_dims_values = variableCell.cur_idx_dims_values[variableCell.name] 32 | if (variableCell.name in sel_var_idx_dims_values and space == sel_space and 33 | cur_idx_dims_values == sel_var_idx_dims_values[variableCell.name]): 34 | min_x_range = var_x_range[(space, variableCell.name)].data['xmin'][0] 35 | max_x_range = var_x_range[(space, variableCell.name)].data['xmax'][0] 36 | hp = find_highest_point(variableCell.reconstructed[space].data['x'], variableCell.reconstructed[space].data['y']) 37 | if not hp: 38 | hp = find_highest_point(variableCell.selection[space].data['x'], variableCell.selection[space].data['y']) 39 | if not hp: 40 | hp = find_highest_point(variableCell.source[space].data['x'], variableCell.source[space].data['y']) 41 | if not hp: 42 | hp = (0,0) 43 | variableCell.clear_selection[space].data = dict( x = [(max_x_range + min_x_range) / 2.], y = [hp[1]+hp[1]*0.1], isIn = [0]) 44 | else: 45 | variableCell.clear_selection[space].data = dict( x = [], y = [], isIn = []) 46 | -------------------------------------------------------------------------------- /ipme/classes/cell/utils/cell_discrete_handler.py: -------------------------------------------------------------------------------- 1 | from ipme.classes.cell.utils.cell_clear_selection import CellClearSelection 2 | 3 | from ipme.utils.constants import COLORS, BORDER_COLORS, PLOT_HEIGHT, PLOT_WIDTH, SIZING_MODE, RUG_DIST_RATIO, DATA_SIZE 4 | from ipme.utils.stats import pmf 5 | from ipme.utils.functions import find_indices 6 | 7 | from bokeh.models import ColumnDataSource, BoxSelectTool, HoverTool 8 | 9 | from bokeh import events 10 | from bokeh.plotting import figure 11 | 12 | import numpy as np 13 | import threading 14 | from functools import partial 15 | 16 | class CellDiscreteHandler: 17 | 18 | def __init__(self): 19 | pass 20 | 21 | @staticmethod 22 | def initialize_glyphs_interactive(variableCell, space): 23 | hover_renderer = [] 24 | so_seg = variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.source[space], line_alpha = 1.0, color = COLORS[0], line_width = 1, selection_color = COLORS[0], \ 25 | nonselection_color = COLORS[0], nonselection_line_alpha = 1.0) 26 | variableCell.plot[space].scatter('x', 'y', source = variableCell.source[space], size = 4, fill_color = COLORS[0], fill_alpha = 1.0, line_color = COLORS[0], selection_fill_color = COLORS[0], \ 27 | nonselection_fill_color = COLORS[0], nonselection_fill_alpha = 1.0, nonselection_line_color = COLORS[0]) 28 | variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.selection[space], line_alpha = 0.7, color = COLORS[2], line_width = 1) 29 | variableCell.plot[space].scatter('x', 'y', source = variableCell.selection[space], size = 4, fill_color = COLORS[2], fill_alpha = 0.7, line_color = COLORS[2]) 30 | rec = variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.reconstructed[space], line_alpha = 0.5, color = COLORS[1], line_width = 1) 31 | variableCell.plot[space].scatter('x', 'y', source = variableCell.reconstructed[space], size = 4, fill_color = COLORS[1], fill_alpha = 0.5, line_color = COLORS[1]) 32 | hover_renderer.append(so_seg) 33 | hover_renderer.append(rec) 34 | ## Add observations as yellow asterisks 35 | if space in variableCell.data: 36 | dat = variableCell.plot[space].asterisk('x', 'y', size = DATA_SIZE, line_color = COLORS[3], source = variableCell.data[space]) 37 | hover_renderer.append(dat) 38 | ##Add BoxSelectTool 39 | variableCell.plot[space].add_tools(BoxSelectTool(dimensions = 'width', renderers = [so_seg])) 40 | ##Tooltips 41 | TOOLTIPS = [("x", "@x"), ("y","@y"),] 42 | hover = HoverTool( tooltips = TOOLTIPS, renderers = hover_renderer, mode = 'mouse') 43 | variableCell.plot[space].tools.append(hover) 44 | 45 | @staticmethod 46 | def initialize_glyphs_static(variableCell, space): 47 | hover_renderer = [] 48 | so_seg = variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.source[space], line_alpha = 1.0, color = COLORS[0], line_width = 1, selection_color = COLORS[0], \ 49 | nonselection_color = COLORS[0], nonselection_line_alpha = 1.0) 50 | variableCell.plot[space].scatter('x', 'y', source = variableCell.source[space], size = 4, fill_color = COLORS[0], fill_alpha = 1.0, line_color = COLORS[0], selection_fill_color = COLORS[0], \ 51 | nonselection_fill_color = COLORS[0], nonselection_fill_alpha = 1.0, nonselection_line_color = COLORS[0]) 52 | hover_renderer.append(so_seg) 53 | ## Add observations as yellow asterisks 54 | if space in variableCell.data: 55 | dat = variableCell.plot[space].asterisk('x', 'y', size = DATA_SIZE, line_color = COLORS[3], source = variableCell.data[space]) 56 | hover_renderer.append(dat) 57 | ## Tooltips 58 | TOOLTIPS = [("x", "@x"), ("y","@y"),] 59 | hover = HoverTool( tooltips = TOOLTIPS, renderers = hover_renderer, mode = 'mouse') 60 | variableCell.plot[space].tools.append(hover) 61 | 62 | @staticmethod 63 | def initialize_fig(variableCell, space): 64 | variableCell.plot[space] = figure( x_range = variableCell.x_range[variableCell.name][space], tools = "wheel_zoom,reset,box_zoom", toolbar_location = 'right', 65 | plot_width = PLOT_WIDTH, plot_height = PLOT_HEIGHT, sizing_mode = SIZING_MODE) 66 | variableCell.plot[space].border_fill_color = BORDER_COLORS[0] 67 | variableCell.plot[space].xaxis.axis_label = variableCell.name 68 | variableCell.plot[space].yaxis.visible = False 69 | variableCell.plot[space].toolbar.logo = None 70 | variableCell.plot[space].xaxis[0].ticker.desired_num_ticks = 3 71 | 72 | @staticmethod 73 | def initialize_fig_interactive(variableCell, space): 74 | CellDiscreteHandler.initialize_fig(variableCell, space) 75 | ##Events 76 | variableCell.plot[space].on_event(events.Tap, partial(CellDiscreteHandler.clear_selection_callback, variableCell, space)) 77 | variableCell.plot[space].on_event(events.SelectionGeometry, partial(CellDiscreteHandler.selectionbox_callback, variableCell, space)) 78 | ##on_change 79 | variableCell.ic.sample_inds_update[space].on_change('data', partial(variableCell.sample_inds_callback, space)) 80 | 81 | @staticmethod 82 | def initialize_fig_static(variableCell, space): 83 | CellDiscreteHandler.initialize_fig(variableCell, space) 84 | ##on_change 85 | variableCell.ic.sample_inds_update[space].on_change('data', partial(variableCell.sample_inds_callback, space)) 86 | 87 | @staticmethod 88 | def initialize_cds(variableCell, space): 89 | samples = variableCell.get_samples_for_cur_idx_dims_values(variableCell.name, space) 90 | variableCell.source[space] = ColumnDataSource(data = pmf(samples)) 91 | variableCell.samples[space] = ColumnDataSource(data = dict(x = samples)) 92 | # data cds 93 | data = variableCell.get_data_for_cur_idx_dims_values(variableCell.name) 94 | if data is not None: 95 | max_v = variableCell.source[space].data['y'].max() 96 | variableCell.data[space] = ColumnDataSource(data = dict(x = data, y = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(data)))) 97 | # initialize sample inds 98 | variableCell.ic.initialize_sample_inds(space, dict(inds = [False]*len(variableCell.samples[space].data['x'])), dict(non_inds = [True]*len(variableCell.samples[space].data['x']))) 99 | 100 | @staticmethod 101 | def initialize_cds_interactive(variableCell, space): 102 | CellDiscreteHandler.initialize_cds(variableCell, space) 103 | variableCell.selection[space] = ColumnDataSource(data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))) 104 | variableCell.reconstructed[space] = ColumnDataSource(data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))) 105 | variableCell.clear_selection[space] = ColumnDataSource(data = dict(x = [], y = [], isIn = [])) 106 | variableCell.ic.var_x_range[(space, variableCell.name)] = ColumnDataSource(data = dict(xmin = np.array([]), xmax = np.array([]))) 107 | 108 | @staticmethod 109 | def initialize_cds_static(variableCell, space): 110 | CellDiscreteHandler.initialize_cds(variableCell, space) 111 | #########TEST########### 112 | 113 | @staticmethod 114 | def update_cds_interactive(variableCell, space): 115 | """ 116 | Updates interaction-related ColumnDataSources (cds). 117 | """ 118 | sel_var_idx_dims_values = variableCell.ic.get_sel_var_idx_dims_values() 119 | sel_space = variableCell.ic.get_sel_space() 120 | var_x_range = variableCell.ic.get_var_x_range() 121 | global_update = variableCell.ic.get_global_update() 122 | if global_update: 123 | cur_idx_dims_values = {} 124 | if variableCell.name in variableCell.cur_idx_dims_values: 125 | cur_idx_dims_values = variableCell.cur_idx_dims_values[variableCell.name] 126 | if variableCell.name in sel_var_idx_dims_values and space == sel_space and cur_idx_dims_values == sel_var_idx_dims_values[variableCell.name]: 127 | variableCell.update_selection_cds(space, var_x_range[(space, variableCell.name)].data['xmin'][0], var_x_range[(space, variableCell.name)].data['xmax'][0]) 128 | else: 129 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])) 130 | variableCell.update_reconstructed_cds(space) 131 | CellClearSelection.update_clear_selection_cds(variableCell, space) 132 | 133 | @staticmethod 134 | def update_cds_static(variableCell, space): 135 | """ 136 | Update source & samples cds in the static mode 137 | """ 138 | samples = variableCell.get_samples_for_cur_idx_dims_values(variableCell.name, space) 139 | inds,_ = variableCell.ic.get_sample_inds(space) 140 | if True in inds: 141 | sel_sample = samples[inds] 142 | variableCell.source[space].data = pmf(sel_sample) 143 | variableCell.samples[space].data = dict( x = sel_sample) 144 | else: 145 | variableCell.source[space].data = pmf(samples) 146 | variableCell.samples[space].data = dict( x = samples) 147 | # data cds 148 | data = variableCell.get_data_for_cur_idx_dims_values(variableCell.name) 149 | if data is not None: 150 | max_v = variableCell.get_max_prob(space) 151 | variableCell.data[space] = ColumnDataSource(data = dict(x = data, y = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(data)))) 152 | 153 | ## ONLY FOR INTERACTIVE CASE 154 | @staticmethod 155 | def selectionbox_callback(variableCell, space, event): 156 | """ 157 | Callback called when selection box is drawn. 158 | """ 159 | xmin = event.geometry['x0'] 160 | xmax = event.geometry['x1'] 161 | cur_idx_dims_values = {} 162 | if variableCell.name in variableCell.cur_idx_dims_values: 163 | cur_idx_dims_values = variableCell.cur_idx_dims_values[variableCell.name] 164 | variableCell.ic.set_selection(variableCell.name, space, (xmin, xmax), cur_idx_dims_values) 165 | for sp in variableCell.spaces: 166 | samples = variableCell.samples[sp].data['x'] 167 | variableCell.ic.add_space_threads(threading.Thread(target = partial(CellDiscreteHandler._selectionbox_space_thread, variableCell, sp, samples, xmin, xmax), daemon = True)) 168 | variableCell.ic.space_threads_join() 169 | 170 | @staticmethod 171 | def _selectionbox_space_thread(variableCell, space, samples, xmin, xmax): 172 | x_range = variableCell.ic.get_var_x_range(space, variableCell.name) 173 | xmin_list = x_range['xmin'] 174 | xmax_list = x_range['xmax'] 175 | if len(xmin_list): 176 | variableCell.update_selection_cds(space, xmin_list[0], xmax_list[0]) 177 | else: 178 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])) 179 | inds = find_indices(samples, lambda e: xmin <= e <= xmax, xmin, xmax) 180 | variableCell.ic.set_sel_var_inds(space, variableCell.name, inds) 181 | variableCell.compute_intersection_of_samples(space) 182 | variableCell.ic.selection_threads_join(space) 183 | 184 | @staticmethod 185 | def update_source_cds_interactive(variableCell, space): 186 | """ 187 | Updates source ColumnDataSource (cds). 188 | """ 189 | samples = variableCell.get_samples_for_cur_idx_dims_values(variableCell.name, space) 190 | variableCell.source[space].data = pmf(samples) 191 | variableCell.samples[space].data = dict( x = samples) 192 | 193 | @staticmethod 194 | def update_selection_cds_interactive(variableCell, space, xmin, xmax): 195 | """ 196 | Updates selection ColumnDataSource (cds). 197 | """ 198 | # Get kde points within [xmin,xmax] 199 | data = {} 200 | data['x'] = np.array([]) 201 | data['y'] = np.array([]) 202 | kde_indices = find_indices(variableCell.source[space].data['x'], lambda e: xmin <= e <= xmax, xmin, xmax) 203 | if len(kde_indices) == 0: 204 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])) 205 | return 206 | data['x'] = variableCell.source[space].data['x'][kde_indices] 207 | data['y'] = variableCell.source[space].data['y'][kde_indices] 208 | data['y0'] = np.asarray(len(data['x'])*[0]) 209 | variableCell.selection[space].data = data 210 | 211 | @staticmethod 212 | def update_reconstructed_cds_interactive(variableCell, space): 213 | """ 214 | Updates reconstructed ColumnDataSource (cds). 215 | """ 216 | samples = variableCell.samples[space].data['x'] 217 | inds,_ = variableCell.ic.get_sample_inds(space) 218 | # if Tulen(inds): 219 | sel_sample = samples[inds] 220 | variableCell.reconstructed[space].data = pmf(sel_sample) 221 | # data cds 222 | data = variableCell.get_data_for_cur_idx_dims_values(variableCell.name) 223 | if data is not None: 224 | max_v = variableCell.get_max_prob(space) 225 | variableCell.data[space] = ColumnDataSource(data = dict(x = data, y = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(data)))) 226 | # # else: 227 | # # variableCell.reconstructed[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])) 228 | # ##########TEST###################to be deleted 229 | # # max_v = variableCell.get_max_prob(space) 230 | # # if max_v!=-1: 231 | # # variableCell.samples[space].data['y'] = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(variableCell.samples[space].data['x'])) 232 | 233 | @staticmethod 234 | def clear_selection_callback(variableCell, space, event): 235 | """ 236 | Callback called when clear selection glyph is clicked. 237 | """ 238 | isIn = variableCell.clear_selection[space].data['isIn'] 239 | if 1 in isIn: 240 | variableCell.ic.set_var_x_range(space, variableCell.name, dict(xmin = np.array([]), xmax = np.array([]))) 241 | variableCell.ic.delete_sel_var_idx_dims_values(variableCell.name) 242 | for sp in variableCell.spaces: 243 | variableCell.ic.add_space_threads(threading.Thread(target = partial(CellDiscreteHandler._clear_selection_cds_update, variableCell, sp), daemon = True)) 244 | variableCell.ic.space_threads_join() 245 | 246 | @staticmethod 247 | def _clear_selection_cds_update(variableCell, space): 248 | x_range = variableCell.ic.get_var_x_range(space, variableCell.name) 249 | xmin_list = x_range['xmin'] 250 | xmax_list = x_range['xmax'] 251 | if len(xmin_list): 252 | variableCell.update_selection_cds(space, xmin_list[0], xmax_list[0]) 253 | else: 254 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])) 255 | variableCell.ic.delete_sel_var_inds(space, variableCell.name) 256 | variableCell.compute_intersection_of_samples(space) 257 | variableCell.ic.selection_threads_join(space) 258 | -------------------------------------------------------------------------------- /ipme/classes/cell/utils/cell_pred_check_handler.py: -------------------------------------------------------------------------------- 1 | from ipme.utils.constants import COLORS, BORDER_COLORS, PLOT_HEIGHT, PLOT_WIDTH, SIZING_MODE 2 | from ipme.utils.functions import get_finite_samples, get_samples_for_pred_check, get_hist_bins_range 3 | from ipme.utils.stats import hist 4 | 5 | import numpy as np 6 | 7 | from bokeh.plotting import figure 8 | from bokeh.models import ColumnDataSource, HoverTool, Legend 9 | from bokeh import events 10 | 11 | from functools import partial 12 | 13 | class CellPredCheckHandler: 14 | 15 | def __init__(self): 16 | pass 17 | 18 | @staticmethod 19 | def initialize_glyphs_interactive(predcheckCell, space): 20 | q = predcheckCell.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=predcheckCell.source[space], \ 21 | fill_color=COLORS[0], line_color="white", fill_alpha=1.0, name="full") 22 | seg = predcheckCell.plot[space].segment(x0 ='x0', y0 ='y0', x1='x1', y1='y1', source=predcheckCell.seg[space], \ 23 | color="black", line_width=2, name="seg") 24 | q_sel = predcheckCell.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=predcheckCell.reconstructed[space], \ 25 | fill_color=COLORS[1], line_color="white", fill_alpha=0.7, name="sel") 26 | ## Add Legends 27 | data = predcheckCell.seg[space].data['x0'] 28 | pvalue = predcheckCell.pvalue[space].data["pv"] 29 | if len(data) and len(pvalue): 30 | legend = Legend(items=[ (predcheckCell.func + "(obs) = " + format(data[0], '.2f'), [seg]), 31 | ("p-value = "+format(pvalue[0],'.4f'), [q]), 32 | ], location="top_left") 33 | predcheckCell.plot[space].add_layout(legend, 'above') 34 | ## Add Tooltips for hist 35 | #####TODO:Correct overlap of tooltips##### 36 | TOOLTIPS = [ 37 | ("top", "@top"), 38 | ("right","@right"), 39 | ("left","@left"), 40 | ] 41 | hover = HoverTool( tooltips=TOOLTIPS,renderers=[q,q_sel]) 42 | predcheckCell.plot[space].tools.append(hover) 43 | 44 | @staticmethod 45 | def initialize_glyphs_static(predcheckCell, space): 46 | q = predcheckCell.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=predcheckCell.source[space], \ 47 | fill_color=COLORS[0], line_color="white", fill_alpha=1.0, name="full") 48 | seg = predcheckCell.plot[space].segment(x0 ='x0', y0 ='y0', x1='x1', y1='y1', source=predcheckCell.seg[space], \ 49 | color="black", line_width=2, name="seg") 50 | ## Add Legends 51 | data = predcheckCell.seg[space].data['x0'] 52 | pvalue = predcheckCell.pvalue[space].data["pv"] 53 | if len(data) and len(pvalue): 54 | legend = Legend(items=[ (predcheckCell.func + "(obs) = " + format(data[0], '.2f'), [seg]), 55 | ("p-value = "+format(pvalue[0],'.4f'), [q]), 56 | ], location="top_left") 57 | predcheckCell.plot[space].add_layout(legend, 'above') 58 | ## Add Tooltips for hist 59 | #####TODO:Correct overlap of tooltips##### 60 | TOOLTIPS = [ 61 | ("top", "@top"), 62 | ("right","@right"), 63 | ("left","@left"), 64 | ] 65 | hover = HoverTool( tooltips=TOOLTIPS,renderers=[q]) 66 | predcheckCell.plot[space].tools.append(hover) 67 | 68 | @staticmethod 69 | def initialize_fig(predcheckCell, space): 70 | predcheckCell.plot[space] = figure(tools = "wheel_zoom,reset", toolbar_location = 'right', plot_width = PLOT_WIDTH, plot_height = PLOT_HEIGHT, sizing_mode = SIZING_MODE) 71 | predcheckCell.plot[space].toolbar.logo = None 72 | predcheckCell.plot[space].yaxis.visible = False 73 | predcheckCell.plot[space].xaxis.axis_label = predcheckCell.func + "(" + predcheckCell.name + ")" 74 | predcheckCell.plot[space].border_fill_color = BORDER_COLORS[0] 75 | predcheckCell.plot[space].xaxis[0].ticker.desired_num_ticks = 3 76 | 77 | @staticmethod 78 | def initialize_fig_interactive(predcheckCell, space): 79 | CellPredCheckHandler.initialize_fig(predcheckCell, space) 80 | predcheckCell.ic.sample_inds[space].on_change('data', partial(predcheckCell.sample_inds_callback, space)) 81 | 82 | @staticmethod 83 | def initialize_fig_static(predcheckCell, space): 84 | CellPredCheckHandler.initialize_fig(predcheckCell, space) 85 | 86 | @staticmethod 87 | def initialize_cds(predcheckCell, space): 88 | ## ColumnDataSource for full sample set 89 | data, samples = predcheckCell.get_samples_for_cur_idx_dims_values(space) 90 | predcheckCell.samples[space] = ColumnDataSource(data=dict(x=samples)) 91 | #data func 92 | if ~np.isfinite(data).all(): 93 | data = get_finite_samples(data) 94 | data_func = get_samples_for_pred_check(data, predcheckCell.func) 95 | #samples func 96 | if ~np.isfinite(samples).all(): 97 | samples = get_finite_samples(samples) 98 | samples_func = get_samples_for_pred_check(samples, predcheckCell.func) 99 | if samples_func.size: 100 | #pvalue 101 | pv = np.count_nonzero(samples_func>=data_func) / len(samples_func) 102 | #histogram 103 | type = predcheckCell._data.get_var_dist_type(predcheckCell.name) 104 | if type == "Continuous": 105 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type) 106 | else: 107 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type, ref_length = None, ref_values=np.unique(samples.flatten())) 108 | 109 | his, edges = hist(samples_func, bins=bins, range=range, density = True) 110 | #cds 111 | predcheckCell.pvalue[space] = ColumnDataSource(data=dict(pv=[pv])) 112 | predcheckCell.source[space] = ColumnDataSource(data=dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his)))) 113 | predcheckCell.seg[space] = ColumnDataSource(data=dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()])) 114 | else: 115 | predcheckCell.pvalue[space] = ColumnDataSource(data=dict(pv=[])) 116 | predcheckCell.source[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[])) 117 | predcheckCell.seg[space] = ColumnDataSource(data=dict(x0=[], x1=[], y0=[], y1=[])) 118 | predcheckCell.ic.initialize_sample_inds(space, dict(inds = [False]*len(predcheckCell.samples[space].data['x'])), dict(non_inds = [True]*len(predcheckCell.samples[space].data['x']))) 119 | 120 | @staticmethod 121 | def initialize_cds_interactive(predcheckCell, space): 122 | CellPredCheckHandler.initialize_cds(predcheckCell, space) 123 | ## ColumnDataSource for restricted sample set 124 | predcheckCell.pvalue_rec[space] = ColumnDataSource(data=dict(pv=[])) 125 | predcheckCell.pvalue_rec[space].on_change('data', partial(predcheckCell.update_legends, space)) 126 | predcheckCell.reconstructed[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[])) 127 | 128 | @staticmethod 129 | def initialize_cds_static(predcheckCell, space): 130 | CellPredCheckHandler.initialize_cds(predcheckCell, space) 131 | 132 | @staticmethod 133 | def update_cds_interactive(predcheckCell, space): 134 | """ 135 | Updates interaction-related ColumnDataSources (cds). 136 | """ 137 | predcheckCell.update_sel_samples_cds(space) 138 | 139 | @staticmethod 140 | def update_cds_static(predcheckCell, space): 141 | """ 142 | Update samples cds in the static mode 143 | """ 144 | ## ColumnDataSource for full sample set 145 | data, samples = predcheckCell.get_samples_for_cur_idx_dims_values(space) 146 | inds, _ = predcheckCell.ic.get_sample_inds(space) 147 | if True in inds: 148 | samples = samples[inds] 149 | predcheckCell.samples[space].data = dict(x=samples) 150 | #data func 151 | if ~np.isfinite(data).all(): 152 | data = get_finite_samples(data) 153 | data_func = get_samples_for_pred_check(data, predcheckCell.func) 154 | #samples func 155 | if ~np.isfinite(samples).all(): 156 | samples = get_finite_samples(samples) 157 | samples_func = get_samples_for_pred_check(samples, predcheckCell.func) 158 | if samples_func.size: 159 | #pvalue 160 | pv = np.count_nonzero(samples_func>=data_func) / len(samples_func) 161 | #histogram 162 | type = predcheckCell._data.get_var_dist_type(predcheckCell.name) 163 | if type == "Continuous": 164 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type) 165 | else: 166 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type, ref_length = None, ref_values=np.unique(samples.flatten())) 167 | 168 | his, edges = hist(samples_func, bins=bins, range=range, density = True) 169 | #cds 170 | predcheckCell.pvalue[space].data = dict(pv=[pv]) 171 | predcheckCell.source[space].data = dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his))) 172 | predcheckCell.seg[space].data = dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()]) 173 | else: 174 | predcheckCell.pvalue[space].data = dict(pv=[]) 175 | predcheckCell.source[space].data = dict(left=[], top=[], right=[], bottom=[]) 176 | predcheckCell.seg[space].data = dict(x0=[], x1=[], y0=[], y1=[]) 177 | 178 | ## ONLY FOR INTERACTIVE CASE 179 | @staticmethod 180 | def update_source_cds_interactive(predcheckCell, space): 181 | """ 182 | Updates samples ColumnDataSource (cds). 183 | """ 184 | ## ColumnDataSource for full sample set 185 | data, samples = predcheckCell.get_samples_for_cur_idx_dims_values(space) 186 | predcheckCell.samples[space].data = dict(x=samples) 187 | #data func 188 | if ~np.isfinite(data).all(): 189 | data = get_finite_samples(data) 190 | data_func = get_samples_for_pred_check(data, predcheckCell.func) 191 | #samples func 192 | if ~np.isfinite(samples).all(): 193 | samples = get_finite_samples(samples) 194 | samples_func = get_samples_for_pred_check(samples, predcheckCell.func) 195 | if samples_func.size: 196 | #pvalue 197 | pv = np.count_nonzero(samples_func>=data_func) / len(samples_func) 198 | #histogram 199 | type = predcheckCell._data.get_var_dist_type(predcheckCell.name) 200 | if type == "Continuous": 201 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type) 202 | else: 203 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type, ref_length = None, ref_values=np.unique(samples.flatten())) 204 | 205 | his, edges = hist(samples_func, bins=bins, range=range, density = True) 206 | #cds 207 | predcheckCell.pvalue[space].data = dict(pv=[pv]) 208 | predcheckCell.source[space].data = dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his))) 209 | predcheckCell.seg[space].data = dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()]) 210 | else: 211 | predcheckCell.pvalue[space].data = data=dict(pv=[]) 212 | predcheckCell.source[space].data = dict(left=[], top=[], right=[], bottom=[]) 213 | predcheckCell.seg[space].data = dict(x0=[], x1=[], y0=[], y1=[]) 214 | 215 | @staticmethod 216 | def update_sel_samples_cds_interactive(predcheckCell, space): 217 | """ 218 | Updates reconstructed ColumnDataSource (cds). 219 | """ 220 | samples = predcheckCell.samples[space].data['x'] 221 | max_full_hist = predcheckCell.source[space].data['top'].max() 222 | if samples.size: 223 | inds,_ = predcheckCell.ic.get_sample_inds(space) 224 | if True in inds: 225 | sel_sample = samples[inds] 226 | if ~np.isfinite(sel_sample).all(): 227 | sel_sample = get_finite_samples(sel_sample) 228 | sel_sample_func = get_samples_for_pred_check(sel_sample, predcheckCell.func) 229 | #data func 230 | data_func = predcheckCell.seg[space].data['x0'][0] 231 | #pvalue in restricted space 232 | sel_pv = np.count_nonzero(sel_sample_func >= data_func) / len(sel_sample_func) 233 | #compute updated histogram 234 | min_p = predcheckCell.source[space].data['left'][0] 235 | max_p = predcheckCell.source[space].data['right'][-1] 236 | min_c = sel_sample_func.min() 237 | max_c = sel_sample_func.max() 238 | if min_c < min_p or max_c > max_p: 239 | ref_len = predcheckCell.source[space].data['right'][0] - min_p 240 | bins, range = get_hist_bins_range(sel_sample_func, predcheckCell.func, predcheckCell._type, ref_length=ref_len) 241 | else: 242 | range = (min_p,max_p) 243 | bins = len(predcheckCell.source[space].data['right']) 244 | his, edges = hist(sel_sample_func, bins=bins, range=range) 245 | ##max selected hist 246 | max_sel_hist = his.max() 247 | #update reconstructed cds 248 | predcheckCell.pvalue_rec[space].data = dict(pv=[sel_pv]) 249 | predcheckCell.reconstructed[space].data = dict(left=edges[:-1], top=his, right =edges[1:], bottom=np.zeros(len(his))) 250 | predcheckCell.seg[space].data['y1'] = [max_sel_hist + 0.1 * max_sel_hist] 251 | else: 252 | predcheckCell.pvalue_rec[space].data = dict(pv=[]) 253 | predcheckCell.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[]) 254 | predcheckCell.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist] 255 | else: 256 | predcheckCell.pvalue_rec[space].data = dict(pv=[]) 257 | predcheckCell.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[]) 258 | predcheckCell.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist] 259 | 260 | 261 | -------------------------------------------------------------------------------- /ipme/classes/cell/utils/cell_scatter_handler.py: -------------------------------------------------------------------------------- 1 | from logging import error 2 | from ipme.utils.constants import COLORS, BORDER_COLORS, PLOT_HEIGHT, PLOT_WIDTH, SIZING_MODE 3 | 4 | from bokeh.models import ColumnDataSource, HoverTool 5 | from bokeh.plotting import figure 6 | 7 | import arviz as az 8 | from functools import partial 9 | 10 | class CellScatterHandler: 11 | 12 | def __init__(self): 13 | pass 14 | 15 | @staticmethod 16 | def initialize_glyphs_interactive(scatterCell, space): 17 | so = scatterCell.plot[space].circle(x="x", y="y", source = scatterCell.non_sel_samples[space], size=4, color=COLORS[0], line_color=None, fill_alpha = 0.1) 18 | scatterCell.plot[space].patches(xs="x", ys="y", source = scatterCell.contours[space], line_color="line_color", fill_alpha="fill_alpha") 19 | re = scatterCell.plot[space].circle(x="x", y="y", source = scatterCell.sel_samples[space], size=4, color=COLORS[1], line_color=None, fill_alpha = 0.4, name="re") 20 | ##Tooltips 21 | TOOLTIPS = [("x", "@x"), ("y","@y"),] 22 | hover = HoverTool( tooltips = TOOLTIPS, renderers = [so,re], mode = 'mouse') 23 | scatterCell.plot[space].tools.append(hover) 24 | 25 | @staticmethod 26 | def initialize_glyphs_static(scatterCell, space): 27 | so = scatterCell.plot[space].circle(x="x", y="y", source = scatterCell.samples[space], size=7, color=COLORS[0], line_color=None, fill_alpha = 0.1) 28 | scatterCell.plot[space].patches(xs="x", ys="y", source = scatterCell.contours[space], line_color="line_color", fill_alpha="fill_alpha") 29 | ##Tooltips 30 | TOOLTIPS = [("x", "@x"), ("y","@y"),] 31 | hover = HoverTool( tooltips = TOOLTIPS, renderers = [so], mode = 'mouse') 32 | scatterCell.plot[space].tools.append(hover) 33 | 34 | @staticmethod 35 | def initialize_fig(scatterCell, space): 36 | var1 = scatterCell.vars[0] 37 | var2 = scatterCell.vars[1] 38 | scatterCell.plot[space] = figure(x_range = scatterCell.x_range[var2][space], y_range = scatterCell.x_range[var1][space], tools = "wheel_zoom,reset,box_zoom", toolbar_location = 'right', 39 | plot_width = PLOT_WIDTH, plot_height = PLOT_HEIGHT, sizing_mode = SIZING_MODE)#tools = [], toolbar_location = None 40 | scatterCell.plot[space].border_fill_color = BORDER_COLORS[0] 41 | scatterCell.plot[space].min_border = 15 42 | scatterCell.plot[space].xaxis.axis_label = var2 43 | scatterCell.plot[space].yaxis.axis_label = var1 44 | # scatterCell.plot[space].yaxis.visible = False 45 | scatterCell.plot[space].toolbar.logo = None 46 | scatterCell.plot[space].xaxis[0].ticker.desired_num_ticks = 3 47 | 48 | @staticmethod 49 | def initialize_fig_interactive(scatterCell, space): 50 | CellScatterHandler.initialize_fig(scatterCell, space) 51 | ##on_change 52 | scatterCell.ic.sample_inds_update[space].on_change('data', partial(scatterCell.sample_inds_callback, space)) 53 | 54 | @staticmethod 55 | def initialize_fig_static(scatterCell, space): 56 | CellScatterHandler.initialize_fig(scatterCell, space) 57 | ##on_change 58 | scatterCell.ic.sample_inds_update[space].on_change('data', partial(scatterCell.sample_inds_callback, space)) 59 | 60 | @staticmethod 61 | def initialize_cds(scatterCell, space): 62 | var1 = scatterCell.vars[0] 63 | var2 = scatterCell.vars[1] 64 | samples1 = scatterCell.get_samples_for_cur_idx_dims_values(var1, space) 65 | samples2 = scatterCell.get_samples_for_cur_idx_dims_values(var2, space) 66 | scatterCell.samples[space] = ColumnDataSource(data = dict(x = samples2, y = samples1)) 67 | patch_x, patch_y = CellScatterHandler.get_contours(samples2, samples1) 68 | scatterCell.contours[space] = ColumnDataSource(data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x))) 69 | scatterCell.ic.initialize_sample_inds(space, dict(inds = [False]*len(scatterCell.samples[space].data['x'])), dict(non_inds = [True]*len(scatterCell.samples[space].data['x']))) 70 | 71 | @staticmethod 72 | def initialize_cds_interactive(scatterCell, space): 73 | CellScatterHandler.initialize_cds(scatterCell, space) 74 | inds, non_inds = scatterCell.ic.get_sample_inds(space) 75 | sel_y = scatterCell.samples[space].data['y'][inds] 76 | non_sel_y = scatterCell.samples[space].data['y'][non_inds] 77 | sel_samples = scatterCell.samples[space].data['x'][inds] 78 | non_sel_samples = scatterCell.samples[space].data['x'][non_inds] 79 | scatterCell.sel_samples[space] = ColumnDataSource(data = dict( x = sel_samples, y = sel_y)) 80 | scatterCell.non_sel_samples[space] = ColumnDataSource(data = dict( x = non_sel_samples, y = non_sel_y)) 81 | 82 | @staticmethod 83 | def initialize_cds_static(scatterCell, space): 84 | CellScatterHandler.initialize_cds(scatterCell, space) 85 | #########TEST########### 86 | 87 | @staticmethod 88 | def update_cds_interactive(scatterCell, space): 89 | """ 90 | Updates interaction-related ColumnDataSources (cds). 91 | """ 92 | scatterCell.update_sel_samples_cds(space) 93 | 94 | @staticmethod 95 | def update_cds_static(scatterCell, space): 96 | """ 97 | Update samples cds in the static mode 98 | """ 99 | var1 = scatterCell.vars[0] 100 | var2 = scatterCell.vars[1] 101 | samples1 = scatterCell.get_samples_for_cur_idx_dims_values(var1, space) 102 | samples2 = scatterCell.get_samples_for_cur_idx_dims_values(var2, space) 103 | inds, _ = scatterCell.ic.get_sample_inds(space) 104 | if True in inds: 105 | sel_sample1 = samples1[inds] 106 | sel_sample2 = samples2[inds] 107 | scatterCell.samples[space].data = dict(x=sel_sample2, y=sel_sample1) 108 | patch_x, patch_y = CellScatterHandler.get_contours(sel_sample2, sel_sample1) 109 | scatterCell.contours[space].data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x)) 110 | else: 111 | scatterCell.samples[space].data = dict(x=samples2, y=samples1) 112 | patch_x, patch_y = CellScatterHandler.get_contours(samples2, samples1) 113 | scatterCell.contours[space].data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x)) 114 | 115 | ## ONLY FOR INTERACTIVE CASE 116 | @staticmethod 117 | def update_source_cds_interactive(scatterCell, space): 118 | """ 119 | Updates samples ColumnDataSource (cds). 120 | """ 121 | var1 = scatterCell.vars[0] 122 | var2 = scatterCell.vars[1] 123 | samples1 = scatterCell.get_samples_for_cur_idx_dims_values(var1, space) 124 | samples2 = scatterCell.get_samples_for_cur_idx_dims_values(var2, space) 125 | scatterCell.samples[space].data = dict(x=samples2, y=samples1) 126 | patch_x, patch_y = CellScatterHandler.get_contours(samples2, samples1) 127 | scatterCell.contours[space].data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x)) 128 | 129 | @staticmethod 130 | def update_sel_samples_cds_interactive(scatterCell, space): 131 | """ 132 | Updates reconstructed ColumnDataSource (cds). 133 | """ 134 | samples1 = scatterCell.samples[space].data['x'] 135 | samples2 = scatterCell.samples[space].data['y'] 136 | inds, non_inds = scatterCell.ic.get_sample_inds(space) 137 | # update sel samples 138 | sel_sample1 = samples1[inds] 139 | sel_sample2 = samples2[inds] 140 | scatterCell.sel_samples[space].data = dict(x=sel_sample1, y=sel_sample2) 141 | # update transparency 142 | for re in scatterCell.plot[space].renderers: 143 | if re.name == "re" and len(sel_sample1) < 50: 144 | re.glyph.fill_alpha = 0.8 145 | elif re.name == "re" and len(sel_sample1) < 200: 146 | re.glyph.fill_alpha = 0.5 147 | elif re.name == "re": 148 | re.glyph.fill_alpha = 0.2 149 | # update non_sel samples 150 | non_sel_sample1 = samples1[non_inds] 151 | non_sel_sample2 = samples2[non_inds] 152 | scatterCell.non_sel_samples[space].data = dict(x=non_sel_sample1, y=non_sel_sample2) 153 | 154 | # @staticmethod 155 | # def set_transparency(samples1, samples2, sel_samples1, sel_samples2): 156 | # s_x_min = samples1.min() 157 | # s_x_max = samples1.max() 158 | # s_y_min = samples2.min() 159 | # s_y_max = samples2.max() 160 | 161 | # sels_x_min = sel_samples1.min() 162 | # sels_x_max = sel_samples1.max() 163 | # sels_y_min = sel_samples2.min() 164 | # sels_y_max = sel_samples2.max() 165 | 166 | @staticmethod 167 | def get_contours(x, y): 168 | try: 169 | _, contour_glyphs = az.plot_kde(x, y, 170 | # hdi_probs=[0.393, 0.865, 0.989], # 1, 2 and 3 sigma contours 171 | contour_kwargs={"line_color":"black", "line_alpha":1}, 172 | contourf_kwargs={"fill_alpha": 0, "cmap": "viridis"}, 173 | backend="bokeh", return_glyph = True, show = False ) 174 | patch_x = [] 175 | patch_y = [] 176 | for renderer in contour_glyphs: 177 | patch_x.append(renderer.data_source.data['x'].tolist()) 178 | patch_y.append(renderer.data_source.data['y'].tolist()) 179 | return patch_x, patch_y 180 | except ValueError: 181 | return [],[] -------------------------------------------------------------------------------- /ipme/classes/cell/utils/cell_widgets.py: -------------------------------------------------------------------------------- 1 | from ipme.utils.functions import get_dim_names_options 2 | from .global_reset import GlobalReset 3 | 4 | from bokeh.models.widgets import Select, Button 5 | 6 | from functools import partial 7 | import threading 8 | 9 | class CellWidgets: 10 | 11 | def __init__(self): 12 | pass 13 | 14 | @staticmethod 15 | def initialize_widgets(cell): 16 | for space in cell.spaces: 17 | cell.widgets[space] = {} 18 | for var in cell.idx_dims: 19 | for _, d_dim in cell.idx_dims[var].items(): 20 | n1, n2, opt1, opt2 = get_dim_names_options(d_dim) 21 | if n1 not in cell.widgets[space]: 22 | cell.widgets[space][n1] = Select(title = n1, value = opt1[0], options = opt1) 23 | cell.widgets[space][n1].on_change("value", partial(cell.widget_callback, w_title = n1, space = space)) 24 | if var not in cell.cur_idx_dims_values: 25 | cell.cur_idx_dims_values[var] = {} 26 | if n1 not in cell.cur_idx_dims_values[var]: 27 | inds = [i for i,v in enumerate(d_dim.values) if v == opt1[0]] 28 | cell.cur_idx_dims_values[var][n1] = inds 29 | if n2: 30 | if n2 not in cell.widgets[space]: 31 | cell.widgets[space][n2] = Select(title = n2, value = opt2[0], options=opt2) 32 | cell.widgets[space][n2].on_change("value", partial(cell.widget_callback, w_title = n2, space = space)) 33 | cell.ic.idx_widgets_mapping(space, d_dim, n1, n2) 34 | if n2 not in cell.cur_idx_dims_values[var]: 35 | cell.cur_idx_dims_values[var][n2] = [0] 36 | 37 | # def _widget_callback(self, attr, old, new, w_title, space): 38 | # """ 39 | # Callback called when an indexing dimension is set to 40 | # a new coordinate (e.g through indexing dimensions widgets). 41 | # """ 42 | # if old == new: 43 | # return 44 | # self._ic.add_widget_threads(threading.Thread(target=partial(self._widget_callback_thread, new, w_title, space), daemon=True)) 45 | # self._ic._widget_lock_event.set() 46 | # 47 | # def _widget_callback_thread(self, new, w_title, space): 48 | # inds = -1 49 | # w2_title = "" 50 | # values = [] 51 | # w1_w2_idx_mapping = self._ic._get_w1_w2_idx_mapping() 52 | # w2_w1_val_mapping = self._ic._get_w2_w1_val_mapping() 53 | # w2_w1_idx_mapping = self._ic._get_w2_w1_idx_mapping() 54 | # widgets = self._widgets[space] 55 | # if space in w1_w2_idx_mapping and \ 56 | # w_title in w1_w2_idx_mapping[space]: 57 | # for w2_title in w1_w2_idx_mapping[space][w_title]: 58 | # name = w_title+"_idx_"+w2_title 59 | # if name in self._idx_dims: 60 | # values = self._idx_dims[name].values 61 | # elif w_title in self._idx_dims: 62 | # values = self._idx_dims[w_title].values 63 | # elif space in w2_w1_idx_mapping and \ 64 | # w_title in w2_w1_idx_mapping[space]: 65 | # for w1_idx in w2_w1_idx_mapping[space][w_title]: 66 | # w1_value = widgets[w1_idx].value 67 | # values = w2_w1_val_mapping[space][w_title][w1_value] 68 | # inds = [i for i,v in enumerate(values) if v == new] 69 | # if inds == -1 or len(inds) == 0: 70 | # return 71 | # self._cur_idx_dims_values[w_title] = inds 72 | # if w2_title and w2_title in self._cur_idx_dims_values: 73 | # self._cur_idx_dims_values[w2_title] = [0] 74 | # if self._mode == 'i': 75 | # self._update_source_cds(space) 76 | # self._ic._set_global_update(True) 77 | # self._update_cds_interactive(space) 78 | # elif self._mode == 's': 79 | # self._update_cds_static(space) 80 | 81 | @staticmethod 82 | def _widget_callback_int(variableCell, attr, old, new, w_title, space): 83 | """ 84 | Callback called when an indexing dimension is set to 85 | a new coordinate (e.g through indexing dimensions widgets). 86 | """ 87 | if old == new: 88 | return 89 | variableCell.ic.add_widget_threads(threading.Thread(target = partial(CellWidgets._widget_callback_thread_inter, variableCell, new, w_title, space), daemon = True)) 90 | variableCell.ic.widget_lock_event.set() 91 | 92 | @staticmethod 93 | def _widget_callback_static(variableCell, attr, old, new, w_title, space): 94 | """ 95 | Callback called when an indexing dimension is set to 96 | a new coordinate (e.g through indexing dimensions widgets). 97 | """ 98 | if old == new: 99 | return 100 | variableCell.ic.add_widget_threads(threading.Thread(target = partial(CellWidgets._widget_callback_thread_static, variableCell, new, w_title, space), daemon = True)) 101 | variableCell.ic.widget_lock_event.set() 102 | 103 | def _widget_callback_thread_inter(variableCell, new, w_title, space): 104 | w1_w2_idx_mapping = variableCell.ic.get_w1_w2_idx_mapping() 105 | w2_w1_val_mapping = variableCell.ic.get_w2_w1_val_mapping() 106 | w2_w1_idx_mapping = variableCell.ic.get_w2_w1_idx_mapping() 107 | widgets = variableCell.widgets[space] 108 | for var in variableCell.idx_dims: 109 | inds = -1 110 | w2_title = "" 111 | values = [] 112 | if space in w1_w2_idx_mapping and w_title in w1_w2_idx_mapping[space]: 113 | for w2_title in w1_w2_idx_mapping[space][w_title]:##review this 114 | name = w_title+"_idx_"+w2_title 115 | if name in variableCell.idx_dims[var]: 116 | values = variableCell.idx_dims[var][name].values 117 | if len(values) == 0 and w_title in variableCell.idx_dims[var]: 118 | values = variableCell.idx_dims[var][w_title].values 119 | if len(values) == 0 and space in w2_w1_idx_mapping and w_title in w2_w1_idx_mapping[space]: 120 | for w1_idx in w2_w1_idx_mapping[space][w_title]: 121 | w1_value = widgets[w1_idx].value 122 | values = w2_w1_val_mapping[space][w_title][w1_value] 123 | inds = [i for i,v in enumerate(values) if v == new] 124 | if inds == -1 or len(inds) == 0: 125 | return 126 | if w_title in variableCell.cur_idx_dims_values[var]: 127 | variableCell.cur_idx_dims_values[var][w_title] = inds 128 | if w2_title and w2_title in variableCell.cur_idx_dims_values[var]: 129 | variableCell.cur_idx_dims_values[var][w2_title] = [0] 130 | variableCell.update_source_cds(space) 131 | variableCell.ic.set_global_update(True) 132 | variableCell.update_cds(space) 133 | 134 | def _widget_callback_thread_static(variableCell, new, w_title, space): 135 | w1_w2_idx_mapping = variableCell.ic.get_w1_w2_idx_mapping() 136 | w2_w1_val_mapping = variableCell.ic.get_w2_w1_val_mapping() 137 | w2_w1_idx_mapping = variableCell.ic.get_w2_w1_idx_mapping() 138 | widgets = variableCell.widgets[space] 139 | for var in variableCell.idx_dims: 140 | inds = -1 141 | w2_title = "" 142 | values = [] 143 | if space in w1_w2_idx_mapping and w_title in w1_w2_idx_mapping[space]: 144 | for w2_title in w1_w2_idx_mapping[space][w_title]: 145 | name = w_title+"_idx_"+w2_title 146 | if name in variableCell.idx_dims[var]: 147 | values = variableCell.idx_dims[var][name].values 148 | if len(values) == 0 and w_title in variableCell.idx_dims[var]: 149 | values = variableCell.idx_dims[var][w_title].values 150 | # elif w_title in variableCell.idx_dims: 151 | # values = variableCell.idx_dims[w_title].values 152 | if len(values) == 0 and space in w2_w1_idx_mapping and w_title in w2_w1_idx_mapping[space]: 153 | # elif space in w2_w1_idx_mapping and w_title in w2_w1_idx_mapping[space]: 154 | for w1_idx in w2_w1_idx_mapping[space][w_title]: 155 | w1_value = widgets[w1_idx].value 156 | values = w2_w1_val_mapping[space][w_title][w1_value] 157 | inds = [i for i,v in enumerate(values) if v == new] 158 | if inds == -1 or len(inds) == 0: 159 | return 160 | if w_title in variableCell.cur_idx_dims_values[var]: 161 | variableCell.cur_idx_dims_values[var][w_title] = inds 162 | if w2_title and w2_title in variableCell.cur_idx_dims_values[var]: 163 | variableCell.cur_idx_dims_values[var][w2_title] = [0] 164 | variableCell.update_cds(space) 165 | 166 | @staticmethod 167 | def widget_callback_interactive(variableCell, attr, old, new, w_title, space): 168 | CellWidgets._widget_callback_int(variableCell, attr, old, new, w_title, space) 169 | 170 | @staticmethod 171 | def widget_callback_static(variableCell, attr, old, new, w_title, space): 172 | CellWidgets._widget_callback_static(variableCell, attr, old, new, w_title, space) 173 | 174 | @staticmethod 175 | def link_cells_widgets(grid): 176 | for c_id, cell in grid.cells.items(): 177 | cell_spaces = cell.get_spaces() 178 | for space in cell_spaces: 179 | for w_id, w in cell.get_widgets_in_space(space).items(): 180 | if w_id in grid.cells_widgets: 181 | if space in grid.cells_widgets[w_id]: 182 | grid.cells_widgets[w_id][space].append(c_id) 183 | else: 184 | grid.cells_widgets[w_id][space] = [c_id] 185 | ## Every new widget is linked to the corresponding widget (of same name) 186 | ## of the 1st space in grid.cells_widgets[w_id] 187 | ## Find target cell to link with current cell 188 | f_space = list(grid.cells_widgets[w_id].keys())[0] 189 | CellWidgets._link_widget_to_target(grid, w, w_id, f_space) 190 | else: 191 | grid.cells_widgets[w_id] = {} 192 | grid.cells_widgets[w_id][space] = [c_id] 193 | f_space = list(grid.cells_widgets[w_id].keys())[0] 194 | if f_space != space: 195 | CellWidgets._link_widget_to_target(grid, w, w_id, f_space) 196 | else: 197 | w = grid.cells[c_id].get_widget(space, w_id) 198 | w.on_change('value', partial(grid.ic.menu_item_click_callback, grid, space, w_id)) 199 | 200 | @staticmethod 201 | def _link_widget_to_target(grid, w, w_id, f_space): 202 | if len(grid.cells_widgets[w_id][f_space]): 203 | t_c_id = grid.cells_widgets[w_id][f_space][0] 204 | t_w = grid.cells[t_c_id].get_widget(f_space, w_id) 205 | if t_w is not None and hasattr(t_w,'js_link'): 206 | t_w.js_link('value', w, 'value') 207 | 208 | @staticmethod 209 | def set_plotted_widgets_interactive(grid): 210 | grid.plotted_widgets = {} 211 | for w_id, space_widgets_dict in grid.cells_widgets.items(): 212 | w_spaces = list(space_widgets_dict.keys()) 213 | if len(w_spaces): 214 | f_space = w_spaces[0] 215 | f_w_list = space_widgets_dict[f_space] 216 | if len(f_w_list): 217 | c_id = f_w_list[0] 218 | for space in w_spaces: 219 | if space not in grid.plotted_widgets: 220 | grid.plotted_widgets[space] = {} 221 | grid.plotted_widgets[space][w_id] = grid.cells[c_id].get_widget(f_space, w_id) 222 | b = Button(label='Reset Diagram', button_type="primary") 223 | b.on_click(partial(GlobalReset.global_reset_callback, grid)) 224 | for space in grid.get_grids(): 225 | if space not in grid.plotted_widgets: 226 | grid.plotted_widgets[space] = {} 227 | grid.plotted_widgets[space]["resetButton"] = b 228 | 229 | @staticmethod 230 | def set_plotted_widgets_static(grid): 231 | grid.plotted_widgets = {} 232 | for w_id, space_widgets_dict in grid.cells_widgets.items(): 233 | w_spaces = list(space_widgets_dict.keys()) 234 | if len(w_spaces): 235 | f_space = w_spaces[0] 236 | f_w_list = space_widgets_dict[f_space] 237 | if len(f_w_list): 238 | c_id = f_w_list[0] 239 | for space in w_spaces: 240 | if space not in grid.plotted_widgets: 241 | grid.plotted_widgets[space] = {} 242 | grid.plotted_widgets[space][w_id] = grid.cells[c_id].get_widget(f_space, w_id) 243 | for space in grid.get_grids(): 244 | if space not in grid.plotted_widgets: 245 | grid.plotted_widgets[space] = {} 246 | -------------------------------------------------------------------------------- /ipme/classes/cell/utils/global_reset.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from functools import partial 3 | 4 | class GlobalReset: 5 | 6 | def __init__(self): 7 | pass 8 | 9 | @staticmethod 10 | def global_reset_callback(grid, event): 11 | grid.ic.reset_sel_var_inds() 12 | grid.ic.reset_sel_space() 13 | grid.ic.reset_sel_var_idx_dims_values() 14 | grid.ic.reset_var_x_range() 15 | grid.ic.set_global_update(True) 16 | for sp in grid.get_grids(): 17 | grid.ic.add_space_threads(threading.Thread(target = partial(GlobalReset._global_reset_thread, grid.ic, sp), daemon = True)) 18 | grid.ic.space_threads_join() 19 | grid.ic.set_global_update(False) 20 | 21 | @staticmethod 22 | def _global_reset_thread(ic, space): 23 | ic.reset_sample_inds(space) 24 | ic.selection_threads_join(space) 25 | -------------------------------------------------------------------------------- /ipme/classes/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/data/__init__.py -------------------------------------------------------------------------------- /ipme/classes/data/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | 4 | from ...interfaces.data_interface import Data_Interface 5 | from .dimension import Dimension 6 | 7 | class Data(Data_Interface): 8 | 9 | def __init__(self, inference_path): 10 | """ 11 | Parameters: 12 | -------- 13 | inference_path A String of the inference data file path. 14 | Sets: 15 | -------- 16 | _inferencedata A structure of the inference data. 17 | _graph A structure of the model's variables graph data. 18 | _observed_variables A List of Strings of the observed variables. 19 | _idx_dimensions A List of Dimension Objects defining 20 | the indexing dimensions of the data. 21 | _spaces A List of Strings in {'prior', 'posterior'} of all 22 | the available MCMC sample spaces in the inference data 23 | """ 24 | Data_Interface.__init__(self, inference_path) 25 | 26 | def _load_inference_data(self, datapath): 27 | """ 28 | Inference data is retrieved from memory in .npz format 29 | 30 | Parameters: 31 | -------- 32 | datapath A String of the inference data path. 33 | Returns: 34 | -------- 35 | A Dictionary of inference data 36 | """ 37 | try: 38 | return np.load(datapath) 39 | except IOError: 40 | print("File %s cannot be loaded" % datapath) 41 | return None 42 | 43 | def _get_graph(self): 44 | """ 45 | A graph of the probabilistic model in json format is 46 | retrieved from inference data. 47 | 48 | Returns: 49 | -------- 50 | A Dictionary of graph data 51 | """ 52 | graph="" 53 | try: 54 | header = self._inferencedata['header.json'] 55 | header_js = json.loads(header) 56 | graph = header_js["inference_data"]["sample_stats"]["attrs"]["graph"] 57 | except KeyError: 58 | print("Inference_data has no key 'header.json'") 59 | return None 60 | return json.loads(graph.replace("'", "\"")) 61 | #return json.loads(graph) 62 | 63 | def _get_obeserved_variables(self): 64 | return [k for k,v in self._graph.items() if v['type'] == 'observed'] 65 | 66 | def _get_idx_dimensions(self): 67 | """ 68 | Returns: 69 | -------- 70 | idx_dimensions A Dict {: List of Dimension objects} 71 | """ 72 | idx_dimensions={} 73 | for k,v in self._graph.items(): 74 | if not k.endswith("__"): 75 | vdims = v["dims"] 76 | vcoords = v["coords"] 77 | if len(vdims): 78 | idx_dimensions[v["name"]] = {} 79 | for d in vdims: 80 | if d in vcoords: 81 | idx_dimensions[v["name"]][d] = Dimension(d, values = vcoords[d]) 82 | return idx_dimensions 83 | 84 | def _get_spaces_from_data(self): 85 | """ 86 | Returns: 87 | -------- 88 | spaces A List of Strings in {'prior','posterior','prior_predictive','posterior_predictive','observed_data', 89 | 'constant_data', 'sample_stats', 'log_likelihood', 'predictions', 'predictions_constant_data'} 90 | """ 91 | spaces=[] 92 | if 'header.json' in self._inferencedata: 93 | header = self._inferencedata['header.json'] 94 | header_js = json.loads(header) 95 | if 'inference_data' in header_js: 96 | for space in ['prior','posterior','prior_predictive','posterior_predictive','observed_data','constant_data', 'sample_stats', 'log_likelihood', 'predictions', 'predictions_constant_data']: 97 | if space in header_js['inference_data']: 98 | spaces.append(space) 99 | return spaces 100 | 101 | def is_observed_variable(self,var_name): 102 | if var_name in self._observed_variables: 103 | return True 104 | else: 105 | return False 106 | 107 | def get_var_type(self, var_name): 108 | """" 109 | Return any in {"observed","free","deterministic"} 110 | """ 111 | return self._graph[var_name]["type"] 112 | 113 | def get_var_dist(self, var_name): 114 | if "dist" in self._graph[var_name]["distribution"]: 115 | return self._graph[var_name]["distribution"]["dist"] 116 | else: 117 | return self._graph[var_name]["type"] 118 | 119 | def get_var_dist_type(self, var_name): 120 | """" 121 | Return any in {"Continuous","Discrete"} 122 | """ 123 | if "type" in self._graph[var_name]["distribution"]: 124 | return self._graph[var_name]["distribution"]["type"] 125 | else: 126 | return "" 127 | 128 | def get_var_parents(self, var_name): 129 | if "parents" in self._graph[var_name]: 130 | return self._graph[var_name]["parents"] 131 | else: 132 | return [] 133 | 134 | def get_samples(self, var_name, space=['prior','posterior']): 135 | """ 136 | Returns the samples of variable of the given space(s). 137 | 138 | Parameters: 139 | -------- 140 | var_name A String of the model's variables name 141 | space Either a List of Strings or a String with String in {'prior','posterior'} 142 | Returns: 143 | -------- 144 | A Dictionary of the form { : } 145 | A String from {"prior", "posterior"} 146 | A numpy.ndarray of samples of the parameter. 147 | e.g. PyMC3 shape=N, sample=M 148 | for i in M 149 | for j in N: 150 | Element (i,j) = (ith sample, jth prior/posterior distribution draw). 151 | If a String of space is given, it returns the numpy.ndarray. 152 | 153 | """ 154 | array_name="" 155 | header = self._inferencedata['header.json'] 156 | header_js = json.loads(header) 157 | if isinstance(space, list): 158 | data = {} 159 | for sp in space: 160 | if sp in self._spaces and self._is_var_in_space(var_name,sp): 161 | array_name = header_js['inference_data'][sp]['array_names'][var_name] 162 | if 'chain' in header_js['inference_data'][sp]['vars'][var_name]['dims']: 163 | data[sp] = np.mean(self._inferencedata[array_name],axis=0) 164 | else: 165 | data[sp] = self._inferencedata[array_name] 166 | return data 167 | elif isinstance(space, str): 168 | data = np.asarray([]) 169 | if space in self._spaces and self._is_var_in_space(var_name,space): 170 | array_name = header_js['inference_data'][space]['array_names'][var_name] 171 | if 'chain' in header_js['inference_data'][space]['vars'][var_name]['dims']: 172 | data = np.mean(self._inferencedata[array_name],axis=0) 173 | else: 174 | data = self._inferencedata[array_name] 175 | return data 176 | else: 177 | raise ValueError("space argument of get_sample should be either a List of Strings or a String") 178 | 179 | def get_observations(self, var_name): 180 | """ 181 | Returns the observations of variable. 182 | 183 | Parameters: 184 | -------- 185 | var_name A String of the model's variables name 186 | Returns: 187 | -------- 188 | A numpy.ndarray of observations of the parameter. 189 | """ 190 | array_name="" 191 | header = self._inferencedata['header.json'] 192 | header_js = json.loads(header) 193 | data = None 194 | if var_name in header_js['inference_data']['observed_data']['array_names']: 195 | array_name = header_js['inference_data']['observed_data']['array_names'][var_name] 196 | dims = header_js['inference_data']['observed_data']['vars'][var_name]['dims'] 197 | if 'chain' in dims: 198 | data = np.mean(self._inferencedata[array_name], axis=0) 199 | else: 200 | data = self._inferencedata[array_name] 201 | return data 202 | else: 203 | return data 204 | 205 | def get_range(self, var_name, space=['prior','posterior']): 206 | """ 207 | Returns the range of samples of variable of the given space(s). 208 | 209 | Parameters: 210 | -------- 211 | var_name A String of the model's variables name 212 | space Either a List of Strings or a String with String in {'prior','posterior'} 213 | Returns: 214 | -------- 215 | Tuple (min,max) 216 | 217 | """ 218 | if self.get_var_type(var_name) == "observed": 219 | if space == "posterior" and "posterior_predictive" in self.get_spaces(): 220 | space="posterior_predictive" 221 | elif space == "prior" and "prior_predictive" in self.get_spaces(): 222 | space="prior_predictive" 223 | data = self.get_samples(var_name, space) 224 | min=0 225 | max=0 226 | if data.size: 227 | min = np.amin(data) 228 | max = np.amax(data) 229 | return (min - 0.1*(max-min),max + 0.1*(max-min)) 230 | 231 | def get_varnames_per_graph_level(self, vars): 232 | """ 233 | Matches variable names to graph levels. 234 | 235 | Returns: 236 | -------- 237 | A Dict of the model's parameters names per graph level. 238 | Level 0 corresponds to the higher level. 239 | {: List of variable names}, level=0,1,2... 240 | """ 241 | nodes = self._add_nodes_to_graph(self._get_observed_nodes(), 0) 242 | varnames_per_graph_level = {} 243 | if vars == 'all': 244 | varnames_per_graph_level = self._reverse_nodes_levels(nodes) 245 | else: 246 | vars_per_level = self._reverse_nodes_levels(nodes) 247 | level = 0 248 | for l in sorted(vars_per_level): 249 | for var in vars: 250 | if var in vars_per_level[l] : 251 | if level not in varnames_per_graph_level: 252 | varnames_per_graph_level[level] = [] 253 | varnames_per_graph_level[level].append(var) 254 | level+=1 255 | return varnames_per_graph_level 256 | 257 | def _is_var_in_space(self, var_name, space): 258 | """ 259 | Returns True if variable is in space and False if not. 260 | 261 | Parameters: 262 | -------- 263 | var_name A String of the model's variables name. 264 | space A String in {'prior','posterior','posterior_predictive','prior_predictive'}. 265 | Returns: 266 | -------- 267 | A Boolean 268 | """ 269 | header = self._inferencedata['header.json'] 270 | header_js = json.loads(header) 271 | if var_name in header_js['inference_data'][space]['array_names']: 272 | return True 273 | else: 274 | return False 275 | 276 | def _get_observed_nodes(self): 277 | """ 278 | Get the observed nodes of the graph. 279 | 280 | Parameters: 281 | -------- 282 | graph A Dictionary 283 | Returns: 284 | -------- 285 | A List of the model's observed nodes of the graph (dict objects) 286 | """ 287 | try: 288 | return [v for _,v in self._graph.items() if v['type'] == 'observed'] 289 | except KeyError: 290 | print("Graph has no key 'type'") 291 | return None 292 | 293 | def _get_graph_nodes(self, varnames): 294 | """ 295 | Get the nodes of the graph indicated by a list of varnames. 296 | 297 | Parameters: 298 | -------- 299 | graph A Dictionary 300 | varnames A List of Strings 301 | Returns: 302 | -------- 303 | A List of the model's nodes of the graph (Dictionary objects) 304 | """ 305 | nodes = [] 306 | for vn in varnames: 307 | if vn in self._graph: 308 | nodes.append(self._graph[vn]) 309 | return nodes 310 | 311 | @staticmethod 312 | def _reverse_nodes_levels(nodes): 313 | """ 314 | Reverses the nodes' levels so that level 0 corresponds 315 | to the highest grid row. 316 | 317 | Parameters: 318 | -------- 319 | nodes A Dictionary 320 | Returns: 321 | -------- 322 | A Dictionary of the nodes with reversed keys 323 | """ 324 | max_level = max(nodes.keys()) 325 | return dict((max_level-k, v) for k, v in nodes.items()) 326 | 327 | def _add_nodes_to_graph(self, level_nodes, level): 328 | """ 329 | Adds nodes to graph levels recursively. 330 | 331 | Parameters: 332 | -------- 333 | level_nodes A List of nodes (dict) of the same graph . 334 | level An Int denoting the level of the graph 335 | where belong to. 336 | Returns: 337 | -------- 338 | A Dict of the model's parameters names per graph level. 339 | Level 0 corresponds to the lowest level. 340 | {: List of variables names}, level=0,1,2... 341 | """ 342 | nodes = {} 343 | try: 344 | for v in level_nodes: 345 | if level in nodes and v['name'] not in nodes[level]: 346 | nodes[level].append(v['name']) 347 | else: 348 | nodes[level]=[v['name']] 349 | 350 | parents_nodes = self._get_graph_nodes(v['parents']) 351 | if(len(parents_nodes)): 352 | # if incl_nodes == 'all': 353 | # par_nodes = self._add_nodes_to_graph(parents_nodes, level+1, 'all') 354 | # else: 355 | # parents_nodes_to_go_ahead = [] 356 | # rest_nodes = [] 357 | # for node in incl_nodes: 358 | # if node in parents_nodes: 359 | # parents_nodes_to_go_ahead.append(node) 360 | # else: 361 | # rest_nodes.append(node) 362 | par_nodes = self._add_nodes_to_graph(parents_nodes, level+1) 363 | for k in par_nodes: 364 | if k in nodes: 365 | for n in par_nodes[k]: 366 | if n not in nodes[k]: 367 | nodes[k].append(n) 368 | else: 369 | nodes[k] = par_nodes[k] 370 | return nodes 371 | except KeyError: 372 | print("Graph node has no key 'name' or 'parents' ") 373 | return None 374 | -------------------------------------------------------------------------------- /ipme/classes/data/dimension.py: -------------------------------------------------------------------------------- 1 | class Dimension: 2 | def __init__(self, name, label='', range=(None,None), step=0, unit='', values=[]): 3 | self.name=name 4 | if label!='' and label is not None: 5 | self.label=label 6 | else: 7 | self.label=name 8 | self.range=range 9 | self.step=step 10 | self.unit=unit 11 | self.values=values -------------------------------------------------------------------------------- /ipme/classes/graph.py: -------------------------------------------------------------------------------- 1 | from .data.data import Data 2 | from .grid.graph_grid import GraphGrid 3 | from .grid.predictive_ckecks_grid import PredictiveChecksGrid 4 | from .interaction_control.interaction_control import IC 5 | 6 | import panel as pn 7 | 8 | class Graph(): 9 | def __init__(self, data_path, mode = "i", vars = 'all', spaces = 'all', predictive_checks = []): 10 | """ 11 | Parameters: 12 | -------- 13 | data_path A String of the zip file with the inference data. 14 | mode A String in {'i','s'}: defines the type of diagram 15 | (interactive or static). 16 | vars A List of variables to be presented in the graph 17 | spaces A List of spaces to be included in graph 18 | predictive_checks A List of observed variables to plot predictive checks. 19 | Sets: 20 | -------- 21 | _mode A String in {"i","s"}, "i":interactive, "s":static. 22 | _graph A Panel component object to visualize model's graoh. 23 | """ 24 | self.ic = IC(Data(data_path)) 25 | if mode not in ["s","i"]: 26 | raise ValueError("ValueError: mode should take a value in {'i','s'}") 27 | self._mode = mode 28 | self._vars = vars 29 | self._spaces = spaces 30 | self._pred_checks = predictive_checks 31 | self._graph_grid = self._create_graph_grid() 32 | self._predictive_checks_grid = self._create_pred_checks_grid() 33 | self._graph = self._create_graph() 34 | 35 | def _create_graph_grid(self): 36 | """ 37 | Creates a GraphGrid object representing the model as a 38 | collection of Panel grids (one per space) and a 39 | collection of plotted widges. 40 | """ 41 | return GraphGrid(self.ic, self._mode, self._vars, self._spaces) 42 | 43 | def _create_pred_checks_grid(self): 44 | """ 45 | Creates a PredictiveChecks object representing the model's 46 | predictive checks for min, max, mean, std of predictions as a 47 | collection of Panel grids (one per space) and a 48 | collection of plotted widges. 49 | """ 50 | return PredictiveChecksGrid(self.ic, self._mode, self._spaces, self._pred_checks) 51 | 52 | def _create_graph(self): 53 | """ 54 | Creates one Tab per space (posterior, prior) presenting the graph. 55 | """ 56 | tabs = pn.Tabs(sizing_mode='stretch_both')#sizing_mode='stretch_both' 57 | ## Tabs for prior-posterior graph 58 | g_grids = self._graph_grid.get_grids() 59 | g_plotted_widgets = self._graph_grid.get_plotted_widgets() 60 | for space in g_grids: 61 | g_col = pn.Column(g_grids[space]) 62 | if space in g_plotted_widgets: 63 | widgetBox = pn.WidgetBox(*list(g_plotted_widgets[space].values()),sizing_mode = 'scale_both') 64 | w_col = pn.Column(widgetBox, width_policy='max', max_width=300, width=250) 65 | tabs.append((space, pn.Row(w_col, g_col)))#, height_policy='max', max_height=800 66 | else: 67 | tabs.append((space, pn.Row(g_col))) 68 | ## Tabs for predictive checks 69 | if self._pred_checks: 70 | pc_grids = self._predictive_checks_grid.get_grids() 71 | for var in pc_grids: 72 | for space in pc_grids[var]: 73 | g_col = pn.Column(pc_grids[var][space]) 74 | tabs.append((var+'_'+space+'_predictive_checks', pn.Row(g_col))) 75 | #tabs.append((space+'_predictive_checks', pn.Row(c.get_plot(space,add_info=False), sizing_mode='stretch_both'))) 76 | return tabs 77 | 78 | def set_coordinates(self, dim, options, value): 79 | self.ic.set_coordinates(self._graph_grid, dim, options, value) 80 | # try: 81 | # if coord_name in self.cells_widgets: 82 | # # space_widgets = self.cells_widgets[coord_name] 83 | # for space in self.cells_widgets[coord_name]: 84 | # c_id_list = self.cells_widgets[coord_name][space] 85 | # for c_id in c_id_list: 86 | # w = self.cells[c_id].get_widget(space, coord_name) 87 | # # old_v = w.value 88 | # w.value = new_val 89 | # w.trigger('value', new_val, new_val) 90 | # except IndexError: 91 | # raise IndexError() 92 | 93 | def get_selection_interactions(self): 94 | return self.ic.get_selection_interactions() 95 | 96 | def get_widgets_interactions(self): 97 | return self.ic.get_widgets_interactions() 98 | 99 | def get_graph(self): 100 | return self._graph 101 | 102 | def get_graph_grid(self): 103 | return self._graph_grid 104 | 105 | def get_pred_checks_grid(self): 106 | return self._create_pred_checks_grid 107 | -------------------------------------------------------------------------------- /ipme/classes/grid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/grid/__init__.py -------------------------------------------------------------------------------- /ipme/classes/grid/graph_grid.py: -------------------------------------------------------------------------------- 1 | from ...interfaces.grid import Grid 2 | from ipme.classes.cell.interactive_continuous_cell import InteractiveContinuousCell 3 | from ipme.classes.cell.interactive_discrete_cell import InteractiveDiscreteCell 4 | from ipme.classes.cell.static_continuous_cell import StaticContinuousCell 5 | from ipme.classes.cell.static_discrete_cell import StaticDiscreteCell 6 | 7 | from ...utils.constants import MAX_NUM_OF_COLS_PER_ROW, MAX_NUM_OF_VARS_PER_ROW, COLS_PER_VAR 8 | import panel as pn 9 | 10 | class GraphGrid(Grid): 11 | def _create_grids(self): 12 | """ 13 | Creates one Cell object per variable. Cell object is the smallest 14 | visualization unit in the grid. Moreover, it creates one Panel GridSpec 15 | object per space. 16 | 17 | Sets: 18 | -------- 19 | _cells A Dict {:Cell object}. 20 | _grids A Dict of pn.GridSpec objects: 21 | {:pn.GridSpec} 22 | """ 23 | graph_grid_map = self._create_graph_grid_mapping() 24 | for row, map_data in graph_grid_map.items(): 25 | level = map_data[0] 26 | vars_list = map_data[1] 27 | level_previous = -1 28 | if (row-1) in graph_grid_map: 29 | level_previous = graph_grid_map[row-1][0] 30 | if level != level_previous: 31 | col = int((MAX_NUM_OF_COLS_PER_ROW - len(vars_list)*COLS_PER_VAR) / 2.) 32 | else: 33 | col = int((MAX_NUM_OF_COLS_PER_ROW - MAX_NUM_OF_VARS_PER_ROW*COLS_PER_VAR) / 2.) 34 | for i,var_name in enumerate(vars_list): 35 | start_point = ( row, int(col + i*COLS_PER_VAR) ) 36 | end_point = ( row+1, int(col + (i+1)*COLS_PER_VAR) ) 37 | #col_l = int(col_f + (i+1)*COLS_PER_VAR) 38 | # grid_bgrd_col = level 39 | if self._mode == "i": 40 | if self._data.get_var_dist_type(var_name) == "Continuous": 41 | c = InteractiveContinuousCell(var_name, self.ic) 42 | else: 43 | c = InteractiveDiscreteCell(var_name, self.ic) 44 | elif self._mode == "s": 45 | if self._data.get_var_dist_type(var_name) == "Continuous": 46 | c = StaticContinuousCell(var_name, self.ic) 47 | else: 48 | c = StaticDiscreteCell(var_name, self.ic) 49 | self.cells[var_name] = c 50 | ##Add to grid 51 | cell_spaces = c.get_spaces() 52 | for space in cell_spaces: 53 | # if space in self._spaces_to_included and space not in self.spaces: 54 | # self.spaces.append(space) 55 | if space in self.spaces or self.spaces == 'all': 56 | if space not in self._grids: 57 | self._grids[space] = pn.GridSpec(sizing_mode = 'stretch_both') 58 | self._grids[space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = pn.Column(c.get_plot(space, add_info = True), width=220, height=220) 59 | self.ic.num_cells = len(self.cells) 60 | 61 | def _create_graph_grid_mapping(self): 62 | """ 63 | Maps the graph levels and the variables to Panel GridSpec rows/cols. 64 | Both =0 and =0 correspond to higher row/level. 65 | 66 | Returns: 67 | -------- 68 | A Dict {: (, List of varnames) } 69 | """ 70 | _varnames_per_graph_level = self._data.get_varnames_per_graph_level(self._vars) 71 | graph_grid_map = {} 72 | grid_level = 0 73 | for graph_level in sorted(_varnames_per_graph_level): 74 | num_vars = len(_varnames_per_graph_level[graph_level]) 75 | row = grid_level 76 | indx = 0 77 | while num_vars > MAX_NUM_OF_VARS_PER_ROW: 78 | while row in graph_grid_map: 79 | row+=1 80 | graph_grid_map[row] = (grid_level,_varnames_per_graph_level[graph_level][indx:indx+MAX_NUM_OF_VARS_PER_ROW]) 81 | row += 1 82 | indx += MAX_NUM_OF_VARS_PER_ROW 83 | num_vars -= MAX_NUM_OF_VARS_PER_ROW 84 | while row in graph_grid_map: 85 | row+=1 86 | graph_grid_map[row] = (graph_level,_varnames_per_graph_level[graph_level][indx:indx+num_vars]) 87 | return graph_grid_map 88 | -------------------------------------------------------------------------------- /ipme/classes/grid/predictive_ckecks_grid.py: -------------------------------------------------------------------------------- 1 | from ...interfaces.grid import Grid 2 | from ipme.classes.cell.interactive_pred_ckeck_cell import InteractivePredCheckCell 3 | from ...utils.constants import MAX_NUM_OF_COLS_PER_ROW, COLS_PER_VAR 4 | import panel as pn 5 | 6 | class PredictiveChecksGrid(Grid): 7 | def __init__(self, control, mode, spaces, predictive_ckecks = []): 8 | """ 9 | Parameters: 10 | -------- 11 | data_obj A Data object. 12 | mode A String in {"i","s"}, "i":interactive, "s":static. 13 | predictive_ckecks A List of observed variables to plot predictive checks. 14 | Sets: 15 | -------- 16 | _data A Data object. 17 | mode A String in {"i","s"}, "i":interactive, "s":static. 18 | _grids A Dict of pn.GridSpec objects: 19 | {:{:pn.GridSpec}} 20 | cells A Dict {:Cell object}, 21 | where pred_check in {'min','max','mean','std'}. 22 | cells_widgets A Dict dict1 of the form (key1,value1) = (, dict2) 23 | dict2 of the form (key1,value1) = (, List of tuples (,) 24 | of the widgets with same name). 25 | plotted_widgets A List of widget objects to be plotted. 26 | """ 27 | self._pred_checks = predictive_ckecks 28 | Grid.__init__(self, control, mode, spaces = spaces) 29 | 30 | def _create_grids(self): 31 | """ 32 | Creates a 2x2 grid of the prior and posterior predictive checks 33 | for min, max, mean and std function. 34 | 35 | Sets: 36 | -------- 37 | _grids A Dict of pn.GridSpec objects: 38 | {:{:pn.GridSpec}} 39 | """ 40 | for var in self._pred_checks: 41 | if self._data.is_observed_variable(var): 42 | # if self._mode == "i": 43 | c_min = InteractivePredCheckCell(var, self.ic, "min") 44 | c_max = InteractivePredCheckCell(var, self.ic, "max") 45 | c_mean = InteractivePredCheckCell(var,self.ic, "mean") 46 | c_std = InteractivePredCheckCell(var, self.ic, "std") 47 | self.cells['min'] = c_min 48 | self.cells['max'] = c_max 49 | self.cells['mean'] = c_mean 50 | self.cells['std'] = c_std 51 | ##Add to grid 52 | cell_spaces = c_min.get_spaces() 53 | self._grids[var] = {} 54 | for space in cell_spaces: 55 | if space in self.spaces or self.spaces == 'all': 56 | if space not in self._grids[var]: 57 | self._grids[var][space] = pn.GridSpec(sizing_mode='stretch_both') 58 | for row in [0,1]: 59 | for i in [0,1]: 60 | col = int((MAX_NUM_OF_COLS_PER_ROW - 2.*COLS_PER_VAR) / 2.) 61 | start_point = ( row, int(col + i*COLS_PER_VAR) ) 62 | end_point = ( row+1, int(col + (i+1)*COLS_PER_VAR) ) 63 | if row == 0 and i == 0: 64 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \ 65 | pn.Column(c_min.get_plot(space), width=220, height=220) 66 | elif row == 0 and i == 1: 67 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \ 68 | pn.Column(c_max.get_plot(space), width=220, height=220) 69 | elif row == 1 and i == 0: 70 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \ 71 | pn.Column(c_mean.get_plot(space), width=220, height=220) 72 | elif row == 1 and i == 1: 73 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \ 74 | pn.Column(c_std.get_plot(space), width=220, height=220) 75 | else: 76 | raise ValueError("Declared predive check variable {} is not an observed variable".format(var)) 77 | -------------------------------------------------------------------------------- /ipme/classes/grid/scatter_matrix_grid.py: -------------------------------------------------------------------------------- 1 | from ...interfaces.grid import Grid 2 | from ipme.classes.cell.interactive_scatter_cell import InteractiveScatterCell 3 | from ipme.classes.cell.static_scatter_cell import StaticScatterCell 4 | from ipme.classes.cell.interactive_continuous_cell import InteractiveContinuousCell 5 | from ipme.classes.cell.interactive_discrete_cell import InteractiveDiscreteCell 6 | from ipme.classes.cell.static_continuous_cell import StaticContinuousCell 7 | from ipme.classes.cell.static_discrete_cell import StaticDiscreteCell 8 | 9 | from ...utils.constants import COLS_PER_VAR 10 | import panel as pn 11 | 12 | class ScatterMatrixGrid(Grid): 13 | def _create_grids(self): 14 | """ 15 | Creates one Cell object per variable. Cell object is the smallest 16 | visualization unit in the grid. Moreover, it creates one Panel GridSpec 17 | object per space. 18 | 19 | Sets: 20 | -------- 21 | _cells A Dict {:Cell object}. 22 | _grids A Dict of pn.GridSpec objects: 23 | {:pn.GridSpec} 24 | """ 25 | for row in range(len(self._vars)): 26 | for col in range(len(self._vars)): 27 | if col > row: 28 | break 29 | start_point = ( row, int(col*COLS_PER_VAR) ) 30 | end_point = ( row+1, int((col+1)*COLS_PER_VAR) ) 31 | var ="" 32 | if col == row: 33 | ##plot VariableCell 34 | var_name = self._vars[col] 35 | if self._mode == "i": 36 | if self._data.get_var_dist_type(var_name) == "Continuous": 37 | c = InteractiveContinuousCell(var_name, self.ic) 38 | else: 39 | c = InteractiveDiscreteCell(var_name, self.ic) 40 | elif self._mode == "s": 41 | if self._data.get_var_dist_type(var_name) == "Continuous": 42 | c = StaticContinuousCell(var_name, self.ic) 43 | else: 44 | c = StaticDiscreteCell(var_name, self.ic) 45 | var = var_name 46 | else: 47 | ##plot pair scatter 48 | var1 = self._vars[row] 49 | var2 = self._vars[col] 50 | if self._mode == "i": 51 | c = InteractiveScatterCell([var1, var2], self.ic) 52 | elif self._mode == "s": 53 | c = StaticScatterCell([var1, var2], self.ic) 54 | var = var1+"_"+var2 55 | self.cells[var] = c 56 | ##Add to grid 57 | cell_spaces = c.get_spaces() 58 | for space in cell_spaces: 59 | # if space not in self.spaces: 60 | # self.spaces.append(space) 61 | if space in self.spaces or self.spaces == 'all': 62 | if space not in self._grids: 63 | self._grids[space] = pn.GridSpec(sizing_mode = 'stretch_both') 64 | self._grids[space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = pn.Column(c.get_plot(space), width=220, height=220) 65 | -------------------------------------------------------------------------------- /ipme/classes/interaction_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/interaction_control/__init__.py -------------------------------------------------------------------------------- /ipme/classes/scatter_matrix.py: -------------------------------------------------------------------------------- 1 | from .data.data import Data 2 | from .grid.scatter_matrix_grid import ScatterMatrixGrid 3 | from .interaction_control.interaction_control import IC 4 | 5 | import panel as pn 6 | 7 | class ScatterMatrix(): 8 | def __init__(self, data_path, mode = "i", vars = [], spaces = 'all'): 9 | """ 10 | Parameters: 11 | -------- 12 | data_path A String of the zip file with the inference data. 13 | mode A String in {'i','s'}: defines the type of diagram 14 | (interactive or static). 15 | vars A List of model variables to be included in the plot. 16 | spaces A List of spaces to be included in graph 17 | Sets: 18 | -------- 19 | _mode A String in {"i","s"}, "i":interactive, "s":static. 20 | _scatter_matrix A Panel component object to visualize model's scatter matrix. 21 | """ 22 | self.ic = IC(Data(data_path)) 23 | if mode not in ["s","i"]: 24 | raise ValueError("ValueError: mode should take a value in {'i','s'}") 25 | self._mode = mode 26 | self._vars = vars 27 | self._spaces = spaces 28 | self._scatter_matrix_grid = self._create_scatter_matrix_grid() 29 | self._scatter_matrix = self._create_scatter_matrix() 30 | 31 | def _create_scatter_matrix_grid(self): 32 | """ 33 | Creates a ScatterMatrixGrid object representing the model as a 34 | collection of Panel grids (one per space) and a 35 | collection of plotted widges. 36 | """ 37 | return ScatterMatrixGrid(self.ic, self._mode, self._vars, self._spaces) 38 | 39 | def _create_scatter_matrix(self): 40 | """ 41 | Creates one Tab per space (posterior, prior) presenting the scatter matrix. 42 | """ 43 | tabs = pn.Tabs(sizing_mode='stretch_both')#sizing_mode='stretch_both' 44 | ## Tabs for prior-posterior scatter matrix 45 | g_grids = self._scatter_matrix_grid.get_grids() 46 | g_plotted_widgets = self._scatter_matrix_grid.get_plotted_widgets() 47 | for space in g_grids: 48 | g_col = pn.Column(g_grids[space]) 49 | if space in g_plotted_widgets: 50 | widgetBox = pn.WidgetBox(*list(g_plotted_widgets[space].values()),sizing_mode = 'scale_both') 51 | w_col = pn.Column(widgetBox, width_policy='max', max_width=300, width=250) 52 | tabs.append((space, pn.Row(w_col, g_col)))#, height_policy='max', max_height=800 53 | else: 54 | tabs.append((space, pn.Row(g_col))) 55 | return tabs 56 | 57 | def set_coordinates(self, dim, options, value): 58 | self.ic.set_coordinates(self._scatter_matrix_grid, dim, options, value) 59 | 60 | def set_coordinate(self, dim, value): 61 | self.ic.set_coordinate(self._scatter_matrix_grid, dim, value) 62 | 63 | def get_selection_interactions(self): 64 | return self.ic.get_selection_interactions() 65 | 66 | def get_widgets_interactions(self): 67 | return self.ic.get_widgets_interactions() 68 | 69 | def get_selection_ranges(self): 70 | """ 71 | Returns List of tuples (xmin,xmax) of selection box 72 | """ 73 | return self.ic.get_selection_ranges() 74 | 75 | def get_scatter_matrix(self): 76 | return self._scatter_matrix 77 | 78 | def get_scatter_matrix_grid(self): 79 | return self._scatter_matrix_grid 80 | 81 | # def get_pred_checks_grid(self): 82 | # return self._create_pred_checks_grid 83 | -------------------------------------------------------------------------------- /ipme/cli.py: -------------------------------------------------------------------------------- 1 | """Console script for imd.""" 2 | import sys 3 | import click 4 | 5 | 6 | @click.command() 7 | def main(args=None): 8 | """Console script for imd.""" 9 | click.echo("Replace this message by putting your code into " 10 | "imd.cli.main") 11 | click.echo("See click documentation at https://click.palletsprojects.com/") 12 | return 0 13 | 14 | 15 | if __name__ == "__main__": 16 | sys.exit(main()) # pragma: no cover 17 | -------------------------------------------------------------------------------- /ipme/interfaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/interfaces/__init__.py -------------------------------------------------------------------------------- /ipme/interfaces/cell.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from ipme.classes.cell.utils.cell_widgets import CellWidgets 4 | from bokeh.io.export import get_screenshot_as_png 5 | 6 | class Cell(ABC): 7 | def __init__(self, vars, control): 8 | """ 9 | Each cell will occupy a certain number of grid columns and will lie on a certain grid row. 10 | Parameters: 11 | -------- 12 | vars A List of variableNames of the model. 13 | control An IC object 14 | Sets: 15 | -------- 16 | vars 17 | ic 18 | spaces A List of Strings in {"prior","posterior"}. 19 | idx_dims A Dict {:{:Dimension obj}}. 20 | cur_idx_dims_values A Dict {:{: Current value of }}. 21 | 22 | plot A Dict {: (bokeh) plot object}. 23 | widgets A Dict {: {: A (bokeh) widget object} }. 24 | """ 25 | self.vars = vars 26 | self.ic = control 27 | self._data = control.data 28 | self.spaces = self._define_spaces() 29 | 30 | #idx_dims-related variables 31 | self.idx_dims = self._data.get_idx_dimensions(self.vars) 32 | self.cur_idx_dims_values = {} 33 | 34 | self.plot = {} 35 | self.widgets = {} 36 | self._initialize_widgets() 37 | self._initialize_plot() 38 | 39 | def _define_spaces(self): 40 | data_spaces = self._data.get_spaces() 41 | spaces = [] 42 | if "prior" in data_spaces: 43 | spaces.append("prior") 44 | if "posterior" in data_spaces: 45 | spaces.append("posterior") 46 | return spaces 47 | 48 | def _initialize_widgets(self): 49 | CellWidgets.initialize_widgets(self) 50 | 51 | @abstractmethod 52 | def widget_callback(self, attr, old, new, w_title, space): 53 | pass 54 | 55 | @abstractmethod 56 | def _initialize_plot(self): 57 | pass 58 | 59 | ## GETTERS 60 | def get_widgets(self): 61 | return self.widgets 62 | 63 | def get_widgets_in_space(self, space): 64 | if space in self.widgets: 65 | return self.widgets[space] 66 | else: 67 | return [] 68 | 69 | def get_widget(self, space, id): 70 | try: 71 | return self.widgets[space][id] 72 | except IndexError: 73 | return None 74 | 75 | def get_plot(self, space): 76 | if space in self.plot: 77 | return self.plot[space] 78 | else: 79 | return None 80 | 81 | def get_screenshot(self, space): 82 | if space in self.plot: 83 | return get_screenshot_as_png(self.plot[space]) 84 | else: 85 | return None 86 | 87 | def get_spaces(self): 88 | return self.spaces 89 | -------------------------------------------------------------------------------- /ipme/interfaces/data_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class Data_Interface(ABC): 4 | 5 | def __init__(self, inference_path): 6 | """ 7 | Parameters: 8 | -------- 9 | inference_path A String of the inference data file path. 10 | Sets: 11 | -------- 12 | _inferencedata A structure of the inference data. 13 | _graph A structure of the model's variables graph data. 14 | _observed_variables A List of Strings of the observed variables. 15 | _idx_dimensions A Dict {:{:Dimension obj}}. 16 | _spaces A List of Strings in {'prior', 'posterior'} of all 17 | the available MCMC sample spaces in the inference data 18 | """ 19 | self._inferencedata = self._load_inference_data(inference_path) 20 | self._graph = self._get_graph() 21 | self._observed_variables = self._get_obeserved_variables() 22 | self.all_variables = [] 23 | self._idx_dimensions = self._get_idx_dimensions() 24 | self._spaces = self._get_spaces_from_data() 25 | 26 | @abstractmethod 27 | def _load_inference_data(self,datapath): 28 | pass 29 | 30 | @abstractmethod 31 | def _get_graph(self): 32 | pass 33 | 34 | @abstractmethod 35 | def _get_obeserved_variables(self): 36 | pass 37 | 38 | @abstractmethod 39 | def _get_idx_dimensions(self): 40 | pass 41 | 42 | @abstractmethod 43 | def _get_spaces_from_data(self): 44 | pass 45 | 46 | @abstractmethod 47 | def is_observed_variable(self,var_name): 48 | pass 49 | 50 | @abstractmethod 51 | def get_var_type(self,var_name): 52 | pass 53 | 54 | @abstractmethod 55 | def get_var_dist(self,var_name): 56 | pass 57 | 58 | @abstractmethod 59 | def get_var_dist_type(self,var_name): 60 | pass 61 | 62 | @abstractmethod 63 | def get_var_parents(self,var_name): 64 | pass 65 | 66 | @abstractmethod 67 | def get_samples(self,var_name,space=['prior','posterior']): 68 | pass 69 | 70 | @abstractmethod 71 | def get_range(self, var_name, space=['prior','posterior']): 72 | pass 73 | 74 | @abstractmethod 75 | def get_varnames_per_graph_level(self): 76 | pass 77 | 78 | def get_data(self): 79 | return self._data 80 | 81 | def get_obeserved_variables(self): 82 | return self._observed_variables 83 | 84 | def get_inferencedata(self): 85 | return self._inferencedata 86 | 87 | def get_idx_dimensions(self, var_names): 88 | idx_dims = {} 89 | for var in var_names: 90 | if var in self._idx_dimensions: 91 | idx_dims[var] = {} 92 | for dim in self._idx_dimensions[var]: 93 | idx_dims[var][dim] = self._idx_dimensions[var][dim] 94 | return idx_dims 95 | 96 | def get_indx_for_idx_dim(self, var_name, d_name, d_value): 97 | indx=-1 98 | if var_name in self._idx_dimensions and d_name in self._idx_dimensions[var_name]: 99 | dvalues = self._idx_dimensions[var_name][d_name].values 100 | if d_value in dvalues: 101 | indx = dvalues.index(d_value) 102 | return indx 103 | 104 | def get_spaces(self): 105 | return self._spaces 106 | 107 | -------------------------------------------------------------------------------- /ipme/interfaces/grid.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from ..classes.cell.utils.cell_widgets import CellWidgets 3 | 4 | class Grid(ABC): 5 | def __init__(self, control, mode, vars = 'all', spaces = 'all'): 6 | """ 7 | Parameters: 8 | -------- 9 | control A IC object. 10 | mode A String in {"i","s"}, "i":interactive, "s":static. 11 | vars A List of variables to be presented in the graph 12 | spaces A List of spaces to be included in graph 13 | Sets: 14 | -------- 15 | ic 16 | _data 17 | _mode 18 | _grids A Dict of pn.GridSpec objects 19 | Either {:{:pn.GridSpec}} 20 | Or {:pn.GridSpec} 21 | cells A Dict either {:Cell object} or {:Cell object}, 22 | where pred_check in {'min','max','mean','std'}. 23 | spaces A List of available spaces in this grid. Elements in {'prior','posterior'} 24 | cells_widgets A Dict dict1 of the form (key1,value1) = (, dict2) 25 | dict2 of the form (key1,value1) = (, List of ). 26 | plotted_widgets A Dict of the form {: List of widget objects to be plotted} . 27 | """ 28 | self.ic = control 29 | self._data = control.data 30 | self._mode = mode 31 | self._vars = vars 32 | # self._spaces_to_included = spaces 33 | self._grids = {} 34 | self.cells = {} 35 | self.spaces = spaces 36 | self._create_grids() 37 | 38 | self.cells_widgets = {} 39 | self.plotted_widgets = {} 40 | self._add_widgets() 41 | 42 | @abstractmethod 43 | def _create_grids(self): 44 | pass 45 | 46 | def _add_widgets(self): 47 | CellWidgets.link_cells_widgets(self) 48 | if self._mode == "i": 49 | CellWidgets.set_plotted_widgets_interactive(self) 50 | else: 51 | CellWidgets.set_plotted_widgets_static(self) 52 | 53 | ## GETTERS 54 | def get_grids(self): 55 | return self._grids 56 | 57 | def get_plotted_widgets(self): 58 | return self.plotted_widgets 59 | 60 | def get_cells(self): 61 | return self.cells 62 | 63 | -------------------------------------------------------------------------------- /ipme/interfaces/predictive_check_cell.py: -------------------------------------------------------------------------------- 1 | from ..interfaces.cell import Cell 2 | 3 | import threading 4 | from abc import abstractmethod 5 | 6 | class PredictiveCheckCell(Cell): 7 | def __init__(self, name, control): 8 | """ 9 | Parameters: 10 | -------- 11 | name A String within the set {""}. 12 | control An InteractiveControl object 13 | Sets: 14 | -------- 15 | _func 16 | _source 17 | _reconstructed 18 | _samples 19 | _seg 20 | 21 | """ 22 | self.name = name 23 | self.source = {} 24 | self.reconstructed = {} 25 | self.samples = {} 26 | self.seg = {} 27 | self.pvalue = {} 28 | self.pvalue_rec = {} 29 | Cell.__init__(self, [name], control) 30 | 31 | ## DATA 32 | def get_samples_for_cur_idx_dims_values(self, space): 33 | """ 34 | Returns the observed data and predictive samples of the observed variable 35 | in space . 36 | 37 | Returns: 38 | -------- 39 | A Tuple (data,samples): data-> observed data and samples-> predictive samples. 40 | """ 41 | data = self.ic.data.get_samples(self.name, 'observed_data') 42 | if self.ic.data.get_var_type(self.name) == "observed": 43 | if space == "posterior" and "posterior_predictive" in self.ic.data.get_spaces(): 44 | space="posterior_predictive" 45 | elif space == "prior" and "prior_predictive" in self.ic.data.get_spaces(): 46 | space="prior_predictive" 47 | samples = self.ic.data.get_samples(self.name, space) 48 | return data, samples 49 | 50 | ## INITIALIZATIONS 51 | def _initialize_plot(self): 52 | for space in self.spaces: 53 | self.initialize_cds(space) 54 | self.initialize_fig(space) 55 | self.initialize_glyphs(space) 56 | 57 | @abstractmethod 58 | def initialize_fig(self, space): 59 | pass 60 | 61 | @abstractmethod 62 | def initialize_cds(self, space): 63 | pass 64 | 65 | @abstractmethod 66 | def initialize_glyphs(self, space): 67 | pass 68 | 69 | ## WIDGETS 70 | @abstractmethod 71 | def widget_callback(self, attr, old, new, w_title, space): 72 | pass 73 | 74 | ## UPDATE 75 | @abstractmethod 76 | def update_cds(self, space): 77 | pass 78 | 79 | def sample_inds_callback(self, space, attr, old, new): 80 | """ 81 | Updates cds when indices of selected samples -- Cell._sample_inds-- 82 | are updated. 83 | """ 84 | self.ic.add_selection_threads(space, threading.Thread(target = self._sample_inds_thread, args = (space,), daemon = True)) 85 | self.ic.sel_lock_event.set() 86 | 87 | def _sample_inds_thread(self, space): 88 | self.update_cds(space) 89 | -------------------------------------------------------------------------------- /ipme/interfaces/scatter_cell.py: -------------------------------------------------------------------------------- 1 | from ..interfaces.cell import Cell 2 | from ..utils.stats import find_x_range 3 | 4 | import numpy as np 5 | import threading 6 | from abc import abstractmethod 7 | 8 | class ScatterCell(Cell): 9 | def __init__(self, vars, control): 10 | """ 11 | Parameters: 12 | -------- 13 | vars A List of variableNames of the model. 14 | control A Control object 15 | Sets: 16 | ----- 17 | x_range Figures axes x_range 18 | """ 19 | self.samples = {} 20 | self.contours = {} 21 | self._all_samples = {} 22 | self.x_range = {} 23 | Cell.__init__(self, vars, control) 24 | 25 | ## DATA 26 | def get_samples(self, space): 27 | """ 28 | Retrieves MCMC samples of into a numpy.ndarray and 29 | sets an entry into self._all_samples Dict. 30 | """ 31 | for var in self.vars: 32 | space_gsam = space 33 | if self._data.get_var_type(var) == "observed": 34 | if space == "posterior" and "posterior_predictive" in self._data.get_spaces(): 35 | space_gsam = "posterior_predictive" 36 | elif space == "prior" and "prior_predictive" in self._data.get_spaces(): 37 | space_gsam = "prior_predictive" 38 | if var not in self._all_samples: 39 | self._all_samples[var] = {} 40 | self._all_samples[var][space] = self._data.get_samples(var, space_gsam).T 41 | # compute x_range 42 | self.x_range[var] = {} 43 | self.x_range[var][space] = find_x_range(self._all_samples[var][space]) 44 | # self.x_range[var][space] = find_x_range(self.get_samples_for_cur_idx_dims_values(var, space)) 45 | 46 | def get_samples_for_cur_idx_dims_values(self, var_name, space): 47 | """ 48 | Returns a numpy.ndarray of the MCMC samples of the 49 | parameter for current index dimensions values. 50 | 51 | Returns: 52 | -------- 53 | A numpy.ndarray. 54 | """ 55 | if var_name in self._all_samples: 56 | data = self._all_samples[var_name] 57 | if space in data: 58 | data = data[space] 59 | else: 60 | raise ValueError("cel {}-{}: space {} not in self._all_samples[{}].keys() {}".format(self.vars[0],self.vars[1],space,var_name,data.keys())) 61 | else: 62 | raise ValueError("var_name {} not in self._all_samples.keys() {}".format(var_name, self._all_samples.keys())) 63 | if var_name in self.cur_idx_dims_values: 64 | for _, dim_value in self.cur_idx_dims_values[var_name].items(): 65 | data = data[dim_value] 66 | return np.squeeze(data).T 67 | 68 | ## INITIALIZATION 69 | def _initialize_plot(self): 70 | for space in self.spaces: 71 | self.get_samples(space) 72 | self.initialize_cds(space) 73 | self.initialize_fig(space) 74 | self.initialize_glyphs(space) 75 | 76 | @abstractmethod 77 | def initialize_fig(self, space): 78 | pass 79 | 80 | @abstractmethod 81 | def initialize_cds(self, space): 82 | pass 83 | 84 | @abstractmethod 85 | def initialize_glyphs(self, space): 86 | pass 87 | 88 | ## WIDGETS 89 | @abstractmethod 90 | def widget_callback(self, attr, old, new, w_title, space): 91 | pass 92 | 93 | ## UPDATE 94 | @abstractmethod 95 | def update_cds(self, space): 96 | pass 97 | 98 | def sample_inds_callback(self, space, attr, old, new): 99 | """ 100 | Updates cds when indices of selected samples -- Cell._sample_inds-- 101 | are updated. 102 | """ 103 | self.ic.add_selection_threads(space, threading.Thread(target = self._sample_inds_thread, args = (space,), daemon = True)) 104 | self.ic.sel_lock_event.set() 105 | 106 | def _sample_inds_thread(self, space): 107 | self.update_cds(space) 108 | 109 | -------------------------------------------------------------------------------- /ipme/interfaces/variable_cell.py: -------------------------------------------------------------------------------- 1 | from ..interfaces.cell import Cell 2 | from ..utils.stats import find_x_range 3 | from ..utils.constants import BORDER_COLORS 4 | 5 | from bokeh.models import Toggle, Div 6 | from bokeh.layouts import layout 7 | from bokeh.io.export import get_screenshot_as_png 8 | 9 | import numpy as np 10 | import threading 11 | from abc import abstractmethod 12 | 13 | class VariableCell(Cell): 14 | def __init__(self, name, control): 15 | """ 16 | Parameters: 17 | -------- 18 | name A String within the set {""}. 19 | control A Control object 20 | Sets: 21 | ----- 22 | x_range Figures axes x_range 23 | _toggle A Dict {: (bokeh) toggle button for visibility of figure}. 24 | _div A Dict {: (bokeh) div parameter-related information}. 25 | """ 26 | self.name = name 27 | self.source = {} 28 | self.samples = {} 29 | self.data = {} 30 | self._all_samples = {} 31 | self._all_data = {} 32 | self.x_range = {} 33 | Cell.__init__(self, [name], control) 34 | self._toggle = {} 35 | self._div = {} 36 | self._initialize_toggle_div() 37 | 38 | ## DATA 39 | def get_samples(self, space): 40 | """ 41 | Retrieves MCMC samples of into a numpy.ndarray and 42 | sets an entry into self._all_samples Dict. 43 | """ 44 | for var in self.vars: 45 | space_gsam = space 46 | if self._data.get_var_type(var) == "observed": 47 | if space == "posterior" and "posterior_predictive" in self._data.get_spaces(): 48 | space_gsam = "posterior_predictive" 49 | elif space == "prior" and "prior_predictive" in self._data.get_spaces(): 50 | space_gsam = "prior_predictive" 51 | if var not in self._all_samples: 52 | self._all_samples[var] = {} 53 | self._all_samples[var][space] = self._data.get_samples(var, space_gsam).T 54 | # get observed data 55 | data = self._data.get_observations(var) 56 | if data is not None: 57 | self._all_data[var] = data 58 | # compute x_range 59 | self.x_range[var] = {} 60 | self.x_range[var][space] = find_x_range(self._all_samples[var][space]) 61 | # self.x_range[var][space] = find_x_range(self.get_samples_for_cur_idx_dims_values(var, space)) 62 | 63 | def get_samples_for_cur_idx_dims_values(self, var_name, space): 64 | """ 65 | Returns a numpy.ndarray of the MCMC samples of the 66 | parameter for current index dimensions values. 67 | 68 | Returns: 69 | -------- 70 | A numpy.ndarray. 71 | """ 72 | if var_name in self._all_samples: 73 | data = self._all_samples[var_name] 74 | if space in data: 75 | data = data[space] 76 | else: 77 | raise ValueError("cel {}-{}: space {} not in self._all_samples[{}].keys() {}".format(self.vars[0],self.vars[1],space,var_name,data.keys())) 78 | else: 79 | raise ValueError("var_name {} not in self._all_samples.keys() {}".format(var_name, self._all_samples.keys())) 80 | if var_name in self.cur_idx_dims_values: 81 | for _, dim_value in self.cur_idx_dims_values[var_name].items(): 82 | data = data[dim_value] 83 | if data.shape == (1,)*len(data.shape): 84 | return data.flatten() 85 | else: 86 | return np.squeeze(data).T 87 | 88 | def get_data_for_cur_idx_dims_values(self, var_name): 89 | """ 90 | Returns a numpy.ndarray of the observations of the 91 | parameter for current index dimensions values. 92 | 93 | Returns: 94 | -------- 95 | A numpy.ndarray. 96 | """ 97 | if var_name in self._all_data: 98 | data = self._all_data[var_name] 99 | else: 100 | return None 101 | if var_name in self.cur_idx_dims_values: 102 | for _, dim_value in self.cur_idx_dims_values[var_name].items(): 103 | data = data[dim_value] 104 | if data.shape == (1,)*len(data.shape): 105 | return data.flatten() 106 | else: 107 | return np.squeeze(data).T 108 | 109 | ## INITIALIZATION 110 | def _initialize_plot(self): 111 | for space in self.spaces: 112 | self.get_samples(space) 113 | self.initialize_cds(space) 114 | self.initialize_fig(space) 115 | self.initialize_glyphs(space) 116 | 117 | @abstractmethod 118 | def initialize_fig(self, space): 119 | pass 120 | 121 | @abstractmethod 122 | def initialize_cds(self, space): 123 | pass 124 | 125 | @abstractmethod 126 | def initialize_glyphs(self, space): 127 | pass 128 | 129 | ## WIDGETS 130 | @abstractmethod 131 | def widget_callback(self, attr, old, new, w_title, space): 132 | pass 133 | 134 | ## UPDATE 135 | @abstractmethod 136 | def update_cds(self, space): 137 | pass 138 | 139 | def sample_inds_callback(self, space, attr, old, new): 140 | """ 141 | Updates cds when indices of selected samples -- Cell._sample_inds-- 142 | are updated. 143 | """ 144 | self.ic.add_selection_threads(space, threading.Thread(target = self._sample_inds_thread, args = (space,), daemon = True)) 145 | self.ic.sel_lock_event.set() 146 | 147 | def _sample_inds_thread(self, space): 148 | self.update_cds(space) 149 | 150 | def compute_intersection_of_samples(self, space): 151 | """ 152 | Computes intersection of sample points based on user's 153 | restrictions per parameter. 154 | """ 155 | sel_var_inds = self.ic.get_sel_var_inds(space = space) 156 | sp_keys = list(sel_var_inds.keys()) 157 | inds_list = np.full((len(self.samples[space].data['x']),), False) 158 | if len(sp_keys)>1: 159 | sets = [] 160 | for var in sp_keys: 161 | sets.append(set(sel_var_inds[var])) 162 | inds_set = set.intersection(*sorted(sets, key = len)) 163 | inds_list[list(inds_set)] = True 164 | elif len(sp_keys) == 1: 165 | inds_list[sel_var_inds[sp_keys[0]]] = True 166 | non_inds_list = list(~inds_list) 167 | self.ic.set_sample_inds(space, dict(inds = list(inds_list)), dict(non_inds = non_inds_list)) 168 | 169 | def _initialize_toggle_div(self): 170 | """" 171 | Creates the toggle headers of each variable node. 172 | """ 173 | for space in self.spaces: 174 | width = self.plot[space].plot_width 175 | height = 40 176 | sizing_mode = self.plot[space].sizing_mode 177 | label = self.name + " ~ " + self._data.get_var_dist(self.name) 178 | text = """parents: %s
dims: %s"""%(self._data.get_var_parents(self.name), list(self._data.get_idx_dimensions(self.name))) 179 | if sizing_mode == 'fixed': 180 | self._toggle[space] = Toggle(label = label, active = False, 181 | width = width, height = height, sizing_mode = sizing_mode, margin = (0,0,0,0)) 182 | self._div[space] = Div(text = text, 183 | width = width, height = height, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] ) 184 | elif sizing_mode == 'scale_width' or sizing_mode == 'stretch_width': 185 | self._toggle[space] = Toggle(label = label, active = False, 186 | height = height, sizing_mode = sizing_mode, margin = (0,0,0,0)) 187 | self._div[space] = Div(text = text, 188 | height = height, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] ) 189 | elif sizing_mode == 'scale_height' or sizing_mode == 'stretch_height': 190 | self._toggle[space] = Toggle(label = label, active = False, 191 | width = width, sizing_mode = sizing_mode, margin = (0,0,0,0)) 192 | self._div[space] = Div(text = text, 193 | width = width, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] ) 194 | else: 195 | self._toggle[space] = Toggle(label = label, active = False, 196 | sizing_mode = sizing_mode, margin = (0,0,0,0)) 197 | self._div[space] = Div(text = text, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] ) 198 | self._toggle[space].js_link('active', self.plot[space], 'visible') 199 | 200 | 201 | def get_max_prob(self, space): 202 | """ 203 | Gets highest point --max probability-- of cds 204 | """ 205 | max_sv = -1 206 | max_rv = -1 207 | if self.source[space].data['y'].size: 208 | max_sv = self.source[space].data['y'].max() 209 | if hasattr(self,'reconstructed') and self.reconstructed[space].data['y'].size: 210 | max_rv = self.reconstructed[space].data['y'].max() 211 | max_v = max([max_sv,max_rv]) 212 | return max_v if max_v!=-1 else None 213 | 214 | def get_plot(self, space, add_info = False): 215 | if space in self.plot: 216 | if add_info and space in self._toggle and space in self._div: 217 | return layout([self._toggle[space]], [self._div[space]], [self.plot[space]]) 218 | else: 219 | return self.plot[space] 220 | else: 221 | return None 222 | 223 | def get_screenshot(self, space, add_info=False): 224 | if space in self.plot: 225 | if add_info and space in self._toggle and space in self._div: 226 | return get_screenshot_as_png(layout([self._toggle[space]], [self._div[space]], [self.plot[space]])) 227 | else: 228 | return get_screenshot_as_png(self.plot[space]) 229 | else: 230 | return None 231 | -------------------------------------------------------------------------------- /ipme/methods.py: -------------------------------------------------------------------------------- 1 | from .classes.graph import Graph 2 | from .classes.scatter_matrix import ScatterMatrix 3 | 4 | def graph(data_path, mode = "i", vars = 'all', spaces = 'all', predictive_checks = []): 5 | graph = Graph(data_path, mode, vars, spaces, predictive_checks) 6 | graph.get_graph().show() 7 | 8 | def scatter_matrix(data_path, mode = "i", vars = [], spaces = 'all'): 9 | scatter_matrix = ScatterMatrix(data_path, mode, vars, spaces) 10 | scatter_matrix.get_scatter_matrix().show() -------------------------------------------------------------------------------- /ipme/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = """Evdoxia Taka""" 2 | __email__ = 'e.taka.1@research.gla.ac.uk' 3 | __version__ = '0.1.0' -------------------------------------------------------------------------------- /ipme/utils/constants.py: -------------------------------------------------------------------------------- 1 | """ Cell Interface 2 | 3 | Colors-related constants: 4 | 5 | Same color palette used in arviz-darkgrid style: 6 | https://arviz-devs.github.io/arviz/examples/matplotlib/mpl_styles.html 7 | 8 | Color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/ 9 | """ 10 | # colors 11 | COLORS = ['#2a2eec', '#fa7c17', '#328c06', '#c10c90', '#933708', '#65e5f3', '#e6e135', '#1ccd6a', '#bd8ad5', '#b16b57'] 12 | BORDER_COLORS=['#d8d8d8','#FFFFFF'] 13 | 14 | ## Plot sizing-related constants 15 | PLOT_WIDTH = 220 16 | PLOT_HEIGHT = 220 17 | SIZING_MODE = "fixed" 18 | 19 | ## Rug plot 20 | RUG_DIST_RATIO = 5.0 21 | RUG_SIZE = 10 #in screen units 22 | 23 | ## Data plot 24 | DATA_DIST_RATIO = 3.0 25 | DATA_SIZE = 6 #in screen units 26 | 27 | """" Grid Interface 28 | 29 | """ 30 | ## 31 | MAX_NUM_OF_COLS_PER_ROW = 12 32 | COLS_PER_VAR = 2 33 | MAX_NUM_OF_VARS_PER_ROW = 5 -------------------------------------------------------------------------------- /ipme/utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import gcd, ceil 3 | 4 | def lcm(list_of_int): 5 | """ 6 | Get the Least Common Multyply (lcm) of a list of integer numbers. 7 | 8 | Parameters: 9 | -------- 10 | list_of_int A List of integers 11 | Returns: 12 | -------- 13 | An integer (the lcm of list_of_int). 14 | """ 15 | try: 16 | lcm = list_of_int[0] 17 | for i in list_of_int[1:]: 18 | lcm = lcm*i/gcd(int(lcm), i) 19 | return int(lcm) 20 | except IndexError: 21 | return None 22 | 23 | def find_indices(lst, condition, xmin=0, xmax=0): 24 | # return [i for i, elem in enumerate(lst) if condition(elem)] 25 | # return [ _ for _ in itertools.compress(list(range(0,len(lst))), map(condition,lst)) ] 26 | return list(np.where((lst>=xmin) & (lst<=xmax))[0]) 27 | 28 | def find_inds_before_after(lst, el): 29 | inds_sm=find_indices(lst, lambda e: e<= el) 30 | if len(inds_sm): 31 | ind_before=inds_sm[-1] 32 | else: 33 | ind_before=-1 34 | inds_bi=find_indices(lst, lambda e: e>= el) 35 | if len(inds_bi): 36 | ind_after=inds_bi[0] 37 | else: 38 | ind_after=-1 39 | return (ind_before,ind_after) 40 | 41 | def find_highest_point( x, y): 42 | x=np.asarray(x, dtype=np.float64) 43 | y=np.asarray(y, dtype=np.float64) 44 | if len(y): 45 | max_idx=np.argmax(y) 46 | return (x[max_idx],y[max_idx]) 47 | else: 48 | return () 49 | 50 | def get_samples_for_pred_check(samples, func): 51 | # samples = np.asarray(samples) 52 | shape = samples.shape 53 | if 0 not in shape: 54 | if len(shape) == 1: 55 | axis = 0 56 | else: 57 | axis = 1 58 | if len(shape) > 2: 59 | fir_dim = shape[0] 60 | sec_dim = 1 61 | for i in np.arange(1,len(shape),1): 62 | sec_dim = sec_dim*shape[i] 63 | samples=samples.reshape(fir_dim,sec_dim) 64 | if func == "min": 65 | samples = samples.min(axis=axis) 66 | elif func == "max": 67 | samples = samples.max(axis=axis) 68 | elif func == "mean": 69 | samples = samples.mean(axis=axis) 70 | elif func == "std": 71 | samples = samples.std(axis=axis) 72 | else: 73 | samples = np.empty([1, 2]) 74 | if ~np.isfinite(samples).all(): 75 | samples = get_finite_samples(samples) 76 | else: 77 | samples = np.empty([1, 2]) 78 | return samples 79 | 80 | def get_finite_samples(np_array): 81 | if isinstance(np_array, np.ndarray): 82 | shape = len(np_array.shape) 83 | if shape == 1: 84 | np_array = np_array[np.isfinite(np_array)] 85 | elif shape > 1: 86 | samples_idx = np.isfinite(np_array).all(axis=shape-1) 87 | for axis in np.arange(shape-2,0,-1): 88 | samples_idx = np.isfinite(np_array).all(axis=axis) 89 | np_array = np_array[samples_idx] 90 | return np_array 91 | 92 | def get_hist_bins_range(samples, func, var_type, ref_length = None, ref_values=None): 93 | """ 94 | Parameters: 95 | -------- 96 | samples Flatten finite samples 97 | func Predictive check criterion {'min','max','mean','std'} 98 | var_type Variable type in {'Discrete','Continuous'} 99 | ref_length A reference length for bin to estimate the number of bins 100 | ref_values A numpy.ndarray with the unique values of a Discrete variable 101 | """ 102 | if func == 'min' or func == 'max' and var_type == "Discrete": 103 | if ref_values is not None: 104 | if len(ref_values)<20: 105 | min_v = ref_values.min() 106 | max_v = ref_values.max() 107 | bins = len(ref_values) 108 | if bins > 1: 109 | range = ( min_v, max_v + (max_v - min_v) / (bins - 1)) 110 | else: 111 | range = ( min_v, min_v+1) 112 | return (bins, range) 113 | else: 114 | values = np.unique(samples) 115 | if len(values) < 20: 116 | min_v = values.min() 117 | max_v = values.max() 118 | bins = len(values) 119 | if bins > 1: 120 | range = ( min_v, max_v + (max_v - min_v) / (bins - 1)) 121 | else: 122 | range = ( min_v, min_v+1) 123 | return (bins, range) 124 | range = (samples.min(),samples.max()) 125 | if ref_length: 126 | bins = ceil((range[1] - range[0]) / ref_length) 127 | range = (range[0], range[0] + bins*ref_length) 128 | else: 129 | bins = 20 130 | return (bins, range) 131 | 132 | def get_dim_names_options(dim): 133 | """ 134 | dim: imd.Dimension object 135 | """ 136 | name1 = dim.name 137 | name2 = None 138 | options1 = dim.values 139 | options2 = [] 140 | if "_idx_" in name1: 141 | idx = name1.find("_idx_") 142 | st_n1 = idx + 5 143 | end_n1 = len(name1) 144 | name2 = name1[st_n1:end_n1] 145 | name1 = name1[0:idx] 146 | values = np.array(dim.values) 147 | options1 = np.unique(values).tolist() 148 | if len(options1): 149 | tmp = np.arange(np.count_nonzero(values == options1[0])) 150 | options2 = list(map(str,tmp)) 151 | return (name1, name2, options1, options2) 152 | 153 | def get_w2_w1_val_mapping(dim): 154 | """ 155 | dim: imd.Dimension object 156 | Returns: 157 | ------- 158 | A Dict {: A List of for this } 159 | """ 160 | options1 = dim.values 161 | values = np.array(dim.values) 162 | options1 = np.unique(values) 163 | val_dict = {} 164 | if len(options1): 165 | for v1 in options1: 166 | tmp = np.arange(np.count_nonzero(values == v1)) 167 | val_dict[v1] = list(map(str,tmp)) 168 | return val_dict 169 | 170 | def get_stratum_range(samples, stratum): 171 | median = np.median(samples) 172 | if stratum == 0 or stratum == 1: 173 | inds_l = np.where(samples=median)[0] 183 | median_h = np.median(samples[inds_h]) 184 | if stratum == 2: 185 | xmin = median 186 | xmax = median_h 187 | elif stratum == 3: 188 | xmin = median_h 189 | xmax = np.max(samples).item() 190 | else: 191 | xmin = np.min(samples).item() 192 | xmax = np.max(samples).item() 193 | return (xmin,xmax) -------------------------------------------------------------------------------- /ipme/utils/js_code.py: -------------------------------------------------------------------------------- 1 | HOVER_CODE=""" 2 | const data = {'x': [], 'y': [], 'isIn': []} 3 | data['x']=source.data.x 4 | data['y']=source.data.y 5 | for (var i = 0; i 1: 113 | new_shape = tuple([-1] + list(x_shape)) 114 | y = y.reshape(new_shape) 115 | 116 | hpd_ = az.hpd(y, credible_interval=credible_interval, circular=False, multimodal=False) 117 | 118 | if smooth: 119 | x_data = np.linspace(x.min(), x.max(), 200) 120 | x_data[0] = (x_data[0] + x_data[1]) / 2 121 | hpd_interp = griddata(x, hpd_, x_data) 122 | y_data = savgol_filter(hpd_interp, axis=0, window_length=55, polyorder=2) 123 | return (np.concatenate((x_data, x_data[::-1])),np.concatenate((y_data[:, 0], y_data[:, 1][::-1]))) 124 | else: 125 | return (np.concatenate((x, x[::-1])),np.concatenate((hpd_[:,0], hpd_[:, 1][::-1]))) 126 | 127 | def hpd(y, credible_interval=0.94): 128 | y = np.asarray(y) 129 | y_shape = y.shape 130 | if 0 in y_shape: 131 | return (np.array([]),np.array([])) 132 | hpd_ = az.hpd(y, credible_interval=credible_interval, circular=False, multimodal=False) 133 | if hpd_.ndim == 1: 134 | hpd_ = np.expand_dims(hpd_, axis=0) 135 | return (hpd_[:,0],hpd_[:,1]) -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/requirements_dev.txt -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:ipme/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | 20 | [aliases] 21 | # Define setup.py command aliases here 22 | 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """The setup script.""" 4 | 5 | from setuptools import setup, find_packages,find_namespace_packages 6 | 7 | # with open('README.md') as readme_file: 8 | # readme = readme_file.read() 9 | 10 | with open('HISTORY.rst') as history_file: 11 | history = history_file.read() 12 | 13 | requirements = ['Click>=7.0', ] 14 | 15 | setup_requirements = [ ] 16 | 17 | test_requirements = [ ] 18 | 19 | setup( 20 | author="Evdoxia Taka", 21 | author_email='e.taka.1@research.gla.ac.uk', 22 | python_requires='>=3.5', 23 | classifiers=[ 24 | 'Development Status :: 2 - Pre-Alpha', 25 | 'Intended Audience :: Developers', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Natural Language :: English', 28 | 'Programming Language :: Python :: 3', 29 | 'Programming Language :: Python :: 3.5', 30 | 'Programming Language :: Python :: 3.6', 31 | 'Programming Language :: Python :: 3.7', 32 | 'Programming Language :: Python :: 3.8', 33 | ], 34 | description="Interactive probabilistic models explorer is an interactive tool for visualizating and exploring Bayesian probabilistic programming models and inference data.", 35 | entry_points={ 36 | 'console_scripts': [ 37 | 'ipme=ipme.cli:main', 38 | ], 39 | }, 40 | install_requires=requirements, 41 | license="MIT license", 42 | # long_description=readme + '\n\n' + history, 43 | include_package_data=True, 44 | keywords='Bayesian probabilistic modeling, Bayesian inference, Markov Chain Monte Carlo, interactive visualization, uncertainty visualization, interpetability', 45 | name='ipme', 46 | packages=find_packages(include=['ipme','ipme*']), 47 | setup_requires=setup_requirements, 48 | test_suite='tests', 49 | tests_require=test_requirements, 50 | url='https://github.com/evdoxiataka/ipme', 51 | version='0.1.0', 52 | zip_safe=False, 53 | ) 54 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py35, py36, py37, py38, flake8 3 | 4 | [travis] 5 | python = 6 | 3.8: py38 7 | 3.7: py37 8 | 3.6: py36 9 | 3.5: py35 10 | 11 | [testenv:flake8] 12 | basepython = python 13 | deps = flake8 14 | commands = flake8 ipme tests 15 | 16 | [testenv] 17 | setenv = 18 | PYTHONPATH = {toxinidir} 19 | 20 | commands = python setup.py test 21 | --------------------------------------------------------------------------------