├── .gitignore ├── LICENSE ├── README.md ├── causallab ├── __init__.py ├── discovery.py ├── experiment.py ├── plot.py ├── serve.py ├── theme.py ├── utils.py └── view.py ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | .idea 3 | 4 | ### Python template 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | dt_output/ 134 | log/ 135 | trial_store/ 136 | tmp/ 137 | catboost_info/ 138 | 139 | #dispatchers 140 | logs/ 141 | workdir/ 142 | dask-worker-space/ 143 | 144 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | CausalLab is an Interactive Causal Analysis Tool. 4 | 5 | 6 | ## Installation 7 | 8 | The latest [YLearn](https://github.com/DataCanvasIO/YLearn) is required to run CausalLab, so install it from the latest source code before installing CausalLab: 9 | 10 | ```console 11 | pip install "torch<2.0.0" "pyro-ppl<1.8.5" gcastle 12 | pip install git+https://github.com/DataCanvasIO/YLearn.git 13 | ``` 14 | 15 | Now, one can install CausalLab from the source: 16 | 17 | ```console 18 | git clone https://github.com/DataCanvasIO/CausalLab 19 | cd CausalLab 20 | pip install . 21 | ``` 22 | 23 | ## Startup 24 | 25 | Run `causal_lab` to startup CausalLab http server on localhost with default port(5006): 26 | 27 | ```console 28 | causal_lab 29 | ``` 30 | 31 | 32 | To accept request from other computers, specify local `host_ip` and `port` to startup CausalLab http server: 33 | 34 | ```console 35 | causal_lab --address --port --allow-websocket-origin=: 36 | ``` 37 | 38 | eg: 39 | 40 | ```console 41 | causal_lab --address 172.20.51.203 --port 15006 --allow-websocket-origin=172.20.51.203:15006 42 | ``` 43 | 44 | 45 | ## License 46 | See the [LICENSE](LICENSE) file for license rights and limitations (Apache-2.0). 47 | -------------------------------------------------------------------------------- /causallab/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /causallab/discovery.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from ylearn.bayesian import _base 7 | from ylearn.bayesian._dag import DiGraph 8 | 9 | 10 | def _alg_notears(data): 11 | from ylearn.causal_discovery._discovery import CausalDiscovery 12 | cd = CausalDiscovery(hidden_layer_dim=[data.shape[1], ]) 13 | return cd(data) 14 | 15 | 16 | def _alg_pc_stable(data): 17 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 18 | cd = GCastleProxy(learner='PC', variant='stable') 19 | return cd(data) 20 | 21 | 22 | def _alg_pc_original(data): 23 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 24 | cd = GCastleProxy(learner='PC', variant='original') 25 | return cd(data) 26 | 27 | 28 | def _alg_ges_bdeu(data): 29 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 30 | cd = GCastleProxy(learner='GES', criterion='bdeu') 31 | return cd(data) 32 | 33 | 34 | def _alg_ges_bic(data): 35 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 36 | cd = GCastleProxy(learner='GES', criterion='bic') 37 | return cd(data) 38 | 39 | 40 | def _alg_icalingam(data): 41 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 42 | cd = GCastleProxy(learner='ICALiNGAM') 43 | return cd(data) 44 | 45 | 46 | def _alg_mcsl(data): 47 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 48 | cd = GCastleProxy(learner='MCSL') 49 | return cd(data) 50 | 51 | 52 | def _alg_grandag(data): 53 | from ylearn.causal_discovery._proxy_gcastle import GCastleProxy 54 | cd = GCastleProxy(learner='GraNDAG', input_dim=data.shape[1]) 55 | return cd(data) 56 | 57 | 58 | discoverers = { 59 | 'PC(Stable)': _alg_pc_stable, 60 | 'PC(Original)': _alg_pc_original, 61 | 'GES(bedu)': _alg_ges_bdeu, 62 | 'GES(bid)': _alg_ges_bic, 63 | 'ICALiNGAM': _alg_icalingam, 64 | # 'MCSL': _alg_mcsl, 65 | # 'GraNDAG': _alg_grandag, 66 | 'NoTears': _alg_notears, 67 | } 68 | 69 | 70 | class CausationHolder: 71 | """ 72 | Parameters 73 | ---------- 74 | node_states: dict, key is node name, value is node state 75 | 76 | Attributes 77 | ---------- 78 | position: dict, key is node name, value is node position tuple (x,y) 79 | threshold: int 80 | matrices: matrix found by discovery algorithms 81 | enabled: list of tuple(cause,effect) 82 | disabled: list of tuple(cause,effect) 83 | """ 84 | 85 | def __init__(self, node_states): 86 | assert isinstance(node_states, dict) 87 | 88 | self.node_states = node_states 89 | self.position = {} 90 | self.threshold = 1 91 | self.matrices = {} 92 | self.enabled = [] 93 | self.disabled = [] 94 | 95 | def reset(self): 96 | self.threshold = 1 97 | self.matrices = {} 98 | self.enabled = set() 99 | self.disabled = set() 100 | 101 | def add_matrix(self, name, matrix): 102 | assert isinstance(name, str) 103 | assert isinstance(matrix, pd.DataFrame) 104 | assert set(matrix.columns.tolist()) == set(matrix.index.tolist()) 105 | assert set(matrix.columns.tolist()).issubset(set(self.node_states.keys())) 106 | 107 | self.matrices[name] = matrix 108 | 109 | def disable(self, cause, effect): 110 | assert cause in set(self.node_states.keys()) 111 | assert effect in set(self.node_states.keys()) 112 | 113 | if (cause, effect) in self.enabled: 114 | self.enabled.remove((cause, effect)) 115 | self.disabled.append((cause, effect)) 116 | 117 | def enable(self, cause, effect): 118 | assert cause in set(self.node_states.keys()) 119 | assert effect in set(self.node_states.keys()) 120 | 121 | if (cause, effect) in self.disabled: 122 | self.disabled.remove((cause, effect)) 123 | self.enabled.append((cause, effect)) 124 | 125 | def remove_disabled(self, cause, effect): 126 | assert cause in set(self.node_states.keys()) 127 | assert effect in set(self.node_states.keys()) 128 | 129 | if (cause, effect) in self.disabled: 130 | self.disabled.remove((cause, effect)) 131 | 132 | def remove_enabled(self, cause, effect): 133 | assert cause in set(self.node_states.keys()) 134 | assert effect in set(self.node_states.keys()) 135 | 136 | if (cause, effect) in self.enabled: 137 | self.enabled.remove((cause, effect)) 138 | 139 | @property 140 | def is_empty(self): 141 | """ 142 | Weather cause-effect does not exist. 143 | """ 144 | return len(self.matrices) == 0 and len(self.enabled) == 0 145 | 146 | @property 147 | def causal_matrix(self): 148 | """ 149 | get the final causal matrix 150 | """ 151 | matrix = None 152 | if len(self.matrices) > 0: 153 | for m in self.matrices.values(): 154 | if matrix is None: 155 | matrix = m 156 | else: 157 | matrix = matrix + m 158 | if self.threshold is not None and self.threshold > 0: 159 | values = np.where(matrix.values >= self.threshold, matrix.values, 0) 160 | matrix = pd.DataFrame(values, columns=matrix.columns, index=matrix.index) 161 | for c, e in self.disabled: 162 | matrix[e][c] = 0 163 | else: 164 | nodes = self.node_states.keys() 165 | matrix = pd.DataFrame(np.zeros((len(nodes), len(nodes)), dtype='int'), 166 | columns=nodes, index=nodes) 167 | 168 | for c, e in self.enabled: 169 | if matrix[e][c] < 1: 170 | matrix[e][c] = 1 171 | 172 | return matrix.copy() 173 | 174 | @property 175 | def graph(self): 176 | """ 177 | return the DiGraph object. 178 | 179 | edge attributes: 180 | ----------------- 181 | weight: int, the number of algorithms found the relation 182 | expert: int, 1: enabled by expert, 0: enabled by discovery algorithms 183 | """ 184 | 185 | def node_attribute(n): 186 | shape = 'box' if isinstance(self.node_states[n], _base.CategoryNodeState) else 'ellipse' 187 | attr = dict(shape=shape) 188 | 189 | if self.position is not None and n in self.position.keys(): 190 | x, y = self.position[n] 191 | attr['x'] = x 192 | attr['y'] = y 193 | return attr 194 | 195 | m = self.causal_matrix 196 | columns = m.columns.tolist() 197 | nodes = [(n, node_attribute(n)) for n in columns] 198 | edges = [(c, e, 199 | dict(weight=m[e][c], 200 | expert=int((c, e) in self.enabled), 201 | ) 202 | ) 203 | for c, e in product(columns, columns) 204 | if c != e and m[e][c] > 0 205 | ] # list of tuple(start, end, edge_data) 206 | 207 | g = DiGraph(None) 208 | g.add_nodes_from(nodes) 209 | g.add_edges_from(edges) 210 | 211 | return g 212 | -------------------------------------------------------------------------------- /causallab/experiment.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | import pickle 4 | from io import BytesIO 5 | 6 | from ylearn.bayesian import DataLoader 7 | from ylearn.bayesian import _base 8 | from ylearn.sklearn_ex import DataCleaner 9 | from causallab.discovery import CausationHolder 10 | 11 | 12 | class BNExperiment(_base.BObject): 13 | """ 14 | Causal Lab Experiment settings 15 | """ 16 | 17 | def __init__(self, train_data, test_data, causation, bn): 18 | if train_data is not None: 19 | train_data, _ = DataCleaner().fit_transform(train_data, y=None) 20 | if test_data is not None: 21 | test_data, _ = DataCleaner().fit_transform(test_data, y=None) 22 | 23 | if causation is None and train_data is not None: 24 | causation = CausationHolder(DataLoader.state_of(train_data)) 25 | 26 | self.train_data = train_data 27 | self.test_data = test_data 28 | self.causation = causation 29 | self.bn = bn 30 | 31 | @staticmethod 32 | def load(file_path): 33 | """ 34 | load experiment from file system 35 | """ 36 | with gzip.open(file_path, 'rb') as f: 37 | return pickle.load(f) 38 | 39 | def save(self, file_path): 40 | """ 41 | save experiment into file system, compress data with gzip. 42 | """ 43 | with gzip.open(file_path, 'wb')as f: 44 | pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) 45 | 46 | @staticmethod 47 | def decode(data): 48 | """ 49 | decode experiment from base64 str 50 | """ 51 | data = base64.b64decode(data) 52 | data = gzip.decompress(data) 53 | buf = BytesIO(data) 54 | obj = pickle.load(buf) 55 | return obj 56 | 57 | def encode(self): 58 | """ 59 | encode experiment with base64 60 | :return: encoded base64 str 61 | """ 62 | buf = BytesIO() 63 | pickle.dump(self, buf, protocol=pickle.HIGHEST_PROTOCOL) 64 | data = buf.getvalue() 65 | data = gzip.compress(data) 66 | data = base64.b64encode(data) 67 | return data 68 | -------------------------------------------------------------------------------- /causallab/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | from bokeh import models as M 5 | from bokeh.io import show, output_notebook 6 | from bokeh.plotting import column, row 7 | from bokeh.plotting import curdoc 8 | from bokeh.themes import Theme 9 | from sklearn.model_selection import train_test_split 10 | 11 | from ylearn.bayesian import _base 12 | from ylearn.utils import logging, is_notebook 13 | from . import utils 14 | from .experiment import BNExperiment 15 | from .theme import my_theme 16 | from .view import BNInterventionView, BNTrainingView, BNEffectView, BNPredictionView, BNPropertyView 17 | from .view import ViewItemNames, DataExplorationView, CausationView 18 | 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | class BNNotebookPlotter(_base.BObject): 23 | def __init__(self, bn): 24 | assert is_notebook(), f'Plot can only be displayed on notebook.' 25 | assert bn is not None 26 | 27 | self.bn = bn 28 | 29 | def plot(self, *, width=1200, height=800): 30 | output_notebook(hide_banner=True, verbose=False) 31 | view = BNPropertyView(bn=self.bn) 32 | layout = view.get_layout() 33 | layout.width = width 34 | layout.height = height 35 | layout.sizing_mode = 'fixed' 36 | return show(layout) 37 | 38 | 39 | class BNExperimentPlotter(_base.BObject): 40 | root_name = 'myroot' 41 | 42 | def __init__(self, *, data_file, test_file, experiment_file, work_dir): 43 | # assert data_file is not None or experiment_file is not None 44 | 45 | self.data_file = data_file 46 | self.test_file = test_file 47 | self.experiment_file = experiment_file 48 | self.work_dir = work_dir 49 | 50 | def plot(self): 51 | if self.experiment_file is not None and len(self.experiment_file) > 0: 52 | # load experiment 53 | exp = BNExperiment.load(self.experiment_file) 54 | elif self.data_file is not None and len(self.data_file) > 0: 55 | # create experiment with data_file and test_file 56 | df_train, df_test = None, None 57 | 58 | if self.data_file is not None: 59 | df_train = utils.load_data(self.data_file) 60 | if self.test_file is not None: 61 | df_test = utils.load_data(self.test_file) 62 | 63 | if df_train is not None and df_test is None: 64 | df_train, df_test = train_test_split(df_train, test_size=0.3, random_state=123) 65 | 66 | exp = BNExperiment( 67 | train_data=df_train, 68 | test_data=df_test, 69 | causation=None, 70 | bn=None, 71 | ) 72 | else: 73 | exp = None 74 | 75 | doc = curdoc() 76 | # print('>' * 30) 77 | # for k, v in doc.session_context.request.headers.items(): 78 | # print(k, ': ', v) 79 | # print('>' * 30) 80 | doc.theme = Theme(json=my_theme) 81 | doc.title = 'Causal Lab' 82 | 83 | if exp is None: 84 | layout = self.to_startup_layout() 85 | else: 86 | layout = self.to_experiment_layout(exp) 87 | 88 | myroot = column(layout, name=self.root_name, sizing_mode='stretch_both') 89 | doc.add_root(myroot) 90 | 91 | def to_experiment_layout(self, experiment): 92 | titles = ['Data', 'Discovery', 'Training', 'Causal Effect', #'Prediction', 'Intervention', 93 | ] 94 | views = [ 95 | DataExplorationView, # (train_data=train_data, test_data=test_data), 96 | CausationView, # (data=train_data, causation=causation), 97 | BNTrainingView, # (bn=bn, data=train_data, causation=causation), 98 | BNEffectView, # (bn=bn, data=test_data), 99 | # BNPredictionView, # (bn=bn, data=test_data), 100 | # BNInterventionView, # (bn=bn, data=train_data), 101 | # # BNPropertyView, # (bn=bn, data=test_data), 102 | ] 103 | 104 | place_holder = M.Div(text='
    loading ...') 105 | panels = [ 106 | M.TabPanel(title=title, child=column(place_holder, sizing_mode='stretch_both')) 107 | for title in titles] 108 | tabs = M.Tabs(tabs=panels, sizing_mode='stretch_both') 109 | last_view = None 110 | 111 | def create_view(i): 112 | nonlocal last_view 113 | 114 | if last_view is not None: 115 | if isinstance(last_view, CausationView): 116 | experiment.causation = last_view.causation 117 | elif isinstance(last_view, (BNTrainingView, BNInterventionView)): 118 | experiment.bn = last_view.bn 119 | print('switch bn to', experiment.bn) 120 | 121 | cls = views[i] 122 | if cls is DataExplorationView: 123 | view = DataExplorationView(train_data=experiment.train_data, test_data=experiment.test_data) 124 | elif cls is CausationView: 125 | view = CausationView(data=experiment.train_data, causation=experiment.causation) 126 | elif cls is BNTrainingView: 127 | view = BNTrainingView(bn=experiment.bn, data=experiment.train_data, causation=experiment.causation) 128 | elif cls is BNInterventionView: 129 | view = cls(bn=experiment.bn, data=experiment.train_data) 130 | elif cls is BNEffectView or cls is BNPredictionView: 131 | view = cls(bn=experiment.bn, data=experiment.test_data) 132 | elif cls is BNPropertyView: 133 | view = cls(bn=experiment.bn, data=experiment.test_data) 134 | else: 135 | raise ValueError(f'???{cls}') 136 | 137 | last_view = view 138 | return view 139 | 140 | def on_tab_active(attr, old_value, new_value): 141 | idx_active = new_value 142 | for i, panel in enumerate(panels): 143 | if i == idx_active: 144 | view = create_view(i) 145 | layout = view.get_layout() 146 | layout = self.add_graph_tool(layout, experiment) 147 | else: 148 | layout = place_holder 149 | panel.child.children = [layout] 150 | 151 | tabs.on_change('active', on_tab_active) 152 | 153 | # init settings 154 | on_tab_active('active', -1, 0) 155 | 156 | return tabs 157 | 158 | def add_graph_tool(self, layout, experiment): 159 | """ 160 | add a graph tool to save experiment 161 | """ 162 | fig = layout.select_one({'name': ViewItemNames.main_graph}) 163 | if fig is None: 164 | return layout 165 | 166 | ds_action = M.ColumnDataSource(data=dict( 167 | # x=[10, 10], 168 | # y=[0, 20], 169 | action=['save', 'download'], 170 | value=['', ''], 171 | )) 172 | js_cb_on_save = M.CustomJS( 173 | args=dict(ds=ds_action), 174 | code=""" 175 | ds.selected.indices = []; 176 | ds.selected.indices = [0]; // trigger python callback 177 | // console.log(ds.data); 178 | """.strip() 179 | ) 180 | js_cb_on_action = M.CustomJS( 181 | args=dict(ds=ds_action), 182 | code=""" 183 | const indices = ds.selected.indices; 184 | const data = ds.data; 185 | if( indices.length>0){ 186 | const idx = indices[0]; 187 | if( data["action"][idx] == "download" ){ 188 | const filename = data["value"][idx]; 189 | console.log("download", filename); 190 | const link = document.createElement('a'); 191 | link.href = "../download/" + filename; 192 | link.download = filename; 193 | link.target = '_blank'; 194 | link.style.visibility = 'hidden'; 195 | link.dispatchEvent(new MouseEvent('click')); 196 | } 197 | } 198 | """.strip() 199 | ) 200 | 201 | def on_action(attr, old_value, new_value): 202 | print('on_ds_flag_change', attr, old_value, new_value) 203 | assert isinstance(new_value, (list, tuple)) 204 | if len(new_value) == 0 or ds_action.data['action'][new_value[0]] != 'save': 205 | # print('skip') 206 | return 207 | 208 | tag = datetime.now().strftime('%Y%m%d%H%M') 209 | file_name = f'experiment_{tag}.pkl.gz' 210 | experiment.save(os.path.join(self.work_dir, file_name)) 211 | values = ds_action.data['value'].copy() 212 | values[1] = file_name 213 | ds_action.data['value'] = values 214 | ds_action.selected.indices = [1] # trigger js callback 215 | logger.info(f'download experiment as file {file_name}') 216 | 217 | ds_action.selected.on_change('indices', on_action) 218 | ds_action.selected.js_on_change('indices', js_cb_on_action) 219 | 220 | save_tool = M.CustomAction( 221 | icon='save', 222 | description='Save Experiment', 223 | callback=js_cb_on_save, 224 | ) 225 | 226 | fig.add_tools(save_tool) 227 | return layout 228 | 229 | def to_startup_layout(self): 230 | file_train_data = M.FileInput(accept=['.csv', '.parquet'], sizing_mode='stretch_width') 231 | file_test_data = M.FileInput(accept=['.csv', '.parquet'], sizing_mode='stretch_width') 232 | file_experiment = M.FileInput(accept='.pkl.gz', sizing_mode='stretch_width') 233 | 234 | btn_start = M.Button(label='Start', button_type="primary", align='center', ) 235 | 236 | div_train_data_msg = M.Div() 237 | div_test_data_msg = M.Div() 238 | div_experiment_msg = M.Div() 239 | div_msg = M.Div(text='', sizing_mode='stretch_width') 240 | 241 | widgets_open = [ 242 | M.Div(), 243 | M.Div(text='Experiment file:'), 244 | file_experiment, 245 | div_experiment_msg, 246 | ] 247 | widgets_new = [ 248 | M.Div(), 249 | M.Div(text='Train data:'), 250 | file_train_data, 251 | div_train_data_msg, 252 | M.Div(), 253 | M.Div(text='Test data:'), 254 | file_test_data, 255 | div_test_data_msg, 256 | ] 257 | panel_new = M.TabPanel(title='New', child=column(widgets_new)) 258 | panel_open = M.TabPanel(title='Open', child=column(widgets_open)) 259 | tab = M.Tabs(tabs=[panel_new, panel_open], width=800, sizing_mode='stretch_width') 260 | 261 | layout = row( 262 | M.Div(text='', sizing_mode='stretch_width'), 263 | column( 264 | M.Div(text='

Experiment

'), 265 | tab, 266 | div_msg, 267 | btn_start, 268 | width=400, 269 | ), 270 | M.Div(text='', sizing_mode='stretch_width'), 271 | sizing_mode='stretch_width') 272 | 273 | ctx = {} 274 | 275 | def on_change_train_data(attr, old_value, new_value): 276 | filename = file_train_data.filename 277 | value = file_train_data.value 278 | df = utils.load_b64data(value, filename) 279 | div_train_data_msg.text = f'shape:{df.shape}' 280 | div_msg.text = '' 281 | ctx['train_data'] = df 282 | 283 | def on_change_test_data(attr, old_value, new_value): 284 | filename = file_test_data.filename 285 | value = file_test_data.value 286 | df = utils.load_b64data(value, filename) 287 | div_test_data_msg.text = f'shape:{df.shape}' 288 | div_msg.text = '' 289 | ctx['test_data'] = df 290 | 291 | def on_change_experiment(attr, old_value, new_value): 292 | # filename = file_experiment.filename 293 | value = file_experiment.value 294 | exp = BNExperiment.decode(value) 295 | div_experiment_msg.text = f'loaded.' 296 | div_msg.text = '' 297 | ctx['experiment'] = exp 298 | 299 | def on_btn_start(): 300 | if tab.active == 0: # new 301 | if 'train_data' not in ctx.keys(): 302 | div_msg.text = 'Not found train_data' 303 | return 304 | train_data = ctx['train_data'] 305 | if 'test_data' in ctx.keys(): 306 | test_data = ctx['test_data'] 307 | else: 308 | train_data, test_data = train_test_split(train_data, test_size=0.2, random_state=123) 309 | exp = BNExperiment(train_data=train_data, test_data=test_data, causation=None, bn=None) 310 | else: # open 311 | if 'experiment' not in ctx.keys(): 312 | div_msg.text = 'Not found experiment' 313 | return 314 | exp = ctx['experiment'] 315 | 316 | doc = curdoc() 317 | myroot = doc.select_one({'name': self.root_name}) 318 | exp_layout = self.to_experiment_layout(exp) 319 | myroot.children = [exp_layout] 320 | 321 | file_train_data.on_change('filename', on_change_train_data) 322 | file_test_data.on_change('filename', on_change_test_data) 323 | file_experiment.on_change('value', on_change_experiment) 324 | btn_start.on_click(on_btn_start) 325 | 326 | return layout 327 | -------------------------------------------------------------------------------- /causallab/serve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from bokeh.command.subcommands.serve import Serve 6 | 7 | from ylearn import __version__ 8 | 9 | 10 | class MyServe(Serve): 11 | work_dir = None 12 | 13 | def customize_applications(self, args, applications): 14 | apps = super().customize_applications(args, applications) 15 | if isinstance(apps, dict) and len(apps) == 1: 16 | apps = {'/lab': next(iter(apps.values()))} 17 | return apps 18 | 19 | def customize_kwargs(self, args, server_kwargs): 20 | kwargs = super().customize_kwargs(args, server_kwargs) 21 | 22 | if self.work_dir is not None: 23 | from tornado.web import StaticFileHandler 24 | my_handlers = [ 25 | (r'/download/(.*)', StaticFileHandler, dict(path=self.work_dir)), 26 | ] 27 | extra_patterns = kwargs.get('extra_patterns', []) 28 | extra_patterns.extend(my_handlers) 29 | kwargs['extra_patterns'] = extra_patterns 30 | 31 | return kwargs 32 | 33 | 34 | def init_argparser(parser): 35 | parser.add_argument('--data', '-D', type=str, required=False, 36 | help='data file') 37 | parser.add_argument('--test', '-T', type=str, required=False, 38 | help='test data file, optional') 39 | parser.add_argument('--experiment', '-X', type=str, required=False, 40 | help='experiment file') 41 | parser.add_argument('--work-dir', '-W', type=str, default='~/.causallab/tmp', 42 | help='a directory for storing temporary files') 43 | 44 | 45 | def run_cleaner(path, file_pattern='*.tmp', interval=60, keep_duration=3600): 46 | from threading import Thread 47 | import time 48 | import glob 49 | import os 50 | from ylearn.utils import logging 51 | logger = logging.getLogger('run_cleaner') 52 | 53 | def clean(): 54 | last_at = 0 55 | while True: 56 | now = time.time() 57 | to_sleep = min(interval - abs(now - last_at), 1.0) 58 | if to_sleep > 0: 59 | time.sleep(to_sleep) 60 | continue 61 | 62 | if not os.path.exists(path): 63 | last_at = time.time() 64 | continue 65 | 66 | for f in glob.glob(os.path.join(path, file_pattern), recursive=False): 67 | assert f.startswith(path) 68 | if os.path.getmtime(f) + keep_duration < now: 69 | try: 70 | os.remove(f) 71 | logger.info(f'{f} removed') 72 | except: 73 | logger.warm(f'failed to remove file {f}') 74 | last_at = time.time() 75 | 76 | t = Thread(target=clean, daemon=True) 77 | t.start() 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser( 82 | prog=sys.argv[0], 83 | epilog='') 84 | 85 | parser.add_argument('-v', '--version', action='version', version=__version__) 86 | init_argparser(parser) 87 | 88 | arg_parsed, argv = parser.parse_known_args() 89 | # if arg_parsed.data is None and arg_parsed.test is None: 90 | # raise ValueError('--data or --test is required.') 91 | 92 | fixed_argv = sys.argv[1:] + [__file__, ] + [ 93 | '--args', 94 | ] 95 | if arg_parsed.data: 96 | fixed_argv.extend(['--data', arg_parsed.data]) 97 | if arg_parsed.test: 98 | fixed_argv.extend(['--test', arg_parsed.test]) 99 | if arg_parsed.experiment: 100 | fixed_argv.extend(['--experiment', arg_parsed.experiment]) 101 | if arg_parsed.work_dir: 102 | fixed_argv.extend(['--work-dir', arg_parsed.work_dir]) 103 | # print('fixed_argv:', fixed_argv) 104 | 105 | if arg_parsed.work_dir: 106 | work_dir = os.path.expanduser(arg_parsed.work_dir) 107 | os.makedirs(work_dir, exist_ok=True) 108 | arg_parsed.work_dir = work_dir 109 | 110 | serve = MyServe(parser=parser) 111 | serve.work_dir = arg_parsed.work_dir 112 | parser.set_defaults(invoke=serve.invoke) 113 | args = parser.parse_args(fixed_argv) 114 | 115 | try: 116 | if arg_parsed.work_dir: 117 | run_cleaner(path=arg_parsed.work_dir, 118 | file_pattern='experiment_*.pkl.gz' 119 | ) 120 | ret = args.invoke(args) 121 | except Exception as e: 122 | print("ERROR: " + str(e), file=sys.stderr) 123 | exit(1) 124 | 125 | if ret is False: 126 | sys.exit(1) 127 | elif ret is not True and isinstance(ret, int) and ret != 0: 128 | sys.exit(ret) 129 | 130 | 131 | def bkapp(): 132 | from causallab.plot import BNExperimentPlotter 133 | 134 | parser = argparse.ArgumentParser('') 135 | init_argparser(parser) 136 | args = parser.parse_args() 137 | work_dir = args.work_dir 138 | if work_dir: 139 | work_dir = os.path.expanduser(work_dir) 140 | plotter = BNExperimentPlotter( 141 | experiment_file=args.experiment, 142 | data_file=args.data, 143 | test_file=args.test, 144 | work_dir=work_dir, 145 | ) 146 | plotter.plot() 147 | 148 | 149 | # print('__name__: >>', __name__) 150 | # print('argv', sys.argv) 151 | 152 | if __name__ == '__main__': 153 | main() 154 | elif __name__.startswith('bokeh_app_'): 155 | bkapp() 156 | -------------------------------------------------------------------------------- /causallab/theme.py: -------------------------------------------------------------------------------- 1 | my_theme = dict( 2 | attr=dict( 3 | tab=dict( 4 | sizing_mode='stretch_both', 5 | ), 6 | Figure=dict( 7 | sizing_mode='stretch_both', 8 | ), 9 | DataTable=dict( 10 | sizing_mode='stretch_both', 11 | ), 12 | Column=dict( 13 | sizing_mode='stretch_both', 14 | ), 15 | Row=dict( 16 | sizing_mode='stretch_both', 17 | ), 18 | ), 19 | ) 20 | -------------------------------------------------------------------------------- /causallab/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import glob 3 | import os.path as path 4 | import threading 5 | from functools import partial 6 | from io import BytesIO 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from scipy import interpolate 11 | 12 | 13 | def load_data(data_path, *, reset_index=False, reader_mapping=None, **kwargs): 14 | """ 15 | load dataframe from data_path 16 | """ 17 | 18 | if reader_mapping is None: 19 | reader_mapping = { 20 | 'csv': partial(pd.read_csv, low_memory=False), 21 | 'txt': partial(pd.read_csv, low_memory=False), 22 | 'parquet': pd.read_parquet, 23 | 'par': pd.read_parquet, 24 | 'json': pd.read_json, 25 | 'pkl': pd.read_pickle, 26 | 'pickle': pd.read_pickle, 27 | } 28 | 29 | def get_file_format(file_path): 30 | return path.splitext(file_path)[-1].lstrip('.') 31 | 32 | def get_file_format_by_glob(data_pattern): 33 | for f in glob.glob(data_pattern, recursive=True): 34 | fmt_ = get_file_format(f) 35 | if fmt_ in reader_mapping.keys(): 36 | return fmt_ 37 | return None 38 | 39 | if glob.has_magic(data_path): 40 | fmt = get_file_format_by_glob(data_path) 41 | elif not path.exists(data_path): 42 | raise ValueError(f'Not found path {data_path}') 43 | elif path.isdir(data_path): 44 | path_pattern = f'{data_path}*' if data_path.endswith(path.sep) else f'{data_path}{path.sep}*' 45 | fmt = get_file_format_by_glob(path_pattern) 46 | else: 47 | fmt = path.splitext(data_path)[-1].lstrip('.') 48 | 49 | if fmt not in reader_mapping.keys(): 50 | raise ValueError(f'Not supported data format{fmt}') 51 | fn = reader_mapping[fmt] 52 | df = fn(data_path, **kwargs) 53 | 54 | if reset_index: 55 | df.reset_index(drop=True, inplace=True) 56 | 57 | return df 58 | 59 | 60 | def load_b64data(data, filename, *, reset_index=False, reader_mapping=None, **kwargs): 61 | """ 62 | load dataframe from data_path 63 | """ 64 | 65 | if reader_mapping is None: 66 | reader_mapping = { 67 | 'csv': partial(pd.read_csv, low_memory=False), 68 | 'txt': partial(pd.read_csv, low_memory=False), 69 | 'parquet': pd.read_parquet, 70 | 'par': pd.read_parquet, 71 | 'json': pd.read_json, 72 | 'pkl': pd.read_pickle, 73 | 'pickle': pd.read_pickle, 74 | } 75 | 76 | fmt = path.splitext(filename)[-1].lstrip('.') 77 | 78 | if fmt not in reader_mapping.keys(): 79 | raise ValueError(f'Not supported data format{fmt}') 80 | fn = reader_mapping[fmt] 81 | 82 | data = base64.b64decode(data) 83 | buf = BytesIO(data) 84 | df = fn(buf, **kwargs) 85 | 86 | if reset_index: 87 | df.reset_index(drop=True, inplace=True) 88 | 89 | return df 90 | 91 | 92 | def smooth_line(xs, ys): 93 | # see: https://github.com/kawache/Python-B-spline-examples 94 | 95 | # tck, u = interpolate.splprep([xs, ys], k=3, s=0) 96 | # u = np.linspace(0, 1, num=len(xs) * 3, endpoint=True) 97 | # out = interpolate.splev(u, tck) 98 | 99 | n = len(xs) 100 | 101 | t = np.linspace(0, 1, n - 2, endpoint=True) 102 | t = np.append([0, 0, 0], t) 103 | t = np.append(t, [1, 1, 1]) 104 | 105 | tck = [t, [xs, ys], 3] 106 | u = np.linspace(0, 1, (max(n * 2, 30)), endpoint=True) 107 | out = interpolate.splev(u, tck) 108 | return out 109 | 110 | 111 | def smooth_line_bak(xs, ys): 112 | xs_orig, ys_orig = xs.copy(), ys.copy() 113 | xs = np.array(xs) 114 | ys = np.array(ys) 115 | 116 | flip_xy = False 117 | flip_ud = False 118 | if np.all(xs[1:] > xs[:-1]): 119 | pass 120 | elif np.all(xs[1:] < xs[:-1]): 121 | xs, ys = np.flipud(xs), np.flipud(ys) 122 | flip_ud = True 123 | elif np.all(ys[1:] > ys[:-1]): 124 | xs, ys = ys, xs 125 | flip_xy = True 126 | elif np.all(ys[1:] < ys[:-1]): 127 | xs, ys = np.flipud(ys), np.flipud(xs) 128 | flip_xy = True 129 | flip_ud = True 130 | else: 131 | # not found monotonic direction, skip smooth 132 | print('xs=', xs_orig) 133 | print('ys=', ys_orig) 134 | print('skip smooth') 135 | return xs_orig, ys_orig 136 | 137 | step_min = (xs[1:] - xs[:-1]).min() 138 | num = min(int((xs[-1] - xs[0]) / step_min * 1), 100) 139 | values_x = np.linspace(start=xs[0], stop=xs[-1], num=num, endpoint=True) 140 | spline = interpolate.PchipInterpolator(xs, ys) 141 | values_y = spline(values_x) 142 | 143 | if flip_ud: 144 | values_x, values_y = np.flipud(values_x), np.flipud(values_y) 145 | if flip_xy: 146 | values_x, values_y = values_y, values_x 147 | 148 | return values_x, values_y 149 | 150 | 151 | def _proc_stub(target, args=None, kwargs=None, on_success=None, on_error=None): 152 | if args is None: 153 | args = [] 154 | if kwargs is None: 155 | kwargs = {} 156 | try: 157 | r = target(*args, **kwargs) 158 | if on_success is not None: 159 | on_success(r) 160 | except Exception as e: 161 | import traceback 162 | traceback.print_exc() 163 | if on_error is not None: 164 | on_error(e) 165 | 166 | 167 | def trun(target, args=None, kwargs=None, on_success=None, on_error=None): 168 | """ 169 | Run target function on thread 170 | :return: threading.Thread 171 | """ 172 | t = threading.Thread( 173 | target=partial(_proc_stub, target, 174 | args=args, 175 | kwargs=kwargs, 176 | on_success=on_success, 177 | on_error=on_error) 178 | ) 179 | t.start() 180 | 181 | return t 182 | -------------------------------------------------------------------------------- /causallab/view.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict, OrderedDict 3 | from copy import deepcopy 4 | from functools import partial 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from bokeh import models as M 10 | from bokeh.plotting import figure, row, column, curdoc 11 | 12 | from ylearn.bayesian import DataLoader, SviBayesianNetwork, _base 13 | from ylearn.utils import logging, is_notebook, nmap, to_list, calc_score 14 | from . import utils 15 | from .discovery import CausationHolder 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | def _discovery(df, algs, callback=None): 21 | from ylearn.sklearn_ex import general_preprocessor 22 | from .discovery import discoverers 23 | 24 | assert (isinstance(algs, (list, tuple)) 25 | and len(algs) > 0 26 | and all(map(lambda a: a in discoverers.keys(), algs))) 27 | 28 | gp = general_preprocessor(number_scaler=True) 29 | data = gp.fit_transform(df) 30 | 31 | matrix = None 32 | for alg in algs: 33 | print('>>>>run:', alg) 34 | start_at = time.time() 35 | m = discoverers[alg](data) 36 | while abs(time.time() - start_at) < 3: 37 | time.sleep(0.1) 38 | if callback is not None: 39 | callback(alg, m) 40 | if matrix is None: 41 | matrix = m 42 | else: 43 | matrix += m 44 | 45 | return matrix 46 | 47 | 48 | class ViewItemNames: 49 | main_graph = 'main_graph' 50 | main_table = 'main_table' 51 | 52 | graph_node = 'graph_node' 53 | 54 | 55 | class PlotView(_base.BObject): 56 | """ 57 | PlotView base class 58 | """ 59 | 60 | def __init__(self, node_states): 61 | assert isinstance(node_states, dict) 62 | 63 | self.node_states = node_states 64 | 65 | def get_layout(self): 66 | side_width = 300 67 | main_layout = self.get_main_layout() 68 | side_layout = self.get_side_layout(main_layout, width=side_width) 69 | main_layout.sizing_mode = 'stretch_both' 70 | side_layout.sizing_mode = 'stretch_height' 71 | return row(main_layout, side_layout, sizing_mode='stretch_both') 72 | 73 | def get_main_layout(self): 74 | raise NotImplemented() 75 | 76 | def get_side_layout(self, main_layout, *, width): 77 | raise NotImplemented() 78 | 79 | @property 80 | def is_py_callback_enabled(self): 81 | return not is_notebook() 82 | 83 | def format_value(self, node, value): 84 | if value is None or value is np.nan or pd.isna(value): 85 | return '' 86 | 87 | state = self.node_states[node] 88 | if isinstance(state, _base.CategoryNodeState): 89 | return str(value) 90 | elif isinstance(value, int): 91 | return str(value) 92 | else: 93 | return f'{float(value):.3f}' 94 | 95 | @staticmethod 96 | def format_dict(dic, title=None): 97 | if len(dic) > 0: 98 | items = [PlotView.format_kv(k, v) for k, v in dic.items()] 99 | else: 100 | items = ['
  • < None >
  • '] 101 | 102 | html = '\n'.join(['
      '] + items + ['
    ']) 103 | if title is not None: 104 | html = f'

    {title}

    \n' + html 105 | return html 106 | 107 | @staticmethod 108 | def format_kv(k, v): 109 | # return f'
  • {key}: {value}' 110 | if isinstance(v, float): 111 | return f'
  • {k}:  {v:.6f}' 112 | elif isinstance(v, np.ndarray) and v.ndim == 0: 113 | return f'
  • {k}:  {v:.6f}' 114 | else: 115 | return f'
  • {k}:  {v}' 116 | 117 | @staticmethod 118 | def decorate_pd_html(html): 119 | style = """ 120 | 133 | """.strip() 134 | return '\n'.join(['
    ', style, html, '
    ']) 135 | 136 | def state_to_html(self, node, state=None): 137 | s = self.node_states[node] if state is None else state 138 | if isinstance(s, _base.CategoryNodeState): 139 | stub = pd.Series(dict( 140 | kind='discrete', 141 | classes=s.classes.tolist(), 142 | )) 143 | else: 144 | stub = pd.Series(dict( 145 | kind='continuous', 146 | mean=s.mean, 147 | scale=s.scale, 148 | min=s.min, 149 | max=s.max, 150 | )) 151 | return PlotView.decorate_pd_html(stub.to_frame()._repr_html_()) 152 | 153 | def node_shape(self, n): 154 | return 'box' if isinstance(self.node_states[n], _base.CategoryNodeState) else 'ellipse' 155 | 156 | @staticmethod 157 | def default_column_formatter(df, col, **kwargs): 158 | kind = df[col].dtype.kind 159 | if kind == 'f': 160 | return M.NumberFormatter(format='0,0.000', **kwargs) 161 | elif kind in 'iu': 162 | return M.NumberFormatter(format='0,0', **kwargs) 163 | else: 164 | return M.StringFormatter(**kwargs) 165 | 166 | 167 | class DataExplorationView(PlotView): 168 | def __init__(self, train_data, test_data): 169 | super().__init__(DataLoader.state_of(train_data)) 170 | 171 | self.train_data = train_data 172 | self.test_data = test_data 173 | 174 | def activate(self, data): 175 | assert data is self.train_data or data is self.test_data 176 | 177 | def get_main_layout(self): 178 | # fig = figure(toolbar_location='above', 179 | # tools="pan,tap,zoom_in,zoom_out", 180 | # # outline_line_color='lightgray', 181 | # sizing_mode='stretch_width', 182 | # height=30 183 | # ) 184 | 185 | table_layout = self.get_table_layout(self.train_data, sizing_mode='stretch_both') 186 | column_layout = self.get_column_layout(self.train_data, sizing_mode='stretch_both') 187 | tabs = [M.TabPanel(title='Detail', child=table_layout), 188 | M.TabPanel(title='Column', child=column_layout), 189 | ] 190 | tabs = M.Tabs(tabs=tabs, sizing_mode='stretch_both') 191 | # return column(fig, tabs, sizing_mode='stretch_both') 192 | return tabs 193 | 194 | def get_side_layout(self, main_layout, *, width=1000): 195 | radio_dataset = M.RadioGroup(labels=['Train Data', 'Test Data'], active=0) 196 | file_test = M.FileInput(accept='.csv,.txt,.json', multiple=False) 197 | 198 | def on_change_dataset(attr, old_value, new_value): 199 | if new_value == 0: 200 | data = self.train_data 201 | else: 202 | data = self.test_data 203 | print('data', data.shape) 204 | table_layout = self.get_table_layout(data, sizing_mode='stretch_both') 205 | column_layout = self.get_column_layout(data, sizing_mode='stretch_both') 206 | tabs = [M.TabPanel(title='Detail', child=table_layout), 207 | M.TabPanel(title='Column', child=column_layout), 208 | ] 209 | main_layout.tabs = tabs 210 | 211 | def on_change_file(attr, old_value, new_value): 212 | print(attr, 'new value:', len(new_value)) 213 | print(attr, 'file name:', file_test.filename) 214 | print(attr, 'mime_type:', file_test.mime_type) 215 | 216 | file_test.on_change('value', on_change_file) 217 | file_test.on_change('mime_type', on_change_file) 218 | file_test.on_change('filename', on_change_file) 219 | radio_dataset.on_change('active', on_change_dataset) 220 | widgets = [M.Div(text='

    Datasets:

    '), 221 | radio_dataset, 222 | file_test, 223 | ] 224 | return column(widgets, width=width) 225 | 226 | def get_table_layout(self, df, **kwargs): 227 | fmt = self.default_column_formatter 228 | table_columns = [M.TableColumn(field=c, 229 | formatter=fmt(df, col=c), ) 230 | for c in df.columns.tolist()] 231 | ds_table = M.ColumnDataSource(df) 232 | table = M.DataTable(columns=table_columns, source=ds_table, editable=False, **kwargs) 233 | return table 234 | 235 | def get_column_layout(self, df, **kwargs): 236 | state = DataLoader.state_of(df) 237 | columns = df.columns.tolist() 238 | state_html = [self.state_to_html(None, state=state[c]) for c in columns] 239 | df_summary = pd.DataFrame(dict(column=columns, state=state_html)) 240 | 241 | ds_summary = M.ColumnDataSource(df_summary) 242 | table_columns = [M.TableColumn(field='column', formatter=M.StringFormatter()), 243 | M.TableColumn(field='state', formatter=M.HTMLTemplateFormatter())] 244 | table = M.DataTable(columns=table_columns, source=ds_summary, editable=False, 245 | row_height=150, 246 | **kwargs) 247 | return table 248 | 249 | 250 | class GraphPlotView(PlotView): 251 | """ 252 | PlotView with graph and optional dataframe 253 | """ 254 | 255 | def __init__(self, *, data=None, node_states=None): 256 | assert node_states is not None or data is not None 257 | 258 | if node_states is not None: 259 | super().__init__(node_states) 260 | else: 261 | super().__init__(DataLoader.state_of(data)) 262 | 263 | self.data = data 264 | 265 | if data is not None: 266 | self.ds_table = self._to_table_ds(data) 267 | self.ds_node, self.ds_edge = self._to_graph_ds() 268 | else: 269 | self.ds_table = None 270 | self.ds_node, self.ds_edge = self._to_graph_ds() 271 | 272 | @property 273 | def data_title(self): 274 | return 'Test Data' 275 | 276 | def get_graph(self): 277 | raise NotImplemented() 278 | 279 | def _to_table_ds(self, df): 280 | source = M.ColumnDataSource(df.copy()) 281 | source.selected.indices = [0] 282 | return source 283 | 284 | def _get_node_edge_layout(self, graph=None, prog=None, node_pos=None, dot_options=None): 285 | if graph is None: 286 | graph = self.get_graph() 287 | 288 | if prog is None: 289 | if len(graph.get_edges()) == 0: 290 | prog = 'neato' 291 | else: 292 | prog = 'dot' 293 | 294 | node_layout, edge_layout = graph.pydot_layout( 295 | prog=prog, node_pos=node_pos, dot_options=dot_options) 296 | return node_layout, edge_layout 297 | 298 | def _to_graph_ds(self): 299 | graph = self.get_graph() 300 | if graph is None: 301 | return None, None 302 | 303 | ds_table = self.ds_table 304 | 305 | # nodes names and values 306 | nodes = graph.get_nodes() 307 | if ds_table is not None and self.is_data_layout_enabled and len(ds_table.selected.indices) > 0: 308 | row_idx = ds_table.selected.indices[0] 309 | values = [self.format_value(n, ds_table.data[n][row_idx]) for n in nodes] 310 | else: 311 | values = [''] * len(nodes) 312 | 313 | # nodes and edges layout 314 | node_layout, edge_layout = self._get_node_edge_layout() 315 | 316 | # create node datasource 317 | ds_node = M.ColumnDataSource(data=dict( 318 | x=[node_layout[n]['x'] for n in nodes], 319 | y=[node_layout[n]['y'] for n in nodes], 320 | width=[node_layout[n]['width'] for n in nodes], 321 | height=[node_layout[n]['height'] for n in nodes], 322 | shape=[self.node_shape(n) for n in nodes], 323 | node=nodes, 324 | value=values, 325 | line_width=[1] * len(nodes), 326 | )) 327 | 328 | # create edge datasource 329 | edges = {k: [] for k in [ 330 | 'start', 'end', 'xs', 'ys', 331 | 'arrow_x_start', 'arrow_y_start', 'arrow_x_end', 'arrow_y_end', 332 | 'dash', 333 | ]} 334 | for (s, e), v in edge_layout.items(): 335 | xs, ys = utils.smooth_line(v['x'], v['y']) 336 | edges['start'].append(s) 337 | edges['end'].append(e) 338 | edges['xs'].append(xs) 339 | edges['ys'].append(ys) 340 | edges['arrow_x_start'].append(xs[-2]) 341 | edges['arrow_y_start'].append(ys[-2]) 342 | edges['arrow_x_end'].append(xs[-1]) 343 | edges['arrow_y_end'].append(ys[-1]) 344 | edges['dash'].append('solid') # default style 345 | ds_edge = M.ColumnDataSource(data=edges) 346 | 347 | def on_data_table_row_change(attr, old_value, new_value): 348 | logger.debug(f'on_data_table_row_change, new_value={new_value}') 349 | assert isinstance(new_value, (list, tuple)) and len(new_value) > 0 350 | 351 | if len(new_value) > 0: 352 | idx = new_value[0] 353 | nodes_ = ds_node.data['node'] 354 | new_values = [self.format_value(n, ds_table.data[n][idx]) for n in nodes_] 355 | ds_node.data['value'] = new_values 356 | 357 | if ds_table is not None and self.is_data_layout_enabled and self.is_py_callback_enabled: 358 | ds_table.selected.on_change('indices', on_data_table_row_change) 359 | 360 | return ds_node, ds_edge 361 | 362 | def get_main_layout(self): 363 | if self.data is not None and self.is_data_layout_enabled: 364 | table_height = 250 365 | data_layout = self.get_data_layout(height=table_height) 366 | graph_layout = self.get_graph_layout() 367 | return column(graph_layout, data_layout) 368 | else: 369 | graph_layout = self.get_graph_layout() 370 | return graph_layout 371 | 372 | def get_graph_layout(self): 373 | # main figure 374 | fig = figure(title='Graph', 375 | toolbar_location='above', 376 | tools=[], 377 | outline_line_color='lightgray', 378 | sizing_mode='stretch_both', 379 | match_aspect=True, 380 | name=ViewItemNames.main_graph, 381 | ) 382 | # fig.toolbar.logo = None # disable bokeh logo 383 | fig.axis.visible = False 384 | 385 | def on_change_inner_width(attr, old_value, new_value): 386 | print(attr, old_value, new_value) 387 | try: 388 | start = fig.x_range.start 389 | fig.x_range.end = start + new_value / 1.5 390 | except: 391 | pass 392 | 393 | def on_change_inner_height(attr, old_value, new_value): 394 | print(attr, old_value, new_value) 395 | try: 396 | start = fig.y_range.start 397 | fig.y_range.end = start + new_value / 1.5 398 | except: 399 | pass 400 | 401 | fig.on_change('inner_width', on_change_inner_width) 402 | fig.on_change('inner_height', on_change_inner_height) 403 | 404 | self._draw_graph_node(fig) 405 | self._draw_graph_text(fig) 406 | self._draw_graph_line(fig) 407 | self._config_graph(fig) 408 | 409 | return fig 410 | 411 | def _config_graph(self, fig): 412 | fig.add_tools(*to_list('pan,tap,zoom_in,zoom_out')) 413 | return fig 414 | 415 | def _draw_graph_node(self, fig): 416 | """ 417 | draw nodes shape 418 | """ 419 | ds_node = self.ds_node 420 | 421 | # define data views 422 | js_filter_ellipse = M.CustomJSFilter(code=''' 423 | const data = source.data['shape']; 424 | const indices = data.map(v => v=="ellipse"); 425 | return indices; 426 | ''') 427 | js_filter_rect = M.CustomJSFilter(code=''' 428 | const data = source.data['shape']; 429 | const indices = data.map(v => v!="ellipse"); 430 | return indices; 431 | ''') 432 | vw_node_ellipse = M.CDSView(filter=js_filter_ellipse) 433 | vw_node_rect = M.CDSView(filter=js_filter_rect) 434 | 435 | shape_options = dict( 436 | width='width', height='height', 437 | line_width='line_width', 438 | fill_color="#F3C797", 439 | line_color='#B7472A', 440 | # selection_color='gray', 441 | # set visual properties for non-selected glyphs 442 | nonselection_fill_alpha=0.6, 443 | # nonselection_fill_color="lightgray", 444 | # nonselection_line_color="firebrick", 445 | # nonselection_line_alpha=1.0, 446 | name=ViewItemNames.graph_node, 447 | ) 448 | 449 | fig.rect(source=ds_node, view=vw_node_rect, **shape_options) 450 | fig.ellipse(source=ds_node, view=vw_node_ellipse, **shape_options) 451 | 452 | return fig 453 | 454 | def _draw_graph_text(self, fig): 455 | """ 456 | draw node text 457 | """ 458 | ds_node = self.ds_node 459 | 460 | text_options = dict( 461 | text_align='center', 462 | text_baseline='middle', 463 | ) 464 | fig.text(text='node', source=ds_node, 465 | y_offset=-10, # text_font_style='bold', 466 | **text_options) 467 | fig.text(text='value', source=ds_node, 468 | y_offset=15, 469 | **text_options) 470 | 471 | return fig 472 | 473 | def _draw_graph_line(self, fig): 474 | """ 475 | draw edges 476 | """ 477 | ds_edge = self.ds_edge 478 | 479 | fig.multi_line(source=ds_edge, line_color='#B7472A', line_dash='dash', ) 480 | oh = M.OpenHead(line_color='#B7472A', line_width=1, size=10) 481 | arr = M.Arrow(end=oh, 482 | x_start='arrow_x_start', y_start='arrow_y_start', 483 | x_end='arrow_x_end', y_end='arrow_y_end', 484 | line_color='#B7472A', line_width=1, 485 | source=ds_edge, 486 | # level='underlay', 487 | ) 488 | fig.add_layout(arr) 489 | 490 | return fig 491 | 492 | def get_data_layout(self, *, height=200): 493 | # nodes = self.ds_node.data['node'] 494 | df = self.data 495 | 496 | _fmt = self.default_column_formatter 497 | ds_columns = self.ds_table.data.keys() 498 | df_columns = df.columns.tolist() 499 | table_columns = [] 500 | for col in ds_columns: 501 | if col in df_columns: 502 | table_columns.append(M.TableColumn( 503 | field=col, 504 | formatter=_fmt(df, col), 505 | )) 506 | # if self.pred_name(node) in ds_columns: 507 | # table_columns.append(M.TableColumn( 508 | # field=self.pred_name(node), 509 | # title=f'PredOf_{node}', 510 | # formatter=_fmt(df, node, font_style='italic', text_color='blue'), 511 | # visible=False, 512 | # )) 513 | 514 | table = M.DataTable(columns=table_columns, source=self.ds_table, editable=False, 515 | sizing_mode='stretch_both', ) 516 | title = M.Div(text=f"{self.data_title}") 517 | 518 | layout = column(title, table, height=height, sizing_mode='stretch_width', ) 519 | return layout 520 | 521 | @property 522 | def is_data_layout_enabled(self): 523 | return True 524 | 525 | 526 | class CausationView(GraphPlotView): 527 | """ 528 | PlotView for causation discovery 529 | """ 530 | 531 | def __init__(self, data, causation=None): 532 | if causation is None: 533 | causation = CausationHolder(DataLoader.state_of(data)) 534 | self.causation = causation 535 | 536 | super().__init__(data=data, node_states=causation.node_states) 537 | 538 | @property 539 | def data_title(self): 540 | return 'Train Data' 541 | 542 | def get_graph(self): 543 | return self.causation.graph 544 | 545 | def _to_graph_ds(self): 546 | ds_node, ds_edge = super()._to_graph_ds() 547 | 548 | # add edge data of weight & expert from graph 549 | weight = [] 550 | expert = [] 551 | marker_x = [] 552 | marker_y = [] 553 | marker = [] 554 | graph = self.get_graph() 555 | for start, end, xs, ys in zip(ds_edge.data['start'], ds_edge.data['end'], 556 | ds_edge.data['xs'], ds_edge.data['ys']): 557 | if graph.has_edge(start, end): 558 | edge_data = graph[start][end] 559 | weight.append(edge_data['weight']) 560 | expert.append(edge_data['expert']) 561 | else: 562 | weight.append(0.0) 563 | expert.append(0) 564 | marker.append(0) 565 | marker_x.append(xs[0]) 566 | marker_y.append(ys[0]) 567 | 568 | ds_edge.data['weight'] = weight 569 | ds_edge.data['expert'] = expert 570 | ds_edge.data['marker'] = marker 571 | ds_edge.data['marker_x'] = marker_x 572 | ds_edge.data['marker_y'] = marker_y 573 | 574 | return ds_node, ds_edge 575 | 576 | doc_cb_move_marker = None 577 | 578 | def _draw_graph_line(self, fig): 579 | """ 580 | draw edges with **weight** as line_width, **expert** as line_color 581 | """ 582 | ds_edge = self.ds_edge 583 | 584 | js_filter_expert = M.CustomJSFilter(code=''' 585 | const data = source.data['expert']; 586 | const indices = Array.from(data).map(v => v>0); 587 | return indices; 588 | ''') 589 | js_filter_not_expert = M.CustomJSFilter(code=''' 590 | const data = source.data['expert']; 591 | const indices = Array.from(data).map(v => v==0); 592 | return indices; 593 | ''') 594 | vw_edge_expert = M.CDSView(filter=js_filter_expert) 595 | vw_edge_not_expert = M.CDSView(filter=js_filter_not_expert) 596 | 597 | fig.multi_line(source=ds_edge, view=vw_edge_expert, line_color='blue', width='weight') 598 | fig.multi_line(source=ds_edge, view=vw_edge_not_expert, line_color='#B7472A', width='weight') 599 | oh = M.OpenHead(line_color='#B7472A', line_width=1, size=10) 600 | arr = M.Arrow(end=oh, 601 | x_start='arrow_x_start', y_start='arrow_y_start', 602 | x_end='arrow_x_end', y_end='arrow_y_end', 603 | line_color='#B7472A', line_width=1, 604 | source=ds_edge, 605 | # level='underlay', 606 | ) 607 | fig.add_layout(arr) 608 | 609 | fig.circle(x='marker_x', y='marker_y', radius=2, 610 | line_color='#B7472A', fill_color='#B7472A', 611 | source=ds_edge) 612 | 613 | def update_marker_xy(): 614 | data = ds_edge.data 615 | marker = [] 616 | marker_x = [] 617 | marker_y = [] 618 | for i, (m, xs, ys) in enumerate(zip(data['marker'], data['xs'], data['ys'])): 619 | m_new = m + 1 if (m + 3) < len(xs) else 0 620 | marker.append(m_new) 621 | marker_x.append(xs[m_new]) 622 | marker_y.append(ys[m_new]) 623 | data['marker'] = marker 624 | data['marker_x'] = marker_x 625 | data['marker_y'] = marker_y 626 | 627 | def setup_cb(): 628 | if CausationView.doc_cb_move_marker is not None: 629 | try: 630 | curdoc().remove_periodic_callback(CausationView.doc_cb_move_marker) 631 | except: 632 | pass 633 | CausationView.doc_cb_move_marker = None 634 | cb = curdoc().add_periodic_callback(update_marker_xy, 200) 635 | CausationView.doc_cb_move_marker = cb 636 | 637 | curdoc().add_timeout_callback(setup_cb, 2000) 638 | return fig 639 | 640 | def _config_graph(self, fig): 641 | super()._config_graph(fig) 642 | 643 | node_renders = fig.select({'name': ViewItemNames.graph_node}) 644 | if node_renders is None: 645 | return fig 646 | 647 | # add PointDrawTool to drag-drop nodes 648 | draw_tool = M.PointDrawTool(renderers=node_renders, drag=True, add=False, ) 649 | fig.add_tools(draw_tool) 650 | 651 | # fig.toolbar.active_tap = draw_tool 652 | 653 | def on_data_change(attr, old_value, new_value): 654 | # find_moved_nodes 655 | eps = 0.01 656 | nodes_moved = [] 657 | for node, xnew, ynew, xold, yold in \ 658 | zip(old_value['node'], old_value['x'], old_value['y'], new_value['x'], new_value['y']): 659 | if abs(xnew - xold) > eps or abs(ynew - yold) > eps: 660 | nodes_moved.append(node) 661 | 662 | if len(nodes_moved) == 0: 663 | return 664 | 665 | # update_graph_layout 666 | logger.info(f'found moved nodes: {nodes_moved} , call update_graph_layout') 667 | node_pos = {n: (x, y) 668 | for n, x, y in zip(new_value['node'], new_value['x'], new_value['y'])} 669 | self.update_graph_layout( 670 | graph=self.get_graph(), prog='neato', node_pos=node_pos, 671 | dot_options=dict(splines='spline')) 672 | 673 | self.ds_node.on_change('data', on_data_change) 674 | 675 | return fig 676 | 677 | def get_side_layout(self, main_layout, *, width): 678 | from causallab.discovery import discoverers 679 | 680 | ce_sep = ' > ' 681 | 682 | def ce_value(cause, effect): 683 | return f'{cause}{ce_sep}{effect}' 684 | 685 | def parse_ce_value(v): 686 | assert isinstance(v, str) and v.find(ce_sep) > 0 687 | i = v.find(ce_sep) 688 | cause = v[:i] 689 | effect = v[i + len(ce_sep):] 690 | return cause, effect 691 | 692 | causation, ds_node, ds_edge = self.causation, self.ds_node, self.ds_edge 693 | 694 | algs = list(discoverers.keys()) 695 | nodes = self.ds_node.data['node'] 696 | 697 | choice_algs = M.MultiChoice( 698 | options=algs, 699 | value=algs[:1], 700 | # max_items=3, 701 | placeholder='click to select', 702 | sizing_mode='stretch_width', 703 | ) 704 | container_algs = column(M.Div(text='')) 705 | btn_discovery = M.Button( 706 | label='Run', 707 | button_type="primary", 708 | ) 709 | container_progress = column(M.Div(text='')) 710 | spinner_threshold = M.Spinner( 711 | title='threshold of discovery', 712 | value=1, step=1, 713 | low=1, high=max(len(causation.matrices), 1), 714 | ) 715 | select_cause = M.Select( 716 | title='cause', 717 | value=nodes[0], 718 | options=nodes, 719 | sizing_mode='stretch_width', 720 | ) 721 | select_effect = M.Select( 722 | title='effect', 723 | value=nodes[1], 724 | options=nodes, 725 | sizing_mode='stretch_width', 726 | ) 727 | btn_add_cause_effect = M.Button( 728 | label='Add', 729 | button_type="success", 730 | align='end', 731 | width=50, sizing_mode='fixed', 732 | ) 733 | select_enabled = M.MultiSelect( 734 | options=[ce_value(c, e) for c, e in causation.enabled], 735 | sizing_mode='stretch_width', 736 | ) 737 | btn_remove_enabled = M.Button( 738 | label='', 739 | icon=M.BuiltinIcon('x', size="1.0em", color="white"), 740 | button_type="success", 741 | align='center', 742 | width=50, sizing_mode='fixed', 743 | ) 744 | select_selection = M.MultiSelect( 745 | title='Selected:', 746 | description='click on graph to select causal relations', 747 | sizing_mode='stretch_width' 748 | ) 749 | btn_add_disabled = M.Button( 750 | label='Del', 751 | button_type="success", 752 | align='center', 753 | width=50, sizing_mode='fixed', 754 | ) 755 | select_disabled = M.MultiSelect( 756 | options=[ce_value(c, e) for c, e in causation.disabled], 757 | sizing_mode='stretch_width', 758 | ) 759 | btn_remove_disabled = M.Button( 760 | label='', 761 | icon=M.BuiltinIcon('x', size="1.0em", color="white"), 762 | button_type="success", 763 | align='center', 764 | width=50, sizing_mode='fixed', 765 | ) 766 | widgets = [ 767 | M.Div(text='

    Causation


    '), 768 | M.Div(text='» Discovery:'), 769 | choice_algs, 770 | # container_algs, 771 | btn_discovery, 772 | # container_progress, 773 | # M.Paragraph(text=''), 774 | M.Div(text='

    » Edit

    '), 775 | spinner_threshold, 776 | M.Div(text='Add causal relationship:'), 777 | row(select_cause, select_effect, btn_add_cause_effect, 778 | width=width - 10, sizing_mode='stretch_width'), 779 | M.Div(text='Added:'), 780 | row(select_enabled, btn_remove_enabled, sizing_mode='stretch_width'), 781 | M.Div(text='Remove causal relationship:'), 782 | # M.Div(text='Selected:'), 783 | row(select_selection, btn_add_disabled, sizing_mode='stretch_width'), 784 | M.Div(text='Removed:'), 785 | row(select_disabled, btn_remove_disabled, sizing_mode='stretch_width'), 786 | ] 787 | 788 | def set_state_btn_add_cause_effect(cause=None, effect=None): 789 | if cause is None: 790 | cause = select_cause.value 791 | if effect is None: 792 | effect = select_effect.value 793 | 794 | btn_add_cause_effect.disabled = \ 795 | cause == effect \ 796 | or ce_value(cause, effect) in select_enabled.options 797 | 798 | def set_state_bth_remove_enabled(): 799 | btn_remove_enabled.disabled = len(select_enabled.value) == 0 800 | 801 | def set_state_bth_add_disabled(): 802 | btn_add_disabled.disabled = \ 803 | len(select_selection.value) == 0 \ 804 | or any(map(lambda c: c in select_disabled.options, select_selection.value)) 805 | 806 | def set_state_bth_remove_disabled(): 807 | btn_remove_disabled.disabled = len(select_disabled.value) == 0 808 | 809 | def on_change_threshold(attr, old_value, new_value): 810 | causation.threshold = spinner_threshold.value 811 | self.update_graph_layout() 812 | 813 | def on_change_cause(attr, old_value, new_value): 814 | set_state_btn_add_cause_effect(cause=new_value) 815 | 816 | def on_change_effect(attr, old_value, new_value): 817 | set_state_btn_add_cause_effect(effect=new_value) 818 | 819 | def on_change_enabled(attr, old_value, new_value): 820 | set_state_bth_remove_enabled() 821 | 822 | def on_change_disabled(attr, old_value, new_value): 823 | set_state_bth_remove_disabled() 824 | 825 | def on_change_selection(attr, old_value, new_value): 826 | set_state_bth_add_disabled() 827 | 828 | def on_node_change(attr, old_value, new_value): 829 | if layout.disabled: 830 | return 831 | 832 | if len(new_value) > 0: 833 | idx = new_value[0] 834 | node = nodes[idx] 835 | 836 | graph = self.causation.graph 837 | options = [ce_value(n, node) for n in graph.get_parents(node)] + \ 838 | [ce_value(node, n) for n in graph.get_children(node)] 839 | else: 840 | options = [] 841 | 842 | select_selection.options = options 843 | select_selection.value = [] 844 | set_state_bth_add_disabled() 845 | 846 | def on_edge_change(attr, old_value, new_value): 847 | if layout.disabled: 848 | return 849 | 850 | if len(new_value) > 0: 851 | idx = new_value[0] 852 | start = ds_edge.data['start'][idx] 853 | end = ds_edge.data['end'][idx] 854 | options = [ce_value(start, end)] 855 | else: 856 | options = [] 857 | 858 | select_selection.options = options 859 | select_selection.value = [] 860 | set_state_bth_add_disabled() 861 | 862 | def on_add_cause_effect_click(): 863 | causation.enable(select_cause.value, select_effect.value) 864 | select_enabled.options = [ce_value(c, e) for c, e in causation.enabled] 865 | select_disabled.options = [ce_value(c, e) for c, e in causation.disabled] 866 | 867 | set_state_btn_add_cause_effect() 868 | self.update_graph_layout() 869 | 870 | def on_remove_enabled_click(): 871 | for v in select_enabled.value: 872 | causation.remove_enabled(*parse_ce_value(v)) 873 | 874 | select_enabled.options = [ce_value(c, e) for c, e in causation.enabled] 875 | select_disabled.options = [ce_value(c, e) for c, e in causation.disabled] 876 | select_enabled.value = [] 877 | 878 | set_state_bth_remove_enabled() 879 | set_state_btn_add_cause_effect() 880 | self.update_graph_layout() 881 | 882 | def on_add_disabled_click(): 883 | for v in select_selection.value: 884 | causation.disable(*parse_ce_value(v)) 885 | 886 | select_enabled.options = [ce_value(c, e) for c, e in causation.enabled] 887 | select_disabled.options = [ce_value(c, e) for c, e in causation.disabled] 888 | 889 | set_state_bth_add_disabled() 890 | self.update_graph_layout() 891 | 892 | def on_remove_disabled_click(): 893 | for v in select_disabled.value: 894 | causation.remove_disabled(*parse_ce_value(v)) 895 | 896 | select_enabled.options = [ce_value(c, e) for c, e in causation.enabled] 897 | select_disabled.options = [ce_value(c, e) for c, e in causation.disabled] 898 | select_disabled.value = [] 899 | 900 | set_state_bth_add_disabled() 901 | set_state_bth_remove_disabled() 902 | self.update_graph_layout() 903 | 904 | def on_discovery_click(): 905 | matrices = {} 906 | done = [] 907 | node_pos = {n: (x, y) 908 | for n, x, y in zip(ds_node.data['node'], ds_node.data['x'], ds_node.data['y'])} 909 | graph_stub = deepcopy(self.get_graph()) 910 | graph_stub.remove_edges_from(graph_stub.get_edges()) 911 | 912 | def on_discovered(alg, matrix): 913 | matrices[alg] = matrix 914 | 915 | def on_discover_success(bn): 916 | done.append(1) 917 | 918 | def on_discover_error(e): 919 | done.append(0) 920 | 921 | def random_edge(): 922 | start, end = 0, 0 923 | while start == end or graph_stub.has_edge(nodes[start], nodes[end]): 924 | start, end = np.random.randint(low=0, high=len(nodes), size=2) 925 | return nodes[start], nodes[end] 926 | 927 | def update_layout(): 928 | found = matrices.copy() 929 | if len(found) > 0: 930 | matrices.clear() 931 | print('update layout with', found.keys()) 932 | for alg, matrix in found.items(): 933 | causation.add_matrix(alg, matrix) 934 | 935 | if len(done) > 0: 936 | curdoc().remove_periodic_callback(cb) 937 | btn_discovery.disabled = False 938 | self.update_graph_layout() 939 | else: 940 | edges = graph_stub.get_edges() 941 | if len(edges) * 2 > len(nodes): 942 | graph_stub.remove_edges_from(edges) 943 | edge_start, edge_end = random_edge() 944 | graph_stub.add_edge(edge_start, edge_end, weight=1.0, expert=0) 945 | self.update_graph_layout( 946 | graph=graph_stub, prog='neato', node_pos=node_pos, 947 | dot_options=dict(splines='spline')) 948 | 949 | cb = curdoc().add_periodic_callback(update_layout, 100) 950 | btn_discovery.disabled = True 951 | causation.matrices.clear() 952 | spinner_threshold.value = 1 953 | spinner_threshold.high = len(choice_algs.value) 954 | utils.trun(_discovery, 955 | args=[self.data, choice_algs.value], 956 | kwargs=dict( 957 | callback=on_discovered, 958 | ), 959 | on_success=on_discover_success, 960 | on_error=on_discover_error, 961 | ) 962 | 963 | # initialize button state 964 | set_state_btn_add_cause_effect() 965 | set_state_bth_remove_enabled() 966 | set_state_bth_add_disabled() 967 | set_state_bth_remove_disabled() 968 | 969 | # bind event handlers 970 | spinner_threshold.on_change('value', on_change_threshold) 971 | select_cause.on_change('value', on_change_cause) 972 | select_effect.on_change('value', on_change_effect) 973 | select_enabled.on_change('value', on_change_enabled) 974 | select_selection.on_change('value', on_change_selection) 975 | select_disabled.on_change('value', on_change_disabled) 976 | 977 | ds_node.selected.on_change('indices', on_node_change) 978 | ds_edge.selected.on_change('indices', on_edge_change) 979 | 980 | btn_add_cause_effect.on_click(on_add_cause_effect_click) 981 | btn_remove_enabled.on_click(on_remove_enabled_click) 982 | btn_add_disabled.on_click(on_add_disabled_click) 983 | btn_remove_disabled.on_click(on_remove_disabled_click) 984 | btn_discovery.on_click(on_discovery_click) 985 | 986 | # return 987 | layout = column(widgets, width=width) 988 | return layout 989 | 990 | def update_graph_layout(self, graph=None, prog=None, node_pos=None, dot_options=None): 991 | causation, ds_table, ds_node, ds_edge = self.causation, self.ds_table, self.ds_node, self.ds_edge 992 | if graph is None: 993 | graph = causation.graph 994 | node_layout, edge_layout = self._get_node_edge_layout( 995 | graph, prog=prog, node_pos=node_pos, dot_options=dot_options) 996 | 997 | nodes = graph.get_nodes() 998 | if len(ds_table.selected.indices) > 0: 999 | row_idx = ds_table.selected.indices[0] 1000 | values = [self.format_value(n, ds_table.data[n][row_idx]) for n in nodes] 1001 | else: 1002 | values = [''] * len(nodes) 1003 | 1004 | new_node_data = dict( 1005 | x=[node_layout[n]['x'] for n in nodes], 1006 | y=[node_layout[n]['y'] for n in nodes], 1007 | width=[node_layout[n]['width'] for n in nodes], 1008 | height=[node_layout[n]['height'] for n in nodes], 1009 | shape=[self.node_shape(n) for n in nodes], 1010 | node=nodes, 1011 | value=values, 1012 | ) 1013 | 1014 | new_edge_data = {k: [] for k in [ 1015 | 'start', 'end', 'xs', 'ys', 1016 | 'arrow_x_start', 'arrow_y_start', 'arrow_x_end', 'arrow_y_end', 1017 | 'weight', 'expert', 1018 | 'marker', 'marker_x', 'marker_y', 1019 | ]} 1020 | for (s, e), v in edge_layout.items(): 1021 | xs, ys = utils.smooth_line(v['x'], v['y']) 1022 | edge_data = graph[s][e] 1023 | new_edge_data['start'].append(s) 1024 | new_edge_data['end'].append(e) 1025 | new_edge_data['xs'].append(xs) 1026 | new_edge_data['ys'].append(ys) 1027 | new_edge_data['arrow_x_start'].append(xs[-2]) 1028 | new_edge_data['arrow_y_start'].append(ys[-2]) 1029 | new_edge_data['arrow_x_end'].append(xs[-1]) 1030 | new_edge_data['arrow_y_end'].append(ys[-1]) 1031 | new_edge_data['weight'].append(edge_data['weight']) 1032 | new_edge_data['expert'].append(edge_data['expert']) 1033 | new_edge_data['marker'].append(0) 1034 | new_edge_data['marker_x'].append(xs[0]) 1035 | new_edge_data['marker_y'].append(ys[0]) 1036 | 1037 | self.patch_ds(ds_node, new_node_data) 1038 | self.patch_ds(ds_edge, new_edge_data) 1039 | 1040 | def patch_ds(self, ds, new_data): 1041 | assert isinstance(ds, M.ColumnDataSource) 1042 | assert isinstance(new_data, dict) 1043 | 1044 | n_old = len(next(iter(ds.data.values()))) 1045 | n_new = len(next(iter(new_data.values()))) 1046 | 1047 | if n_old == n_new: 1048 | for k, v in new_data.items(): 1049 | ds.data[k] = v 1050 | else: 1051 | ds.data = pd.DataFrame(new_data) 1052 | 1053 | @property 1054 | def is_data_layout_enabled(self): 1055 | # return False 1056 | return True 1057 | 1058 | 1059 | class BNPlotView(GraphPlotView): 1060 | """ 1061 | PlotView with BayesianNetwork 1062 | """ 1063 | 1064 | def __init__(self, *, bn=None, data=None): 1065 | assert bn is not None or data is not None 1066 | 1067 | self.bn = bn 1068 | 1069 | if bn is not None: 1070 | super().__init__(data=data, node_states=bn.state_) 1071 | else: 1072 | super().__init__(data=data, node_states=DataLoader.state_of(data)) 1073 | 1074 | def get_graph(self): 1075 | return self.bn.graph if self.bn is not None else None 1076 | 1077 | def _to_graph_ds(self): 1078 | ds_node, ds_edge = super()._to_graph_ds() 1079 | if ds_node is None or ds_edge is None: 1080 | return ds_node, ds_edge 1081 | 1082 | bn = self.bn 1083 | 1084 | # nodes names and values 1085 | # nodes = graph.get_nodes(True) 1086 | nodes = ds_node.data['node'] 1087 | # values, predicted_values = self._get_node_values(nodes, ds_table=ds_table, idx=None) 1088 | 1089 | # node states 1090 | state_html = [self.state_to_html(n) for n in nodes] 1091 | ds_node.data['state'] = state_html 1092 | 1093 | # module 1094 | module = [bn.model_.get_node_function_cls(n).__name__ if bn is not None else '' 1095 | for n in nodes] 1096 | ds_node.data['module'] = module 1097 | 1098 | # upstream 1099 | upstream_html, upstream_num = nmap(self.parent_to_html, nodes) 1100 | ds_node.data['upstream'] = upstream_html 1101 | ds_node.data['upstream_n'] = upstream_num 1102 | 1103 | # fitted params 1104 | params = self.get_node_params() 1105 | params_html = [self.params_to_html(params[n]) for n in nodes] 1106 | ds_node.data['params'] = params_html 1107 | 1108 | # intervened 1109 | interventions = self.get_bn_interventions() 1110 | intervention_values = [self.format_value(n, interventions.get(n)) for n in nodes] 1111 | ds_node.data['intervention'] = intervention_values 1112 | ds_node.data['line_width'] = [self._line_width(n in interventions.keys()) for n in nodes] 1113 | ds_edge.data['dash'] = [ 1114 | self._line_dash(e in interventions.keys()) for e in ds_edge.data['end'] 1115 | ] 1116 | 1117 | def on_node_data_change(attr, old_value, new_value): 1118 | if logger.is_debug_enabled(): 1119 | logger.debug(f'on_node_data_change, new_value={new_value}') 1120 | 1121 | # update edge datasource 'dash' when node is intervened 1122 | intervened_new = dict(zip(nodes, new_value['intervention'])) 1123 | ends = ds_edge.data['end'] 1124 | dash_old = ds_edge.data['dash'] 1125 | dash_new = [self._line_dash(intervened_new[e] != '') for e in ends] 1126 | dash_patches = [(i, dn) for i, (do, dn) in enumerate(zip(dash_old, dash_new)) if do != dn] 1127 | if len(dash_patches) > 0: 1128 | ds_edge.patch(patches=dict(dash=dash_patches)) 1129 | 1130 | if self.is_py_callback_enabled: 1131 | ds_node.on_change('data', on_node_data_change) 1132 | 1133 | return ds_node, ds_edge 1134 | 1135 | def get_node_params(self): 1136 | bn = self.bn 1137 | params = defaultdict(OrderedDict) 1138 | if bn is not None and bn._is_fitted: 1139 | for k, v in bn.fitted_params.items(): 1140 | if isinstance(v, torch.Tensor): 1141 | v = v.numpy() 1142 | ks = k.split('__') 1143 | node, param_name = ks[0], '.'.join(ks[1:]) 1144 | params[node][param_name] = v 1145 | return params 1146 | 1147 | def get_bn_interventions(self): 1148 | bn = self.bn 1149 | 1150 | if bn is not None and bn._is_fitted: 1151 | return bn.interventions 1152 | else: 1153 | return {} 1154 | 1155 | @staticmethod 1156 | def params_to_html(params_dict): 1157 | r = ['
    '] 1158 | for k, v in params_dict.items(): 1159 | r.append(f' » {k}:') 1160 | if isinstance(v, np.ndarray) and v.ndim == 2: 1161 | r.append(BNPlotView.decorate_pd_html(pd.DataFrame(v)._repr_html_())) 1162 | elif isinstance(v, np.ndarray) and v.ndim == 1: 1163 | r.append(f'

    ') 1164 | r.append('[ ' + ', '.join(map(lambda vi: f'{vi:.6f}', v.tolist())) + ' ]') 1165 | r.append(f'

    ') 1166 | else: # scalar 1167 | r.append(f'

    ') 1168 | r.append(f'{v:.6f}') 1169 | r.append(f'

    ') 1170 | r.append(' 0: 1176 | html = '

    ' + ', '.join(parents) + '

    ' 1177 | else: 1178 | html = '

    <None>

    ' 1179 | return html, len(parents) 1180 | 1181 | def summary_html(self): 1182 | graph = self.get_graph() 1183 | params = self.get_node_params() 1184 | 1185 | n_params = np.sum([len(nps) for nps in params.values()]) 1186 | n_elements = np.sum([np.prod(v.shape) for nps in params.values() for v in nps.values()]) 1187 | 1188 | html = ['Graph summary'] 1189 | html.extend([ 1190 | '
      ', 1191 | f'
    • Node number: {len(graph.get_nodes())}
    • ', 1192 | f'
    • Edge number: {len(graph.get_edges())}
    • ', 1193 | f'
    • Param number: {int(n_params)}
    • ', 1194 | f'
    • Param elements: {int(n_elements)}
    • ', 1195 | '
    ', 1196 | ]) 1197 | 1198 | return '\n'.join(html) 1199 | 1200 | @staticmethod 1201 | def _line_dash(intervened): 1202 | return 'dashed' if intervened else 'solid' 1203 | 1204 | @staticmethod 1205 | def _line_width(intervened): 1206 | return 2 if intervened else 1 1207 | 1208 | def _get_table_columns(self): 1209 | _fmt = self.default_column_formatter 1210 | nodes = self.ds_node.data['node'] 1211 | df = self.data 1212 | ds_columns = self.ds_table.data.keys() 1213 | table_columns = [] 1214 | for node in nodes: 1215 | if node in ds_columns: 1216 | table_columns.append(M.TableColumn( 1217 | field=node, 1218 | formatter=_fmt(df, node), 1219 | )) 1220 | return table_columns 1221 | 1222 | def get_table_layout(self, **kwargs): 1223 | table_columns = self._get_table_columns() 1224 | table = M.DataTable(columns=table_columns, source=self.ds_table, editable=False, **kwargs) 1225 | title = M.Div(text=f"{self.data_title}") 1226 | 1227 | layout = column(title, table) 1228 | return layout 1229 | 1230 | def _draw_graph_text(self, fig): 1231 | """ 1232 | draw node text 1233 | """ 1234 | ds_node = self.ds_node 1235 | 1236 | js_filter_intervened = M.CustomJSFilter(code=''' 1237 | const intervention = source.data['intervention']; 1238 | const indices = intervention.map(v => v.length>0); 1239 | return indices; 1240 | ''') 1241 | js_filter_not_intervened = M.CustomJSFilter(code=''' 1242 | const intervention = source.data['intervention']; 1243 | const indices = intervention.map( (v,i) => v.length==0); 1244 | return indices; 1245 | ''') 1246 | vw_node_intervened = M.CDSView(filter=js_filter_intervened) 1247 | vw_node_not_intervened = M.CDSView(filter=js_filter_not_intervened) 1248 | 1249 | text_options = dict( 1250 | text_align='center', 1251 | text_baseline='middle', 1252 | ) 1253 | 1254 | # 2 elements for vw_node_not_intervened 1255 | fig.text(text='node', source=ds_node, view=vw_node_not_intervened, 1256 | y_offset=-10, # text_font_style='bold', 1257 | **text_options) 1258 | fig.text(text='value', source=ds_node, view=vw_node_not_intervened, 1259 | y_offset=15, 1260 | **text_options) 1261 | 1262 | # 3 elements for vw_node_intervened 1263 | fig.text(text='node', source=ds_node, view=vw_node_intervened, 1264 | y_offset=-15, # text_font_style='bold', 1265 | **text_options) 1266 | fig.text(text='intervention', source=ds_node, view=vw_node_intervened, 1267 | y_offset=10, 1268 | **text_options) 1269 | fig.text(text='value', source=ds_node, view=vw_node_intervened, 1270 | y_offset=25, text_alpha=0.5, text_font_size='12px', 1271 | **text_options) 1272 | 1273 | return fig 1274 | 1275 | 1276 | class BNTrainingView(BNPlotView): 1277 | def __init__(self, *, bn=None, data=None, causation=None): 1278 | self.causation = causation 1279 | 1280 | super().__init__(bn=bn, data=data) 1281 | 1282 | def get_graph(self): 1283 | return self.causation.graph 1284 | 1285 | def get_side_layout(self, main_layout, *, width): 1286 | graph = self.get_graph() 1287 | if not graph.is_dag: 1288 | msg = M.Div(text='

    Warning


    ' 1289 | 'The graph is not a valid DAG.') 1290 | return column([msg], width=width) 1291 | 1292 | # slider_epochs = M.Slider( 1293 | # title='epochs', 1294 | # value=100, start=1, end=1000, step=1, 1295 | # sizing_mode='stretch_width', 1296 | # ) 1297 | # slider_lr = M.Slider( 1298 | # title='learning_rate', 1299 | # value=0.01, start=0.001, end=0.3, step=0.001, format='.000', 1300 | # sizing_mode='stretch_width', 1301 | # ) 1302 | num_epochs = M.NumericInput( 1303 | title='epochs:', 1304 | mode='int', 1305 | value=100, low=1, high=1000, 1306 | sizing_mode='stretch_width', 1307 | ) 1308 | num_lr = M.NumericInput( 1309 | title='learning_rate:', 1310 | mode='float', 1311 | value=0.01, low=0.001, high=0.5, 1312 | sizing_mode='stretch_width', 1313 | ) 1314 | choice_loss = M.Select( 1315 | title='loss:', 1316 | options=['ELBO', 'CausalEffect_ELBO'], 1317 | sizing_mode='stretch_width', 1318 | ) 1319 | btn_fit = M.Button( 1320 | label='Fit', 1321 | button_type="primary", 1322 | ) 1323 | container_progress = column(M.Div(text=''), sizing_mode='stretch_width') 1324 | widgets = [ 1325 | M.Div(text='Settings:'), 1326 | # slider_epochs, 1327 | # slider_lr, 1328 | num_epochs, 1329 | num_lr, 1330 | choice_loss, 1331 | M.Paragraph(text=''), 1332 | btn_fit, 1333 | container_progress 1334 | ] 1335 | 1336 | def get_fitting_plot(epochs): 1337 | source = M.ColumnDataSource(data=dict(i=[], loss=[])) 1338 | fig = figure(title='loss:', 1339 | toolbar_location=None, 1340 | # tools="pan,tap", 1341 | outline_line_color='lightgray', 1342 | x_range=(0, epochs), 1343 | height=width, 1344 | sizing_mode='stretch_width', 1345 | ) 1346 | container_progress.children = [fig, ] 1347 | 1348 | fig.line(x='i', y='loss', source=source) 1349 | return fig, source 1350 | 1351 | def on_btn_click(): 1352 | fig, ds = get_fitting_plot(num_epochs.value) 1353 | losses = [] 1354 | done = [] 1355 | 1356 | def on_fitting(i, loss): 1357 | # print('>>>', i, loss) 1358 | losses.append([i, loss]) 1359 | 1360 | def update_fitting(): 1361 | new_data = losses.copy() 1362 | losses.clear() 1363 | 1364 | if len(new_data) > 0: 1365 | i, loss = zip(*new_data) 1366 | patches = dict( 1367 | i=i, 1368 | loss=loss, 1369 | ) 1370 | ds.stream(patches) 1371 | # print('stream', len(i)) 1372 | 1373 | if done: 1374 | curdoc().remove_periodic_callback(cb) 1375 | btn_fit.disabled = False 1376 | 1377 | def on_fit_success(bn): 1378 | done.append(1) 1379 | self.bn = bn 1380 | print(bn) 1381 | 1382 | def on_fit_error(e): 1383 | done.append(0) 1384 | print(e) 1385 | 1386 | cb = curdoc().add_periodic_callback(update_fitting, 100) 1387 | print(cb) 1388 | btn_fit.disabled = True 1389 | 1390 | bn = SviBayesianNetwork(graph) 1391 | utils.trun(lambda: bn.fit(self.data, 1392 | epochs=num_epochs.value, 1393 | lr=num_lr.value, 1394 | celoss=choice_loss.value == 'CausalEffect_ELBO', 1395 | inplace=False, 1396 | verbose=on_fitting, 1397 | random_state=123, 1398 | ), 1399 | on_success=on_fit_success, 1400 | on_error=on_fit_error, 1401 | ) 1402 | 1403 | btn_fit.on_click(on_btn_click) 1404 | 1405 | layout = column(widgets, width=width) 1406 | 1407 | return layout 1408 | 1409 | 1410 | class FittedBNPlotView(BNPlotView): 1411 | """ 1412 | PlotView with fitted BayesianNetwork 1413 | """ 1414 | 1415 | def get_layout(self): 1416 | bn = self.bn 1417 | if bn is not None and bn._is_fitted: 1418 | return super().get_layout() 1419 | else: 1420 | msg = M.Div(text='

    Warning


    ' 1421 | 'The model is not available.') 1422 | return column([msg]) 1423 | 1424 | 1425 | class BNPropertyView(FittedBNPlotView): 1426 | def get_side_layout(self, main_layout, *, width): 1427 | summary = M.Div(text=self.summary_html()) 1428 | 1429 | header = M.Div(text='Node name:') 1430 | state = M.Div(text='...') 1431 | module = M.Div(text='Module:') 1432 | 1433 | params_content = M.Div(text='...') # style={'background': '#dddddd'}) 1434 | 1435 | node_properties = column([ 1436 | header, 1437 | state, 1438 | module, 1439 | M.Div(text='''
    Parameters:'''), 1440 | params_content 1441 | ], visible=False) 1442 | 1443 | ds_node = self.ds_node 1444 | js_on_node_change = M.CustomJS( 1445 | args=dict(ds_node=ds_node, 1446 | div_summary=summary, div_props=node_properties, 1447 | div_header=header, div_state=state, 1448 | div_module=module, div_params=params_content), 1449 | code=''' 1450 | const node_selected = cb_obj.indices.length > 0; 1451 | div_props.visible = node_selected; 1452 | div_summary.visible = !node_selected; 1453 | 1454 | if(node_selected){ 1455 | const idx = cb_obj.indices[0]; 1456 | const data = ds_node.data; 1457 | div_header.text = ' Node '+ data['node'][idx] +':'; 1458 | div_module.text = ' Module:

    » '+ data['module'][idx] +'

    '; 1459 | div_state.text = data['state'][idx]; 1460 | div_params.text = data['params'][idx]; 1461 | } 1462 | '''.strip() 1463 | ) 1464 | ds_node.selected.js_on_change('indices', js_on_node_change) 1465 | 1466 | layout = column(summary, node_properties, width=width) 1467 | return layout 1468 | 1469 | 1470 | class BNPredictionView(FittedBNPlotView): 1471 | PPV = '_pred_of_' # prefix of predicted values 1472 | 1473 | def pred_name(self, node): 1474 | return f'{self.PPV}{node}' 1475 | 1476 | def _to_table_ds(self, df): 1477 | # graph = bn.graph 1478 | # nodes = graph.get_nodes() 1479 | # outcome_nodes = filter(lambda n: len(graph.get_parents(n)) > 0, nodes) 1480 | 1481 | df = df.copy() 1482 | for n in df.columns.tolist(): # outcome_nodes: 1483 | df[self.pred_name(n)] = np.nan 1484 | source = M.ColumnDataSource(df) 1485 | return source 1486 | 1487 | def _to_graph_ds(self): 1488 | ds_table = self.ds_table 1489 | ds_node, ds_edge = super()._to_graph_ds() 1490 | 1491 | if ds_node is None or ds_edge is None: 1492 | return ds_node, ds_edge 1493 | 1494 | # append data item: 'predictive' 1495 | nodes = ds_node.data['node'] 1496 | ds_node.data['predictive'] = [''] * len(nodes) 1497 | 1498 | def on_data_table_change(attr, old_value, new_value): 1499 | logger.debug(f'on_data_table_change, new_value={new_value}') 1500 | 1501 | # update node 'predictive' when predicted value is ready 1502 | idx = ds_table.selected.indices[0] if len(ds_table.selected.indices) > 0 else 0 1503 | _, predicted_values_new = self._get_node_values(nodes, ds_table=new_value, idx=idx) 1504 | predicted_values_old = ds_node.data['predictive'] 1505 | value_pairs = zip(predicted_values_old, predicted_values_new) 1506 | data_patches = [(i, vn) for i, (vo, vn) in enumerate(value_pairs) if vo != vn] 1507 | if len(data_patches) > 0: 1508 | ds_node.patch(patches=dict(predictive=data_patches)) 1509 | 1510 | if ds_table is not None and self.is_py_callback_enabled: 1511 | self.ds_table.on_change('data', on_data_table_change) 1512 | 1513 | return ds_node, ds_edge 1514 | 1515 | def _get_table_columns(self): 1516 | _fmt = self.default_column_formatter 1517 | 1518 | nodes = self.ds_node.data['node'] 1519 | df = self.data 1520 | ds_columns = self.ds_table.data.keys() 1521 | table_columns = [] 1522 | for node in nodes: 1523 | if node in ds_columns: 1524 | table_columns.append(M.TableColumn( 1525 | field=node, 1526 | formatter=_fmt(df, node), 1527 | )) 1528 | if self.pred_name(node) in ds_columns: 1529 | table_columns.append(M.TableColumn( 1530 | field=self.pred_name(node), 1531 | title=f'PredOf_{node}', 1532 | formatter=_fmt(df, node, font_style='italic', text_color='blue'), 1533 | visible=False, 1534 | )) 1535 | return table_columns 1536 | 1537 | def _get_node_values(self, nodes, ds_table=None, idx=None): 1538 | assert ds_table is None or isinstance(ds_table, (M.DataSource, dict)) 1539 | 1540 | if ds_table is not None: 1541 | if isinstance(ds_table, M.DataSource): 1542 | table_data = ds_table.data 1543 | else: 1544 | table_data = ds_table 1545 | if idx is None: 1546 | if len(ds_table.selected.indices) > 0: 1547 | idx = ds_table.selected.indices[0] 1548 | else: 1549 | idx = 0 # use 1st line 1550 | values = [self.format_value(n, table_data[n][idx]) for n in nodes] 1551 | pred_names = map(self.pred_name, nodes) 1552 | predicted_values = [ 1553 | self.format_value(n, table_data[pn][idx]) if pn in table_data.keys() else '' 1554 | for n, pn in zip(nodes, pred_names) 1555 | ] 1556 | else: 1557 | values = [''] * len(nodes) 1558 | predicted_values = [''] * len(nodes) 1559 | 1560 | return values, predicted_values 1561 | 1562 | def _draw_graph_text(self, fig): 1563 | """ 1564 | draw node text 1565 | """ 1566 | ds_node = self.ds_node 1567 | 1568 | js_filter_intervened = M.CustomJSFilter(code=''' 1569 | const intervention = source.data['intervention']; 1570 | const indices = intervention.map(v => v.length>0); 1571 | return indices; 1572 | ''') 1573 | js_filter_predicted = M.CustomJSFilter(code=''' 1574 | const intervention = source.data['intervention']; 1575 | const predictive = source.data['predictive']; 1576 | const indices = predictive.map( (v,i) => v.length>0 && intervention[i].length==0); 1577 | return indices; 1578 | ''') 1579 | js_filter_not_intervened = M.CustomJSFilter(code=''' 1580 | const intervention = source.data['intervention']; 1581 | const predictive = source.data['predictive']; 1582 | const indices = intervention.map( (v,i) => v.length==0 && predictive[i].length==0); 1583 | return indices; 1584 | ''') 1585 | vw_node_intervened = M.CDSView(filter=js_filter_intervened) 1586 | vw_node_predicted = M.CDSView(filter=js_filter_predicted) 1587 | # vw_node_intervened_or_predicted = \ 1588 | # M.CDSView(filter=M.UnionFilter(operands=[js_filter_intervened, js_filter_predicted])) 1589 | vw_node_not_intervened = M.CDSView(filter=js_filter_not_intervened) 1590 | 1591 | text_options = dict( 1592 | text_align='center', 1593 | text_baseline='middle', 1594 | ) 1595 | fig.text(text='node', source=ds_node, view=vw_node_not_intervened, 1596 | y_offset=-10, # text_font_style='bold', 1597 | **text_options) 1598 | fig.text(text='value', source=ds_node, view=vw_node_not_intervened, 1599 | y_offset=15, 1600 | **text_options) 1601 | 1602 | fig.text(text='node', source=ds_node, view=vw_node_intervened, 1603 | y_offset=-15, # text_font_style='bold', 1604 | **text_options) 1605 | fig.text(text='intervention', source=ds_node, view=vw_node_intervened, 1606 | y_offset=10, 1607 | **text_options) 1608 | fig.text(text='value', source=ds_node, view=vw_node_intervened, 1609 | y_offset=25, text_alpha=0.5, text_font_size='12px', 1610 | **text_options) 1611 | 1612 | fig.text(text='node', source=ds_node, view=vw_node_predicted, 1613 | y_offset=-15, 1614 | **text_options) 1615 | fig.text(text='predictive', source=ds_node, view=vw_node_predicted, 1616 | y_offset=10, text_font_style='italic', 1617 | **text_options) 1618 | fig.text(text='value', source=ds_node, view=vw_node_predicted, 1619 | y_offset=25, text_alpha=0.5, text_font_size='12px', 1620 | **text_options) 1621 | 1622 | return fig 1623 | 1624 | def get_side_layout(self, main_layout, *, width): 1625 | ds_node, ds_table, test_data = self.ds_node, self.ds_table, self.data 1626 | graph = self.bn.graph 1627 | nodes = ds_node.data['node'] 1628 | interventions = self.bn.interventions 1629 | 1630 | def _predictable(node): 1631 | return len(graph.get_parents(node)) > 0 and node not in interventions.keys() 1632 | 1633 | outcome_nodes = list(filter(_predictable, nodes)) 1634 | 1635 | outcome_title = M.Div(text='Outcome:') 1636 | outcome_widget = M.MultiChoice( 1637 | options=outcome_nodes, max_items=2, 1638 | placeholder='click to select', 1639 | sizing_mode='stretch_width', 1640 | ) 1641 | div_scores = M.Div(text='') 1642 | btn = M.Button(label='predict', button_type="primary", disabled=True) 1643 | layout = column(outcome_title, outcome_widget, btn, div_scores, 1644 | width=width) 1645 | 1646 | def on_node_change(attr, old_value, new_value): 1647 | if layout.disabled: 1648 | return 1649 | 1650 | if len(new_value) > 0: 1651 | idx = new_value[0] 1652 | node = nodes[idx] 1653 | if len(outcome_widget.value) == 0 and node in outcome_widget.options: 1654 | outcome_widget.value = [node] 1655 | 1656 | def on_value_change(attr, old_value, new_value): 1657 | assert isinstance(new_value, (list, tuple)) 1658 | btn.disabled = len(new_value) == 0 1659 | 1660 | def on_btn_click(): 1661 | if len(outcome_widget.value) == 0: 1662 | return 1663 | 1664 | node_names = outcome_widget.value 1665 | logger.info(f'predict with {node_names}') 1666 | 1667 | df = test_data.copy() 1668 | y_trues = {} 1669 | pred_names = [] 1670 | for c in node_names: 1671 | if c in df.columns.tolist(): 1672 | y_trues[c] = df.pop(c) 1673 | # do predicting 1674 | y_preds = self.bn.predict(df, outcome=node_names) 1675 | 1676 | for c in node_names: 1677 | pn = self.pred_name(c) 1678 | pred_names.append(pn) 1679 | 1680 | # update table datasource 1681 | ds_table.data[pn] = y_preds[c].values 1682 | 1683 | # show scoring 1684 | score_html = '\n'.join( 1685 | self.format_dict(calc_score(y_true, y_preds=y_preds[c]), title=c) 1686 | for c, y_true in y_trues.items() 1687 | ) 1688 | div_scores.text = score_html 1689 | 1690 | ppv = self.PPV 1691 | widget_table = main_layout.children[-1].children[-1] 1692 | for c in widget_table.columns: 1693 | if c.field in pred_names: 1694 | c.visible = True 1695 | elif c.field.startswith(ppv): 1696 | c.visible = False 1697 | 1698 | outcome_widget.on_change('value', on_value_change) 1699 | ds_node.selected.on_change('indices', on_node_change) 1700 | btn.on_click(on_btn_click) 1701 | 1702 | return layout 1703 | 1704 | 1705 | class BNInterventionView(FittedBNPlotView): 1706 | @property 1707 | def data_title(self): 1708 | return 'Train Data' 1709 | 1710 | def get_side_layout(self, main_layout, *, width): 1711 | ds_node, ds_table, test_data = self.ds_node, self.ds_table, self.data 1712 | 1713 | bn = self.bn 1714 | graph = bn.graph 1715 | nodes = ds_node.data['node'] 1716 | values = ds_node.data['value'] 1717 | intervenable_nodes = list(filter(lambda n: len(graph.get_children(n)) > 0, nodes)) 1718 | 1719 | intervention_applied = M.Div(text=self.intervention_to_html()) 1720 | node_choice = M.MultiChoice( 1721 | options=intervenable_nodes, max_items=3, 1722 | placeholder='press to select node', 1723 | sizing_mode='stretch_width', 1724 | ) 1725 | intervention_placeholder = M.Div(text='no selected') 1726 | intervention_items = column(intervention_placeholder, sizing_mode='stretch_width') 1727 | 1728 | btn_do = M.Button(label='do', button_type="primary", disabled=True) 1729 | widgets = [ 1730 | M.Div(text='Applied interventions:'), 1731 | intervention_applied, 1732 | M.Div(text='New intervention:'), 1733 | node_choice, 1734 | intervention_items, 1735 | btn_do, 1736 | ] 1737 | 1738 | intervention_children = {} 1739 | node_selected = [] 1740 | intervention_settings = {} 1741 | 1742 | def on_choice_change(attr, old_value, new_value): 1743 | node_values = ds_node.data['value'] 1744 | new_children = [] 1745 | for n in new_value: 1746 | if n not in intervention_children.keys(): 1747 | idx = nodes.index(n) 1748 | widget = self._node_editable_widget(n, node_values[idx], width=width - 10) 1749 | intervention_children[n] = widget 1750 | widget.on_change('value', partial(on_intervention_change, n)) 1751 | intervention_settings[n] = widget.value 1752 | new_children.append(intervention_children[n]) 1753 | if len(new_children) == 0: 1754 | new_children.append(intervention_placeholder) 1755 | intervention_items.children = new_children 1756 | 1757 | node_selected.clear() 1758 | node_selected.extend(new_value) 1759 | btn_do.disabled = len(node_selected) == 0 1760 | 1761 | def on_intervention_change(node, attr, old_value, new_value): 1762 | intervention_settings[node] = new_value 1763 | 1764 | def on_do(): 1765 | assert len(node_selected) > 0 1766 | logger.info(f'do intervention with {node_selected}') 1767 | 1768 | intervention = {n: intervention_settings[n] for n in node_selected} 1769 | bn.do(intervention, data=test_data, inplace=True) 1770 | intervention_applied.text = self.intervention_to_html() 1771 | 1772 | if ds_table.selected.indices: 1773 | idx = ds_table.selected.indices[0] 1774 | else: 1775 | idx = 0 1776 | self.patch_node_ds(row_data=idx, params=True, intervention=True) 1777 | 1778 | node_choice.on_change('value', on_choice_change) 1779 | btn_do.on_click(on_do) 1780 | 1781 | return column(*widgets, width=width) 1782 | 1783 | def patch_node_ds(self, *, row_data=None, params=False, intervention=False): 1784 | ds_table = self.ds_table 1785 | ds_node = self.ds_node 1786 | nodes = ds_node.data['node'] 1787 | 1788 | if row_data is not None: 1789 | assert isinstance(row_data, int) 1790 | idx = row_data 1791 | new_data = [self.format_value(n, ds_table.data[n][idx]) for n in nodes] 1792 | ds_node.data['value'] = new_data 1793 | 1794 | if params: 1795 | if not isinstance(params, dict): 1796 | params = self.get_node_params() 1797 | params_html = [self.params_to_html(params[n]) for n in nodes] 1798 | ds_node.data['params'] = params_html 1799 | 1800 | if intervention: 1801 | interventions = self.bn.interventions 1802 | intervention_new = [self.format_value(n, interventions.get(n)) for n in nodes] 1803 | ds_node.data['intervention'] = intervention_new 1804 | ds_node.data['line_width'] = [self._line_width(n in interventions.keys()) for n in nodes] 1805 | 1806 | def intervention_to_html(self): 1807 | return self.format_dict(self.bn.interventions) 1808 | 1809 | def _node_editable_widget(self, node, value=None, **kwargs): 1810 | state = self.bn.state_[node] 1811 | options = dict(title=f'{node}:') 1812 | 1813 | if isinstance(state, _base.CategoryNodeState): 1814 | options.update(value=str(value), options=list(map(str, state.classes.tolist()))) 1815 | options.update(kwargs) 1816 | widget = M.Select(**options) 1817 | else: 1818 | step = .0001 # state.max - state.max / 20 1819 | fmt = "0[.]0000" 1820 | if value is not None: 1821 | value = float(value) 1822 | options.update(value=value, start=state.min, end=state.max, step=step, format=fmt) 1823 | options.update(kwargs) 1824 | widget = M.Slider(**options) 1825 | 1826 | return widget 1827 | 1828 | 1829 | class BNEffectView(FittedBNPlotView): 1830 | def get_side_layout(self, main_layout, *, width): 1831 | ds_node, ds_table, test_data = self.ds_node, self.ds_table, self.data 1832 | bn = self.bn 1833 | graph = bn.graph 1834 | nodes = ds_node.data['node'] 1835 | treatment_nodes = list(filter(lambda n: len(graph.get_children(n)) > 0, nodes)) 1836 | outcome_nodes = list(filter(lambda n: len(graph.get_parents(n)) > 0, nodes)) 1837 | 1838 | treatment_title = M.Div(text='Treatments:') 1839 | treatment_widget = M.MultiChoice( 1840 | options=treatment_nodes, max_items=2, 1841 | sizing_mode='stretch_width', 1842 | ) 1843 | treatment_placeholder = M.Div(text='no selected') 1844 | treatment_items = column(treatment_placeholder, sizing_mode='stretch_width') 1845 | outcome_title = M.Div(text='
    Outcome:') 1846 | outcome_widget = M.MultiChoice( 1847 | options=outcome_nodes, max_items=2, 1848 | sizing_mode='stretch_width', 1849 | ) 1850 | btn_estimate = M.Button(label='estimate', button_type="primary", disabled=True) 1851 | effect_items = column(M.Div(text='effect'), sizing_mode='stretch_width') 1852 | 1853 | widgets = [treatment_title, treatment_widget, treatment_items, 1854 | outcome_title, outcome_widget, 1855 | btn_estimate, effect_items] 1856 | layout = column(*widgets, width=width) 1857 | 1858 | treatment_children = {} 1859 | treatment_selected = [] 1860 | outcome_selected = [] 1861 | treats = {} 1862 | controls = {} 1863 | 1864 | def on_node_change(attr, old_value, new_value): 1865 | assert isinstance(new_value, (list, tuple)) 1866 | if layout.disabled: 1867 | return 1868 | 1869 | if len(new_value) > 0: 1870 | idx = new_value[0] 1871 | node = ds_node.data['node'][idx] 1872 | if len(treatment_widget.value) == 0 and node in treatment_widget.options: 1873 | treatment_widget.value = [node] 1874 | elif len(outcome_widget.value) == 0 and node in outcome_widget.options: 1875 | outcome_widget.value = [node] 1876 | 1877 | def on_treatment_change(attr, old_value, new_value): 1878 | new_children = [] 1879 | for n in new_value: 1880 | if n not in treatment_children.keys(): 1881 | layout, tc = self._node_treatment(n, width=width - 20) 1882 | treatment_children[n] = (layout, tc) 1883 | tc[0].on_change('value', partial(on_treat_control_change, n, 0)) # control 1884 | tc[1].on_change('value', partial(on_treat_control_change, n, 1)) # treat 1885 | controls[n] = tc[0].value 1886 | treats[n] = tc[1].value 1887 | new_children.append(treatment_children[n][0]) 1888 | if len(new_children) == 0: 1889 | new_children.append(treatment_placeholder) 1890 | treatment_items.children = new_children 1891 | 1892 | treatment_selected.clear() 1893 | treatment_selected.extend(new_value) 1894 | btn_estimate.disabled = len(treatment_selected) == 0 or len(outcome_widget.value) == 0 1895 | 1896 | def on_treat_control_change(node, t_or_c, attr, old_value, new_value): 1897 | if t_or_c == 0: 1898 | controls[node] = new_value 1899 | else: 1900 | treats[node] = new_value 1901 | 1902 | def on_outcome_change(attr, old_value, new_value): 1903 | outcome_selected.clear() 1904 | outcome_selected.extend(new_value) 1905 | btn_estimate.disabled = len(treatment_selected) == 0 or len(outcome_widget.value) == 0 1906 | 1907 | def on_btn_estimate(): 1908 | assert len(treatment_selected) > 0 1909 | logger.info(f'estimate with treatment={treatment_selected}, outcome={outcome_selected}') 1910 | 1911 | treat, control = nmap(lambda n: (treats[n], controls[n]), treatment_selected) 1912 | df_test = test_data.copy() 1913 | for c in outcome_selected: 1914 | if c in df_test.columns.tolist(): 1915 | df_test.pop(c) 1916 | 1917 | progress = [0] 1918 | done = [] 1919 | ite = [] 1920 | 1921 | def on_progress(n): 1922 | progress[0] = n 1923 | 1924 | def update_progress(): 1925 | ds_progress.data['right'] = progress 1926 | 1927 | if done: 1928 | curdoc().remove_periodic_callback(cb) 1929 | btn_estimate.disabled = False 1930 | if ite: 1931 | fig = get_ite_plot(ite[-1], outcome_selected) 1932 | effect_items.children = [fig] 1933 | 1934 | def on_success(effect): 1935 | ite.append(effect) 1936 | done.append(1) 1937 | 1938 | def on_error(e): 1939 | done.append(0) 1940 | 1941 | n_sample = 200 1942 | fig, ds_progress = get_estimating_plot(n_sample) 1943 | effect_items.children = [fig] 1944 | 1945 | cb = curdoc().add_periodic_callback(update_progress, 100) 1946 | print(cb) 1947 | btn_estimate.disabled = True 1948 | 1949 | utils.trun(lambda: bn.estimate(df_test, 1950 | outcome=outcome_selected, 1951 | treatment=treatment_selected, 1952 | treat=treat, 1953 | control=control, 1954 | num_samples=n_sample, 1955 | verbose=on_progress, 1956 | random_state=101, 1957 | ), 1958 | on_success=on_success, 1959 | on_error=on_error, 1960 | ) 1961 | 1962 | def get_estimating_plot(x_limit): 1963 | source = M.ColumnDataSource(data=dict(y=['progress'], right=[0])) 1964 | fig = figure(title='progress:', 1965 | toolbar_location=None, 1966 | outline_line_color='lightgray', 1967 | x_range=(0, x_limit), 1968 | y_range=['progress'], 1969 | height=20, 1970 | sizing_mode='stretch_width', 1971 | ) 1972 | fig.axis.visible = False 1973 | fig.hbar(y='y', right='right', source=source) 1974 | return fig, source 1975 | 1976 | def get_ite_plot(ite, outcome): 1977 | from bokeh.palettes import Category10 1978 | fig = figure(toolbar_location=None, 1979 | # tools="pan,tap", 1980 | outline_line_color='lightgray', 1981 | height=width, 1982 | sizing_mode='stretch_width', 1983 | ) 1984 | ### 1985 | ate = ['ATE:'] 1986 | for i, c in enumerate(outcome): 1987 | state = bn.state_[c] 1988 | if isinstance(state, _base.CategoryNodeState): 1989 | col = f'{c}_{state.classes[-1]}' # plot the ite for last item 1990 | else: 1991 | col = c 1992 | 1993 | x = ite[col].values 1994 | ate.append(f' {col}: {x.mean():.6f}') 1995 | bins = np.linspace(x.min(), x.max(), 100) 1996 | hist, edges = np.histogram(x, density=True, bins=bins) 1997 | fig.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], 1998 | alpha=0.5, legend_label=col, # f'{col}(mean {ate:.6f})', 1999 | fill_color=Category10[3][i], 2000 | ) 2001 | ate_line = M.Span(location=x.mean(), dimension='height', 2002 | line_color='#B7472A', 2003 | # line_color=Category10[3][i], 2004 | line_width=2, 2005 | ) 2006 | fig.add_layout(ate_line) 2007 | fig.title = '\n'.join(ate) 2008 | fig.legend.visible = len(outcome) > 1 2009 | return fig 2010 | 2011 | ds_node.selected.on_change('indices', on_node_change) 2012 | treatment_widget.on_change('value', on_treatment_change) 2013 | outcome_widget.on_change('value', on_outcome_change) 2014 | btn_estimate.on_click(on_btn_estimate) 2015 | 2016 | return layout 2017 | 2018 | def _node_treatment(self, node, width=None, **kwargs): 2019 | state = self.bn.state_[node] 2020 | options = kwargs.copy() 2021 | options.update(sizing_mode='stretch_width') 2022 | if isinstance(state, _base.CategoryNodeState): 2023 | classes = list(map(str, state.classes.tolist())) 2024 | options.update(options=classes) 2025 | widget_c = M.Select(title='control:', value=classes[0], **options) 2026 | widget_t = M.Select(title='treat:', value=classes[-1], **options) 2027 | else: 2028 | fmt = "0[.]000" 2029 | # step = .001 # state.max - state.max / 20 2030 | # options.update(start=state.min, end=state.max, step=step, format=fmt) 2031 | # widget_c = M.Slider(title='control', value=state.min, **options) 2032 | # widget_t = M.Slider(title='treat', value=state.max, **options) 2033 | options.update(mode='float', low=state.min, high=state.max, format=fmt) 2034 | widget_c = M.NumericInput(title='control:', value=state.min, **options) 2035 | widget_t = M.NumericInput(title='control:', value=state.max, **options) 2036 | 2037 | layout = column( 2038 | M.Div(text=f'{node}:'), 2039 | row(widget_c, widget_t, sizing_mode='stretch_width'), 2040 | width=width, 2041 | sizing_mode='stretch_width', 2042 | ) 2043 | return layout, (widget_c, widget_t) 2044 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.5 2 | pandas>=0.25.3 3 | bokeh==3.0.* 4 | tornado 5 | ylearn 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataCanvasIO/CausalLab/be5957af99c97e7b31eb4a8dfa045d437c1fb05e/setup.cfg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | 5 | from setuptools import find_packages 6 | from setuptools import setup 7 | 8 | home_url = 'https://github.com/DataCanvasIO/CausalLab' 9 | 10 | 11 | def read_requirements(file_path='requirements.txt'): 12 | import os 13 | 14 | if not os.path.exists(file_path): 15 | return [] 16 | 17 | with open(file_path, 'r')as f: 18 | lines = f.readlines() 19 | 20 | lines = [x.strip('\n').strip(' ') for x in lines] 21 | lines = list(filter(lambda x: len(x) > 0 and not x.startswith('#'), lines)) 22 | 23 | return lines 24 | 25 | 26 | def read_extra_requirements(): 27 | import glob 28 | import re 29 | 30 | extra = {} 31 | 32 | for file_name in glob.glob('requirements-*.txt'): 33 | key = re.search('requirements-(.+).txt', file_name).group(1) 34 | req = read_requirements(file_name) 35 | if req: 36 | extra[key] = req 37 | 38 | if extra and 'all' not in extra.keys(): 39 | extra['all'] = sorted({v for req in extra.values() for v in req}) 40 | 41 | return extra 42 | 43 | 44 | def read_description(file_path='README.md'): 45 | with open(file_path, encoding='utf-8') as f: 46 | desc = f.read() 47 | return desc 48 | 49 | 50 | import causallab 51 | 52 | version = causallab.__version__ 53 | 54 | MIN_PYTHON_VERSION = '>=3.8' 55 | 56 | # long_description = open('README.md', encoding='utf-8').read() 57 | long_description = read_description() 58 | 59 | requires = read_requirements() 60 | extras_require = read_extra_requirements() 61 | 62 | setup( 63 | name='causallab', 64 | version=version, 65 | description='An Interactive Causal Analysis Tool', 66 | long_description=long_description, 67 | long_description_content_type="text/markdown", 68 | url=home_url, 69 | author='DataCanvas Community', 70 | author_email='yangjian@zetyun.com', 71 | license='Apache License 2.0', 72 | install_requires=requires, 73 | python_requires=MIN_PYTHON_VERSION, 74 | extras_require=extras_require, 75 | classifiers=[ 76 | 'Operating System :: OS Independent', 77 | 'Intended Audience :: Developers', 78 | 'Intended Audience :: Education', 79 | 'Intended Audience :: Science/Research', 80 | 'Programming Language :: Python', 81 | 'Programming Language :: Python :: 3.8', 82 | 'Programming Language :: Python :: 3.9', 83 | 'Programming Language :: Python :: 3.10', 84 | 'Topic :: Scientific/Engineering', 85 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 86 | 'Topic :: Software Development', 87 | 'Topic :: Software Development :: Libraries', 88 | 'Topic :: Software Development :: Libraries :: Python Modules', 89 | ], 90 | packages=find_packages(exclude=('docs', 'tests*')), 91 | package_data={ 92 | }, 93 | entry_points={ 94 | 'console_scripts': [ 95 | 'causal_lab = causallab.serve:main', 96 | ] 97 | }, 98 | zip_safe=False, 99 | include_package_data=True, 100 | ) 101 | --------------------------------------------------------------------------------