├── .editorconfig
├── .gitignore
├── .idea
├── .gitignore
├── ipme.iml
├── modules.xml
└── vcs.xml
├── .travis.yml
├── AUTHORS.rst
├── CONTRIBUTING.rst
├── HISTORY.rst
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── examples
├── coal_mining_disasters
│ ├── ipme.py
│ └── model.py
├── data
│ └── evaluation_sleepstudy.csv
├── drivers_reaction_times
│ ├── ipme.py
│ ├── model_hierarchical.py
│ └── model_pooled.py
├── eight_schools_problem
│ ├── ipme.py
│ ├── model_centered.py
│ └── model_non_centered.py
├── golf_putting
│ ├── ipme.py
│ ├── model_geometry.py
│ └── model_simple.py
├── radon_basement
│ ├── ipme.py
│ └── model.py
├── stochastic_volatility
│ ├── ipme.py
│ └── model.py
└── user_study
│ ├── min_temperature
│ ├── ipme.py
│ └── model.py
│ ├── random_number_generator
│ ├── ipme.py
│ └── model.py
│ └── reaction_times
│ ├── ipme.py
│ └── model.py
├── ipme
├── __init__.py
├── classes
│ ├── __init__.py
│ ├── cell
│ │ ├── __init__.py
│ │ ├── interactive_continuous_cell.py
│ │ ├── interactive_discrete_cell.py
│ │ ├── interactive_pred_ckeck_cell.py
│ │ ├── interactive_scatter_cell.py
│ │ ├── static_continuous_cell.py
│ │ ├── static_discrete_cell.py
│ │ ├── static_pred_check_cell.py
│ │ ├── static_scatter_cell.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── cell_clear_selection.py
│ │ │ ├── cell_continuous_handler.py
│ │ │ ├── cell_discrete_handler.py
│ │ │ ├── cell_pred_check_handler.py
│ │ │ ├── cell_scatter_handler.py
│ │ │ ├── cell_widgets.py
│ │ │ └── global_reset.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── data.py
│ │ └── dimension.py
│ ├── graph.py
│ ├── grid
│ │ ├── __init__.py
│ │ ├── graph_grid.py
│ │ ├── predictive_ckecks_grid.py
│ │ └── scatter_matrix_grid.py
│ ├── interaction_control
│ │ ├── __init__.py
│ │ └── interaction_control.py
│ └── scatter_matrix.py
├── cli.py
├── interfaces
│ ├── __init__.py
│ ├── cell.py
│ ├── data_interface.py
│ ├── grid.py
│ ├── predictive_check_cell.py
│ ├── scatter_cell.py
│ └── variable_cell.py
├── methods.py
└── utils
│ ├── __init__.py
│ ├── constants.py
│ ├── functions.py
│ ├── js_code.py
│ └── stats.py
├── requirements_dev.txt
├── setup.cfg
├── setup.py
└── tox.ini
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 | # IDE settings
105 | .vscode/
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/ipme.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | # Config file for automatic testing at travis-ci.com
2 |
3 | language: python
4 | python:
5 | - 3.8
6 | - 3.7
7 | - 3.6
8 | - 3.5
9 |
10 | # Command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors
11 | install: pip install -U tox-travis
12 |
13 | # Command to run tests, e.g. python setup.py test
14 | script: tox
15 |
16 |
17 |
--------------------------------------------------------------------------------
/AUTHORS.rst:
--------------------------------------------------------------------------------
1 | =======
2 | Credits
3 | =======
4 |
5 | Development Lead
6 | ----------------
7 |
8 | * Evdoxia Taka
9 |
10 | Contributors
11 | ------------
12 |
13 | None yet. Why not be the first?
14 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | .. highlight:: shell
2 |
3 | ============
4 | Contributing
5 | ============
6 |
7 | Contributions are welcome, and they are greatly appreciated! Every little bit
8 | helps, and credit will always be given.
9 |
10 | You can contribute in many ways:
11 |
12 | Types of Contributions
13 | ----------------------
14 |
15 | Report Bugs
16 | ~~~~~~~~~~~
17 |
18 | Report bugs at https://github.com/evdoxiataka/ipme/issues.
19 |
20 | If you are reporting a bug, please include:
21 |
22 | * Your operating system name and version.
23 | * Any details about your local setup that might be helpful in troubleshooting.
24 | * Detailed steps to reproduce the bug.
25 |
26 | Fix Bugs
27 | ~~~~~~~~
28 |
29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
30 | wanted" is open to whoever wants to implement it.
31 |
32 | Implement Features
33 | ~~~~~~~~~~~~~~~~~~
34 |
35 | Look through the GitHub issues for features. Anything tagged with "enhancement"
36 | and "help wanted" is open to whoever wants to implement it.
37 |
38 | Write Documentation
39 | ~~~~~~~~~~~~~~~~~~~
40 |
41 | ipme could always use more documentation, whether as part of the
42 | official ipme docs, in docstrings, or even on the web in blog posts,
43 | articles, and such.
44 |
45 | Submit Feedback
46 | ~~~~~~~~~~~~~~~
47 |
48 | The best way to send feedback is to file an issue at https://github.com/evdoxiataka/ipme/issues.
49 |
50 | If you are proposing a feature:
51 |
52 | * Explain in detail how it would work.
53 | * Keep the scope as narrow as possible, to make it easier to implement.
54 | * Remember that this is a volunteer-driven project, and that contributions
55 | are welcome :)
56 |
57 | Get Started!
58 | ------------
59 |
60 | Ready to contribute? Here's how to set up `ipme` for local development.
61 |
62 | 1. Fork the `ipme` repo on GitHub.
63 | 2. Clone your fork locally::
64 |
65 | $ git clone git@github.com:your_name_here/ipme.git
66 |
67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::
68 |
69 | $ mkvirtualenv ipme
70 | $ cd ipme/
71 | $ python setup.py develop
72 |
73 | 4. Create a branch for local development::
74 |
75 | $ git checkout -b name-of-your-bugfix-or-feature
76 |
77 | Now you can make your changes locally.
78 |
79 | 5. When you're done making changes, check that your changes pass flake8 and the
80 | tests, including testing other Python versions with tox::
81 |
82 | $ flake8 ipme tests
83 | $ python setup.py test or pytest
84 | $ tox
85 |
86 | To get flake8 and tox, just pip install them into your virtualenv.
87 |
88 | 6. Commit your changes and push your branch to GitHub::
89 |
90 | $ git add .
91 | $ git commit -m "Your detailed description of your changes."
92 | $ git push origin name-of-your-bugfix-or-feature
93 |
94 | 7. Submit a pull request through the GitHub website.
95 |
96 | Pull Request Guidelines
97 | -----------------------
98 |
99 | Before you submit a pull request, check that it meets these guidelines:
100 |
101 | 1. The pull request should include tests.
102 | 2. If the pull request adds functionality, the docs should be updated. Put
103 | your new functionality into a function with a docstring, and add the
104 | feature to the list in README.rst.
105 | 3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check
106 | https://travis-ci.com/evdoxiataka/ipme/pull_requests
107 | and make sure that the tests pass for all supported Python versions.
108 |
109 | Tips
110 | ----
111 |
112 | To run a subset of tests::
113 |
114 |
115 | $ python -m unittest tests.test_imd
116 |
117 | Deploying
118 | ---------
119 |
120 | A reminder for the maintainers on how to deploy.
121 | Make sure all your changes are committed (including an entry in HISTORY.rst).
122 | Then run::
123 |
124 | $ bump2version patch # possible: major / minor / patch
125 | $ git push
126 | $ git push --tags
127 |
128 | Travis will then deploy to PyPI if tests pass.
129 |
--------------------------------------------------------------------------------
/HISTORY.rst:
--------------------------------------------------------------------------------
1 | =======
2 | History
3 | =======
4 |
5 | 0.1.0 (2020-05-31)
6 | ------------------
7 |
8 | * First release on PyPI.
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020, Evdoxia Taka
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include AUTHORS.rst
2 | include CONTRIBUTING.rst
3 | include HISTORY.rst
4 | include LICENSE
5 | include README.rst
6 |
7 | recursive-include tests *
8 | recursive-exclude * __pycache__
9 | recursive-exclude * *.py[co]
10 |
11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
12 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: clean clean-test clean-pyc clean-build docs help
2 | .DEFAULT_GOAL := help
3 |
4 | define BROWSER_PYSCRIPT
5 | import os, webbrowser, sys
6 |
7 | from urllib.request import pathname2url
8 |
9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
10 | endef
11 | export BROWSER_PYSCRIPT
12 |
13 | define PRINT_HELP_PYSCRIPT
14 | import re, sys
15 |
16 | for line in sys.stdin:
17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
18 | if match:
19 | target, help = match.groups()
20 | print("%-20s %s" % (target, help))
21 | endef
22 | export PRINT_HELP_PYSCRIPT
23 |
24 | BROWSER := python -c "$$BROWSER_PYSCRIPT"
25 |
26 | help:
27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
28 |
29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
30 |
31 | clean-build: ## remove build artifacts
32 | rm -fr build/
33 | rm -fr dist/
34 | rm -fr .eggs/
35 | find . -name '*.egg-info' -exec rm -fr {} +
36 | find . -name '*.egg' -exec rm -f {} +
37 |
38 | clean-pyc: ## remove Python file artifacts
39 | find . -name '*.pyc' -exec rm -f {} +
40 | find . -name '*.pyo' -exec rm -f {} +
41 | find . -name '*~' -exec rm -f {} +
42 | find . -name '__pycache__' -exec rm -fr {} +
43 |
44 | clean-test: ## remove test and coverage artifacts
45 | rm -fr .tox/
46 | rm -f .coverage
47 | rm -fr htmlcov/
48 | rm -fr .pytest_cache
49 |
50 | lint: ## check style with flake8
51 | flake8 ipme tests
52 |
53 | test: ## run tests quickly with the default Python
54 | python setup.py test
55 |
56 | test-all: ## run tests on every Python version with tox
57 | tox
58 |
59 | coverage: ## check code coverage quickly with the default Python
60 | coverage run --source ipme setup.py test
61 | coverage report -m
62 | coverage html
63 | $(BROWSER) htmlcov/index.html
64 |
65 | docs: ## generate Sphinx HTML documentation, including API docs
66 | rm -f docs/ipme.rst
67 | rm -f docs/modules.rst
68 | sphinx-apidoc -o docs/ ipme
69 | $(MAKE) -C docs clean
70 | $(MAKE) -C docs html
71 | $(BROWSER) docs/_build/html/index.html
72 |
73 | servedocs: docs ## compile the docs watching for changes
74 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D .
75 |
76 | release: dist ## package and upload a release
77 | twine upload dist/*
78 |
79 | dist: clean ## builds source and wheel package
80 | python setup.py sdist
81 | python setup.py bdist_wheel
82 | ls -l dist
83 |
84 | install: clean ## install the package to the active Python's site-packages
85 | python setup.py install
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Interactive Visualizations of Probabilistic Models
2 | This package provides interactive visualizations of probabilistc models' inherent uncertainty in the prior and posterior sample space.
3 |
4 | These visualizations are intented for interactive exploration of the *uncertainty* in Bayesian probabilistic models and enhancement of their interpretability.
5 |
6 | ## Requirements
7 | * Probabilistic models are expressed in a Probabilistic Programming Language (PPL), and
8 | * A sample-based inference algorithm is used for the inference (e.g. MCMC).
9 |
10 | ## Input
11 | The visualizations provided by this package take as an input a zip file in the *.npz* format. This file contains a description of the model's structure and the inference results (MCMC samples)
12 | in a standardized form of a collection of npy arrays and metadata (in a json format).
13 | The [arviz_json](https://github.com/johnhw/arviz_json) package creates this standardized output of probabilistic models expressed in *PyMC3*. For more details about the standardization of the IPME input see *Taka et al. 2020*.
14 |
15 | The following figure presents the pipeline for the automatic transformation of a probabilistic model expressed in a PPL and its sample-based inference results into the standardized .npz file.
16 |
17 | 
18 |
19 | ## Example of a Probabilistic Model
20 | The following probabilistic statements describe the drivers' reaction times model under sleep-deprivation conditions for 10 consecutive days and 18 lorry drivers. This is a hierarchical linear regression probabilistic model. The problem and data for this model retrieved from *Belenky et al. 2003*.
21 |
22 | 
23 |
24 | ## Interactive Probabilistic Models Explorer (IPME)
25 | The *IPME* representation of the model is presented in the following figure. See more details about this representation in *Taka et al. 2020*. A demo of this visualization can be found in my [talk](https://www.youtube.com/watch?v=2hadiSJRAJI&feature=youtu.be) to PyMCON 2020.
26 | ```python
27 | import ipme
28 | """
29 | mode: String in {'i','s'} for interactive or static
30 | vars: 'all' or List of variable names e.g. ['a','b']
31 | spaces: String in {'all','prior','posterior'} or List of spaces e.g. ['prior','posterior']
32 | predictive_checks: List of observed variables names
33 | """
34 | ipme.graph("reaction_times_hierarchical.npz", mode = "i", vars = 'all', spaces = 'all', predictive_checks = ['y_pred'])
35 | ```
36 |
37 |
38 |
39 |
40 | https://user-images.githubusercontent.com/37831445/205636036-ec1a6820-f368-4332-9fcf-0a508c8dd63f.mp4
41 |
42 |
43 | ## Interactive Pair Plot (IPP)
44 | The *IPP* representation of the model is presented in the following figure. This visualization was introduced and evaluated in *Taka et al. 2022*.
45 | ```python
46 | import ipme
47 | """
48 | mode: String in {'i','s'} for interactive or static
49 | vars: List of variable names e.g. ['a','b']
50 | spaces: String in {'all','prior','posterior'} or List of spaces e.g. ['prior','posterior']
51 | """
52 | ipme.scatter_matrix('reaction_times_hierarchical.npz', mode = "i", vars = ['sigma_a','sigma_b','sigma_sigma','mu_a','mu_b','sigma','a','b','y_pred'], spaces = 'all')
53 | ```
54 |
55 |
56 |
57 | https://user-images.githubusercontent.com/37831445/205637701-ac11ff87-c240-4882-8b69-8b0b15d8e344.mp4
58 |
59 |
60 | # Examples
61 | The folder `/examples` in this repository includes some examples of use. The examples illustrates the definition of Bayesian probabilistic models and running of sample-based inference in PyMC3. The examples are organized per problem. Each problem's directory includes the following Python scripts:
62 | * *`model.py`*: includes the definition of the model in PyMC3, and exports the inference data into a *.npz* file.
63 | * *`ipme.py`*: demonstrates the use of the ipme package for the visualization of the model.
64 |
65 | The folder `/examples/user_study` contains the models used in the user study presented in *Taka et al. 2022*.
66 |
67 | **Note:** To run these scripts, you need to install the following Python libraries: PyMC3, ArviZ, and the arviz_json and ipme packages (the last two can only be installed through github).
68 |
69 | # Please Cite:
70 | *E. Taka, S. Stein, and J. H. Williamson.* Increasing interpretability of Bayesian probabilistic programming models through interactive representations. Frontiers in Computer Science, 2:52, 2020. doi: 10.3389/fcomp.2020.567344. URL: https://www.frontiersin.org/article/10.3389/fcomp.2020.567344
71 |
72 | *E. Taka, S. Stein and J. H. Williamson,* "Does Interactive Conditioning Help Users Better Understand the Structure of Probabilistic Models?," in IEEE Transactions on Visualization and Computer Graphics, doi: 10.1109/TVCG.2022.3231967
73 |
74 | # References
75 | *E. Taka, S. Stein, and J. H. Williamson.* Increasing interpretability of Bayesian probabilistic programming models through interactive representations. Frontiers in Computer Science, 2:52, 2020. doi: 10.3389/fcomp.2020.567344. URL: https://www.frontiersin.org/article/10.3389/fcomp.2020.567344
76 |
77 | *E. Taka, S. Stein and J. H. Williamson,* "Does Interactive Conditioning Help Users Better Understand the Structure of Probabilistic Models?," in IEEE Transactions on Visualization and Computer Graphics, doi: 10.1109/TVCG.2022.3231967
78 |
79 | *G. Belenky, N. J. Wesensten, D. R. Thorne, M. L. Thomas, H. C. Sing, D. P. Redmond, M. B. Russo, and T. J.Balkin.* Patterns of performance degradation and restoration during sleep restriction and subsequent recovery: a sleep dose-response study. Journal of Sleep Research, vol. 12, no. 1, pp. 1–12, 2003. URL: https://onlinelibrary.wiley.com/doi/abs/10.1046/j.1365-2869.2003.00337.x
80 |
81 |
82 | *The “Closed-Loop Data Science for Complex, Computationally- and Data-Intensive Analytics” project*. URL: https://www.gla.ac.uk/schools/computing/research/researchsections/ida-section/closedloop/
83 |
84 | *arviz_json* (Automatic Transformation of PyMC3 models into standardized output): https://github.com/johnhw/arviz_json
85 |
--------------------------------------------------------------------------------
/examples/coal_mining_disasters/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='coal_mining_disasters_PyMC3.npz'
5 | ipme.graph(infer_datapath, predictive_checks = ['disasters'])
--------------------------------------------------------------------------------
/examples/coal_mining_disasters/model.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pymc3 as pm
3 | import numpy as np
4 | import arviz as az
5 | from arviz_json import get_dag, arviz_to_json
6 |
7 | ## Discrete Variables Model
8 | ## Reference: https://docs.pymc.io/notebooks/getting_started.html#Case-study-2:-Coal-mining-disasters
9 | #data
10 | disaster_data = pd.Series([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
11 | 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
12 | 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
13 | 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
14 | 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
15 | 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
16 | 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
17 | years = np.arange(1851, 1962)
18 | years_missing=[1890,1935]
19 |
20 | #model-inference
21 | fileName='coal_mining_disasters_PyMC3'
22 | samples=10000
23 | tune=10000
24 | chains=2
25 | coords = {"year": years}
26 | with pm.Model(coords=coords) as disaster_model:
27 | switchpoint = pm.DiscreteUniform('switchpoint', lower=years.min(), upper=years.max(), testval=1900)
28 | early_rate = pm.Exponential('early_rate', 1)
29 | late_rate = pm.Exponential('late_rate', 1)
30 | rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
31 | disasters = pm.Poisson('disasters', rate, observed=disaster_data, dims='year')
32 | #inference
33 | trace = pm.sample(samples, chains=chains, tune=tune)
34 | prior = pm.sample_prior_predictive(samples=samples)
35 | posterior_predictive = pm.sample_posterior_predictive(trace,samples=samples)
36 |
37 | ## STEP 1
38 | # will also capture all the sampler statistics
39 | data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive)
40 |
41 | ## STEP 2
42 | # extract dag
43 | dag = get_dag(disaster_model)
44 | # insert dag into sampler stat attributes
45 | data.sample_stats.attrs["graph"] = str(dag)
46 |
47 | ## STEP 3
48 | # save data
49 | arviz_to_json(data, fileName+'.npz')
50 |
--------------------------------------------------------------------------------
/examples/data/evaluation_sleepstudy.csv:
--------------------------------------------------------------------------------
1 | "","Reaction","Days","Subject"
2 | "1",249.56,0,"308"
3 | "2",258.7047,1,"308"
4 | "3",250.8006,2,"308"
5 | "4",321.4398,3,"308"
6 | "5",356.8519,4,"308"
7 | "6",414.6901,5,"308"
8 | "7",382.2038,6,"308"
9 | "8",290.1486,7,"308"
10 | "9",430.5853,8,"308"
11 | "10",466.3535,9,"308"
12 | "11",222.7339,0,"309"
13 | "12",205.2658,1,"309"
14 | "13",202.9778,2,"309"
15 | "14",204.707,3,"309"
16 | "15",207.7161,4,"309"
17 | "16",215.9618,5,"309"
18 | "17",213.6303,6,"309"
19 | "18",217.7272,7,"309"
20 | "19",224.2957,8,"309"
21 | "20",237.3142,9,"309"
22 | "21",199.0539,0,"310"
23 | "22",194.3322,1,"310"
24 | "23",234.32,2,"310"
25 | "24",232.8416,3,"310"
26 | "25",229.3074,4,"310"
27 | "26",220.4579,5,"310"
28 | "27",235.4208,6,"310"
29 | "28",255.7511,7,"310"
30 | "29",261.0125,8,"310"
31 | "30",247.5153,9,"310"
32 | "31",321.5426,0,"330"
33 | "32",300.4002,1,"330"
34 | "33",283.8565,2,"330"
35 | "34",285.133,3,"330"
36 | "35",285.7973,4,"330"
37 | "36",297.5855,5,"330"
38 | "37",280.2396,6,"330"
39 | "38",318.2613,7,"330"
40 | "39",305.3495,8,"330"
41 | "40",354.0487,9,"330"
42 | "41",287.6079,0,"331"
43 | "42",285,1,"331"
44 | "43",301.8206,2,"331"
45 | "44",320.1153,3,"331"
46 | "45",316.2773,4,"331"
47 | "46",293.3187,5,"331"
48 | "47",290.075,6,"331"
49 | "48",334.8177,7,"331"
50 | "49",293.7469,8,"331"
51 | "50",371.5811,9,"331"
52 | "51",234.8606,0,"332"
53 | "52",242.8118,1,"332"
54 | "53",272.9613,2,"332"
55 | "54",309.7688,3,"332"
56 | "55",317.4629,4,"332"
57 | "56",309.9976,5,"332"
58 | "57",454.1619,6,"332"
59 | "58",346.8311,7,"332"
60 | "59",330.3003,8,"332"
61 | "60",253.8644,9,"332"
62 | "61",283.8424,0,"333"
63 | "62",289.555,1,"333"
64 | "63",276.7693,2,"333"
65 | "64",299.8097,3,"333"
66 | "65",297.171,4,"333"
67 | "66",338.1665,5,"333"
68 | "67",332.0265,6,"333"
69 | "68",348.8399,7,"333"
70 | "69",333.36,8,"333"
71 | "70",362.0428,9,"333"
72 | "71",265.4731,0,"334"
73 | "72",276.2012,1,"334"
74 | "73",243.3647,2,"334"
75 | "74",254.6723,3,"334"
76 | "75",279.0244,4,"334"
77 | "76",284.1912,5,"334"
78 | "77",305.5248,6,"334"
79 | "78",331.5229,7,"334"
80 | "79",335.7469,8,"334"
81 | "80",377.299,9,"334"
82 | "81",241.6083,0,"335"
83 | "82",273.9472,1,"335"
84 | "83",254.4907,2,"335"
85 | "84",270.8021,3,"335"
86 | "85",251.4519,4,"335"
87 | "86",254.6362,5,"335"
88 | "87",245.4523,6,"335"
89 | "88",235.311,7,"335"
90 | "89",235.7541,8,"335"
91 | "90",237.2466,9,"335"
92 | "91",312.3666,0,"337"
93 | "92",313.8058,1,"337"
94 | "93",291.6112,2,"337"
95 | "94",346.1222,3,"337"
96 | "95",365.7324,4,"337"
97 | "96",391.8385,5,"337"
98 | "97",404.2601,6,"337"
99 | "98",416.6923,7,"337"
100 | "99",455.8643,8,"337"
101 | "100",458.9167,9,"337"
102 | "101",236.1032,0,"349"
103 | "102",230.3167,1,"349"
104 | "103",238.9256,2,"349"
105 | "104",254.922,3,"349"
106 | "105",250.7103,4,"349"
107 | "106",269.7744,5,"349"
108 | "107",281.5648,6,"349"
109 | "108",308.102,7,"349"
110 | "109",336.2806,8,"349"
111 | "110",351.6451,9,"349"
112 | "111",256.2968,0,"350"
113 | "112",243.4543,1,"350"
114 | "113",256.2046,2,"350"
115 | "114",255.5271,3,"350"
116 | "115",268.9165,4,"350"
117 | "116",329.7247,5,"350"
118 | "117",379.4445,6,"350"
119 | "118",362.9184,7,"350"
120 | "119",394.4872,8,"350"
121 | "120",389.0527,9,"350"
122 | "121",250.5265,0,"351"
123 | "122",300.0576,1,"351"
124 | "123",269.8939,2,"351"
125 | "124",280.5891,3,"351"
126 | "125",271.8274,4,"351"
127 | "126",304.6336,5,"351"
128 | "127",287.7466,6,"351"
129 | "128",266.5955,7,"351"
130 | "129",321.5418,8,"351"
131 | "130",347.5655,9,"351"
132 | "131",221.6771,0,"352"
133 | "132",298.1939,1,"352"
134 | "133",326.8785,2,"352"
135 | "134",346.8555,3,"352"
136 | "135",348.7402,4,"352"
137 | "136",352.8287,5,"352"
138 | "137",354.4266,6,"352"
139 | "138",360.4326,7,"352"
140 | "139",375.6406,8,"352"
141 | "140",388.5417,9,"352"
142 | "141",271.9235,0,"369"
143 | "142",268.4369,1,"369"
144 | "143",257.2424,2,"369"
145 | "144",277.6566,3,"369"
146 | "145",314.8222,4,"369"
147 | "146",317.2135,5,"369"
148 | "147",298.1353,6,"369"
149 | "148",348.1229,7,"369"
150 | "149",340.28,8,"369"
151 | "150",366.5131,9,"369"
152 | "151",225.264,0,"370"
153 | "152",234.5235,1,"370"
154 | "153",238.9008,2,"370"
155 | "154",240.473,3,"370"
156 | "155",267.5373,4,"370"
157 | "156",344.1937,5,"370"
158 | "157",281.1481,6,"370"
159 | "158",347.5855,7,"370"
160 | "159",365.163,8,"370"
161 | "160",372.2288,9,"370"
162 | "161",269.8804,0,"371"
163 | "162",272.4428,1,"371"
164 | "163",277.8989,2,"371"
165 | "164",281.7895,3,"371"
166 | "165",279.1705,4,"371"
167 | "166",284.512,5,"371"
168 | "167",259.2658,6,"371"
169 | "168",304.6306,7,"371"
170 | "169",350.7807,8,"371"
171 | "170",369.4692,9,"371"
172 | "171",269.4117,0,"372"
173 | "172",273.474,1,"372"
174 | "173",297.5968,2,"372"
175 | "174",310.6316,3,"372"
176 | "175",287.1726,4,"372"
177 | "176",329.6076,5,"372"
178 | "177",334.4818,6,"372"
179 | "178",343.2199,7,"372"
180 | "179",369.1417,8,"372"
181 | "180",364.1236,9,"372"
182 |
--------------------------------------------------------------------------------
/examples/drivers_reaction_times/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='reaction_times_pooled.npz'
5 | # infer_datapath='reaction_times_hierarchical.npz'
6 | ipme.graph(infer_datapath, predictive_checks = ['y_pred'])
--------------------------------------------------------------------------------
/examples/drivers_reaction_times/model_hierarchical.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pymc3 as pm
3 | import arviz as az
4 | from arviz_json import get_dag, arviz_to_json
5 |
6 | ## data
7 | DATAPATH='../data/evaluation_sleepstudy.csv' ## data from Belenky et al. (2003)
8 | reactions = pd.read_csv(DATAPATH, usecols=['Reaction','Days','Subject'])
9 |
10 | ## Drivers Reaction Times Hierarchical Model
11 | samples = 2000
12 | chains = 2
13 | tune = 1000
14 |
15 | ## data
16 | driver_idx, drivers = pd.factorize(reactions["Subject"], sort=True)
17 | day_idx, days = pd.factorize(reactions["Days"], sort=True)
18 |
19 | ##dims
20 | coords_h = {"driver": drivers,"driver_idx_day":reactions.Subject}
21 |
22 | with pm.Model(coords=coords_h) as hierarchical_model:
23 | ## model
24 | #hyper-priors
25 | mu_a = pm.Normal('mu_a' ,mu=100, sd=250)
26 | sigma_a = pm.HalfNormal('sigma_a', sd=250)
27 | mu_b = pm.Normal('mu_b',mu=10, sd=250)
28 | sigma_b = pm.HalfNormal('sigma_b', sd=250)
29 | sigma_sigma = pm.HalfNormal('sigma_sigma', sd=200)
30 | #priors
31 | a = pm.Normal("a", mu=mu_a, sd=sigma_a, dims="driver")
32 | b = pm.Normal("b", mu=mu_b, sd=sigma_b, dims="driver")
33 | sigma = pm.HalfNormal("sigma", sd=sigma_sigma, dims="driver")
34 | y_pred = pm.Normal('y_pred', mu=a[driver_idx]+b[driver_idx]*day_idx, sd=sigma[driver_idx], observed=reactions.Reaction, dims="driver_idx_day")
35 | ## inference
36 | trace_hi = pm.sample(draws=samples, chains=chains, tune=tune)
37 | prior_hi = pm.sample_prior_predictive(samples=samples)
38 | posterior_predictive_hi = pm.sample_posterior_predictive(trace_hi, samples=samples)
39 |
40 | ## STEP 1
41 | ## export inference results in ArviZ InferenceData obj
42 | ## will also capture all the sampler statistics
43 | data_hi = az.from_pymc3(trace = trace_hi, prior = prior_hi, posterior_predictive = posterior_predictive_hi)
44 |
45 | ## STEP 2
46 | ## extract dag
47 | dag_hi = get_dag(hierarchical_model)
48 | ## insert dag into sampler stat attributes
49 | data_hi.sample_stats.attrs["graph"] = str(dag_hi)
50 |
51 | ## STEP 3
52 | ## save data
53 | fileName_hi = "reaction_times_hierarchical"
54 | arviz_to_json(data_hi, fileName_hi+'.npz')
--------------------------------------------------------------------------------
/examples/drivers_reaction_times/model_pooled.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pymc3 as pm
3 | import arviz as az
4 | from arviz_json import get_dag, arviz_to_json
5 |
6 | ## data
7 | DATAPATH='../data/evaluation_sleepstudy.csv' ## data from Belenky et al. (2003)
8 | reactions = pd.read_csv(DATAPATH,usecols=['Reaction','Days','Subject'])
9 |
10 | ## Drivers Reaction Times Pooled Model
11 | samples=4000
12 | chains=2
13 | tune=2000
14 |
15 | ## data
16 | driver_idx, drivers = pd.factorize(reactions["Subject"], sort=True)
17 | day_idx, days = pd.factorize(reactions["Days"], sort=True)
18 |
19 | ## dims
20 | coords_p = {"driver": drivers,"driver_idx_day":reactions.Subject}
21 |
22 | with pm.Model(coords=coords_p) as fullyPooled_model:
23 | ## model
24 | a = pm.Normal("a", mu=100, sd=250)
25 | b = pm.Normal("b", mu=10, sd=250)
26 | sigma = pm.HalfNormal("sigma", sd=200)
27 | y_pred = pm.Normal('y_pred', mu = a + b*day_idx, sd = sigma, observed = reactions.Reaction, dims = "driver_idx_day")
28 | ## inference
29 | trace_p = pm.sample(samples, chains=chains, tune=tune)
30 | prior_p = pm.sample_prior_predictive(samples=samples)
31 | posterior_predictive_p = pm.sample_posterior_predictive(trace_p, samples=samples)
32 |
33 | ## STEP 1
34 | ## export inference results in ArviZ InferenceData obj
35 | ## will also capture all the sampler statistics
36 | data_p = az.from_pymc3(trace = trace_p, prior = prior_p, posterior_predictive = posterior_predictive_p)
37 |
38 | ## STEP 2
39 | ## extract dag
40 | dag_p = get_dag(fullyPooled_model)
41 | ## insert dag into sampler stat attributes
42 | data_p.sample_stats.attrs["graph"] = str(dag_p)
43 |
44 | ## STEP 3
45 | ## save data
46 | fileName_p = "reaction_times_pooled"
47 | arviz_to_json(data_p, fileName_p+'.npz')
--------------------------------------------------------------------------------
/examples/eight_schools_problem/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='inference_8_schools_centered.npz'
5 | #infer_datapath='inference_8_schools_non_centered.npz'
6 | ipme.graph(infer_datapath, predictive_checks = ['y'])
7 |
--------------------------------------------------------------------------------
/examples/eight_schools_problem/model_centered.py:
--------------------------------------------------------------------------------
1 | import pymc3 as pm
2 | import numpy as np
3 | import arviz as az
4 | from arviz_json import get_dag, arviz_to_json
5 | SEED = [20100420, 20134234]
6 |
7 | #Hierarchical Model
8 | #Reference: https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html#The-Eight-Schools-Model
9 | #data
10 | J = 8
11 | obs = np.array([28., 8., -3., 7., -1., 1., 18., 12.])
12 | sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])
13 |
14 | #model-inference
15 | coords_c = {"school": ["A","B","C","D","E","F","G","H"]}
16 | fileName_c="eight_schools_centered"
17 | samples=4000
18 | chains=2
19 | tune=1000
20 | with pm.Model(coords=coords_c) as centered_eight:
21 | mu = pm.Normal('mu', mu=0, sigma=5)
22 | tau = pm.HalfCauchy('tau', beta=5)
23 | theta = pm.Normal('theta', mu=mu, sigma=tau, dims='school')
24 | y = pm.Normal('y', mu=theta, sigma=sigma, observed=obs, dims='school')
25 | #inference
26 | trace_c = pm.sample(samples, chains=chains, tune=tune, random_seed=SEED)
27 | prior_c= pm.sample_prior_predictive(samples=samples)
28 | posterior_predictive_c = pm.sample_posterior_predictive(trace_c, samples=samples)
29 |
30 | ## STEP 1
31 | # will also capture all the sampler statistics
32 | data_c = az.from_pymc3(trace = trace_c, prior = prior_c, posterior_predictive = posterior_predictive_c)
33 |
34 | ## STEP 2
35 | #dag
36 | dag_c = get_dag(centered_eight)
37 | # insert dag into sampler stat attributes
38 | data_c.sample_stats.attrs["graph"] = str(dag_c)
39 |
40 | ## STEP 3
41 | # save data
42 | arviz_to_json(data_c, fileName_c+'.npz')
43 |
--------------------------------------------------------------------------------
/examples/eight_schools_problem/model_non_centered.py:
--------------------------------------------------------------------------------
1 | import pymc3 as pm
2 | import pandas as pd
3 | import numpy as np
4 | import arviz as az
5 | from arviz_json import get_dag, arviz_to_json
6 | SEED = [20100420, 20134234]
7 |
8 | #Hierarchical Model
9 | #Reference: https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html#A-Non-Centered-Eight-Schools-Implementation
10 | #data
11 | J = 8
12 | obs = np.array([28., 8., -3., 7., -1., 1., 18., 12.])
13 | sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])
14 |
15 | #model-inference
16 | coords = {"school": ["A","B","C","D","E","F","G","H"]}
17 | samples=5000
18 | chains=2
19 | tune=1000
20 | fileName="eight_schools_non_centered"
21 | with pm.Model(coords=coords) as NonCentered_eight:
22 | mu = pm.Normal('mu', mu=0, sigma=5)
23 | tau = pm.HalfCauchy('tau', beta=5)
24 | theta_tilde = pm.Normal('theta_t', mu=0, sigma=1, dims='school')
25 | theta = pm.Deterministic('theta', mu + tau * theta_tilde, dims='school')
26 | y = pm.Normal('y', mu=theta, sigma=sigma, observed=obs, dims='school')
27 | #inference
28 | trace_nc = pm.sample(samples, chains=chains, tune=tune, random_seed=SEED, target_accept=.90)
29 | prior_nc= pm.sample_prior_predictive(samples=samples)
30 | posterior_predictive_nc = pm.sample_posterior_predictive(trace_nc,samples=samples)
31 |
32 | ## STEP 1
33 | # will also capture all the sampler statistics
34 | data_nc = az.from_pymc3(trace = trace_nc, prior = prior_nc, posterior_predictive = posterior_predictive_nc)
35 |
36 | ## STEP 2
37 | #dag
38 | dag_nc = get_dag(NonCentered_eight)
39 | # insert dag into sampler stat attributes
40 | data_nc.sample_stats.attrs["graph"] = str(dag_nc)
41 |
42 | ## STEP 3
43 | # save data
44 | arviz_to_json(data_nc, fileName+'.npz')
45 |
--------------------------------------------------------------------------------
/examples/golf_putting/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='golf_simple_PyMC3.npz'
5 | #infer_datapath='golf_geometry_PyMC3.npz'
6 | ipme.graph(infer_datapath, predictive_checks = ['successes'])
7 |
--------------------------------------------------------------------------------
/examples/golf_putting/model_geometry.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import io
3 | import pymc3 as pm
4 | import arviz as az
5 | from arviz_json import get_dag, arviz_to_json
6 | import theano.tensor as tt
7 |
8 | CUP_RADIUS = (4.25 / 2) / 12
9 | BALL_RADIUS = (1.68 / 2) / 12
10 |
11 | def Phi(x):
12 | """Calculates the standard normal cumulative distribution function."""
13 | return 0.5 + 0.5 * tt.erf(x / tt.sqrt(2.))
14 |
15 | #Binomial Model
16 | #Reference: https://docs.pymc.io/notebooks/putting_workflow.html#Geometry-based-model
17 | #data
18 | golf_data = """distance tries successes
19 | 2 1443 1346
20 | 3 694 577
21 | 4 455 337
22 | 5 353 208
23 | 6 272 149
24 | 7 256 136
25 | 8 240 111
26 | 9 217 69
27 | 10 200 67
28 | 11 237 75
29 | 12 202 52
30 | 13 192 46
31 | 14 174 54
32 | 15 167 28
33 | 16 201 27
34 | 17 195 31
35 | 18 191 33
36 | 19 147 20
37 | 20 152 24"""
38 | data = pd.read_csv(io.StringIO(golf_data), sep=" ")
39 |
40 | #model-inference
41 | coords = {"distance": data.distance}
42 | fileName='golf_geometry_PyMC3'
43 | samples=2000
44 | chains=2
45 | tune=1000
46 | geometry_model=pm.Model(coords=coords)
47 | with geometry_model:
48 | #to store the n-parameter of Binomial dist
49 | #in the constant group of ArviZ InferenceData
50 | #You should always call it n for imd to retrieve it
51 | n = pm.Data('n', data.tries)
52 | sigma_angle = pm.HalfNormal('sigma_angle')
53 | p_goes_in = pm.Deterministic('p_goes_in', 2 * Phi(tt.arcsin((CUP_RADIUS - BALL_RADIUS) / data.distance) / sigma_angle) - 1, dims='distance')
54 | successes = pm.Binomial('successes', n=n, p=p_goes_in, observed=data.successes, dims='distance')
55 | #inference
56 | trace_g = pm.sample(draws=samples, chains=chains, tune=tune)
57 | prior_g= pm.sample_prior_predictive(samples=samples)
58 | posterior_predictive_g = pm.sample_posterior_predictive(trace_g,samples=samples)
59 |
60 | ## STEP 1
61 | # will also capture all the sampler statistics
62 | data_g = az.from_pymc3(trace=trace_g, prior=prior_g, posterior_predictive=posterior_predictive_g)
63 |
64 | ## STEP 2
65 | #dag
66 | dag_g = get_dag(geometry_model)
67 | # insert dag into sampler stat attributes
68 | data_g.sample_stats.attrs["graph"] = str(dag_g)
69 |
70 | ## STEP 3
71 | # save data
72 | arviz_to_json(data_g, fileName+'.npz')
--------------------------------------------------------------------------------
/examples/golf_putting/model_simple.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import io
3 | import pymc3 as pm
4 | import arviz as az
5 | from arviz_json import get_dag, arviz_to_json
6 |
7 | #Binomial Logistic Regression Model
8 | #Reference: https://docs.pymc.io/notebooks/putting_workflow.html#Logit-model
9 | #data
10 | golf_data = """distance tries successes
11 | 2 1443 1346
12 | 3 694 577
13 | 4 455 337
14 | 5 353 208
15 | 6 272 149
16 | 7 256 136
17 | 8 240 111
18 | 9 217 69
19 | 10 200 67
20 | 11 237 75
21 | 12 202 52
22 | 13 192 46
23 | 14 174 54
24 | 15 167 28
25 | 16 201 27
26 | 17 195 31
27 | 18 191 33
28 | 19 147 20
29 | 20 152 24"""
30 | data = pd.read_csv(io.StringIO(golf_data), sep=" ")
31 |
32 | #model-inference
33 | coords = {"distance": data.distance}
34 | fileName='golf_simple_PyMC3'
35 | samples=2000
36 | chains=2
37 | tune=1000
38 | simple_model=pm.Model(coords=coords)
39 | with simple_model:
40 | #to store the n-parameter of Binomial dist
41 | #in the constant group of ArviZ InferenceData
42 | #You should always call it n for imd to retrieve it
43 | n = pm.Data('n', data.tries)
44 | a = pm.Normal('a')
45 | b = pm.Normal('b')
46 | p_goes_in = pm.Deterministic('p_goes_in', pm.math.invlogit(a * data.distance + b), dims='distance')
47 | successes = pm.Binomial('successes', n=n, p=p_goes_in, observed=data.successes, dims='distance')
48 | #inference
49 | # Get posterior trace, prior trace, posterior predictive samples, and the DAG
50 | trace = pm.sample(draws=samples, chains=chains, tune=tune)
51 | prior= pm.sample_prior_predictive(samples=samples)
52 | posterior_predictive = pm.sample_posterior_predictive(trace,samples=samples)
53 |
54 | ## STEP 1
55 | # will also capture all the sampler statistics
56 | data_s = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive)
57 |
58 | ## STEP 2
59 | #dag
60 | dag = get_dag(simple_model)
61 | # insert dag into sampler stat attributes
62 | data_s.sample_stats.attrs["graph"] = str(dag)
63 |
64 | ## STEP 3
65 | # save data
66 | arviz_to_json(data_s, fileName+'.npz')
--------------------------------------------------------------------------------
/examples/radon_basement/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='radon_basement_PyMC3.npz'
5 | ipme.graph(infer_datapath, predictive_checks = ['radon'])
--------------------------------------------------------------------------------
/examples/radon_basement/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pymc3 as pm
3 | import pandas as pd
4 | import theano
5 | import arviz as az
6 | from arviz_json import get_dag, arviz_to_json
7 |
8 | #Hierarchical Linear Regression Model
9 | #Reference1: https://docs.pymc.io/notebooks/multilevel_modeling.html
10 | #Reference2: https://docs.pymc.io/notebooks/GLM-hierarchical.html#The-data-set
11 | #data
12 | data_r = pd.read_csv(pm.get_data('radon.csv'))
13 | data_r['log_radon'] = data_r['log_radon'].astype(theano.config.floatX)
14 | county_names = data_r.county.unique()
15 | county_idx = data_r.county_code.values
16 |
17 | n_counties = len(data_r.county.unique())
18 |
19 | #model-inference
20 | fileName='radon_basement_PyMC3'
21 | samples=3000
22 | tune=10000
23 | chains=2
24 | coords = {"county":county_names, "county_idx_household":data_r["county"].tolist()}
25 | with pm.Model(coords=coords) as model:
26 | # Hyperpriors for group nodes
27 | mu_a = pm.Normal('mu_a', mu=0., sigma=100)
28 | sigma_a = pm.HalfNormal('sigma_a', 5.)
29 | mu_b = pm.Normal('mu_b', mu=0., sigma=100)
30 | sigma_b = pm.HalfNormal('sigma_b', 5.)
31 |
32 | # Intercept for each county, distributed around group mean mu_a
33 | # Above we just set mu and sd to a fixed value while here we
34 | # plug in a common group distribution for all a and b (which are
35 | # vectors of length n_counties).
36 | a = pm.Normal('a', mu=mu_a, sigma=sigma_a, dims='county')
37 | # Intercept for each county, distributed around group mean mu_a
38 | b = pm.Normal('b', mu=mu_b, sigma=sigma_b, dims='county')
39 |
40 | # Model error
41 | eps = pm.HalfCauchy('eps', 5.)
42 |
43 | radon_est = a[county_idx] + b[county_idx]*data_r.floor.values
44 |
45 | # Data likelihood
46 | radon = pm.Normal('radon', mu=radon_est,
47 | sigma=eps, observed=data_r.log_radon, dims='county_idx_household')
48 |
49 | #Inference
50 | trace = pm.sample(samples, chains=chains, tune=tune, target_accept=1.0)
51 | prior = pm.sample_prior_predictive(samples=samples)
52 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples)
53 |
54 | ## STEP 1
55 | # will also capture all the sampler statistics
56 | data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive)
57 |
58 | ## STEP 2
59 | #dag
60 | dag = get_dag(model)
61 | # insert dag into sampler stat attributes
62 | data.sample_stats.attrs["graph"] = str(dag)
63 |
64 | ## STEP 3
65 | # save data
66 | arviz_to_json(data, fileName+'.npz')
--------------------------------------------------------------------------------
/examples/stochastic_volatility/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='stochastic_volatility__PyMC3.npz'
5 | ipme.graph(infer_datapath, predictive_checks=['returns'])
--------------------------------------------------------------------------------
/examples/stochastic_volatility/model.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pymc3 as pm
3 | import arviz as az
4 | import numpy as np
5 | from arviz_json import get_dag, arviz_to_json
6 |
7 | #StudentT Timeseries Model
8 | #Reference1: https://docs.pymc.io/notebooks/getting_started.html#Case-study-1:-Stochastic-volatility
9 | #Reference2: https://docs.pymc.io/notebooks/stochastic_volatility.html#Stochastic-Volatility-model
10 | #data
11 | returns = pd.read_csv(pm.get_data('SP500.csv'), parse_dates=True, index_col=0)
12 | dates = returns.index.strftime("%Y/%m/%d").tolist()
13 |
14 | #model-inference
15 | fileName='stochastic_volatility_PyMC3'
16 | samples=2000
17 | tune=2000
18 | chains=2
19 | coords = {"date": dates}
20 | with pm.Model(coords=coords) as model:
21 | step_size = pm.Exponential('step_size', 10)
22 | volatility = pm.GaussianRandomWalk('volatility', sigma = step_size, dims='date')
23 | nu = pm.Exponential('nu', 0.1)
24 | returns = pm.StudentT('returns', nu = nu, lam = np.exp(-2*volatility) , observed = returns["change"], dims='date')
25 | #inference
26 | trace = pm.sample(draws=samples, chains=chains, tune=tune)
27 | prior = pm.sample_prior_predictive(samples=samples)
28 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples)
29 |
30 | ## STEP 1
31 | # will also capture all the sampler statistics
32 | data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive)
33 |
34 | ## STEP 2
35 | #dag
36 | dag = get_dag(model)
37 | # insert dag into sampler stat attributes
38 | data.sample_stats.attrs["graph"] = str(dag)
39 |
40 | ## STEP 3
41 | # save data
42 | arviz_to_json(data, fileName+'.npz')
43 |
--------------------------------------------------------------------------------
/examples/user_study/min_temperature/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='min_temperature.npz'
5 | ipme.scatter_matrix(infer_datapath,vars=['a','b','c','temperature'])
--------------------------------------------------------------------------------
/examples/user_study/min_temperature/model.py:
--------------------------------------------------------------------------------
1 | import pymc3 as pm
2 | import numpy as np
3 | import arviz as az
4 | # !pip install git+https://github.com/johnhw/arviz_json.git
5 | from arviz_json import get_dag, arviz_to_json
6 |
7 | RANDOM_SEED = np.random.seed(1225)
8 |
9 | ## the average minimum temperature in Scotland in month November for the years
10 | ## 1884-2020: metoffice.gov.uk
11 | ## https://www.metoffice.gov.uk/pub/data/weather/uk/climate/datasets/Tmin/date/Scotland.txt
12 | l = [1.0,1.1,2.5,0.8,2.7,2.9,1.4,1.1,2.2,0.3,3.6,1.7,2.1,3.5,1.5,4.6,2.3,
13 | 1.2,3.5,1.8,1.3,1.2,3.7,1.4,2.9,-0.3,-1.1,1.1,1.6,3.2,1.9,-1.3,3.2,
14 | 3.1,0.8,-1.6,3.6,0.7,2.1,-0.5,3.5,-0.0,1.1,1.8,2.7,2.0,0.6,3.6,1.8,
15 | 1.7,1.5,1.7,1.3,1.5,3.8,3.0,1.9,1.7,1.0,2.2,0.4,3.4,2.7,1.6,3.0,2.4,
16 | 0.6,3.4,-0.3,4.0,1.6,3.5,2.6,2.8,2.2,2.9,1.7,1.1,1.5,2.2,2.3,-0.2,0.5,
17 | 1.5,1.7,-0.8,1.9,1.7,0.9,0.9,1.7,1.8,1.5,0.9,3.1,1.2,1.9,2.5,2.2,3.0,
18 | 3.3,-1.1,2.5,2.6,1.7,2.0,2.0,2.0,1.5,0.4,5.2,3.0,0.2,4.6,1.3,2.9,1.8,
19 | 3.0,3.8,3.4,3.3,1.5,3.4,3.4,2.0,2.8,0.0,4.8,1.7,1.1,3.8,3.6,0.3,1.5,3.6,1.1]
20 | years = np.arange(1884,2020)
21 |
22 | fileName='min_temperature'
23 | samples=4000
24 | chains=2
25 | tune=1000
26 | coords = {'years':years}
27 | temperature_model = pm.Model(coords=coords)
28 | with temperature_model:
29 | #priors
30 | a = pm.Uniform('a', 80, 100)
31 | b = pm.Normal('b', mu=2, sd=10)
32 | c = pm.HalfNormal('c', sd=10)
33 |
34 | #predictions
35 | temperature = pm.Normal('temperature',
36 | mu = b,
37 | sd = c,
38 | observed = l,
39 | dims = 'years')
40 |
41 | trace = pm.sample(draws=samples, chains=chains, tune=tune)
42 | prior = pm.sample_prior_predictive(samples=samples)
43 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples)
44 | dag = get_dag(temperature_model)
45 |
46 | # will also capture all the sampler statistics
47 | data = az.from_pymc3(trace=trace,
48 | prior=prior,
49 | posterior_predictive=posterior_predictive)
50 |
51 | # insert dag into sampler stat attributes
52 | data.sample_stats.attrs["graph"] = str(dag)
53 |
54 | # save data
55 | arviz_to_json(data, fileName+'.npz')
--------------------------------------------------------------------------------
/examples/user_study/random_number_generator/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='transformation.npz'
5 | ipme.scatter_matrix(infer_datapath, vars=['a','b','c','random_number'])
--------------------------------------------------------------------------------
/examples/user_study/random_number_generator/model.py:
--------------------------------------------------------------------------------
1 | import pymc3 as pm
2 | import numpy as np
3 | import arviz as az
4 | # !pip install git+https://github.com/johnhw/arviz_json.git
5 | from arviz_json import get_dag, arviz_to_json
6 |
7 | RANDOM_SEED = np.random.seed(1225)
8 |
9 | x_data = np.random.uniform(-3, 5, size=5)
10 | obs = np.arange(len(x_data))
11 |
12 | fileName = 'transformation'
13 | samples = 4000
14 | # tune = 10000
15 | chains = 1
16 | coords = {"obs": obs}
17 | uniform_model = pm.Model(coords=coords)
18 | with uniform_model:
19 | # Priors for unknown model parameters
20 | a = pm.Normal('a', mu = 0, sd=10)
21 | b = pm.HalfNormal('b', sd = 10)
22 | c = pm.HalfNormal('c', sd=20)
23 |
24 | l = a - c
25 | u = a + c
26 | random_number = pm.Uniform('random_number',
27 | lower=l,
28 | upper=u,
29 | observed=x_data,
30 | dims="obs")
31 |
32 | trace = pm.sample(samples, chains = chains, target_accept=0.95)
33 | posterior_predictive = pm.sample_posterior_predictive(trace, samples=samples)
34 | prior = pm.sample_prior_predictive(samples=samples, random_seed=RANDOM_SEED)
35 |
36 | # will also capture all the sampler statistics
37 | data = az.from_pymc3(trace = trace)
38 |
39 | # will also capture all the sampler statistics
40 | data = az.from_pymc3(trace = trace,
41 | prior = prior,
42 | posterior_predictive = posterior_predictive)
43 |
44 | dag = get_dag(uniform_model)
45 | data.sample_stats.attrs["graph"] = str(dag)
46 |
47 | # save data
48 | arviz_to_json(data, fileName+'.npz')
--------------------------------------------------------------------------------
/examples/user_study/reaction_times/ipme.py:
--------------------------------------------------------------------------------
1 | import ipme
2 |
3 | if __name__=="__main__":
4 | infer_datapath='reaction_times_hierarchical.npz'
5 | ipme.scatter_matrix(infer_datapath,vars=['a','b','c','d','reaction_time'])
--------------------------------------------------------------------------------
/examples/user_study/reaction_times/model.py:
--------------------------------------------------------------------------------
1 | import pymc3 as pm
2 | import pandas as pd
3 | import numpy as np
4 | import arviz as az
5 | # !pip install git+https://github.com/johnhw/arviz_json.git
6 | from arviz_json import get_dag, arviz_to_json
7 |
8 | RANDOM_SEED = np.random.seed(1225)
9 |
10 | DATAPATH='../../data/evaluation_sleepstudy.csv' ## data from Belenky et al. (2003)
11 | reactions = pd.read_csv(DATAPATH,usecols=['Reaction','Days','Subject'])
12 | reactions = reactions[0:3*10]
13 |
14 | driver_idx, drivers = pd.factorize(reactions["Subject"], sort=True)
15 | day_idx, days = pd.factorize(reactions["Days"], sort=True)
16 | coords_c = {"driver": drivers,"driver_idx_day":reactions.Subject}
17 | with pm.Model(coords=coords_c) as hierarchical_model:
18 | #hyper-priors
19 | c = pm.Normal('c', mu=100, sd=150)
20 | e = pm.HalfNormal('e', sd=150)
21 | f = pm.Normal('f', mu=10, sd=100)
22 | g = pm.HalfNormal('g', sd=100)
23 | h = pm.HalfNormal('h', sd=200)
24 | #priors
25 | a = pm.Normal("a", mu=c, sd=e, dims="driver")
26 | b = pm.Normal("b", mu=f, sd=g, dims="driver")
27 | sigma = pm.HalfNormal("sigma", sd=h, dims="driver")
28 | d = pm.Normal("d", mu = 0, sd=10.0)
29 |
30 | reaction_time = pm.Normal('reaction_time',
31 | mu = a[driver_idx]+b[driver_idx]*day_idx,
32 | sd=sigma[driver_idx],
33 | observed=reactions.Reaction,
34 | dims="driver_idx_day")
35 |
36 | samples=4000
37 | chains=3
38 | tune=3000
39 | fileName_hi="reaction_times_hierarchical"
40 | with hierarchical_model:
41 | # Get posterior trace, prior trace, posterior predictive samples, and the DAG
42 | trace_hi = pm.sample(draws=samples, chains=chains, tune=tune, target_accept=0.9)
43 | prior_hi = pm.sample_prior_predictive(samples=samples)
44 | posterior_predictive_hi = pm.sample_posterior_predictive(trace_hi,
45 | samples=samples)
46 |
47 | # export inference results in ArviZ InferenceData obj
48 | # will also capture all the sampler statistics
49 | data_hi = az.from_pymc3(trace = trace_hi,
50 | prior = prior_hi,
51 | posterior_predictive = posterior_predictive_hi)
52 |
53 | # insert dag into sampler stat attributes
54 | dag_hi = get_dag(hierarchical_model)
55 | data_hi.sample_stats.attrs["graph"] = str(dag_hi)
56 |
57 | # save data
58 | arviz_to_json(data_hi, fileName_hi+'.npz')
--------------------------------------------------------------------------------
/ipme/__init__.py:
--------------------------------------------------------------------------------
1 | """Top-level package for imd."""
2 |
3 | __author__ = """Evdoxia Taka"""
4 | __email__ = 'e.taka.1@research.gla.ac.uk'
5 | __version__ = '0.1.0'
6 |
7 | from .methods import *
--------------------------------------------------------------------------------
/ipme/classes/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = """Evdoxia Taka"""
2 | __email__ = 'e.taka.1@research.gla.ac.uk'
3 | __version__ = '0.1.0'
4 |
--------------------------------------------------------------------------------
/ipme/classes/cell/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/cell/__init__.py
--------------------------------------------------------------------------------
/ipme/classes/cell/interactive_continuous_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.variable_cell import VariableCell
2 | from .utils.cell_continuous_handler import CellContinuousHandler
3 | from .utils.cell_clear_selection import CellClearSelection
4 | from ..cell.utils.cell_widgets import CellWidgets
5 |
6 | class InteractiveContinuousCell(VariableCell):
7 | def __init__(self, name, control):
8 | """
9 | Parameters:
10 | --------
11 | name A String within the set {""}.
12 | control A Control object
13 | """
14 | self.selection = {}
15 | self.sel_samples = {}
16 | self.non_sel_samples = {}
17 | self.reconstructed = {}
18 | self.clear_selection = {}
19 | VariableCell.__init__(self, name, control)
20 |
21 | def initialize_cds(self, space):
22 | CellContinuousHandler.initialize_cds_interactive(self, space)
23 |
24 | def initialize_fig(self, space):
25 | CellContinuousHandler.initialize_fig_interactive(self, space)
26 |
27 | def initialize_glyphs(self, space):
28 | CellContinuousHandler.initialize_glyphs_interactive(self, space)
29 | CellClearSelection.initialize_glyphs_x_button(self, space)
30 |
31 | def widget_callback(self, attr, old, new, w_title, space):
32 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space)
33 |
34 | def update_cds(self, space):
35 | CellContinuousHandler.update_cds_interactive(self, space)
36 |
37 | ## ONLY FOR INTERACTIVE CASE
38 | def update_source_cds(self, space):
39 | CellContinuousHandler.update_source_cds_interactive(self, space)
40 |
41 | def update_selection_cds(self, space, xmin, xmax):
42 | CellContinuousHandler.update_selection_cds_interactive(self, space, xmin, xmax)
43 |
44 | def update_reconstructed_cds(self, space):
45 | CellContinuousHandler.update_reconstructed_cds_interactive(self, space)
46 |
--------------------------------------------------------------------------------
/ipme/classes/cell/interactive_discrete_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.variable_cell import VariableCell
2 | from .utils.cell_discrete_handler import CellDiscreteHandler
3 | from .utils.cell_clear_selection import CellClearSelection
4 | from ..cell.utils.cell_widgets import CellWidgets
5 |
6 | class InteractiveDiscreteCell(VariableCell):
7 | def __init__(self, name, control):
8 | """
9 | Parameters:
10 | --------
11 | name A String within the set {""}.
12 | control A Control object
13 | """
14 | self.selection = {}
15 | self.reconstructed = {}
16 | self.clear_selection = {}
17 | VariableCell.__init__(self, name, control)
18 |
19 | def initialize_cds(self, space):
20 | CellDiscreteHandler.initialize_cds_interactive(self, space)
21 |
22 | def initialize_fig(self, space):
23 | CellDiscreteHandler.initialize_fig_interactive(self, space)
24 |
25 | def initialize_glyphs(self, space):
26 | CellDiscreteHandler.initialize_glyphs_interactive(self, space)
27 | CellClearSelection.initialize_glyphs_x_button(self, space)
28 |
29 | def widget_callback(self, attr, old, new, w_title, space):
30 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space)
31 |
32 | def update_cds(self, space):
33 | CellDiscreteHandler.update_cds_interactive(self, space)
34 |
35 | ## ONLY FOR INTERACTIVE CASE
36 | def update_source_cds(self, space):
37 | CellDiscreteHandler.update_source_cds_interactive(self, space)
38 |
39 | def update_selection_cds(self, space, xmin, xmax):
40 | CellDiscreteHandler.update_selection_cds_interactive(self, space, xmin, xmax)
41 |
42 | def update_reconstructed_cds(self, space):
43 | CellDiscreteHandler.update_reconstructed_cds_interactive(self, space)
44 |
--------------------------------------------------------------------------------
/ipme/classes/cell/interactive_pred_ckeck_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.predictive_check_cell import PredictiveCheckCell
2 | from .utils.cell_pred_check_handler import CellPredCheckHandler
3 | from ..cell.utils.cell_widgets import CellWidgets
4 |
5 | # from ipme.utils.stats import hist
6 | # from ipme.utils.functions import get_finite_samples, get_samples_for_pred_check, get_hist_bins_range
7 | # from ipme.utils.constants import COLORS
8 |
9 | # import numpy as np
10 | from bokeh.models import LegendItem
11 |
12 | # import threading
13 | # from abc import abstractmethod
14 |
15 | class InteractivePredCheckCell(PredictiveCheckCell):
16 | def __init__(self, name, control, function = 'min'):
17 | """
18 | Parameters:
19 | --------
20 | name A String within the set {""}.
21 | function A String in {"min","max","mean","std"}.
22 | Sets:
23 | --------
24 | func
25 | source
26 | reconstructed
27 | samples
28 | seg
29 | """
30 | self.func = function
31 | PredictiveCheckCell.__init__(self, name, control)
32 |
33 | def initialize_cds(self, space):
34 | CellPredCheckHandler.initialize_cds_interactive(self, space)
35 |
36 | def initialize_fig(self, space):
37 | CellPredCheckHandler.initialize_fig_interactive(self, space)
38 |
39 | def initialize_glyphs(self, space):
40 | CellPredCheckHandler.initialize_glyphs_interactive(self, space)
41 |
42 | def widget_callback(self, attr, old, new, w_title, space):
43 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space)
44 |
45 | def update_cds(self, space):
46 | CellPredCheckHandler.update_cds_interactive(self, space)
47 |
48 | ## ONLY FOR INTERACTIVE CASE
49 | def update_source_cds(self, space):
50 | CellPredCheckHandler.update_source_cds_interactive(self, space)
51 |
52 | def update_sel_samples_cds(self, space):
53 | CellPredCheckHandler.update_sel_samples_cds_interactive(self, space)
54 |
55 | # def initialize_glyphs(self,space):
56 | # q = self.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=self.source[space], \
57 | # fill_color=COLORS[0], line_color="white", fill_alpha=1.0, name="full")
58 | # seg = self.plot[space].segment(x0 ='x0', y0 ='y0', x1='x1', y1='y1', source=self.seg[space], \
59 | # color="black", line_width=2, name="seg")
60 | # q_sel = self.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=self.reconstructed[space], \
61 | # fill_color=COLORS[1], line_color="white", fill_alpha=0.7, name="sel")
62 | # ## Add Legends
63 | # data = self.seg[space].data['x0']
64 | # pvalue = self.pvalue[space].data["pv"]
65 | # if len(data) and len(pvalue):
66 | # legend = Legend(items=[ (self.func + "(obs) = " + format(data[0], '.2f'), [seg]),
67 | # ("p-value = "+format(pvalue[0],'.4f'), [q]),
68 | # ], location="top_left")
69 | # self.plot[space].add_layout(legend, 'above')
70 | # ## Add Tooltips for hist
71 | # #####TODO:Correct overlap of tooltips#####
72 | # TOOLTIPS = [
73 | # ("top", "@top"),
74 | # ("right","@right"),
75 | # ("left","@left"),
76 | # ]
77 | # hover = HoverTool( tooltips=TOOLTIPS,renderers=[q,q_sel])
78 | # self.plot[space].tools.append(hover)
79 |
80 | ## Update legends when data in _pvalue_rec cds is updated
81 | def update_legends(self, space, attr, old, new):
82 | if len(self.plot[space].legend.items) == 3:
83 | self.plot[space].legend.items.pop()
84 | r = self.plot[space].select(name="sel")
85 | pvalue = self.pvalue_rec[space].data["pv"]
86 | if len(r) and len(pvalue):
87 | self.plot[space].legend.items.append(LegendItem(label="p-value = " + format(pvalue[0], '.4f'), \
88 | renderers=[r[0]]))
89 |
90 | # def widget_callback(self, attr, old, new, w_title, space):
91 | # inds = self.ic.data.get_indx_for_idx_dim(self.name, w_title, new)
92 | # if inds==-1:
93 | # return
94 | # self.cur_idx_dims_values[self.name][w_title] = inds
95 | # self._update_plot(space)
96 |
97 | # def _update_plot(self, space):
98 | # pass
99 |
100 | # def initialize_cds(self, space):
101 | # CellPredCheckHandler.initialize_cds_interactive(self, space)
102 | # # ## ColumnDataSource for full sample set
103 | # # data, samples = self.get_data_for_cur_idx_dims_values(space)
104 | # # self.samples[space] = ColumnDataSource(data=dict(x=samples))
105 | # # #data func
106 | # # if ~np.isfinite(data).all():
107 | # # data = get_finite_samples(data)
108 | # # data_func = get_samples_for_pred_check(data, self.func)
109 | # # #samples func
110 | # # if ~np.isfinite(samples).all():
111 | # # samples = get_finite_samples(samples)
112 | # # samples_func = get_samples_for_pred_check(samples, self.func)
113 | # # if samples_func.size:
114 | # # #pvalue
115 | # # pv = np.count_nonzero(samples_func>=data_func) / len(samples_func)
116 | # # #histogram
117 | # # type = self._data.get_var_dist_type(self.name)
118 | # # if type == "Continuous":
119 | # # bins, range = get_hist_bins_range(samples_func, self.func, type)
120 | # # else:
121 | # # bins, range = get_hist_bins_range(samples_func, self.func, type, ref_length = None, ref_values=np.unique(samples.flatten()))
122 |
123 | # # his, edges = hist(samples_func, bins=bins, range=range, density=True)
124 | # # #cds
125 | # # self.pvalue[space] = ColumnDataSource(data=dict(pv=[pv]))
126 | # # self.source[space] = ColumnDataSource(data=dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his))))
127 | # # self.seg[space] = ColumnDataSource(data=dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()]))
128 | # # else:
129 | # # self.pvalue[space] = ColumnDataSource(data=dict(pv=[]))
130 | # # self.source[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[]))
131 | # # self.seg[space] = ColumnDataSource(data=dict(x0=[], x1=[], y0=[], y1=[]))
132 |
133 | # # ## ColumnDataSource for restricted sample set
134 | # # self.pvalue_rec[space] = ColumnDataSource(data=dict(pv=[]))
135 | # # self.pvalue_rec[space].on_change('data', partial(self._update_legends, space))
136 | # # self.reconstructed[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[]))
137 |
138 | # ## Update plots when indices of selected samples are updated
139 | # def sample_inds_callback(self, space, attr, old, new):
140 | # # _, samples = self._get_data_for_cur_idx_dims_values(space)
141 | # samples = self.samples[space].data['x']
142 | # max_full_hist = self.source[space].data['top'].max()
143 | # if samples.size:
144 | # inds=self.ic.sample_inds[space].data['inds']
145 | # if len(inds):
146 | # sel_sample = samples[inds]
147 | # if ~np.isfinite(sel_sample).all():
148 | # sel_sample = get_finite_samples(sel_sample)
149 | # sel_sample_func = get_samples_for_pred_check(sel_sample, self.func)
150 | # #data func
151 | # data_func = self.seg[space].data['x0'][0]
152 | # #pvalue in restricted space
153 | # sel_pv = np.count_nonzero(sel_sample_func >= data_func) / len(sel_sample_func)
154 | # #compute updated histogram
155 | # min_p = self.source[space].data['left'][0]
156 | # max_p = self.source[space].data['right'][-1]
157 | # min_c = sel_sample_func.min()
158 | # max_c = sel_sample_func.max()
159 | # if min_c < min_p or max_c > max_p:
160 | # ref_len = self.source[space].data['right'][0] - min_p
161 | # bins, range = get_hist_bins_range(sel_sample_func, self.func, self._type, ref_length=ref_len)
162 | # else:
163 | # range = (min_p,max_p)
164 | # bins = len(self.source[space].data['right'])
165 | # his, edges = hist(sel_sample_func, bins=bins, range=range)
166 | # ##max selected hist
167 | # max_sel_hist = his.max()
168 | # #update reconstructed cds
169 | # self.pvalue_rec[space].data = dict(pv=[sel_pv])
170 | # self.reconstructed[space].data = dict(left=edges[:-1], top=his, right =edges[1:], bottom=np.zeros(len(his)))
171 | # self.seg[space].data['y1'] = [max_sel_hist + 0.1 * max_sel_hist]
172 | # else:
173 | # self.pvalue_rec[space].data = dict(pv=[])
174 | # self.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[])
175 | # self.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist]
176 | # else:
177 | # self.pvalue_rec[space].data = dict(pv=[])
178 | # self.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[])
179 | # self.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist]
180 |
--------------------------------------------------------------------------------
/ipme/classes/cell/interactive_scatter_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.scatter_cell import ScatterCell
2 | from .utils.cell_scatter_handler import CellScatterHandler
3 | from ..cell.utils.cell_widgets import CellWidgets
4 |
5 | class InteractiveScatterCell(ScatterCell):
6 | def __init__(self, vars, control):
7 | """
8 | Parameters:
9 | --------
10 | name A String within the set {""}.
11 | control A Control object
12 | """
13 | self.sel_samples = {}
14 | self.non_sel_samples = {}
15 | ScatterCell.__init__(self, vars, control)
16 |
17 | def initialize_cds(self, space):
18 | CellScatterHandler.initialize_cds_interactive(self, space)
19 |
20 | def initialize_fig(self, space):
21 | CellScatterHandler.initialize_fig_interactive(self, space)
22 |
23 | def initialize_glyphs(self, space):
24 | CellScatterHandler.initialize_glyphs_interactive(self, space)
25 |
26 | def widget_callback(self, attr, old, new, w_title, space):
27 | CellWidgets.widget_callback_interactive(self, attr, old, new, w_title, space)
28 |
29 | def update_cds(self, space):
30 | CellScatterHandler.update_cds_interactive(self, space)
31 |
32 | ## ONLY FOR INTERACTIVE CASE
33 | def update_source_cds(self, space):
34 | CellScatterHandler.update_source_cds_interactive(self, space)
35 |
36 | def update_sel_samples_cds(self, space):
37 | CellScatterHandler.update_sel_samples_cds_interactive(self, space)
38 |
--------------------------------------------------------------------------------
/ipme/classes/cell/static_continuous_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.variable_cell import VariableCell
2 | from .utils.cell_continuous_handler import CellContinuousHandler
3 | from ..cell.utils.cell_widgets import CellWidgets
4 |
5 | from ipme.utils.functions import get_stratum_range, find_indices
6 |
7 | class StaticContinuousCell(VariableCell):
8 | def __init__(self, name, control):
9 | """
10 | Parameters:
11 | --------
12 | name A String within the set {""}.
13 | control A Control object
14 | """
15 | VariableCell.__init__(self, name, control)
16 |
17 | def initialize_cds(self, space):
18 | CellContinuousHandler.initialize_cds_static(self, space)
19 |
20 | def initialize_fig(self, space):
21 | CellContinuousHandler.initialize_fig_static(self, space)
22 |
23 | def initialize_glyphs(self, space):
24 | CellContinuousHandler.initialize_glyphs_static(self, space)
25 |
26 | def widget_callback(self, attr, old, new, w_title, space):
27 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space)
28 |
29 | def update_cds(self, space):
30 | CellContinuousHandler.update_cds_static(self, space)
31 |
32 | ## ONLY FOR STATIC CASE
33 | def set_stratum(self, space, stratum = 0):
34 | """
35 | Sets selection by spliting the ordered sample set
36 | in 4 equal-sized subsets.
37 | """
38 | samples = self.get_samples_for_cur_idx_dims_values(space)
39 | xmin,xmax = get_stratum_range(samples, stratum)
40 | inds = find_indices(samples, lambda e: xmin <= e <= xmax, xmin, xmax)
41 | self.ic.set_sel_var_inds(space, self.name, inds)
42 | self.compute_intersection_of_samples(space)
43 | return (xmin,xmax)
44 |
45 |
--------------------------------------------------------------------------------
/ipme/classes/cell/static_discrete_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.variable_cell import VariableCell
2 | from .utils.cell_discrete_handler import CellDiscreteHandler
3 | from ..cell.utils.cell_widgets import CellWidgets
4 |
5 | from ipme.utils.functions import get_stratum_range, find_indices
6 |
7 | class StaticDiscreteCell(VariableCell):
8 | def __init__(self, name, control):
9 | """
10 | Parameters:
11 | --------
12 | name A String within the set {""}.
13 | control A Control object
14 | """
15 | VariableCell.__init__(self, name, control)
16 |
17 | def initialize_cds(self, space):
18 | CellDiscreteHandler.initialize_cds_static(self, space)
19 |
20 | def initialize_fig(self, space):
21 | CellDiscreteHandler.initialize_fig_static(self, space)
22 |
23 | def initialize_glyphs(self, space):
24 | CellDiscreteHandler.initialize_glyphs_static(self, space)
25 |
26 | def widget_callback(self, attr, old, new, w_title, space):
27 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space)
28 |
29 | def update_cds(self, space):
30 | CellDiscreteHandler.update_cds_static(self, space)
31 |
32 | ## ONLY FOR STATIC CASE
33 | def set_stratum(self, space, stratum = 0):
34 | """
35 | Sets selection by spliting the ordered sample set
36 | in 4 equal-sized subsets.
37 | """
38 | samples = self.get_samples_for_cur_idx_dims_values(space)
39 | xmin,xmax = get_stratum_range(samples, stratum)
40 | inds = find_indices(samples, lambda e: xmin <= e <= xmax, xmin, xmax)
41 | self.ic.set_sel_var_inds(space, self.name, inds)
42 | self.compute_intersection_of_samples(space)
43 | return (xmin,xmax)
44 |
--------------------------------------------------------------------------------
/ipme/classes/cell/static_pred_check_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.predictive_check_cell import PredictiveCheckCell
2 | from .utils.cell_pred_check_handler import CellPredCheckHandler
3 | from ..cell.utils.cell_widgets import CellWidgets
4 |
5 | class StaticPredCheckCell(PredictiveCheckCell):
6 | def __init__(self, vars, control, function = 'min'):
7 | """
8 | Parameters:
9 | --------
10 | vars A List of variableNames of the model.
11 | control A Control object
12 | function A String in {"min","max","mean","std"}.
13 | """
14 | self.func = function
15 | PredictiveCheckCell.__init__(self, vars, control)
16 |
17 | def initialize_cds(self, space):
18 | CellPredCheckHandler.initialize_cds_static(self, space)
19 |
20 | def initialize_fig(self, space):
21 | CellPredCheckHandler.initialize_fig_static(self, space)
22 |
23 | def initialize_glyphs(self, space):
24 | CellPredCheckHandler.initialize_glyphs_static(self, space)
25 |
26 | def widget_callback(self, attr, old, new, w_title, space):
27 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space)
28 |
29 | def update_cds(self, space):
30 | CellPredCheckHandler.update_cds_static(self, space)
31 |
32 |
--------------------------------------------------------------------------------
/ipme/classes/cell/static_scatter_cell.py:
--------------------------------------------------------------------------------
1 | from ipme.interfaces.scatter_cell import ScatterCell
2 | from .utils.cell_scatter_handler import CellScatterHandler
3 | from ..cell.utils.cell_widgets import CellWidgets
4 |
5 | class StaticScatterCell(ScatterCell):
6 | def __init__(self, vars, control):
7 | """
8 | Parameters:
9 | --------
10 | vars A List of variableNames of the model.
11 | control A Control object
12 | """
13 | ScatterCell.__init__(self, vars, control)
14 |
15 | def initialize_cds(self, space):
16 | CellScatterHandler.initialize_cds_static(self, space)
17 |
18 | def initialize_fig(self, space):
19 | CellScatterHandler.initialize_fig_static(self, space)
20 |
21 | def initialize_glyphs(self, space):
22 | CellScatterHandler.initialize_glyphs_static(self, space)
23 |
24 | def widget_callback(self, attr, old, new, w_title, space):
25 | CellWidgets.widget_callback_static(self, attr, old, new, w_title, space)
26 |
27 | def update_cds(self, space):
28 | CellScatterHandler.update_cds_static(self, space)
29 |
30 |
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/cell/utils/__init__.py
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/cell_clear_selection.py:
--------------------------------------------------------------------------------
1 | from ipme.utils.functions import find_highest_point
2 | from ipme.utils.js_code import HOVER_CODE
3 |
4 | from bokeh.models import HoverTool, CustomJS
5 |
6 | class CellClearSelection:
7 |
8 | def __init__(self):
9 | pass
10 |
11 | @staticmethod
12 | def initialize_glyphs_x_button(variableCell, space):
13 | ## x-button to clear selection
14 | sq_x = variableCell.plot[space].scatter('x', 'y', marker = "square_x", size = 10, fill_color = "grey", hover_fill_color = "firebrick", \
15 | fill_alpha = 0.5, hover_alpha = 1.0, line_color = "grey", hover_line_color = "white", \
16 | source = variableCell.clear_selection[space], name = 'clear_selection')
17 | ## Add HoverTool for x-button
18 | variableCell.plot[space].add_tools(HoverTool(tooltips = "Clear Selection", renderers = [sq_x], mode = 'mouse', show_arrow = False,
19 | callback = CustomJS(args = dict(source = variableCell.clear_selection[space]), code = HOVER_CODE)))
20 |
21 | @staticmethod
22 | def update_clear_selection_cds(variableCell, space):
23 | """
24 | Updates clear_selection ColumnDataSource (cds).
25 | """
26 | sel_var_idx_dims_values = variableCell.ic.get_sel_var_idx_dims_values()
27 | sel_space = variableCell.ic.get_sel_space()
28 | var_x_range = variableCell.ic.get_var_x_range()
29 | cur_idx_dims_values = {}
30 | if variableCell.name in variableCell.cur_idx_dims_values:
31 | cur_idx_dims_values = variableCell.cur_idx_dims_values[variableCell.name]
32 | if (variableCell.name in sel_var_idx_dims_values and space == sel_space and
33 | cur_idx_dims_values == sel_var_idx_dims_values[variableCell.name]):
34 | min_x_range = var_x_range[(space, variableCell.name)].data['xmin'][0]
35 | max_x_range = var_x_range[(space, variableCell.name)].data['xmax'][0]
36 | hp = find_highest_point(variableCell.reconstructed[space].data['x'], variableCell.reconstructed[space].data['y'])
37 | if not hp:
38 | hp = find_highest_point(variableCell.selection[space].data['x'], variableCell.selection[space].data['y'])
39 | if not hp:
40 | hp = find_highest_point(variableCell.source[space].data['x'], variableCell.source[space].data['y'])
41 | if not hp:
42 | hp = (0,0)
43 | variableCell.clear_selection[space].data = dict( x = [(max_x_range + min_x_range) / 2.], y = [hp[1]+hp[1]*0.1], isIn = [0])
44 | else:
45 | variableCell.clear_selection[space].data = dict( x = [], y = [], isIn = [])
46 |
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/cell_discrete_handler.py:
--------------------------------------------------------------------------------
1 | from ipme.classes.cell.utils.cell_clear_selection import CellClearSelection
2 |
3 | from ipme.utils.constants import COLORS, BORDER_COLORS, PLOT_HEIGHT, PLOT_WIDTH, SIZING_MODE, RUG_DIST_RATIO, DATA_SIZE
4 | from ipme.utils.stats import pmf
5 | from ipme.utils.functions import find_indices
6 |
7 | from bokeh.models import ColumnDataSource, BoxSelectTool, HoverTool
8 |
9 | from bokeh import events
10 | from bokeh.plotting import figure
11 |
12 | import numpy as np
13 | import threading
14 | from functools import partial
15 |
16 | class CellDiscreteHandler:
17 |
18 | def __init__(self):
19 | pass
20 |
21 | @staticmethod
22 | def initialize_glyphs_interactive(variableCell, space):
23 | hover_renderer = []
24 | so_seg = variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.source[space], line_alpha = 1.0, color = COLORS[0], line_width = 1, selection_color = COLORS[0], \
25 | nonselection_color = COLORS[0], nonselection_line_alpha = 1.0)
26 | variableCell.plot[space].scatter('x', 'y', source = variableCell.source[space], size = 4, fill_color = COLORS[0], fill_alpha = 1.0, line_color = COLORS[0], selection_fill_color = COLORS[0], \
27 | nonselection_fill_color = COLORS[0], nonselection_fill_alpha = 1.0, nonselection_line_color = COLORS[0])
28 | variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.selection[space], line_alpha = 0.7, color = COLORS[2], line_width = 1)
29 | variableCell.plot[space].scatter('x', 'y', source = variableCell.selection[space], size = 4, fill_color = COLORS[2], fill_alpha = 0.7, line_color = COLORS[2])
30 | rec = variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.reconstructed[space], line_alpha = 0.5, color = COLORS[1], line_width = 1)
31 | variableCell.plot[space].scatter('x', 'y', source = variableCell.reconstructed[space], size = 4, fill_color = COLORS[1], fill_alpha = 0.5, line_color = COLORS[1])
32 | hover_renderer.append(so_seg)
33 | hover_renderer.append(rec)
34 | ## Add observations as yellow asterisks
35 | if space in variableCell.data:
36 | dat = variableCell.plot[space].asterisk('x', 'y', size = DATA_SIZE, line_color = COLORS[3], source = variableCell.data[space])
37 | hover_renderer.append(dat)
38 | ##Add BoxSelectTool
39 | variableCell.plot[space].add_tools(BoxSelectTool(dimensions = 'width', renderers = [so_seg]))
40 | ##Tooltips
41 | TOOLTIPS = [("x", "@x"), ("y","@y"),]
42 | hover = HoverTool( tooltips = TOOLTIPS, renderers = hover_renderer, mode = 'mouse')
43 | variableCell.plot[space].tools.append(hover)
44 |
45 | @staticmethod
46 | def initialize_glyphs_static(variableCell, space):
47 | hover_renderer = []
48 | so_seg = variableCell.plot[space].segment(x0 = 'x', y0 ='y0', x1 = 'x', y1 = 'y', source = variableCell.source[space], line_alpha = 1.0, color = COLORS[0], line_width = 1, selection_color = COLORS[0], \
49 | nonselection_color = COLORS[0], nonselection_line_alpha = 1.0)
50 | variableCell.plot[space].scatter('x', 'y', source = variableCell.source[space], size = 4, fill_color = COLORS[0], fill_alpha = 1.0, line_color = COLORS[0], selection_fill_color = COLORS[0], \
51 | nonselection_fill_color = COLORS[0], nonselection_fill_alpha = 1.0, nonselection_line_color = COLORS[0])
52 | hover_renderer.append(so_seg)
53 | ## Add observations as yellow asterisks
54 | if space in variableCell.data:
55 | dat = variableCell.plot[space].asterisk('x', 'y', size = DATA_SIZE, line_color = COLORS[3], source = variableCell.data[space])
56 | hover_renderer.append(dat)
57 | ## Tooltips
58 | TOOLTIPS = [("x", "@x"), ("y","@y"),]
59 | hover = HoverTool( tooltips = TOOLTIPS, renderers = hover_renderer, mode = 'mouse')
60 | variableCell.plot[space].tools.append(hover)
61 |
62 | @staticmethod
63 | def initialize_fig(variableCell, space):
64 | variableCell.plot[space] = figure( x_range = variableCell.x_range[variableCell.name][space], tools = "wheel_zoom,reset,box_zoom", toolbar_location = 'right',
65 | plot_width = PLOT_WIDTH, plot_height = PLOT_HEIGHT, sizing_mode = SIZING_MODE)
66 | variableCell.plot[space].border_fill_color = BORDER_COLORS[0]
67 | variableCell.plot[space].xaxis.axis_label = variableCell.name
68 | variableCell.plot[space].yaxis.visible = False
69 | variableCell.plot[space].toolbar.logo = None
70 | variableCell.plot[space].xaxis[0].ticker.desired_num_ticks = 3
71 |
72 | @staticmethod
73 | def initialize_fig_interactive(variableCell, space):
74 | CellDiscreteHandler.initialize_fig(variableCell, space)
75 | ##Events
76 | variableCell.plot[space].on_event(events.Tap, partial(CellDiscreteHandler.clear_selection_callback, variableCell, space))
77 | variableCell.plot[space].on_event(events.SelectionGeometry, partial(CellDiscreteHandler.selectionbox_callback, variableCell, space))
78 | ##on_change
79 | variableCell.ic.sample_inds_update[space].on_change('data', partial(variableCell.sample_inds_callback, space))
80 |
81 | @staticmethod
82 | def initialize_fig_static(variableCell, space):
83 | CellDiscreteHandler.initialize_fig(variableCell, space)
84 | ##on_change
85 | variableCell.ic.sample_inds_update[space].on_change('data', partial(variableCell.sample_inds_callback, space))
86 |
87 | @staticmethod
88 | def initialize_cds(variableCell, space):
89 | samples = variableCell.get_samples_for_cur_idx_dims_values(variableCell.name, space)
90 | variableCell.source[space] = ColumnDataSource(data = pmf(samples))
91 | variableCell.samples[space] = ColumnDataSource(data = dict(x = samples))
92 | # data cds
93 | data = variableCell.get_data_for_cur_idx_dims_values(variableCell.name)
94 | if data is not None:
95 | max_v = variableCell.source[space].data['y'].max()
96 | variableCell.data[space] = ColumnDataSource(data = dict(x = data, y = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(data))))
97 | # initialize sample inds
98 | variableCell.ic.initialize_sample_inds(space, dict(inds = [False]*len(variableCell.samples[space].data['x'])), dict(non_inds = [True]*len(variableCell.samples[space].data['x'])))
99 |
100 | @staticmethod
101 | def initialize_cds_interactive(variableCell, space):
102 | CellDiscreteHandler.initialize_cds(variableCell, space)
103 | variableCell.selection[space] = ColumnDataSource(data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])))
104 | variableCell.reconstructed[space] = ColumnDataSource(data = dict(x = np.array([]), y = np.array([]), y0 = np.array([])))
105 | variableCell.clear_selection[space] = ColumnDataSource(data = dict(x = [], y = [], isIn = []))
106 | variableCell.ic.var_x_range[(space, variableCell.name)] = ColumnDataSource(data = dict(xmin = np.array([]), xmax = np.array([])))
107 |
108 | @staticmethod
109 | def initialize_cds_static(variableCell, space):
110 | CellDiscreteHandler.initialize_cds(variableCell, space)
111 | #########TEST###########
112 |
113 | @staticmethod
114 | def update_cds_interactive(variableCell, space):
115 | """
116 | Updates interaction-related ColumnDataSources (cds).
117 | """
118 | sel_var_idx_dims_values = variableCell.ic.get_sel_var_idx_dims_values()
119 | sel_space = variableCell.ic.get_sel_space()
120 | var_x_range = variableCell.ic.get_var_x_range()
121 | global_update = variableCell.ic.get_global_update()
122 | if global_update:
123 | cur_idx_dims_values = {}
124 | if variableCell.name in variableCell.cur_idx_dims_values:
125 | cur_idx_dims_values = variableCell.cur_idx_dims_values[variableCell.name]
126 | if variableCell.name in sel_var_idx_dims_values and space == sel_space and cur_idx_dims_values == sel_var_idx_dims_values[variableCell.name]:
127 | variableCell.update_selection_cds(space, var_x_range[(space, variableCell.name)].data['xmin'][0], var_x_range[(space, variableCell.name)].data['xmax'][0])
128 | else:
129 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))
130 | variableCell.update_reconstructed_cds(space)
131 | CellClearSelection.update_clear_selection_cds(variableCell, space)
132 |
133 | @staticmethod
134 | def update_cds_static(variableCell, space):
135 | """
136 | Update source & samples cds in the static mode
137 | """
138 | samples = variableCell.get_samples_for_cur_idx_dims_values(variableCell.name, space)
139 | inds,_ = variableCell.ic.get_sample_inds(space)
140 | if True in inds:
141 | sel_sample = samples[inds]
142 | variableCell.source[space].data = pmf(sel_sample)
143 | variableCell.samples[space].data = dict( x = sel_sample)
144 | else:
145 | variableCell.source[space].data = pmf(samples)
146 | variableCell.samples[space].data = dict( x = samples)
147 | # data cds
148 | data = variableCell.get_data_for_cur_idx_dims_values(variableCell.name)
149 | if data is not None:
150 | max_v = variableCell.get_max_prob(space)
151 | variableCell.data[space] = ColumnDataSource(data = dict(x = data, y = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(data))))
152 |
153 | ## ONLY FOR INTERACTIVE CASE
154 | @staticmethod
155 | def selectionbox_callback(variableCell, space, event):
156 | """
157 | Callback called when selection box is drawn.
158 | """
159 | xmin = event.geometry['x0']
160 | xmax = event.geometry['x1']
161 | cur_idx_dims_values = {}
162 | if variableCell.name in variableCell.cur_idx_dims_values:
163 | cur_idx_dims_values = variableCell.cur_idx_dims_values[variableCell.name]
164 | variableCell.ic.set_selection(variableCell.name, space, (xmin, xmax), cur_idx_dims_values)
165 | for sp in variableCell.spaces:
166 | samples = variableCell.samples[sp].data['x']
167 | variableCell.ic.add_space_threads(threading.Thread(target = partial(CellDiscreteHandler._selectionbox_space_thread, variableCell, sp, samples, xmin, xmax), daemon = True))
168 | variableCell.ic.space_threads_join()
169 |
170 | @staticmethod
171 | def _selectionbox_space_thread(variableCell, space, samples, xmin, xmax):
172 | x_range = variableCell.ic.get_var_x_range(space, variableCell.name)
173 | xmin_list = x_range['xmin']
174 | xmax_list = x_range['xmax']
175 | if len(xmin_list):
176 | variableCell.update_selection_cds(space, xmin_list[0], xmax_list[0])
177 | else:
178 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))
179 | inds = find_indices(samples, lambda e: xmin <= e <= xmax, xmin, xmax)
180 | variableCell.ic.set_sel_var_inds(space, variableCell.name, inds)
181 | variableCell.compute_intersection_of_samples(space)
182 | variableCell.ic.selection_threads_join(space)
183 |
184 | @staticmethod
185 | def update_source_cds_interactive(variableCell, space):
186 | """
187 | Updates source ColumnDataSource (cds).
188 | """
189 | samples = variableCell.get_samples_for_cur_idx_dims_values(variableCell.name, space)
190 | variableCell.source[space].data = pmf(samples)
191 | variableCell.samples[space].data = dict( x = samples)
192 |
193 | @staticmethod
194 | def update_selection_cds_interactive(variableCell, space, xmin, xmax):
195 | """
196 | Updates selection ColumnDataSource (cds).
197 | """
198 | # Get kde points within [xmin,xmax]
199 | data = {}
200 | data['x'] = np.array([])
201 | data['y'] = np.array([])
202 | kde_indices = find_indices(variableCell.source[space].data['x'], lambda e: xmin <= e <= xmax, xmin, xmax)
203 | if len(kde_indices) == 0:
204 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))
205 | return
206 | data['x'] = variableCell.source[space].data['x'][kde_indices]
207 | data['y'] = variableCell.source[space].data['y'][kde_indices]
208 | data['y0'] = np.asarray(len(data['x'])*[0])
209 | variableCell.selection[space].data = data
210 |
211 | @staticmethod
212 | def update_reconstructed_cds_interactive(variableCell, space):
213 | """
214 | Updates reconstructed ColumnDataSource (cds).
215 | """
216 | samples = variableCell.samples[space].data['x']
217 | inds,_ = variableCell.ic.get_sample_inds(space)
218 | # if Tulen(inds):
219 | sel_sample = samples[inds]
220 | variableCell.reconstructed[space].data = pmf(sel_sample)
221 | # data cds
222 | data = variableCell.get_data_for_cur_idx_dims_values(variableCell.name)
223 | if data is not None:
224 | max_v = variableCell.get_max_prob(space)
225 | variableCell.data[space] = ColumnDataSource(data = dict(x = data, y = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(data))))
226 | # # else:
227 | # # variableCell.reconstructed[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))
228 | # ##########TEST###################to be deleted
229 | # # max_v = variableCell.get_max_prob(space)
230 | # # if max_v!=-1:
231 | # # variableCell.samples[space].data['y'] = np.asarray([-1*max_v/RUG_DIST_RATIO]*len(variableCell.samples[space].data['x']))
232 |
233 | @staticmethod
234 | def clear_selection_callback(variableCell, space, event):
235 | """
236 | Callback called when clear selection glyph is clicked.
237 | """
238 | isIn = variableCell.clear_selection[space].data['isIn']
239 | if 1 in isIn:
240 | variableCell.ic.set_var_x_range(space, variableCell.name, dict(xmin = np.array([]), xmax = np.array([])))
241 | variableCell.ic.delete_sel_var_idx_dims_values(variableCell.name)
242 | for sp in variableCell.spaces:
243 | variableCell.ic.add_space_threads(threading.Thread(target = partial(CellDiscreteHandler._clear_selection_cds_update, variableCell, sp), daemon = True))
244 | variableCell.ic.space_threads_join()
245 |
246 | @staticmethod
247 | def _clear_selection_cds_update(variableCell, space):
248 | x_range = variableCell.ic.get_var_x_range(space, variableCell.name)
249 | xmin_list = x_range['xmin']
250 | xmax_list = x_range['xmax']
251 | if len(xmin_list):
252 | variableCell.update_selection_cds(space, xmin_list[0], xmax_list[0])
253 | else:
254 | variableCell.selection[space].data = dict(x = np.array([]), y = np.array([]), y0 = np.array([]))
255 | variableCell.ic.delete_sel_var_inds(space, variableCell.name)
256 | variableCell.compute_intersection_of_samples(space)
257 | variableCell.ic.selection_threads_join(space)
258 |
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/cell_pred_check_handler.py:
--------------------------------------------------------------------------------
1 | from ipme.utils.constants import COLORS, BORDER_COLORS, PLOT_HEIGHT, PLOT_WIDTH, SIZING_MODE
2 | from ipme.utils.functions import get_finite_samples, get_samples_for_pred_check, get_hist_bins_range
3 | from ipme.utils.stats import hist
4 |
5 | import numpy as np
6 |
7 | from bokeh.plotting import figure
8 | from bokeh.models import ColumnDataSource, HoverTool, Legend
9 | from bokeh import events
10 |
11 | from functools import partial
12 |
13 | class CellPredCheckHandler:
14 |
15 | def __init__(self):
16 | pass
17 |
18 | @staticmethod
19 | def initialize_glyphs_interactive(predcheckCell, space):
20 | q = predcheckCell.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=predcheckCell.source[space], \
21 | fill_color=COLORS[0], line_color="white", fill_alpha=1.0, name="full")
22 | seg = predcheckCell.plot[space].segment(x0 ='x0', y0 ='y0', x1='x1', y1='y1', source=predcheckCell.seg[space], \
23 | color="black", line_width=2, name="seg")
24 | q_sel = predcheckCell.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=predcheckCell.reconstructed[space], \
25 | fill_color=COLORS[1], line_color="white", fill_alpha=0.7, name="sel")
26 | ## Add Legends
27 | data = predcheckCell.seg[space].data['x0']
28 | pvalue = predcheckCell.pvalue[space].data["pv"]
29 | if len(data) and len(pvalue):
30 | legend = Legend(items=[ (predcheckCell.func + "(obs) = " + format(data[0], '.2f'), [seg]),
31 | ("p-value = "+format(pvalue[0],'.4f'), [q]),
32 | ], location="top_left")
33 | predcheckCell.plot[space].add_layout(legend, 'above')
34 | ## Add Tooltips for hist
35 | #####TODO:Correct overlap of tooltips#####
36 | TOOLTIPS = [
37 | ("top", "@top"),
38 | ("right","@right"),
39 | ("left","@left"),
40 | ]
41 | hover = HoverTool( tooltips=TOOLTIPS,renderers=[q,q_sel])
42 | predcheckCell.plot[space].tools.append(hover)
43 |
44 | @staticmethod
45 | def initialize_glyphs_static(predcheckCell, space):
46 | q = predcheckCell.plot[space].quad(top='top', bottom='bottom', left='left', right='right', source=predcheckCell.source[space], \
47 | fill_color=COLORS[0], line_color="white", fill_alpha=1.0, name="full")
48 | seg = predcheckCell.plot[space].segment(x0 ='x0', y0 ='y0', x1='x1', y1='y1', source=predcheckCell.seg[space], \
49 | color="black", line_width=2, name="seg")
50 | ## Add Legends
51 | data = predcheckCell.seg[space].data['x0']
52 | pvalue = predcheckCell.pvalue[space].data["pv"]
53 | if len(data) and len(pvalue):
54 | legend = Legend(items=[ (predcheckCell.func + "(obs) = " + format(data[0], '.2f'), [seg]),
55 | ("p-value = "+format(pvalue[0],'.4f'), [q]),
56 | ], location="top_left")
57 | predcheckCell.plot[space].add_layout(legend, 'above')
58 | ## Add Tooltips for hist
59 | #####TODO:Correct overlap of tooltips#####
60 | TOOLTIPS = [
61 | ("top", "@top"),
62 | ("right","@right"),
63 | ("left","@left"),
64 | ]
65 | hover = HoverTool( tooltips=TOOLTIPS,renderers=[q])
66 | predcheckCell.plot[space].tools.append(hover)
67 |
68 | @staticmethod
69 | def initialize_fig(predcheckCell, space):
70 | predcheckCell.plot[space] = figure(tools = "wheel_zoom,reset", toolbar_location = 'right', plot_width = PLOT_WIDTH, plot_height = PLOT_HEIGHT, sizing_mode = SIZING_MODE)
71 | predcheckCell.plot[space].toolbar.logo = None
72 | predcheckCell.plot[space].yaxis.visible = False
73 | predcheckCell.plot[space].xaxis.axis_label = predcheckCell.func + "(" + predcheckCell.name + ")"
74 | predcheckCell.plot[space].border_fill_color = BORDER_COLORS[0]
75 | predcheckCell.plot[space].xaxis[0].ticker.desired_num_ticks = 3
76 |
77 | @staticmethod
78 | def initialize_fig_interactive(predcheckCell, space):
79 | CellPredCheckHandler.initialize_fig(predcheckCell, space)
80 | predcheckCell.ic.sample_inds[space].on_change('data', partial(predcheckCell.sample_inds_callback, space))
81 |
82 | @staticmethod
83 | def initialize_fig_static(predcheckCell, space):
84 | CellPredCheckHandler.initialize_fig(predcheckCell, space)
85 |
86 | @staticmethod
87 | def initialize_cds(predcheckCell, space):
88 | ## ColumnDataSource for full sample set
89 | data, samples = predcheckCell.get_samples_for_cur_idx_dims_values(space)
90 | predcheckCell.samples[space] = ColumnDataSource(data=dict(x=samples))
91 | #data func
92 | if ~np.isfinite(data).all():
93 | data = get_finite_samples(data)
94 | data_func = get_samples_for_pred_check(data, predcheckCell.func)
95 | #samples func
96 | if ~np.isfinite(samples).all():
97 | samples = get_finite_samples(samples)
98 | samples_func = get_samples_for_pred_check(samples, predcheckCell.func)
99 | if samples_func.size:
100 | #pvalue
101 | pv = np.count_nonzero(samples_func>=data_func) / len(samples_func)
102 | #histogram
103 | type = predcheckCell._data.get_var_dist_type(predcheckCell.name)
104 | if type == "Continuous":
105 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type)
106 | else:
107 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type, ref_length = None, ref_values=np.unique(samples.flatten()))
108 |
109 | his, edges = hist(samples_func, bins=bins, range=range, density = True)
110 | #cds
111 | predcheckCell.pvalue[space] = ColumnDataSource(data=dict(pv=[pv]))
112 | predcheckCell.source[space] = ColumnDataSource(data=dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his))))
113 | predcheckCell.seg[space] = ColumnDataSource(data=dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()]))
114 | else:
115 | predcheckCell.pvalue[space] = ColumnDataSource(data=dict(pv=[]))
116 | predcheckCell.source[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[]))
117 | predcheckCell.seg[space] = ColumnDataSource(data=dict(x0=[], x1=[], y0=[], y1=[]))
118 | predcheckCell.ic.initialize_sample_inds(space, dict(inds = [False]*len(predcheckCell.samples[space].data['x'])), dict(non_inds = [True]*len(predcheckCell.samples[space].data['x'])))
119 |
120 | @staticmethod
121 | def initialize_cds_interactive(predcheckCell, space):
122 | CellPredCheckHandler.initialize_cds(predcheckCell, space)
123 | ## ColumnDataSource for restricted sample set
124 | predcheckCell.pvalue_rec[space] = ColumnDataSource(data=dict(pv=[]))
125 | predcheckCell.pvalue_rec[space].on_change('data', partial(predcheckCell.update_legends, space))
126 | predcheckCell.reconstructed[space] = ColumnDataSource(data=dict(left=[], top=[], right=[], bottom=[]))
127 |
128 | @staticmethod
129 | def initialize_cds_static(predcheckCell, space):
130 | CellPredCheckHandler.initialize_cds(predcheckCell, space)
131 |
132 | @staticmethod
133 | def update_cds_interactive(predcheckCell, space):
134 | """
135 | Updates interaction-related ColumnDataSources (cds).
136 | """
137 | predcheckCell.update_sel_samples_cds(space)
138 |
139 | @staticmethod
140 | def update_cds_static(predcheckCell, space):
141 | """
142 | Update samples cds in the static mode
143 | """
144 | ## ColumnDataSource for full sample set
145 | data, samples = predcheckCell.get_samples_for_cur_idx_dims_values(space)
146 | inds, _ = predcheckCell.ic.get_sample_inds(space)
147 | if True in inds:
148 | samples = samples[inds]
149 | predcheckCell.samples[space].data = dict(x=samples)
150 | #data func
151 | if ~np.isfinite(data).all():
152 | data = get_finite_samples(data)
153 | data_func = get_samples_for_pred_check(data, predcheckCell.func)
154 | #samples func
155 | if ~np.isfinite(samples).all():
156 | samples = get_finite_samples(samples)
157 | samples_func = get_samples_for_pred_check(samples, predcheckCell.func)
158 | if samples_func.size:
159 | #pvalue
160 | pv = np.count_nonzero(samples_func>=data_func) / len(samples_func)
161 | #histogram
162 | type = predcheckCell._data.get_var_dist_type(predcheckCell.name)
163 | if type == "Continuous":
164 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type)
165 | else:
166 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type, ref_length = None, ref_values=np.unique(samples.flatten()))
167 |
168 | his, edges = hist(samples_func, bins=bins, range=range, density = True)
169 | #cds
170 | predcheckCell.pvalue[space].data = dict(pv=[pv])
171 | predcheckCell.source[space].data = dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his)))
172 | predcheckCell.seg[space].data = dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()])
173 | else:
174 | predcheckCell.pvalue[space].data = dict(pv=[])
175 | predcheckCell.source[space].data = dict(left=[], top=[], right=[], bottom=[])
176 | predcheckCell.seg[space].data = dict(x0=[], x1=[], y0=[], y1=[])
177 |
178 | ## ONLY FOR INTERACTIVE CASE
179 | @staticmethod
180 | def update_source_cds_interactive(predcheckCell, space):
181 | """
182 | Updates samples ColumnDataSource (cds).
183 | """
184 | ## ColumnDataSource for full sample set
185 | data, samples = predcheckCell.get_samples_for_cur_idx_dims_values(space)
186 | predcheckCell.samples[space].data = dict(x=samples)
187 | #data func
188 | if ~np.isfinite(data).all():
189 | data = get_finite_samples(data)
190 | data_func = get_samples_for_pred_check(data, predcheckCell.func)
191 | #samples func
192 | if ~np.isfinite(samples).all():
193 | samples = get_finite_samples(samples)
194 | samples_func = get_samples_for_pred_check(samples, predcheckCell.func)
195 | if samples_func.size:
196 | #pvalue
197 | pv = np.count_nonzero(samples_func>=data_func) / len(samples_func)
198 | #histogram
199 | type = predcheckCell._data.get_var_dist_type(predcheckCell.name)
200 | if type == "Continuous":
201 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type)
202 | else:
203 | bins, range = get_hist_bins_range(samples_func, predcheckCell.func, type, ref_length = None, ref_values=np.unique(samples.flatten()))
204 |
205 | his, edges = hist(samples_func, bins=bins, range=range, density = True)
206 | #cds
207 | predcheckCell.pvalue[space].data = dict(pv=[pv])
208 | predcheckCell.source[space].data = dict(left=edges[:-1], top=his, right=edges[1:], bottom=np.zeros(len(his)))
209 | predcheckCell.seg[space].data = dict(x0=[data_func], x1=[data_func], y0=[0], y1=[his.max() + 0.1 * his.max()])
210 | else:
211 | predcheckCell.pvalue[space].data = data=dict(pv=[])
212 | predcheckCell.source[space].data = dict(left=[], top=[], right=[], bottom=[])
213 | predcheckCell.seg[space].data = dict(x0=[], x1=[], y0=[], y1=[])
214 |
215 | @staticmethod
216 | def update_sel_samples_cds_interactive(predcheckCell, space):
217 | """
218 | Updates reconstructed ColumnDataSource (cds).
219 | """
220 | samples = predcheckCell.samples[space].data['x']
221 | max_full_hist = predcheckCell.source[space].data['top'].max()
222 | if samples.size:
223 | inds,_ = predcheckCell.ic.get_sample_inds(space)
224 | if True in inds:
225 | sel_sample = samples[inds]
226 | if ~np.isfinite(sel_sample).all():
227 | sel_sample = get_finite_samples(sel_sample)
228 | sel_sample_func = get_samples_for_pred_check(sel_sample, predcheckCell.func)
229 | #data func
230 | data_func = predcheckCell.seg[space].data['x0'][0]
231 | #pvalue in restricted space
232 | sel_pv = np.count_nonzero(sel_sample_func >= data_func) / len(sel_sample_func)
233 | #compute updated histogram
234 | min_p = predcheckCell.source[space].data['left'][0]
235 | max_p = predcheckCell.source[space].data['right'][-1]
236 | min_c = sel_sample_func.min()
237 | max_c = sel_sample_func.max()
238 | if min_c < min_p or max_c > max_p:
239 | ref_len = predcheckCell.source[space].data['right'][0] - min_p
240 | bins, range = get_hist_bins_range(sel_sample_func, predcheckCell.func, predcheckCell._type, ref_length=ref_len)
241 | else:
242 | range = (min_p,max_p)
243 | bins = len(predcheckCell.source[space].data['right'])
244 | his, edges = hist(sel_sample_func, bins=bins, range=range)
245 | ##max selected hist
246 | max_sel_hist = his.max()
247 | #update reconstructed cds
248 | predcheckCell.pvalue_rec[space].data = dict(pv=[sel_pv])
249 | predcheckCell.reconstructed[space].data = dict(left=edges[:-1], top=his, right =edges[1:], bottom=np.zeros(len(his)))
250 | predcheckCell.seg[space].data['y1'] = [max_sel_hist + 0.1 * max_sel_hist]
251 | else:
252 | predcheckCell.pvalue_rec[space].data = dict(pv=[])
253 | predcheckCell.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[])
254 | predcheckCell.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist]
255 | else:
256 | predcheckCell.pvalue_rec[space].data = dict(pv=[])
257 | predcheckCell.reconstructed[space].data = dict(left=[], top=[], right=[], bottom=[])
258 | predcheckCell.seg[space].data['y1'] = [max_full_hist + 0.1 * max_full_hist]
259 |
260 |
261 |
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/cell_scatter_handler.py:
--------------------------------------------------------------------------------
1 | from logging import error
2 | from ipme.utils.constants import COLORS, BORDER_COLORS, PLOT_HEIGHT, PLOT_WIDTH, SIZING_MODE
3 |
4 | from bokeh.models import ColumnDataSource, HoverTool
5 | from bokeh.plotting import figure
6 |
7 | import arviz as az
8 | from functools import partial
9 |
10 | class CellScatterHandler:
11 |
12 | def __init__(self):
13 | pass
14 |
15 | @staticmethod
16 | def initialize_glyphs_interactive(scatterCell, space):
17 | so = scatterCell.plot[space].circle(x="x", y="y", source = scatterCell.non_sel_samples[space], size=4, color=COLORS[0], line_color=None, fill_alpha = 0.1)
18 | scatterCell.plot[space].patches(xs="x", ys="y", source = scatterCell.contours[space], line_color="line_color", fill_alpha="fill_alpha")
19 | re = scatterCell.plot[space].circle(x="x", y="y", source = scatterCell.sel_samples[space], size=4, color=COLORS[1], line_color=None, fill_alpha = 0.4, name="re")
20 | ##Tooltips
21 | TOOLTIPS = [("x", "@x"), ("y","@y"),]
22 | hover = HoverTool( tooltips = TOOLTIPS, renderers = [so,re], mode = 'mouse')
23 | scatterCell.plot[space].tools.append(hover)
24 |
25 | @staticmethod
26 | def initialize_glyphs_static(scatterCell, space):
27 | so = scatterCell.plot[space].circle(x="x", y="y", source = scatterCell.samples[space], size=7, color=COLORS[0], line_color=None, fill_alpha = 0.1)
28 | scatterCell.plot[space].patches(xs="x", ys="y", source = scatterCell.contours[space], line_color="line_color", fill_alpha="fill_alpha")
29 | ##Tooltips
30 | TOOLTIPS = [("x", "@x"), ("y","@y"),]
31 | hover = HoverTool( tooltips = TOOLTIPS, renderers = [so], mode = 'mouse')
32 | scatterCell.plot[space].tools.append(hover)
33 |
34 | @staticmethod
35 | def initialize_fig(scatterCell, space):
36 | var1 = scatterCell.vars[0]
37 | var2 = scatterCell.vars[1]
38 | scatterCell.plot[space] = figure(x_range = scatterCell.x_range[var2][space], y_range = scatterCell.x_range[var1][space], tools = "wheel_zoom,reset,box_zoom", toolbar_location = 'right',
39 | plot_width = PLOT_WIDTH, plot_height = PLOT_HEIGHT, sizing_mode = SIZING_MODE)#tools = [], toolbar_location = None
40 | scatterCell.plot[space].border_fill_color = BORDER_COLORS[0]
41 | scatterCell.plot[space].min_border = 15
42 | scatterCell.plot[space].xaxis.axis_label = var2
43 | scatterCell.plot[space].yaxis.axis_label = var1
44 | # scatterCell.plot[space].yaxis.visible = False
45 | scatterCell.plot[space].toolbar.logo = None
46 | scatterCell.plot[space].xaxis[0].ticker.desired_num_ticks = 3
47 |
48 | @staticmethod
49 | def initialize_fig_interactive(scatterCell, space):
50 | CellScatterHandler.initialize_fig(scatterCell, space)
51 | ##on_change
52 | scatterCell.ic.sample_inds_update[space].on_change('data', partial(scatterCell.sample_inds_callback, space))
53 |
54 | @staticmethod
55 | def initialize_fig_static(scatterCell, space):
56 | CellScatterHandler.initialize_fig(scatterCell, space)
57 | ##on_change
58 | scatterCell.ic.sample_inds_update[space].on_change('data', partial(scatterCell.sample_inds_callback, space))
59 |
60 | @staticmethod
61 | def initialize_cds(scatterCell, space):
62 | var1 = scatterCell.vars[0]
63 | var2 = scatterCell.vars[1]
64 | samples1 = scatterCell.get_samples_for_cur_idx_dims_values(var1, space)
65 | samples2 = scatterCell.get_samples_for_cur_idx_dims_values(var2, space)
66 | scatterCell.samples[space] = ColumnDataSource(data = dict(x = samples2, y = samples1))
67 | patch_x, patch_y = CellScatterHandler.get_contours(samples2, samples1)
68 | scatterCell.contours[space] = ColumnDataSource(data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x)))
69 | scatterCell.ic.initialize_sample_inds(space, dict(inds = [False]*len(scatterCell.samples[space].data['x'])), dict(non_inds = [True]*len(scatterCell.samples[space].data['x'])))
70 |
71 | @staticmethod
72 | def initialize_cds_interactive(scatterCell, space):
73 | CellScatterHandler.initialize_cds(scatterCell, space)
74 | inds, non_inds = scatterCell.ic.get_sample_inds(space)
75 | sel_y = scatterCell.samples[space].data['y'][inds]
76 | non_sel_y = scatterCell.samples[space].data['y'][non_inds]
77 | sel_samples = scatterCell.samples[space].data['x'][inds]
78 | non_sel_samples = scatterCell.samples[space].data['x'][non_inds]
79 | scatterCell.sel_samples[space] = ColumnDataSource(data = dict( x = sel_samples, y = sel_y))
80 | scatterCell.non_sel_samples[space] = ColumnDataSource(data = dict( x = non_sel_samples, y = non_sel_y))
81 |
82 | @staticmethod
83 | def initialize_cds_static(scatterCell, space):
84 | CellScatterHandler.initialize_cds(scatterCell, space)
85 | #########TEST###########
86 |
87 | @staticmethod
88 | def update_cds_interactive(scatterCell, space):
89 | """
90 | Updates interaction-related ColumnDataSources (cds).
91 | """
92 | scatterCell.update_sel_samples_cds(space)
93 |
94 | @staticmethod
95 | def update_cds_static(scatterCell, space):
96 | """
97 | Update samples cds in the static mode
98 | """
99 | var1 = scatterCell.vars[0]
100 | var2 = scatterCell.vars[1]
101 | samples1 = scatterCell.get_samples_for_cur_idx_dims_values(var1, space)
102 | samples2 = scatterCell.get_samples_for_cur_idx_dims_values(var2, space)
103 | inds, _ = scatterCell.ic.get_sample_inds(space)
104 | if True in inds:
105 | sel_sample1 = samples1[inds]
106 | sel_sample2 = samples2[inds]
107 | scatterCell.samples[space].data = dict(x=sel_sample2, y=sel_sample1)
108 | patch_x, patch_y = CellScatterHandler.get_contours(sel_sample2, sel_sample1)
109 | scatterCell.contours[space].data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x))
110 | else:
111 | scatterCell.samples[space].data = dict(x=samples2, y=samples1)
112 | patch_x, patch_y = CellScatterHandler.get_contours(samples2, samples1)
113 | scatterCell.contours[space].data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x))
114 |
115 | ## ONLY FOR INTERACTIVE CASE
116 | @staticmethod
117 | def update_source_cds_interactive(scatterCell, space):
118 | """
119 | Updates samples ColumnDataSource (cds).
120 | """
121 | var1 = scatterCell.vars[0]
122 | var2 = scatterCell.vars[1]
123 | samples1 = scatterCell.get_samples_for_cur_idx_dims_values(var1, space)
124 | samples2 = scatterCell.get_samples_for_cur_idx_dims_values(var2, space)
125 | scatterCell.samples[space].data = dict(x=samples2, y=samples1)
126 | patch_x, patch_y = CellScatterHandler.get_contours(samples2, samples1)
127 | scatterCell.contours[space].data = dict(x = patch_x, y = patch_y, line_color=["black"]*len(patch_x), fill_alpha=[0]*len(patch_x))
128 |
129 | @staticmethod
130 | def update_sel_samples_cds_interactive(scatterCell, space):
131 | """
132 | Updates reconstructed ColumnDataSource (cds).
133 | """
134 | samples1 = scatterCell.samples[space].data['x']
135 | samples2 = scatterCell.samples[space].data['y']
136 | inds, non_inds = scatterCell.ic.get_sample_inds(space)
137 | # update sel samples
138 | sel_sample1 = samples1[inds]
139 | sel_sample2 = samples2[inds]
140 | scatterCell.sel_samples[space].data = dict(x=sel_sample1, y=sel_sample2)
141 | # update transparency
142 | for re in scatterCell.plot[space].renderers:
143 | if re.name == "re" and len(sel_sample1) < 50:
144 | re.glyph.fill_alpha = 0.8
145 | elif re.name == "re" and len(sel_sample1) < 200:
146 | re.glyph.fill_alpha = 0.5
147 | elif re.name == "re":
148 | re.glyph.fill_alpha = 0.2
149 | # update non_sel samples
150 | non_sel_sample1 = samples1[non_inds]
151 | non_sel_sample2 = samples2[non_inds]
152 | scatterCell.non_sel_samples[space].data = dict(x=non_sel_sample1, y=non_sel_sample2)
153 |
154 | # @staticmethod
155 | # def set_transparency(samples1, samples2, sel_samples1, sel_samples2):
156 | # s_x_min = samples1.min()
157 | # s_x_max = samples1.max()
158 | # s_y_min = samples2.min()
159 | # s_y_max = samples2.max()
160 |
161 | # sels_x_min = sel_samples1.min()
162 | # sels_x_max = sel_samples1.max()
163 | # sels_y_min = sel_samples2.min()
164 | # sels_y_max = sel_samples2.max()
165 |
166 | @staticmethod
167 | def get_contours(x, y):
168 | try:
169 | _, contour_glyphs = az.plot_kde(x, y,
170 | # hdi_probs=[0.393, 0.865, 0.989], # 1, 2 and 3 sigma contours
171 | contour_kwargs={"line_color":"black", "line_alpha":1},
172 | contourf_kwargs={"fill_alpha": 0, "cmap": "viridis"},
173 | backend="bokeh", return_glyph = True, show = False )
174 | patch_x = []
175 | patch_y = []
176 | for renderer in contour_glyphs:
177 | patch_x.append(renderer.data_source.data['x'].tolist())
178 | patch_y.append(renderer.data_source.data['y'].tolist())
179 | return patch_x, patch_y
180 | except ValueError:
181 | return [],[]
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/cell_widgets.py:
--------------------------------------------------------------------------------
1 | from ipme.utils.functions import get_dim_names_options
2 | from .global_reset import GlobalReset
3 |
4 | from bokeh.models.widgets import Select, Button
5 |
6 | from functools import partial
7 | import threading
8 |
9 | class CellWidgets:
10 |
11 | def __init__(self):
12 | pass
13 |
14 | @staticmethod
15 | def initialize_widgets(cell):
16 | for space in cell.spaces:
17 | cell.widgets[space] = {}
18 | for var in cell.idx_dims:
19 | for _, d_dim in cell.idx_dims[var].items():
20 | n1, n2, opt1, opt2 = get_dim_names_options(d_dim)
21 | if n1 not in cell.widgets[space]:
22 | cell.widgets[space][n1] = Select(title = n1, value = opt1[0], options = opt1)
23 | cell.widgets[space][n1].on_change("value", partial(cell.widget_callback, w_title = n1, space = space))
24 | if var not in cell.cur_idx_dims_values:
25 | cell.cur_idx_dims_values[var] = {}
26 | if n1 not in cell.cur_idx_dims_values[var]:
27 | inds = [i for i,v in enumerate(d_dim.values) if v == opt1[0]]
28 | cell.cur_idx_dims_values[var][n1] = inds
29 | if n2:
30 | if n2 not in cell.widgets[space]:
31 | cell.widgets[space][n2] = Select(title = n2, value = opt2[0], options=opt2)
32 | cell.widgets[space][n2].on_change("value", partial(cell.widget_callback, w_title = n2, space = space))
33 | cell.ic.idx_widgets_mapping(space, d_dim, n1, n2)
34 | if n2 not in cell.cur_idx_dims_values[var]:
35 | cell.cur_idx_dims_values[var][n2] = [0]
36 |
37 | # def _widget_callback(self, attr, old, new, w_title, space):
38 | # """
39 | # Callback called when an indexing dimension is set to
40 | # a new coordinate (e.g through indexing dimensions widgets).
41 | # """
42 | # if old == new:
43 | # return
44 | # self._ic.add_widget_threads(threading.Thread(target=partial(self._widget_callback_thread, new, w_title, space), daemon=True))
45 | # self._ic._widget_lock_event.set()
46 | #
47 | # def _widget_callback_thread(self, new, w_title, space):
48 | # inds = -1
49 | # w2_title = ""
50 | # values = []
51 | # w1_w2_idx_mapping = self._ic._get_w1_w2_idx_mapping()
52 | # w2_w1_val_mapping = self._ic._get_w2_w1_val_mapping()
53 | # w2_w1_idx_mapping = self._ic._get_w2_w1_idx_mapping()
54 | # widgets = self._widgets[space]
55 | # if space in w1_w2_idx_mapping and \
56 | # w_title in w1_w2_idx_mapping[space]:
57 | # for w2_title in w1_w2_idx_mapping[space][w_title]:
58 | # name = w_title+"_idx_"+w2_title
59 | # if name in self._idx_dims:
60 | # values = self._idx_dims[name].values
61 | # elif w_title in self._idx_dims:
62 | # values = self._idx_dims[w_title].values
63 | # elif space in w2_w1_idx_mapping and \
64 | # w_title in w2_w1_idx_mapping[space]:
65 | # for w1_idx in w2_w1_idx_mapping[space][w_title]:
66 | # w1_value = widgets[w1_idx].value
67 | # values = w2_w1_val_mapping[space][w_title][w1_value]
68 | # inds = [i for i,v in enumerate(values) if v == new]
69 | # if inds == -1 or len(inds) == 0:
70 | # return
71 | # self._cur_idx_dims_values[w_title] = inds
72 | # if w2_title and w2_title in self._cur_idx_dims_values:
73 | # self._cur_idx_dims_values[w2_title] = [0]
74 | # if self._mode == 'i':
75 | # self._update_source_cds(space)
76 | # self._ic._set_global_update(True)
77 | # self._update_cds_interactive(space)
78 | # elif self._mode == 's':
79 | # self._update_cds_static(space)
80 |
81 | @staticmethod
82 | def _widget_callback_int(variableCell, attr, old, new, w_title, space):
83 | """
84 | Callback called when an indexing dimension is set to
85 | a new coordinate (e.g through indexing dimensions widgets).
86 | """
87 | if old == new:
88 | return
89 | variableCell.ic.add_widget_threads(threading.Thread(target = partial(CellWidgets._widget_callback_thread_inter, variableCell, new, w_title, space), daemon = True))
90 | variableCell.ic.widget_lock_event.set()
91 |
92 | @staticmethod
93 | def _widget_callback_static(variableCell, attr, old, new, w_title, space):
94 | """
95 | Callback called when an indexing dimension is set to
96 | a new coordinate (e.g through indexing dimensions widgets).
97 | """
98 | if old == new:
99 | return
100 | variableCell.ic.add_widget_threads(threading.Thread(target = partial(CellWidgets._widget_callback_thread_static, variableCell, new, w_title, space), daemon = True))
101 | variableCell.ic.widget_lock_event.set()
102 |
103 | def _widget_callback_thread_inter(variableCell, new, w_title, space):
104 | w1_w2_idx_mapping = variableCell.ic.get_w1_w2_idx_mapping()
105 | w2_w1_val_mapping = variableCell.ic.get_w2_w1_val_mapping()
106 | w2_w1_idx_mapping = variableCell.ic.get_w2_w1_idx_mapping()
107 | widgets = variableCell.widgets[space]
108 | for var in variableCell.idx_dims:
109 | inds = -1
110 | w2_title = ""
111 | values = []
112 | if space in w1_w2_idx_mapping and w_title in w1_w2_idx_mapping[space]:
113 | for w2_title in w1_w2_idx_mapping[space][w_title]:##review this
114 | name = w_title+"_idx_"+w2_title
115 | if name in variableCell.idx_dims[var]:
116 | values = variableCell.idx_dims[var][name].values
117 | if len(values) == 0 and w_title in variableCell.idx_dims[var]:
118 | values = variableCell.idx_dims[var][w_title].values
119 | if len(values) == 0 and space in w2_w1_idx_mapping and w_title in w2_w1_idx_mapping[space]:
120 | for w1_idx in w2_w1_idx_mapping[space][w_title]:
121 | w1_value = widgets[w1_idx].value
122 | values = w2_w1_val_mapping[space][w_title][w1_value]
123 | inds = [i for i,v in enumerate(values) if v == new]
124 | if inds == -1 or len(inds) == 0:
125 | return
126 | if w_title in variableCell.cur_idx_dims_values[var]:
127 | variableCell.cur_idx_dims_values[var][w_title] = inds
128 | if w2_title and w2_title in variableCell.cur_idx_dims_values[var]:
129 | variableCell.cur_idx_dims_values[var][w2_title] = [0]
130 | variableCell.update_source_cds(space)
131 | variableCell.ic.set_global_update(True)
132 | variableCell.update_cds(space)
133 |
134 | def _widget_callback_thread_static(variableCell, new, w_title, space):
135 | w1_w2_idx_mapping = variableCell.ic.get_w1_w2_idx_mapping()
136 | w2_w1_val_mapping = variableCell.ic.get_w2_w1_val_mapping()
137 | w2_w1_idx_mapping = variableCell.ic.get_w2_w1_idx_mapping()
138 | widgets = variableCell.widgets[space]
139 | for var in variableCell.idx_dims:
140 | inds = -1
141 | w2_title = ""
142 | values = []
143 | if space in w1_w2_idx_mapping and w_title in w1_w2_idx_mapping[space]:
144 | for w2_title in w1_w2_idx_mapping[space][w_title]:
145 | name = w_title+"_idx_"+w2_title
146 | if name in variableCell.idx_dims[var]:
147 | values = variableCell.idx_dims[var][name].values
148 | if len(values) == 0 and w_title in variableCell.idx_dims[var]:
149 | values = variableCell.idx_dims[var][w_title].values
150 | # elif w_title in variableCell.idx_dims:
151 | # values = variableCell.idx_dims[w_title].values
152 | if len(values) == 0 and space in w2_w1_idx_mapping and w_title in w2_w1_idx_mapping[space]:
153 | # elif space in w2_w1_idx_mapping and w_title in w2_w1_idx_mapping[space]:
154 | for w1_idx in w2_w1_idx_mapping[space][w_title]:
155 | w1_value = widgets[w1_idx].value
156 | values = w2_w1_val_mapping[space][w_title][w1_value]
157 | inds = [i for i,v in enumerate(values) if v == new]
158 | if inds == -1 or len(inds) == 0:
159 | return
160 | if w_title in variableCell.cur_idx_dims_values[var]:
161 | variableCell.cur_idx_dims_values[var][w_title] = inds
162 | if w2_title and w2_title in variableCell.cur_idx_dims_values[var]:
163 | variableCell.cur_idx_dims_values[var][w2_title] = [0]
164 | variableCell.update_cds(space)
165 |
166 | @staticmethod
167 | def widget_callback_interactive(variableCell, attr, old, new, w_title, space):
168 | CellWidgets._widget_callback_int(variableCell, attr, old, new, w_title, space)
169 |
170 | @staticmethod
171 | def widget_callback_static(variableCell, attr, old, new, w_title, space):
172 | CellWidgets._widget_callback_static(variableCell, attr, old, new, w_title, space)
173 |
174 | @staticmethod
175 | def link_cells_widgets(grid):
176 | for c_id, cell in grid.cells.items():
177 | cell_spaces = cell.get_spaces()
178 | for space in cell_spaces:
179 | for w_id, w in cell.get_widgets_in_space(space).items():
180 | if w_id in grid.cells_widgets:
181 | if space in grid.cells_widgets[w_id]:
182 | grid.cells_widgets[w_id][space].append(c_id)
183 | else:
184 | grid.cells_widgets[w_id][space] = [c_id]
185 | ## Every new widget is linked to the corresponding widget (of same name)
186 | ## of the 1st space in grid.cells_widgets[w_id]
187 | ## Find target cell to link with current cell
188 | f_space = list(grid.cells_widgets[w_id].keys())[0]
189 | CellWidgets._link_widget_to_target(grid, w, w_id, f_space)
190 | else:
191 | grid.cells_widgets[w_id] = {}
192 | grid.cells_widgets[w_id][space] = [c_id]
193 | f_space = list(grid.cells_widgets[w_id].keys())[0]
194 | if f_space != space:
195 | CellWidgets._link_widget_to_target(grid, w, w_id, f_space)
196 | else:
197 | w = grid.cells[c_id].get_widget(space, w_id)
198 | w.on_change('value', partial(grid.ic.menu_item_click_callback, grid, space, w_id))
199 |
200 | @staticmethod
201 | def _link_widget_to_target(grid, w, w_id, f_space):
202 | if len(grid.cells_widgets[w_id][f_space]):
203 | t_c_id = grid.cells_widgets[w_id][f_space][0]
204 | t_w = grid.cells[t_c_id].get_widget(f_space, w_id)
205 | if t_w is not None and hasattr(t_w,'js_link'):
206 | t_w.js_link('value', w, 'value')
207 |
208 | @staticmethod
209 | def set_plotted_widgets_interactive(grid):
210 | grid.plotted_widgets = {}
211 | for w_id, space_widgets_dict in grid.cells_widgets.items():
212 | w_spaces = list(space_widgets_dict.keys())
213 | if len(w_spaces):
214 | f_space = w_spaces[0]
215 | f_w_list = space_widgets_dict[f_space]
216 | if len(f_w_list):
217 | c_id = f_w_list[0]
218 | for space in w_spaces:
219 | if space not in grid.plotted_widgets:
220 | grid.plotted_widgets[space] = {}
221 | grid.plotted_widgets[space][w_id] = grid.cells[c_id].get_widget(f_space, w_id)
222 | b = Button(label='Reset Diagram', button_type="primary")
223 | b.on_click(partial(GlobalReset.global_reset_callback, grid))
224 | for space in grid.get_grids():
225 | if space not in grid.plotted_widgets:
226 | grid.plotted_widgets[space] = {}
227 | grid.plotted_widgets[space]["resetButton"] = b
228 |
229 | @staticmethod
230 | def set_plotted_widgets_static(grid):
231 | grid.plotted_widgets = {}
232 | for w_id, space_widgets_dict in grid.cells_widgets.items():
233 | w_spaces = list(space_widgets_dict.keys())
234 | if len(w_spaces):
235 | f_space = w_spaces[0]
236 | f_w_list = space_widgets_dict[f_space]
237 | if len(f_w_list):
238 | c_id = f_w_list[0]
239 | for space in w_spaces:
240 | if space not in grid.plotted_widgets:
241 | grid.plotted_widgets[space] = {}
242 | grid.plotted_widgets[space][w_id] = grid.cells[c_id].get_widget(f_space, w_id)
243 | for space in grid.get_grids():
244 | if space not in grid.plotted_widgets:
245 | grid.plotted_widgets[space] = {}
246 |
--------------------------------------------------------------------------------
/ipme/classes/cell/utils/global_reset.py:
--------------------------------------------------------------------------------
1 | import threading
2 | from functools import partial
3 |
4 | class GlobalReset:
5 |
6 | def __init__(self):
7 | pass
8 |
9 | @staticmethod
10 | def global_reset_callback(grid, event):
11 | grid.ic.reset_sel_var_inds()
12 | grid.ic.reset_sel_space()
13 | grid.ic.reset_sel_var_idx_dims_values()
14 | grid.ic.reset_var_x_range()
15 | grid.ic.set_global_update(True)
16 | for sp in grid.get_grids():
17 | grid.ic.add_space_threads(threading.Thread(target = partial(GlobalReset._global_reset_thread, grid.ic, sp), daemon = True))
18 | grid.ic.space_threads_join()
19 | grid.ic.set_global_update(False)
20 |
21 | @staticmethod
22 | def _global_reset_thread(ic, space):
23 | ic.reset_sample_inds(space)
24 | ic.selection_threads_join(space)
25 |
--------------------------------------------------------------------------------
/ipme/classes/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/data/__init__.py
--------------------------------------------------------------------------------
/ipme/classes/data/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 |
4 | from ...interfaces.data_interface import Data_Interface
5 | from .dimension import Dimension
6 |
7 | class Data(Data_Interface):
8 |
9 | def __init__(self, inference_path):
10 | """
11 | Parameters:
12 | --------
13 | inference_path A String of the inference data file path.
14 | Sets:
15 | --------
16 | _inferencedata A structure of the inference data.
17 | _graph A structure of the model's variables graph data.
18 | _observed_variables A List of Strings of the observed variables.
19 | _idx_dimensions A List of Dimension Objects defining
20 | the indexing dimensions of the data.
21 | _spaces A List of Strings in {'prior', 'posterior'} of all
22 | the available MCMC sample spaces in the inference data
23 | """
24 | Data_Interface.__init__(self, inference_path)
25 |
26 | def _load_inference_data(self, datapath):
27 | """
28 | Inference data is retrieved from memory in .npz format
29 |
30 | Parameters:
31 | --------
32 | datapath A String of the inference data path.
33 | Returns:
34 | --------
35 | A Dictionary of inference data
36 | """
37 | try:
38 | return np.load(datapath)
39 | except IOError:
40 | print("File %s cannot be loaded" % datapath)
41 | return None
42 |
43 | def _get_graph(self):
44 | """
45 | A graph of the probabilistic model in json format is
46 | retrieved from inference data.
47 |
48 | Returns:
49 | --------
50 | A Dictionary of graph data
51 | """
52 | graph=""
53 | try:
54 | header = self._inferencedata['header.json']
55 | header_js = json.loads(header)
56 | graph = header_js["inference_data"]["sample_stats"]["attrs"]["graph"]
57 | except KeyError:
58 | print("Inference_data has no key 'header.json'")
59 | return None
60 | return json.loads(graph.replace("'", "\""))
61 | #return json.loads(graph)
62 |
63 | def _get_obeserved_variables(self):
64 | return [k for k,v in self._graph.items() if v['type'] == 'observed']
65 |
66 | def _get_idx_dimensions(self):
67 | """
68 | Returns:
69 | --------
70 | idx_dimensions A Dict {: List of Dimension objects}
71 | """
72 | idx_dimensions={}
73 | for k,v in self._graph.items():
74 | if not k.endswith("__"):
75 | vdims = v["dims"]
76 | vcoords = v["coords"]
77 | if len(vdims):
78 | idx_dimensions[v["name"]] = {}
79 | for d in vdims:
80 | if d in vcoords:
81 | idx_dimensions[v["name"]][d] = Dimension(d, values = vcoords[d])
82 | return idx_dimensions
83 |
84 | def _get_spaces_from_data(self):
85 | """
86 | Returns:
87 | --------
88 | spaces A List of Strings in {'prior','posterior','prior_predictive','posterior_predictive','observed_data',
89 | 'constant_data', 'sample_stats', 'log_likelihood', 'predictions', 'predictions_constant_data'}
90 | """
91 | spaces=[]
92 | if 'header.json' in self._inferencedata:
93 | header = self._inferencedata['header.json']
94 | header_js = json.loads(header)
95 | if 'inference_data' in header_js:
96 | for space in ['prior','posterior','prior_predictive','posterior_predictive','observed_data','constant_data', 'sample_stats', 'log_likelihood', 'predictions', 'predictions_constant_data']:
97 | if space in header_js['inference_data']:
98 | spaces.append(space)
99 | return spaces
100 |
101 | def is_observed_variable(self,var_name):
102 | if var_name in self._observed_variables:
103 | return True
104 | else:
105 | return False
106 |
107 | def get_var_type(self, var_name):
108 | """"
109 | Return any in {"observed","free","deterministic"}
110 | """
111 | return self._graph[var_name]["type"]
112 |
113 | def get_var_dist(self, var_name):
114 | if "dist" in self._graph[var_name]["distribution"]:
115 | return self._graph[var_name]["distribution"]["dist"]
116 | else:
117 | return self._graph[var_name]["type"]
118 |
119 | def get_var_dist_type(self, var_name):
120 | """"
121 | Return any in {"Continuous","Discrete"}
122 | """
123 | if "type" in self._graph[var_name]["distribution"]:
124 | return self._graph[var_name]["distribution"]["type"]
125 | else:
126 | return ""
127 |
128 | def get_var_parents(self, var_name):
129 | if "parents" in self._graph[var_name]:
130 | return self._graph[var_name]["parents"]
131 | else:
132 | return []
133 |
134 | def get_samples(self, var_name, space=['prior','posterior']):
135 | """
136 | Returns the samples of variable of the given space(s).
137 |
138 | Parameters:
139 | --------
140 | var_name A String of the model's variables name
141 | space Either a List of Strings or a String with String in {'prior','posterior'}
142 | Returns:
143 | --------
144 | A Dictionary of the form { : }
145 | A String from {"prior", "posterior"}
146 | A numpy.ndarray of samples of the parameter.
147 | e.g. PyMC3 shape=N, sample=M
148 | for i in M
149 | for j in N:
150 | Element (i,j) = (ith sample, jth prior/posterior distribution draw).
151 | If a String of space is given, it returns the numpy.ndarray.
152 |
153 | """
154 | array_name=""
155 | header = self._inferencedata['header.json']
156 | header_js = json.loads(header)
157 | if isinstance(space, list):
158 | data = {}
159 | for sp in space:
160 | if sp in self._spaces and self._is_var_in_space(var_name,sp):
161 | array_name = header_js['inference_data'][sp]['array_names'][var_name]
162 | if 'chain' in header_js['inference_data'][sp]['vars'][var_name]['dims']:
163 | data[sp] = np.mean(self._inferencedata[array_name],axis=0)
164 | else:
165 | data[sp] = self._inferencedata[array_name]
166 | return data
167 | elif isinstance(space, str):
168 | data = np.asarray([])
169 | if space in self._spaces and self._is_var_in_space(var_name,space):
170 | array_name = header_js['inference_data'][space]['array_names'][var_name]
171 | if 'chain' in header_js['inference_data'][space]['vars'][var_name]['dims']:
172 | data = np.mean(self._inferencedata[array_name],axis=0)
173 | else:
174 | data = self._inferencedata[array_name]
175 | return data
176 | else:
177 | raise ValueError("space argument of get_sample should be either a List of Strings or a String")
178 |
179 | def get_observations(self, var_name):
180 | """
181 | Returns the observations of variable.
182 |
183 | Parameters:
184 | --------
185 | var_name A String of the model's variables name
186 | Returns:
187 | --------
188 | A numpy.ndarray of observations of the parameter.
189 | """
190 | array_name=""
191 | header = self._inferencedata['header.json']
192 | header_js = json.loads(header)
193 | data = None
194 | if var_name in header_js['inference_data']['observed_data']['array_names']:
195 | array_name = header_js['inference_data']['observed_data']['array_names'][var_name]
196 | dims = header_js['inference_data']['observed_data']['vars'][var_name]['dims']
197 | if 'chain' in dims:
198 | data = np.mean(self._inferencedata[array_name], axis=0)
199 | else:
200 | data = self._inferencedata[array_name]
201 | return data
202 | else:
203 | return data
204 |
205 | def get_range(self, var_name, space=['prior','posterior']):
206 | """
207 | Returns the range of samples of variable of the given space(s).
208 |
209 | Parameters:
210 | --------
211 | var_name A String of the model's variables name
212 | space Either a List of Strings or a String with String in {'prior','posterior'}
213 | Returns:
214 | --------
215 | Tuple (min,max)
216 |
217 | """
218 | if self.get_var_type(var_name) == "observed":
219 | if space == "posterior" and "posterior_predictive" in self.get_spaces():
220 | space="posterior_predictive"
221 | elif space == "prior" and "prior_predictive" in self.get_spaces():
222 | space="prior_predictive"
223 | data = self.get_samples(var_name, space)
224 | min=0
225 | max=0
226 | if data.size:
227 | min = np.amin(data)
228 | max = np.amax(data)
229 | return (min - 0.1*(max-min),max + 0.1*(max-min))
230 |
231 | def get_varnames_per_graph_level(self, vars):
232 | """
233 | Matches variable names to graph levels.
234 |
235 | Returns:
236 | --------
237 | A Dict of the model's parameters names per graph level.
238 | Level 0 corresponds to the higher level.
239 | {: List of variable names}, level=0,1,2...
240 | """
241 | nodes = self._add_nodes_to_graph(self._get_observed_nodes(), 0)
242 | varnames_per_graph_level = {}
243 | if vars == 'all':
244 | varnames_per_graph_level = self._reverse_nodes_levels(nodes)
245 | else:
246 | vars_per_level = self._reverse_nodes_levels(nodes)
247 | level = 0
248 | for l in sorted(vars_per_level):
249 | for var in vars:
250 | if var in vars_per_level[l] :
251 | if level not in varnames_per_graph_level:
252 | varnames_per_graph_level[level] = []
253 | varnames_per_graph_level[level].append(var)
254 | level+=1
255 | return varnames_per_graph_level
256 |
257 | def _is_var_in_space(self, var_name, space):
258 | """
259 | Returns True if variable is in space and False if not.
260 |
261 | Parameters:
262 | --------
263 | var_name A String of the model's variables name.
264 | space A String in {'prior','posterior','posterior_predictive','prior_predictive'}.
265 | Returns:
266 | --------
267 | A Boolean
268 | """
269 | header = self._inferencedata['header.json']
270 | header_js = json.loads(header)
271 | if var_name in header_js['inference_data'][space]['array_names']:
272 | return True
273 | else:
274 | return False
275 |
276 | def _get_observed_nodes(self):
277 | """
278 | Get the observed nodes of the graph.
279 |
280 | Parameters:
281 | --------
282 | graph A Dictionary
283 | Returns:
284 | --------
285 | A List of the model's observed nodes of the graph (dict objects)
286 | """
287 | try:
288 | return [v for _,v in self._graph.items() if v['type'] == 'observed']
289 | except KeyError:
290 | print("Graph has no key 'type'")
291 | return None
292 |
293 | def _get_graph_nodes(self, varnames):
294 | """
295 | Get the nodes of the graph indicated by a list of varnames.
296 |
297 | Parameters:
298 | --------
299 | graph A Dictionary
300 | varnames A List of Strings
301 | Returns:
302 | --------
303 | A List of the model's nodes of the graph (Dictionary objects)
304 | """
305 | nodes = []
306 | for vn in varnames:
307 | if vn in self._graph:
308 | nodes.append(self._graph[vn])
309 | return nodes
310 |
311 | @staticmethod
312 | def _reverse_nodes_levels(nodes):
313 | """
314 | Reverses the nodes' levels so that level 0 corresponds
315 | to the highest grid row.
316 |
317 | Parameters:
318 | --------
319 | nodes A Dictionary
320 | Returns:
321 | --------
322 | A Dictionary of the nodes with reversed keys
323 | """
324 | max_level = max(nodes.keys())
325 | return dict((max_level-k, v) for k, v in nodes.items())
326 |
327 | def _add_nodes_to_graph(self, level_nodes, level):
328 | """
329 | Adds nodes to graph levels recursively.
330 |
331 | Parameters:
332 | --------
333 | level_nodes A List of nodes (dict) of the same graph .
334 | level An Int denoting the level of the graph
335 | where belong to.
336 | Returns:
337 | --------
338 | A Dict of the model's parameters names per graph level.
339 | Level 0 corresponds to the lowest level.
340 | {: List of variables names}, level=0,1,2...
341 | """
342 | nodes = {}
343 | try:
344 | for v in level_nodes:
345 | if level in nodes and v['name'] not in nodes[level]:
346 | nodes[level].append(v['name'])
347 | else:
348 | nodes[level]=[v['name']]
349 |
350 | parents_nodes = self._get_graph_nodes(v['parents'])
351 | if(len(parents_nodes)):
352 | # if incl_nodes == 'all':
353 | # par_nodes = self._add_nodes_to_graph(parents_nodes, level+1, 'all')
354 | # else:
355 | # parents_nodes_to_go_ahead = []
356 | # rest_nodes = []
357 | # for node in incl_nodes:
358 | # if node in parents_nodes:
359 | # parents_nodes_to_go_ahead.append(node)
360 | # else:
361 | # rest_nodes.append(node)
362 | par_nodes = self._add_nodes_to_graph(parents_nodes, level+1)
363 | for k in par_nodes:
364 | if k in nodes:
365 | for n in par_nodes[k]:
366 | if n not in nodes[k]:
367 | nodes[k].append(n)
368 | else:
369 | nodes[k] = par_nodes[k]
370 | return nodes
371 | except KeyError:
372 | print("Graph node has no key 'name' or 'parents' ")
373 | return None
374 |
--------------------------------------------------------------------------------
/ipme/classes/data/dimension.py:
--------------------------------------------------------------------------------
1 | class Dimension:
2 | def __init__(self, name, label='', range=(None,None), step=0, unit='', values=[]):
3 | self.name=name
4 | if label!='' and label is not None:
5 | self.label=label
6 | else:
7 | self.label=name
8 | self.range=range
9 | self.step=step
10 | self.unit=unit
11 | self.values=values
--------------------------------------------------------------------------------
/ipme/classes/graph.py:
--------------------------------------------------------------------------------
1 | from .data.data import Data
2 | from .grid.graph_grid import GraphGrid
3 | from .grid.predictive_ckecks_grid import PredictiveChecksGrid
4 | from .interaction_control.interaction_control import IC
5 |
6 | import panel as pn
7 |
8 | class Graph():
9 | def __init__(self, data_path, mode = "i", vars = 'all', spaces = 'all', predictive_checks = []):
10 | """
11 | Parameters:
12 | --------
13 | data_path A String of the zip file with the inference data.
14 | mode A String in {'i','s'}: defines the type of diagram
15 | (interactive or static).
16 | vars A List of variables to be presented in the graph
17 | spaces A List of spaces to be included in graph
18 | predictive_checks A List of observed variables to plot predictive checks.
19 | Sets:
20 | --------
21 | _mode A String in {"i","s"}, "i":interactive, "s":static.
22 | _graph A Panel component object to visualize model's graoh.
23 | """
24 | self.ic = IC(Data(data_path))
25 | if mode not in ["s","i"]:
26 | raise ValueError("ValueError: mode should take a value in {'i','s'}")
27 | self._mode = mode
28 | self._vars = vars
29 | self._spaces = spaces
30 | self._pred_checks = predictive_checks
31 | self._graph_grid = self._create_graph_grid()
32 | self._predictive_checks_grid = self._create_pred_checks_grid()
33 | self._graph = self._create_graph()
34 |
35 | def _create_graph_grid(self):
36 | """
37 | Creates a GraphGrid object representing the model as a
38 | collection of Panel grids (one per space) and a
39 | collection of plotted widges.
40 | """
41 | return GraphGrid(self.ic, self._mode, self._vars, self._spaces)
42 |
43 | def _create_pred_checks_grid(self):
44 | """
45 | Creates a PredictiveChecks object representing the model's
46 | predictive checks for min, max, mean, std of predictions as a
47 | collection of Panel grids (one per space) and a
48 | collection of plotted widges.
49 | """
50 | return PredictiveChecksGrid(self.ic, self._mode, self._spaces, self._pred_checks)
51 |
52 | def _create_graph(self):
53 | """
54 | Creates one Tab per space (posterior, prior) presenting the graph.
55 | """
56 | tabs = pn.Tabs(sizing_mode='stretch_both')#sizing_mode='stretch_both'
57 | ## Tabs for prior-posterior graph
58 | g_grids = self._graph_grid.get_grids()
59 | g_plotted_widgets = self._graph_grid.get_plotted_widgets()
60 | for space in g_grids:
61 | g_col = pn.Column(g_grids[space])
62 | if space in g_plotted_widgets:
63 | widgetBox = pn.WidgetBox(*list(g_plotted_widgets[space].values()),sizing_mode = 'scale_both')
64 | w_col = pn.Column(widgetBox, width_policy='max', max_width=300, width=250)
65 | tabs.append((space, pn.Row(w_col, g_col)))#, height_policy='max', max_height=800
66 | else:
67 | tabs.append((space, pn.Row(g_col)))
68 | ## Tabs for predictive checks
69 | if self._pred_checks:
70 | pc_grids = self._predictive_checks_grid.get_grids()
71 | for var in pc_grids:
72 | for space in pc_grids[var]:
73 | g_col = pn.Column(pc_grids[var][space])
74 | tabs.append((var+'_'+space+'_predictive_checks', pn.Row(g_col)))
75 | #tabs.append((space+'_predictive_checks', pn.Row(c.get_plot(space,add_info=False), sizing_mode='stretch_both')))
76 | return tabs
77 |
78 | def set_coordinates(self, dim, options, value):
79 | self.ic.set_coordinates(self._graph_grid, dim, options, value)
80 | # try:
81 | # if coord_name in self.cells_widgets:
82 | # # space_widgets = self.cells_widgets[coord_name]
83 | # for space in self.cells_widgets[coord_name]:
84 | # c_id_list = self.cells_widgets[coord_name][space]
85 | # for c_id in c_id_list:
86 | # w = self.cells[c_id].get_widget(space, coord_name)
87 | # # old_v = w.value
88 | # w.value = new_val
89 | # w.trigger('value', new_val, new_val)
90 | # except IndexError:
91 | # raise IndexError()
92 |
93 | def get_selection_interactions(self):
94 | return self.ic.get_selection_interactions()
95 |
96 | def get_widgets_interactions(self):
97 | return self.ic.get_widgets_interactions()
98 |
99 | def get_graph(self):
100 | return self._graph
101 |
102 | def get_graph_grid(self):
103 | return self._graph_grid
104 |
105 | def get_pred_checks_grid(self):
106 | return self._create_pred_checks_grid
107 |
--------------------------------------------------------------------------------
/ipme/classes/grid/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/grid/__init__.py
--------------------------------------------------------------------------------
/ipme/classes/grid/graph_grid.py:
--------------------------------------------------------------------------------
1 | from ...interfaces.grid import Grid
2 | from ipme.classes.cell.interactive_continuous_cell import InteractiveContinuousCell
3 | from ipme.classes.cell.interactive_discrete_cell import InteractiveDiscreteCell
4 | from ipme.classes.cell.static_continuous_cell import StaticContinuousCell
5 | from ipme.classes.cell.static_discrete_cell import StaticDiscreteCell
6 |
7 | from ...utils.constants import MAX_NUM_OF_COLS_PER_ROW, MAX_NUM_OF_VARS_PER_ROW, COLS_PER_VAR
8 | import panel as pn
9 |
10 | class GraphGrid(Grid):
11 | def _create_grids(self):
12 | """
13 | Creates one Cell object per variable. Cell object is the smallest
14 | visualization unit in the grid. Moreover, it creates one Panel GridSpec
15 | object per space.
16 |
17 | Sets:
18 | --------
19 | _cells A Dict {:Cell object}.
20 | _grids A Dict of pn.GridSpec objects:
21 | {:pn.GridSpec}
22 | """
23 | graph_grid_map = self._create_graph_grid_mapping()
24 | for row, map_data in graph_grid_map.items():
25 | level = map_data[0]
26 | vars_list = map_data[1]
27 | level_previous = -1
28 | if (row-1) in graph_grid_map:
29 | level_previous = graph_grid_map[row-1][0]
30 | if level != level_previous:
31 | col = int((MAX_NUM_OF_COLS_PER_ROW - len(vars_list)*COLS_PER_VAR) / 2.)
32 | else:
33 | col = int((MAX_NUM_OF_COLS_PER_ROW - MAX_NUM_OF_VARS_PER_ROW*COLS_PER_VAR) / 2.)
34 | for i,var_name in enumerate(vars_list):
35 | start_point = ( row, int(col + i*COLS_PER_VAR) )
36 | end_point = ( row+1, int(col + (i+1)*COLS_PER_VAR) )
37 | #col_l = int(col_f + (i+1)*COLS_PER_VAR)
38 | # grid_bgrd_col = level
39 | if self._mode == "i":
40 | if self._data.get_var_dist_type(var_name) == "Continuous":
41 | c = InteractiveContinuousCell(var_name, self.ic)
42 | else:
43 | c = InteractiveDiscreteCell(var_name, self.ic)
44 | elif self._mode == "s":
45 | if self._data.get_var_dist_type(var_name) == "Continuous":
46 | c = StaticContinuousCell(var_name, self.ic)
47 | else:
48 | c = StaticDiscreteCell(var_name, self.ic)
49 | self.cells[var_name] = c
50 | ##Add to grid
51 | cell_spaces = c.get_spaces()
52 | for space in cell_spaces:
53 | # if space in self._spaces_to_included and space not in self.spaces:
54 | # self.spaces.append(space)
55 | if space in self.spaces or self.spaces == 'all':
56 | if space not in self._grids:
57 | self._grids[space] = pn.GridSpec(sizing_mode = 'stretch_both')
58 | self._grids[space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = pn.Column(c.get_plot(space, add_info = True), width=220, height=220)
59 | self.ic.num_cells = len(self.cells)
60 |
61 | def _create_graph_grid_mapping(self):
62 | """
63 | Maps the graph levels and the variables to Panel GridSpec rows/cols.
64 | Both =0 and =0 correspond to higher row/level.
65 |
66 | Returns:
67 | --------
68 | A Dict {: (, List of varnames) }
69 | """
70 | _varnames_per_graph_level = self._data.get_varnames_per_graph_level(self._vars)
71 | graph_grid_map = {}
72 | grid_level = 0
73 | for graph_level in sorted(_varnames_per_graph_level):
74 | num_vars = len(_varnames_per_graph_level[graph_level])
75 | row = grid_level
76 | indx = 0
77 | while num_vars > MAX_NUM_OF_VARS_PER_ROW:
78 | while row in graph_grid_map:
79 | row+=1
80 | graph_grid_map[row] = (grid_level,_varnames_per_graph_level[graph_level][indx:indx+MAX_NUM_OF_VARS_PER_ROW])
81 | row += 1
82 | indx += MAX_NUM_OF_VARS_PER_ROW
83 | num_vars -= MAX_NUM_OF_VARS_PER_ROW
84 | while row in graph_grid_map:
85 | row+=1
86 | graph_grid_map[row] = (graph_level,_varnames_per_graph_level[graph_level][indx:indx+num_vars])
87 | return graph_grid_map
88 |
--------------------------------------------------------------------------------
/ipme/classes/grid/predictive_ckecks_grid.py:
--------------------------------------------------------------------------------
1 | from ...interfaces.grid import Grid
2 | from ipme.classes.cell.interactive_pred_ckeck_cell import InteractivePredCheckCell
3 | from ...utils.constants import MAX_NUM_OF_COLS_PER_ROW, COLS_PER_VAR
4 | import panel as pn
5 |
6 | class PredictiveChecksGrid(Grid):
7 | def __init__(self, control, mode, spaces, predictive_ckecks = []):
8 | """
9 | Parameters:
10 | --------
11 | data_obj A Data object.
12 | mode A String in {"i","s"}, "i":interactive, "s":static.
13 | predictive_ckecks A List of observed variables to plot predictive checks.
14 | Sets:
15 | --------
16 | _data A Data object.
17 | mode A String in {"i","s"}, "i":interactive, "s":static.
18 | _grids A Dict of pn.GridSpec objects:
19 | {:{:pn.GridSpec}}
20 | cells A Dict {:Cell object},
21 | where pred_check in {'min','max','mean','std'}.
22 | cells_widgets A Dict dict1 of the form (key1,value1) = (, dict2)
23 | dict2 of the form (key1,value1) = (, List of tuples (,)
24 | of the widgets with same name).
25 | plotted_widgets A List of widget objects to be plotted.
26 | """
27 | self._pred_checks = predictive_ckecks
28 | Grid.__init__(self, control, mode, spaces = spaces)
29 |
30 | def _create_grids(self):
31 | """
32 | Creates a 2x2 grid of the prior and posterior predictive checks
33 | for min, max, mean and std function.
34 |
35 | Sets:
36 | --------
37 | _grids A Dict of pn.GridSpec objects:
38 | {:{:pn.GridSpec}}
39 | """
40 | for var in self._pred_checks:
41 | if self._data.is_observed_variable(var):
42 | # if self._mode == "i":
43 | c_min = InteractivePredCheckCell(var, self.ic, "min")
44 | c_max = InteractivePredCheckCell(var, self.ic, "max")
45 | c_mean = InteractivePredCheckCell(var,self.ic, "mean")
46 | c_std = InteractivePredCheckCell(var, self.ic, "std")
47 | self.cells['min'] = c_min
48 | self.cells['max'] = c_max
49 | self.cells['mean'] = c_mean
50 | self.cells['std'] = c_std
51 | ##Add to grid
52 | cell_spaces = c_min.get_spaces()
53 | self._grids[var] = {}
54 | for space in cell_spaces:
55 | if space in self.spaces or self.spaces == 'all':
56 | if space not in self._grids[var]:
57 | self._grids[var][space] = pn.GridSpec(sizing_mode='stretch_both')
58 | for row in [0,1]:
59 | for i in [0,1]:
60 | col = int((MAX_NUM_OF_COLS_PER_ROW - 2.*COLS_PER_VAR) / 2.)
61 | start_point = ( row, int(col + i*COLS_PER_VAR) )
62 | end_point = ( row+1, int(col + (i+1)*COLS_PER_VAR) )
63 | if row == 0 and i == 0:
64 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \
65 | pn.Column(c_min.get_plot(space), width=220, height=220)
66 | elif row == 0 and i == 1:
67 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \
68 | pn.Column(c_max.get_plot(space), width=220, height=220)
69 | elif row == 1 and i == 0:
70 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \
71 | pn.Column(c_mean.get_plot(space), width=220, height=220)
72 | elif row == 1 and i == 1:
73 | self._grids[var][space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = \
74 | pn.Column(c_std.get_plot(space), width=220, height=220)
75 | else:
76 | raise ValueError("Declared predive check variable {} is not an observed variable".format(var))
77 |
--------------------------------------------------------------------------------
/ipme/classes/grid/scatter_matrix_grid.py:
--------------------------------------------------------------------------------
1 | from ...interfaces.grid import Grid
2 | from ipme.classes.cell.interactive_scatter_cell import InteractiveScatterCell
3 | from ipme.classes.cell.static_scatter_cell import StaticScatterCell
4 | from ipme.classes.cell.interactive_continuous_cell import InteractiveContinuousCell
5 | from ipme.classes.cell.interactive_discrete_cell import InteractiveDiscreteCell
6 | from ipme.classes.cell.static_continuous_cell import StaticContinuousCell
7 | from ipme.classes.cell.static_discrete_cell import StaticDiscreteCell
8 |
9 | from ...utils.constants import COLS_PER_VAR
10 | import panel as pn
11 |
12 | class ScatterMatrixGrid(Grid):
13 | def _create_grids(self):
14 | """
15 | Creates one Cell object per variable. Cell object is the smallest
16 | visualization unit in the grid. Moreover, it creates one Panel GridSpec
17 | object per space.
18 |
19 | Sets:
20 | --------
21 | _cells A Dict {:Cell object}.
22 | _grids A Dict of pn.GridSpec objects:
23 | {:pn.GridSpec}
24 | """
25 | for row in range(len(self._vars)):
26 | for col in range(len(self._vars)):
27 | if col > row:
28 | break
29 | start_point = ( row, int(col*COLS_PER_VAR) )
30 | end_point = ( row+1, int((col+1)*COLS_PER_VAR) )
31 | var =""
32 | if col == row:
33 | ##plot VariableCell
34 | var_name = self._vars[col]
35 | if self._mode == "i":
36 | if self._data.get_var_dist_type(var_name) == "Continuous":
37 | c = InteractiveContinuousCell(var_name, self.ic)
38 | else:
39 | c = InteractiveDiscreteCell(var_name, self.ic)
40 | elif self._mode == "s":
41 | if self._data.get_var_dist_type(var_name) == "Continuous":
42 | c = StaticContinuousCell(var_name, self.ic)
43 | else:
44 | c = StaticDiscreteCell(var_name, self.ic)
45 | var = var_name
46 | else:
47 | ##plot pair scatter
48 | var1 = self._vars[row]
49 | var2 = self._vars[col]
50 | if self._mode == "i":
51 | c = InteractiveScatterCell([var1, var2], self.ic)
52 | elif self._mode == "s":
53 | c = StaticScatterCell([var1, var2], self.ic)
54 | var = var1+"_"+var2
55 | self.cells[var] = c
56 | ##Add to grid
57 | cell_spaces = c.get_spaces()
58 | for space in cell_spaces:
59 | # if space not in self.spaces:
60 | # self.spaces.append(space)
61 | if space in self.spaces or self.spaces == 'all':
62 | if space not in self._grids:
63 | self._grids[space] = pn.GridSpec(sizing_mode = 'stretch_both')
64 | self._grids[space][ start_point[0]:end_point[0], start_point[1]:end_point[1] ] = pn.Column(c.get_plot(space), width=220, height=220)
65 |
--------------------------------------------------------------------------------
/ipme/classes/interaction_control/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/classes/interaction_control/__init__.py
--------------------------------------------------------------------------------
/ipme/classes/scatter_matrix.py:
--------------------------------------------------------------------------------
1 | from .data.data import Data
2 | from .grid.scatter_matrix_grid import ScatterMatrixGrid
3 | from .interaction_control.interaction_control import IC
4 |
5 | import panel as pn
6 |
7 | class ScatterMatrix():
8 | def __init__(self, data_path, mode = "i", vars = [], spaces = 'all'):
9 | """
10 | Parameters:
11 | --------
12 | data_path A String of the zip file with the inference data.
13 | mode A String in {'i','s'}: defines the type of diagram
14 | (interactive or static).
15 | vars A List of model variables to be included in the plot.
16 | spaces A List of spaces to be included in graph
17 | Sets:
18 | --------
19 | _mode A String in {"i","s"}, "i":interactive, "s":static.
20 | _scatter_matrix A Panel component object to visualize model's scatter matrix.
21 | """
22 | self.ic = IC(Data(data_path))
23 | if mode not in ["s","i"]:
24 | raise ValueError("ValueError: mode should take a value in {'i','s'}")
25 | self._mode = mode
26 | self._vars = vars
27 | self._spaces = spaces
28 | self._scatter_matrix_grid = self._create_scatter_matrix_grid()
29 | self._scatter_matrix = self._create_scatter_matrix()
30 |
31 | def _create_scatter_matrix_grid(self):
32 | """
33 | Creates a ScatterMatrixGrid object representing the model as a
34 | collection of Panel grids (one per space) and a
35 | collection of plotted widges.
36 | """
37 | return ScatterMatrixGrid(self.ic, self._mode, self._vars, self._spaces)
38 |
39 | def _create_scatter_matrix(self):
40 | """
41 | Creates one Tab per space (posterior, prior) presenting the scatter matrix.
42 | """
43 | tabs = pn.Tabs(sizing_mode='stretch_both')#sizing_mode='stretch_both'
44 | ## Tabs for prior-posterior scatter matrix
45 | g_grids = self._scatter_matrix_grid.get_grids()
46 | g_plotted_widgets = self._scatter_matrix_grid.get_plotted_widgets()
47 | for space in g_grids:
48 | g_col = pn.Column(g_grids[space])
49 | if space in g_plotted_widgets:
50 | widgetBox = pn.WidgetBox(*list(g_plotted_widgets[space].values()),sizing_mode = 'scale_both')
51 | w_col = pn.Column(widgetBox, width_policy='max', max_width=300, width=250)
52 | tabs.append((space, pn.Row(w_col, g_col)))#, height_policy='max', max_height=800
53 | else:
54 | tabs.append((space, pn.Row(g_col)))
55 | return tabs
56 |
57 | def set_coordinates(self, dim, options, value):
58 | self.ic.set_coordinates(self._scatter_matrix_grid, dim, options, value)
59 |
60 | def set_coordinate(self, dim, value):
61 | self.ic.set_coordinate(self._scatter_matrix_grid, dim, value)
62 |
63 | def get_selection_interactions(self):
64 | return self.ic.get_selection_interactions()
65 |
66 | def get_widgets_interactions(self):
67 | return self.ic.get_widgets_interactions()
68 |
69 | def get_selection_ranges(self):
70 | """
71 | Returns List of tuples (xmin,xmax) of selection box
72 | """
73 | return self.ic.get_selection_ranges()
74 |
75 | def get_scatter_matrix(self):
76 | return self._scatter_matrix
77 |
78 | def get_scatter_matrix_grid(self):
79 | return self._scatter_matrix_grid
80 |
81 | # def get_pred_checks_grid(self):
82 | # return self._create_pred_checks_grid
83 |
--------------------------------------------------------------------------------
/ipme/cli.py:
--------------------------------------------------------------------------------
1 | """Console script for imd."""
2 | import sys
3 | import click
4 |
5 |
6 | @click.command()
7 | def main(args=None):
8 | """Console script for imd."""
9 | click.echo("Replace this message by putting your code into "
10 | "imd.cli.main")
11 | click.echo("See click documentation at https://click.palletsprojects.com/")
12 | return 0
13 |
14 |
15 | if __name__ == "__main__":
16 | sys.exit(main()) # pragma: no cover
17 |
--------------------------------------------------------------------------------
/ipme/interfaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/ipme/interfaces/__init__.py
--------------------------------------------------------------------------------
/ipme/interfaces/cell.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from ipme.classes.cell.utils.cell_widgets import CellWidgets
4 | from bokeh.io.export import get_screenshot_as_png
5 |
6 | class Cell(ABC):
7 | def __init__(self, vars, control):
8 | """
9 | Each cell will occupy a certain number of grid columns and will lie on a certain grid row.
10 | Parameters:
11 | --------
12 | vars A List of variableNames of the model.
13 | control An IC object
14 | Sets:
15 | --------
16 | vars
17 | ic
18 | spaces A List of Strings in {"prior","posterior"}.
19 | idx_dims A Dict {:{:Dimension obj}}.
20 | cur_idx_dims_values A Dict {:{: Current value of }}.
21 |
22 | plot A Dict {: (bokeh) plot object}.
23 | widgets A Dict {: {: A (bokeh) widget object} }.
24 | """
25 | self.vars = vars
26 | self.ic = control
27 | self._data = control.data
28 | self.spaces = self._define_spaces()
29 |
30 | #idx_dims-related variables
31 | self.idx_dims = self._data.get_idx_dimensions(self.vars)
32 | self.cur_idx_dims_values = {}
33 |
34 | self.plot = {}
35 | self.widgets = {}
36 | self._initialize_widgets()
37 | self._initialize_plot()
38 |
39 | def _define_spaces(self):
40 | data_spaces = self._data.get_spaces()
41 | spaces = []
42 | if "prior" in data_spaces:
43 | spaces.append("prior")
44 | if "posterior" in data_spaces:
45 | spaces.append("posterior")
46 | return spaces
47 |
48 | def _initialize_widgets(self):
49 | CellWidgets.initialize_widgets(self)
50 |
51 | @abstractmethod
52 | def widget_callback(self, attr, old, new, w_title, space):
53 | pass
54 |
55 | @abstractmethod
56 | def _initialize_plot(self):
57 | pass
58 |
59 | ## GETTERS
60 | def get_widgets(self):
61 | return self.widgets
62 |
63 | def get_widgets_in_space(self, space):
64 | if space in self.widgets:
65 | return self.widgets[space]
66 | else:
67 | return []
68 |
69 | def get_widget(self, space, id):
70 | try:
71 | return self.widgets[space][id]
72 | except IndexError:
73 | return None
74 |
75 | def get_plot(self, space):
76 | if space in self.plot:
77 | return self.plot[space]
78 | else:
79 | return None
80 |
81 | def get_screenshot(self, space):
82 | if space in self.plot:
83 | return get_screenshot_as_png(self.plot[space])
84 | else:
85 | return None
86 |
87 | def get_spaces(self):
88 | return self.spaces
89 |
--------------------------------------------------------------------------------
/ipme/interfaces/data_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | class Data_Interface(ABC):
4 |
5 | def __init__(self, inference_path):
6 | """
7 | Parameters:
8 | --------
9 | inference_path A String of the inference data file path.
10 | Sets:
11 | --------
12 | _inferencedata A structure of the inference data.
13 | _graph A structure of the model's variables graph data.
14 | _observed_variables A List of Strings of the observed variables.
15 | _idx_dimensions A Dict {:{:Dimension obj}}.
16 | _spaces A List of Strings in {'prior', 'posterior'} of all
17 | the available MCMC sample spaces in the inference data
18 | """
19 | self._inferencedata = self._load_inference_data(inference_path)
20 | self._graph = self._get_graph()
21 | self._observed_variables = self._get_obeserved_variables()
22 | self.all_variables = []
23 | self._idx_dimensions = self._get_idx_dimensions()
24 | self._spaces = self._get_spaces_from_data()
25 |
26 | @abstractmethod
27 | def _load_inference_data(self,datapath):
28 | pass
29 |
30 | @abstractmethod
31 | def _get_graph(self):
32 | pass
33 |
34 | @abstractmethod
35 | def _get_obeserved_variables(self):
36 | pass
37 |
38 | @abstractmethod
39 | def _get_idx_dimensions(self):
40 | pass
41 |
42 | @abstractmethod
43 | def _get_spaces_from_data(self):
44 | pass
45 |
46 | @abstractmethod
47 | def is_observed_variable(self,var_name):
48 | pass
49 |
50 | @abstractmethod
51 | def get_var_type(self,var_name):
52 | pass
53 |
54 | @abstractmethod
55 | def get_var_dist(self,var_name):
56 | pass
57 |
58 | @abstractmethod
59 | def get_var_dist_type(self,var_name):
60 | pass
61 |
62 | @abstractmethod
63 | def get_var_parents(self,var_name):
64 | pass
65 |
66 | @abstractmethod
67 | def get_samples(self,var_name,space=['prior','posterior']):
68 | pass
69 |
70 | @abstractmethod
71 | def get_range(self, var_name, space=['prior','posterior']):
72 | pass
73 |
74 | @abstractmethod
75 | def get_varnames_per_graph_level(self):
76 | pass
77 |
78 | def get_data(self):
79 | return self._data
80 |
81 | def get_obeserved_variables(self):
82 | return self._observed_variables
83 |
84 | def get_inferencedata(self):
85 | return self._inferencedata
86 |
87 | def get_idx_dimensions(self, var_names):
88 | idx_dims = {}
89 | for var in var_names:
90 | if var in self._idx_dimensions:
91 | idx_dims[var] = {}
92 | for dim in self._idx_dimensions[var]:
93 | idx_dims[var][dim] = self._idx_dimensions[var][dim]
94 | return idx_dims
95 |
96 | def get_indx_for_idx_dim(self, var_name, d_name, d_value):
97 | indx=-1
98 | if var_name in self._idx_dimensions and d_name in self._idx_dimensions[var_name]:
99 | dvalues = self._idx_dimensions[var_name][d_name].values
100 | if d_value in dvalues:
101 | indx = dvalues.index(d_value)
102 | return indx
103 |
104 | def get_spaces(self):
105 | return self._spaces
106 |
107 |
--------------------------------------------------------------------------------
/ipme/interfaces/grid.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from ..classes.cell.utils.cell_widgets import CellWidgets
3 |
4 | class Grid(ABC):
5 | def __init__(self, control, mode, vars = 'all', spaces = 'all'):
6 | """
7 | Parameters:
8 | --------
9 | control A IC object.
10 | mode A String in {"i","s"}, "i":interactive, "s":static.
11 | vars A List of variables to be presented in the graph
12 | spaces A List of spaces to be included in graph
13 | Sets:
14 | --------
15 | ic
16 | _data
17 | _mode
18 | _grids A Dict of pn.GridSpec objects
19 | Either {:{:pn.GridSpec}}
20 | Or {:pn.GridSpec}
21 | cells A Dict either {:Cell object} or {:Cell object},
22 | where pred_check in {'min','max','mean','std'}.
23 | spaces A List of available spaces in this grid. Elements in {'prior','posterior'}
24 | cells_widgets A Dict dict1 of the form (key1,value1) = (, dict2)
25 | dict2 of the form (key1,value1) = (, List of ).
26 | plotted_widgets A Dict of the form {: List of widget objects to be plotted} .
27 | """
28 | self.ic = control
29 | self._data = control.data
30 | self._mode = mode
31 | self._vars = vars
32 | # self._spaces_to_included = spaces
33 | self._grids = {}
34 | self.cells = {}
35 | self.spaces = spaces
36 | self._create_grids()
37 |
38 | self.cells_widgets = {}
39 | self.plotted_widgets = {}
40 | self._add_widgets()
41 |
42 | @abstractmethod
43 | def _create_grids(self):
44 | pass
45 |
46 | def _add_widgets(self):
47 | CellWidgets.link_cells_widgets(self)
48 | if self._mode == "i":
49 | CellWidgets.set_plotted_widgets_interactive(self)
50 | else:
51 | CellWidgets.set_plotted_widgets_static(self)
52 |
53 | ## GETTERS
54 | def get_grids(self):
55 | return self._grids
56 |
57 | def get_plotted_widgets(self):
58 | return self.plotted_widgets
59 |
60 | def get_cells(self):
61 | return self.cells
62 |
63 |
--------------------------------------------------------------------------------
/ipme/interfaces/predictive_check_cell.py:
--------------------------------------------------------------------------------
1 | from ..interfaces.cell import Cell
2 |
3 | import threading
4 | from abc import abstractmethod
5 |
6 | class PredictiveCheckCell(Cell):
7 | def __init__(self, name, control):
8 | """
9 | Parameters:
10 | --------
11 | name A String within the set {""}.
12 | control An InteractiveControl object
13 | Sets:
14 | --------
15 | _func
16 | _source
17 | _reconstructed
18 | _samples
19 | _seg
20 |
21 | """
22 | self.name = name
23 | self.source = {}
24 | self.reconstructed = {}
25 | self.samples = {}
26 | self.seg = {}
27 | self.pvalue = {}
28 | self.pvalue_rec = {}
29 | Cell.__init__(self, [name], control)
30 |
31 | ## DATA
32 | def get_samples_for_cur_idx_dims_values(self, space):
33 | """
34 | Returns the observed data and predictive samples of the observed variable
35 | in space .
36 |
37 | Returns:
38 | --------
39 | A Tuple (data,samples): data-> observed data and samples-> predictive samples.
40 | """
41 | data = self.ic.data.get_samples(self.name, 'observed_data')
42 | if self.ic.data.get_var_type(self.name) == "observed":
43 | if space == "posterior" and "posterior_predictive" in self.ic.data.get_spaces():
44 | space="posterior_predictive"
45 | elif space == "prior" and "prior_predictive" in self.ic.data.get_spaces():
46 | space="prior_predictive"
47 | samples = self.ic.data.get_samples(self.name, space)
48 | return data, samples
49 |
50 | ## INITIALIZATIONS
51 | def _initialize_plot(self):
52 | for space in self.spaces:
53 | self.initialize_cds(space)
54 | self.initialize_fig(space)
55 | self.initialize_glyphs(space)
56 |
57 | @abstractmethod
58 | def initialize_fig(self, space):
59 | pass
60 |
61 | @abstractmethod
62 | def initialize_cds(self, space):
63 | pass
64 |
65 | @abstractmethod
66 | def initialize_glyphs(self, space):
67 | pass
68 |
69 | ## WIDGETS
70 | @abstractmethod
71 | def widget_callback(self, attr, old, new, w_title, space):
72 | pass
73 |
74 | ## UPDATE
75 | @abstractmethod
76 | def update_cds(self, space):
77 | pass
78 |
79 | def sample_inds_callback(self, space, attr, old, new):
80 | """
81 | Updates cds when indices of selected samples -- Cell._sample_inds--
82 | are updated.
83 | """
84 | self.ic.add_selection_threads(space, threading.Thread(target = self._sample_inds_thread, args = (space,), daemon = True))
85 | self.ic.sel_lock_event.set()
86 |
87 | def _sample_inds_thread(self, space):
88 | self.update_cds(space)
89 |
--------------------------------------------------------------------------------
/ipme/interfaces/scatter_cell.py:
--------------------------------------------------------------------------------
1 | from ..interfaces.cell import Cell
2 | from ..utils.stats import find_x_range
3 |
4 | import numpy as np
5 | import threading
6 | from abc import abstractmethod
7 |
8 | class ScatterCell(Cell):
9 | def __init__(self, vars, control):
10 | """
11 | Parameters:
12 | --------
13 | vars A List of variableNames of the model.
14 | control A Control object
15 | Sets:
16 | -----
17 | x_range Figures axes x_range
18 | """
19 | self.samples = {}
20 | self.contours = {}
21 | self._all_samples = {}
22 | self.x_range = {}
23 | Cell.__init__(self, vars, control)
24 |
25 | ## DATA
26 | def get_samples(self, space):
27 | """
28 | Retrieves MCMC samples of into a numpy.ndarray and
29 | sets an entry into self._all_samples Dict.
30 | """
31 | for var in self.vars:
32 | space_gsam = space
33 | if self._data.get_var_type(var) == "observed":
34 | if space == "posterior" and "posterior_predictive" in self._data.get_spaces():
35 | space_gsam = "posterior_predictive"
36 | elif space == "prior" and "prior_predictive" in self._data.get_spaces():
37 | space_gsam = "prior_predictive"
38 | if var not in self._all_samples:
39 | self._all_samples[var] = {}
40 | self._all_samples[var][space] = self._data.get_samples(var, space_gsam).T
41 | # compute x_range
42 | self.x_range[var] = {}
43 | self.x_range[var][space] = find_x_range(self._all_samples[var][space])
44 | # self.x_range[var][space] = find_x_range(self.get_samples_for_cur_idx_dims_values(var, space))
45 |
46 | def get_samples_for_cur_idx_dims_values(self, var_name, space):
47 | """
48 | Returns a numpy.ndarray of the MCMC samples of the
49 | parameter for current index dimensions values.
50 |
51 | Returns:
52 | --------
53 | A numpy.ndarray.
54 | """
55 | if var_name in self._all_samples:
56 | data = self._all_samples[var_name]
57 | if space in data:
58 | data = data[space]
59 | else:
60 | raise ValueError("cel {}-{}: space {} not in self._all_samples[{}].keys() {}".format(self.vars[0],self.vars[1],space,var_name,data.keys()))
61 | else:
62 | raise ValueError("var_name {} not in self._all_samples.keys() {}".format(var_name, self._all_samples.keys()))
63 | if var_name in self.cur_idx_dims_values:
64 | for _, dim_value in self.cur_idx_dims_values[var_name].items():
65 | data = data[dim_value]
66 | return np.squeeze(data).T
67 |
68 | ## INITIALIZATION
69 | def _initialize_plot(self):
70 | for space in self.spaces:
71 | self.get_samples(space)
72 | self.initialize_cds(space)
73 | self.initialize_fig(space)
74 | self.initialize_glyphs(space)
75 |
76 | @abstractmethod
77 | def initialize_fig(self, space):
78 | pass
79 |
80 | @abstractmethod
81 | def initialize_cds(self, space):
82 | pass
83 |
84 | @abstractmethod
85 | def initialize_glyphs(self, space):
86 | pass
87 |
88 | ## WIDGETS
89 | @abstractmethod
90 | def widget_callback(self, attr, old, new, w_title, space):
91 | pass
92 |
93 | ## UPDATE
94 | @abstractmethod
95 | def update_cds(self, space):
96 | pass
97 |
98 | def sample_inds_callback(self, space, attr, old, new):
99 | """
100 | Updates cds when indices of selected samples -- Cell._sample_inds--
101 | are updated.
102 | """
103 | self.ic.add_selection_threads(space, threading.Thread(target = self._sample_inds_thread, args = (space,), daemon = True))
104 | self.ic.sel_lock_event.set()
105 |
106 | def _sample_inds_thread(self, space):
107 | self.update_cds(space)
108 |
109 |
--------------------------------------------------------------------------------
/ipme/interfaces/variable_cell.py:
--------------------------------------------------------------------------------
1 | from ..interfaces.cell import Cell
2 | from ..utils.stats import find_x_range
3 | from ..utils.constants import BORDER_COLORS
4 |
5 | from bokeh.models import Toggle, Div
6 | from bokeh.layouts import layout
7 | from bokeh.io.export import get_screenshot_as_png
8 |
9 | import numpy as np
10 | import threading
11 | from abc import abstractmethod
12 |
13 | class VariableCell(Cell):
14 | def __init__(self, name, control):
15 | """
16 | Parameters:
17 | --------
18 | name A String within the set {""}.
19 | control A Control object
20 | Sets:
21 | -----
22 | x_range Figures axes x_range
23 | _toggle A Dict {: (bokeh) toggle button for visibility of figure}.
24 | _div A Dict {: (bokeh) div parameter-related information}.
25 | """
26 | self.name = name
27 | self.source = {}
28 | self.samples = {}
29 | self.data = {}
30 | self._all_samples = {}
31 | self._all_data = {}
32 | self.x_range = {}
33 | Cell.__init__(self, [name], control)
34 | self._toggle = {}
35 | self._div = {}
36 | self._initialize_toggle_div()
37 |
38 | ## DATA
39 | def get_samples(self, space):
40 | """
41 | Retrieves MCMC samples of into a numpy.ndarray and
42 | sets an entry into self._all_samples Dict.
43 | """
44 | for var in self.vars:
45 | space_gsam = space
46 | if self._data.get_var_type(var) == "observed":
47 | if space == "posterior" and "posterior_predictive" in self._data.get_spaces():
48 | space_gsam = "posterior_predictive"
49 | elif space == "prior" and "prior_predictive" in self._data.get_spaces():
50 | space_gsam = "prior_predictive"
51 | if var not in self._all_samples:
52 | self._all_samples[var] = {}
53 | self._all_samples[var][space] = self._data.get_samples(var, space_gsam).T
54 | # get observed data
55 | data = self._data.get_observations(var)
56 | if data is not None:
57 | self._all_data[var] = data
58 | # compute x_range
59 | self.x_range[var] = {}
60 | self.x_range[var][space] = find_x_range(self._all_samples[var][space])
61 | # self.x_range[var][space] = find_x_range(self.get_samples_for_cur_idx_dims_values(var, space))
62 |
63 | def get_samples_for_cur_idx_dims_values(self, var_name, space):
64 | """
65 | Returns a numpy.ndarray of the MCMC samples of the
66 | parameter for current index dimensions values.
67 |
68 | Returns:
69 | --------
70 | A numpy.ndarray.
71 | """
72 | if var_name in self._all_samples:
73 | data = self._all_samples[var_name]
74 | if space in data:
75 | data = data[space]
76 | else:
77 | raise ValueError("cel {}-{}: space {} not in self._all_samples[{}].keys() {}".format(self.vars[0],self.vars[1],space,var_name,data.keys()))
78 | else:
79 | raise ValueError("var_name {} not in self._all_samples.keys() {}".format(var_name, self._all_samples.keys()))
80 | if var_name in self.cur_idx_dims_values:
81 | for _, dim_value in self.cur_idx_dims_values[var_name].items():
82 | data = data[dim_value]
83 | if data.shape == (1,)*len(data.shape):
84 | return data.flatten()
85 | else:
86 | return np.squeeze(data).T
87 |
88 | def get_data_for_cur_idx_dims_values(self, var_name):
89 | """
90 | Returns a numpy.ndarray of the observations of the
91 | parameter for current index dimensions values.
92 |
93 | Returns:
94 | --------
95 | A numpy.ndarray.
96 | """
97 | if var_name in self._all_data:
98 | data = self._all_data[var_name]
99 | else:
100 | return None
101 | if var_name in self.cur_idx_dims_values:
102 | for _, dim_value in self.cur_idx_dims_values[var_name].items():
103 | data = data[dim_value]
104 | if data.shape == (1,)*len(data.shape):
105 | return data.flatten()
106 | else:
107 | return np.squeeze(data).T
108 |
109 | ## INITIALIZATION
110 | def _initialize_plot(self):
111 | for space in self.spaces:
112 | self.get_samples(space)
113 | self.initialize_cds(space)
114 | self.initialize_fig(space)
115 | self.initialize_glyphs(space)
116 |
117 | @abstractmethod
118 | def initialize_fig(self, space):
119 | pass
120 |
121 | @abstractmethod
122 | def initialize_cds(self, space):
123 | pass
124 |
125 | @abstractmethod
126 | def initialize_glyphs(self, space):
127 | pass
128 |
129 | ## WIDGETS
130 | @abstractmethod
131 | def widget_callback(self, attr, old, new, w_title, space):
132 | pass
133 |
134 | ## UPDATE
135 | @abstractmethod
136 | def update_cds(self, space):
137 | pass
138 |
139 | def sample_inds_callback(self, space, attr, old, new):
140 | """
141 | Updates cds when indices of selected samples -- Cell._sample_inds--
142 | are updated.
143 | """
144 | self.ic.add_selection_threads(space, threading.Thread(target = self._sample_inds_thread, args = (space,), daemon = True))
145 | self.ic.sel_lock_event.set()
146 |
147 | def _sample_inds_thread(self, space):
148 | self.update_cds(space)
149 |
150 | def compute_intersection_of_samples(self, space):
151 | """
152 | Computes intersection of sample points based on user's
153 | restrictions per parameter.
154 | """
155 | sel_var_inds = self.ic.get_sel_var_inds(space = space)
156 | sp_keys = list(sel_var_inds.keys())
157 | inds_list = np.full((len(self.samples[space].data['x']),), False)
158 | if len(sp_keys)>1:
159 | sets = []
160 | for var in sp_keys:
161 | sets.append(set(sel_var_inds[var]))
162 | inds_set = set.intersection(*sorted(sets, key = len))
163 | inds_list[list(inds_set)] = True
164 | elif len(sp_keys) == 1:
165 | inds_list[sel_var_inds[sp_keys[0]]] = True
166 | non_inds_list = list(~inds_list)
167 | self.ic.set_sample_inds(space, dict(inds = list(inds_list)), dict(non_inds = non_inds_list))
168 |
169 | def _initialize_toggle_div(self):
170 | """"
171 | Creates the toggle headers of each variable node.
172 | """
173 | for space in self.spaces:
174 | width = self.plot[space].plot_width
175 | height = 40
176 | sizing_mode = self.plot[space].sizing_mode
177 | label = self.name + " ~ " + self._data.get_var_dist(self.name)
178 | text = """parents: %s
dims: %s"""%(self._data.get_var_parents(self.name), list(self._data.get_idx_dimensions(self.name)))
179 | if sizing_mode == 'fixed':
180 | self._toggle[space] = Toggle(label = label, active = False,
181 | width = width, height = height, sizing_mode = sizing_mode, margin = (0,0,0,0))
182 | self._div[space] = Div(text = text,
183 | width = width, height = height, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] )
184 | elif sizing_mode == 'scale_width' or sizing_mode == 'stretch_width':
185 | self._toggle[space] = Toggle(label = label, active = False,
186 | height = height, sizing_mode = sizing_mode, margin = (0,0,0,0))
187 | self._div[space] = Div(text = text,
188 | height = height, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] )
189 | elif sizing_mode == 'scale_height' or sizing_mode == 'stretch_height':
190 | self._toggle[space] = Toggle(label = label, active = False,
191 | width = width, sizing_mode = sizing_mode, margin = (0,0,0,0))
192 | self._div[space] = Div(text = text,
193 | width = width, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] )
194 | else:
195 | self._toggle[space] = Toggle(label = label, active = False,
196 | sizing_mode = sizing_mode, margin = (0,0,0,0))
197 | self._div[space] = Div(text = text, sizing_mode = sizing_mode, margin = (0,0,0,0), background = BORDER_COLORS[0] )
198 | self._toggle[space].js_link('active', self.plot[space], 'visible')
199 |
200 |
201 | def get_max_prob(self, space):
202 | """
203 | Gets highest point --max probability-- of cds
204 | """
205 | max_sv = -1
206 | max_rv = -1
207 | if self.source[space].data['y'].size:
208 | max_sv = self.source[space].data['y'].max()
209 | if hasattr(self,'reconstructed') and self.reconstructed[space].data['y'].size:
210 | max_rv = self.reconstructed[space].data['y'].max()
211 | max_v = max([max_sv,max_rv])
212 | return max_v if max_v!=-1 else None
213 |
214 | def get_plot(self, space, add_info = False):
215 | if space in self.plot:
216 | if add_info and space in self._toggle and space in self._div:
217 | return layout([self._toggle[space]], [self._div[space]], [self.plot[space]])
218 | else:
219 | return self.plot[space]
220 | else:
221 | return None
222 |
223 | def get_screenshot(self, space, add_info=False):
224 | if space in self.plot:
225 | if add_info and space in self._toggle and space in self._div:
226 | return get_screenshot_as_png(layout([self._toggle[space]], [self._div[space]], [self.plot[space]]))
227 | else:
228 | return get_screenshot_as_png(self.plot[space])
229 | else:
230 | return None
231 |
--------------------------------------------------------------------------------
/ipme/methods.py:
--------------------------------------------------------------------------------
1 | from .classes.graph import Graph
2 | from .classes.scatter_matrix import ScatterMatrix
3 |
4 | def graph(data_path, mode = "i", vars = 'all', spaces = 'all', predictive_checks = []):
5 | graph = Graph(data_path, mode, vars, spaces, predictive_checks)
6 | graph.get_graph().show()
7 |
8 | def scatter_matrix(data_path, mode = "i", vars = [], spaces = 'all'):
9 | scatter_matrix = ScatterMatrix(data_path, mode, vars, spaces)
10 | scatter_matrix.get_scatter_matrix().show()
--------------------------------------------------------------------------------
/ipme/utils/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = """Evdoxia Taka"""
2 | __email__ = 'e.taka.1@research.gla.ac.uk'
3 | __version__ = '0.1.0'
--------------------------------------------------------------------------------
/ipme/utils/constants.py:
--------------------------------------------------------------------------------
1 | """ Cell Interface
2 |
3 | Colors-related constants:
4 |
5 | Same color palette used in arviz-darkgrid style:
6 | https://arviz-devs.github.io/arviz/examples/matplotlib/mpl_styles.html
7 |
8 | Color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
9 | """
10 | # colors
11 | COLORS = ['#2a2eec', '#fa7c17', '#328c06', '#c10c90', '#933708', '#65e5f3', '#e6e135', '#1ccd6a', '#bd8ad5', '#b16b57']
12 | BORDER_COLORS=['#d8d8d8','#FFFFFF']
13 |
14 | ## Plot sizing-related constants
15 | PLOT_WIDTH = 220
16 | PLOT_HEIGHT = 220
17 | SIZING_MODE = "fixed"
18 |
19 | ## Rug plot
20 | RUG_DIST_RATIO = 5.0
21 | RUG_SIZE = 10 #in screen units
22 |
23 | ## Data plot
24 | DATA_DIST_RATIO = 3.0
25 | DATA_SIZE = 6 #in screen units
26 |
27 | """" Grid Interface
28 |
29 | """
30 | ##
31 | MAX_NUM_OF_COLS_PER_ROW = 12
32 | COLS_PER_VAR = 2
33 | MAX_NUM_OF_VARS_PER_ROW = 5
--------------------------------------------------------------------------------
/ipme/utils/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import gcd, ceil
3 |
4 | def lcm(list_of_int):
5 | """
6 | Get the Least Common Multyply (lcm) of a list of integer numbers.
7 |
8 | Parameters:
9 | --------
10 | list_of_int A List of integers
11 | Returns:
12 | --------
13 | An integer (the lcm of list_of_int).
14 | """
15 | try:
16 | lcm = list_of_int[0]
17 | for i in list_of_int[1:]:
18 | lcm = lcm*i/gcd(int(lcm), i)
19 | return int(lcm)
20 | except IndexError:
21 | return None
22 |
23 | def find_indices(lst, condition, xmin=0, xmax=0):
24 | # return [i for i, elem in enumerate(lst) if condition(elem)]
25 | # return [ _ for _ in itertools.compress(list(range(0,len(lst))), map(condition,lst)) ]
26 | return list(np.where((lst>=xmin) & (lst<=xmax))[0])
27 |
28 | def find_inds_before_after(lst, el):
29 | inds_sm=find_indices(lst, lambda e: e<= el)
30 | if len(inds_sm):
31 | ind_before=inds_sm[-1]
32 | else:
33 | ind_before=-1
34 | inds_bi=find_indices(lst, lambda e: e>= el)
35 | if len(inds_bi):
36 | ind_after=inds_bi[0]
37 | else:
38 | ind_after=-1
39 | return (ind_before,ind_after)
40 |
41 | def find_highest_point( x, y):
42 | x=np.asarray(x, dtype=np.float64)
43 | y=np.asarray(y, dtype=np.float64)
44 | if len(y):
45 | max_idx=np.argmax(y)
46 | return (x[max_idx],y[max_idx])
47 | else:
48 | return ()
49 |
50 | def get_samples_for_pred_check(samples, func):
51 | # samples = np.asarray(samples)
52 | shape = samples.shape
53 | if 0 not in shape:
54 | if len(shape) == 1:
55 | axis = 0
56 | else:
57 | axis = 1
58 | if len(shape) > 2:
59 | fir_dim = shape[0]
60 | sec_dim = 1
61 | for i in np.arange(1,len(shape),1):
62 | sec_dim = sec_dim*shape[i]
63 | samples=samples.reshape(fir_dim,sec_dim)
64 | if func == "min":
65 | samples = samples.min(axis=axis)
66 | elif func == "max":
67 | samples = samples.max(axis=axis)
68 | elif func == "mean":
69 | samples = samples.mean(axis=axis)
70 | elif func == "std":
71 | samples = samples.std(axis=axis)
72 | else:
73 | samples = np.empty([1, 2])
74 | if ~np.isfinite(samples).all():
75 | samples = get_finite_samples(samples)
76 | else:
77 | samples = np.empty([1, 2])
78 | return samples
79 |
80 | def get_finite_samples(np_array):
81 | if isinstance(np_array, np.ndarray):
82 | shape = len(np_array.shape)
83 | if shape == 1:
84 | np_array = np_array[np.isfinite(np_array)]
85 | elif shape > 1:
86 | samples_idx = np.isfinite(np_array).all(axis=shape-1)
87 | for axis in np.arange(shape-2,0,-1):
88 | samples_idx = np.isfinite(np_array).all(axis=axis)
89 | np_array = np_array[samples_idx]
90 | return np_array
91 |
92 | def get_hist_bins_range(samples, func, var_type, ref_length = None, ref_values=None):
93 | """
94 | Parameters:
95 | --------
96 | samples Flatten finite samples
97 | func Predictive check criterion {'min','max','mean','std'}
98 | var_type Variable type in {'Discrete','Continuous'}
99 | ref_length A reference length for bin to estimate the number of bins
100 | ref_values A numpy.ndarray with the unique values of a Discrete variable
101 | """
102 | if func == 'min' or func == 'max' and var_type == "Discrete":
103 | if ref_values is not None:
104 | if len(ref_values)<20:
105 | min_v = ref_values.min()
106 | max_v = ref_values.max()
107 | bins = len(ref_values)
108 | if bins > 1:
109 | range = ( min_v, max_v + (max_v - min_v) / (bins - 1))
110 | else:
111 | range = ( min_v, min_v+1)
112 | return (bins, range)
113 | else:
114 | values = np.unique(samples)
115 | if len(values) < 20:
116 | min_v = values.min()
117 | max_v = values.max()
118 | bins = len(values)
119 | if bins > 1:
120 | range = ( min_v, max_v + (max_v - min_v) / (bins - 1))
121 | else:
122 | range = ( min_v, min_v+1)
123 | return (bins, range)
124 | range = (samples.min(),samples.max())
125 | if ref_length:
126 | bins = ceil((range[1] - range[0]) / ref_length)
127 | range = (range[0], range[0] + bins*ref_length)
128 | else:
129 | bins = 20
130 | return (bins, range)
131 |
132 | def get_dim_names_options(dim):
133 | """
134 | dim: imd.Dimension object
135 | """
136 | name1 = dim.name
137 | name2 = None
138 | options1 = dim.values
139 | options2 = []
140 | if "_idx_" in name1:
141 | idx = name1.find("_idx_")
142 | st_n1 = idx + 5
143 | end_n1 = len(name1)
144 | name2 = name1[st_n1:end_n1]
145 | name1 = name1[0:idx]
146 | values = np.array(dim.values)
147 | options1 = np.unique(values).tolist()
148 | if len(options1):
149 | tmp = np.arange(np.count_nonzero(values == options1[0]))
150 | options2 = list(map(str,tmp))
151 | return (name1, name2, options1, options2)
152 |
153 | def get_w2_w1_val_mapping(dim):
154 | """
155 | dim: imd.Dimension object
156 | Returns:
157 | -------
158 | A Dict {: A List of for this }
159 | """
160 | options1 = dim.values
161 | values = np.array(dim.values)
162 | options1 = np.unique(values)
163 | val_dict = {}
164 | if len(options1):
165 | for v1 in options1:
166 | tmp = np.arange(np.count_nonzero(values == v1))
167 | val_dict[v1] = list(map(str,tmp))
168 | return val_dict
169 |
170 | def get_stratum_range(samples, stratum):
171 | median = np.median(samples)
172 | if stratum == 0 or stratum == 1:
173 | inds_l = np.where(samples=median)[0]
183 | median_h = np.median(samples[inds_h])
184 | if stratum == 2:
185 | xmin = median
186 | xmax = median_h
187 | elif stratum == 3:
188 | xmin = median_h
189 | xmax = np.max(samples).item()
190 | else:
191 | xmin = np.min(samples).item()
192 | xmax = np.max(samples).item()
193 | return (xmin,xmax)
--------------------------------------------------------------------------------
/ipme/utils/js_code.py:
--------------------------------------------------------------------------------
1 | HOVER_CODE="""
2 | const data = {'x': [], 'y': [], 'isIn': []}
3 | data['x']=source.data.x
4 | data['y']=source.data.y
5 | for (var i = 0; i 1:
113 | new_shape = tuple([-1] + list(x_shape))
114 | y = y.reshape(new_shape)
115 |
116 | hpd_ = az.hpd(y, credible_interval=credible_interval, circular=False, multimodal=False)
117 |
118 | if smooth:
119 | x_data = np.linspace(x.min(), x.max(), 200)
120 | x_data[0] = (x_data[0] + x_data[1]) / 2
121 | hpd_interp = griddata(x, hpd_, x_data)
122 | y_data = savgol_filter(hpd_interp, axis=0, window_length=55, polyorder=2)
123 | return (np.concatenate((x_data, x_data[::-1])),np.concatenate((y_data[:, 0], y_data[:, 1][::-1])))
124 | else:
125 | return (np.concatenate((x, x[::-1])),np.concatenate((hpd_[:,0], hpd_[:, 1][::-1])))
126 |
127 | def hpd(y, credible_interval=0.94):
128 | y = np.asarray(y)
129 | y_shape = y.shape
130 | if 0 in y_shape:
131 | return (np.array([]),np.array([]))
132 | hpd_ = az.hpd(y, credible_interval=credible_interval, circular=False, multimodal=False)
133 | if hpd_.ndim == 1:
134 | hpd_ = np.expand_dims(hpd_, axis=0)
135 | return (hpd_[:,0],hpd_[:,1])
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evdoxiataka/ipme/f3398596c6af547908f39683eb1830a6bc081482/requirements_dev.txt
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.1.0
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:setup.py]
7 | search = version='{current_version}'
8 | replace = version='{new_version}'
9 |
10 | [bumpversion:file:ipme/__init__.py]
11 | search = __version__ = '{current_version}'
12 | replace = __version__ = '{new_version}'
13 |
14 | [bdist_wheel]
15 | universal = 1
16 |
17 | [flake8]
18 | exclude = docs
19 |
20 | [aliases]
21 | # Define setup.py command aliases here
22 |
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """The setup script."""
4 |
5 | from setuptools import setup, find_packages,find_namespace_packages
6 |
7 | # with open('README.md') as readme_file:
8 | # readme = readme_file.read()
9 |
10 | with open('HISTORY.rst') as history_file:
11 | history = history_file.read()
12 |
13 | requirements = ['Click>=7.0', ]
14 |
15 | setup_requirements = [ ]
16 |
17 | test_requirements = [ ]
18 |
19 | setup(
20 | author="Evdoxia Taka",
21 | author_email='e.taka.1@research.gla.ac.uk',
22 | python_requires='>=3.5',
23 | classifiers=[
24 | 'Development Status :: 2 - Pre-Alpha',
25 | 'Intended Audience :: Developers',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Natural Language :: English',
28 | 'Programming Language :: Python :: 3',
29 | 'Programming Language :: Python :: 3.5',
30 | 'Programming Language :: Python :: 3.6',
31 | 'Programming Language :: Python :: 3.7',
32 | 'Programming Language :: Python :: 3.8',
33 | ],
34 | description="Interactive probabilistic models explorer is an interactive tool for visualizating and exploring Bayesian probabilistic programming models and inference data.",
35 | entry_points={
36 | 'console_scripts': [
37 | 'ipme=ipme.cli:main',
38 | ],
39 | },
40 | install_requires=requirements,
41 | license="MIT license",
42 | # long_description=readme + '\n\n' + history,
43 | include_package_data=True,
44 | keywords='Bayesian probabilistic modeling, Bayesian inference, Markov Chain Monte Carlo, interactive visualization, uncertainty visualization, interpetability',
45 | name='ipme',
46 | packages=find_packages(include=['ipme','ipme*']),
47 | setup_requires=setup_requirements,
48 | test_suite='tests',
49 | tests_require=test_requirements,
50 | url='https://github.com/evdoxiataka/ipme',
51 | version='0.1.0',
52 | zip_safe=False,
53 | )
54 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py35, py36, py37, py38, flake8
3 |
4 | [travis]
5 | python =
6 | 3.8: py38
7 | 3.7: py37
8 | 3.6: py36
9 | 3.5: py35
10 |
11 | [testenv:flake8]
12 | basepython = python
13 | deps = flake8
14 | commands = flake8 ipme tests
15 |
16 | [testenv]
17 | setenv =
18 | PYTHONPATH = {toxinidir}
19 |
20 | commands = python setup.py test
21 |
--------------------------------------------------------------------------------