├── .gitignore ├── LICENSE ├── README.md ├── causal_discovery_algs ├── __init__.py ├── brai.py ├── fci.py ├── icd.py ├── pc.py ├── rai.py └── ts_icd.py ├── causal_discovery_utils ├── __init__.py ├── cond_indep_tests.py ├── constraint_based.py ├── data_utils.py ├── performance_measures.py └── stat_utils.py ├── causal_reasoning ├── __init__.py └── cleann_explainer.py ├── example_data ├── Alarm1_data │ ├── Alarm1_graph.txt │ └── Alarm1_s500_v1.txt └── Alarm1_testdata │ └── Alarm1_testdata_s500_v1.txt ├── experiment_utils ├── __init__.py ├── explanation.py ├── synthetic_graphs.py └── threshold_select_ci_test.py ├── graphical_models ├── __init__.py ├── arrow_head_types.py ├── basic_equivalance_class_graph.py ├── basic_graph.py ├── dag.py ├── partial_ancestral_graph.py ├── partially_dag.py ├── possible_dsep_tree.py └── undirected_graph.py ├── imgs ├── ExampleAnimationICD.gif ├── ExamplePAG.png └── FrameworkBlockDiagram.png ├── notebooks ├── causal_discovery_from_time_series.ipynb ├── causal_discovery_under_causal_sufficiency.ipynb ├── causal_discovery_with_a_perfect_oracle.ipynb ├── causal_discovery_with_latent_confounders.ipynb ├── causal_reasoning_with_CLEANN_explanations.ipynb ├── imgs │ └── TimeSeriesMeasurmentSites.png └── partial_ancestral_graphs.ipynb ├── plot_utils ├── __init__.py ├── draw_graph.py └── graph_layout.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg └── unit_tests └── graphical_models ├── test_basic_equivalance_class_graph.py ├── test_dag.py └── test_partial_ancestral_graph.py /.gitignore: -------------------------------------------------------------------------------- 1 | # pycharm filesgit 2 | .idea/ 3 | 4 | # vs-code files 5 | .vscode/ 6 | 7 | # jupyter notebook caches 8 | .ipynb_checkpoints/ 9 | 10 | # mac stores 11 | .DS_Store 12 | 13 | # Created by https://www.toptal.com/developers/gitignore/api/python 14 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 15 | 16 | ### Python ### 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causality Lab 2 | 3 | This repository contains research code of novel causal discovery algorithms developed at Intel Labs, as well as other common algorithms, 4 | and classes for developing and examining new algorithms for causal structure learning. 5 | 6 | **Update (December 2023)**: [CLEANN](https://arxiv.org/abs/2310.20307 "Rohekar Raanan, Gurwicz Yaniv, and Nisimov Shami. NeurIPS 2023") is novel algorithm presented at [NeurIPS 2023](https://neurips.cc/ "Advances in Neural Information Processing Systems"). It generates causal explanations for the outcomes of existing pre-trained Transformer neural networks. At its core, it is based on the novel causal interpretation of self-attention presented in the paper, and executes attention-based causal-discovery (ABCD). 7 | [This notebook](notebooks/causal_reasoning_with_CLEANN_explanations.ipynb) demonstrates, using a simple example, how to use CLEANN. 8 | 9 | ## Table of Contents 10 | 11 | - [Algorithms and Baselines](#algorithms-and-baselines) 12 | - [Developing and Examining Algorithms](#developing-and-examining-algorithms) 13 | - [Installation](#installation) 14 | - [Usage Example](#usage-example) 15 | - [References](#references) 16 | 17 | 18 | ## Algorithms and Baselines 19 | 20 | Included algorithms learn causal structures from observational data, and reason using these learned causal graphs. 21 | There are three families of algorithms: 22 | 23 | 1. **Causal discovery under causal sufficiency and bayesian network structure learning** 24 | 1. PC algorithm (Spirtes et al., 2000) 25 | 2. RAI algorithm, Recursive Autonomy Identification ([Yehezkel and Lerner, 2009](https://www.jmlr.org/papers/volume10/yehezkel09a/yehezkel09a.pdf)). This algorithm is used for learning the structure in the B2N algorithm ([Rohekar et al., NeurIPS 2018b](https://arxiv.org/pdf/1806.09141.pdf)) 26 | 3. B-RAI algorithm, Bootstrap/Bayesian-RAI for uncertainty estimation ([Rohekar et al., NeurIPS 2018a](https://arxiv.org/abs/1809.04828)). This algorithm is used for learning the structure of BRAINet ([Rohekar et al., NeurIPS 2019](https://arxiv.org/abs/1905.13195)) 27 | 28 | 2. **Causal discovery in the presence of latent confounders and selection bias** 29 | 1. FCI algorithm, Fast Causal Inference (Spirtes et at., 2000) 30 | 2. ICD algorithm, Iterative Causal Discovery ([Rohekar et al., NeurIPS 2021](https://arxiv.org/abs/2111.04095)) 31 | 3. TS-ICD algorithm, ICD for time-series data ([Rohekar et al., ICML 2023](https://arxiv.org/abs/2306.00624)) 32 | 3. **Causal reasoning** 33 | 1. CLEANN algorithm, Causal Explanation from Attention in Neural Networks ([Rohekar et al., 2023](https://arxiv.org/abs/2310.20307 "Rohekar Raanan, Gurwicz Yaniv, and Nisimov Shami. NeurIPS 2023"), [Nisimov et al., 2022](https://arxiv.org/abs/2210.10621 "Nisimov Shami, Rohekar Raanan, Gurwicz Yaniv, Koren Guy, and Novik Gal. CONSEQUENCES, RecSys 2022")). 34 | 35 | ![Example ICD](imgs/ExampleAnimationICD.gif) 36 | 37 | 38 | ## Developing and Examining Algorithms 39 | 40 | This repository includes several classes and methods for implementing new algorithms and testing them. These can be grouped into three categories: 41 | 42 | 1. **Simulation**: 43 | 1. [Random DAG sampling](experiment_utils/synthetic_graphs.py) 44 | 2. [Observational data sampling](experiment_utils/synthetic_graphs.py) 45 | 2. **Causal structure learning**: 46 | 1. [Classes for handling graphical models](graphical_models) (e.g., methods for graph traversal and calculating graph properties). Supported graph types: 47 | 1. Directed acyclic graph (DAG): commonly used for representing causal DAGs 48 | 2. Partially directed graph (PDAG/CPDAG): a Markov equivalence class of DAGs under causal sufficiency 49 | 3. Undirected graph (UG) usually used for representing adjacency in the graph (skeleton) 50 | 4. Ancestral graph (PAG/MAG): a MAG is an equivalence class of DAGs, and a PAG is an equivalence class of MAGs (Richardson and Spirtes, 2002). 51 | 3. [Statistical tests (CI tests)](causal_discovery_utils/cond_indep_tests.py) operating on data and a perfect CI oracle (see [causal discovery with a perfect oracle](notebooks/causal_discovery_with_a_perfect_oracle.ipynb)) 52 | 3. **Performance evaluations**: 53 | 1. [Graph structural accuracy](causal_discovery_utils/performance_measures.py) 54 | 1. Skeleton accuracy: FNR, FPR, structural Hamming distance 55 | 2. Orientation accuracy 56 | 3. Overall graph accuracy: BDeu score 57 | 2. [Computational cost](causal_discovery_utils/cond_indep_tests.py): Counters for CI tests (internal caching ensures counting once each a unique test) 58 | 3. [Plots for DAGs and ancestral graphs](plot_utils). 59 | 60 | A new algorithm can be developed by inheriting classes of existing algorithms (e.g., B-RAI inherits RAI) or by creating a new class. 61 | The only method required to be implemented is `learn_structure()`. For conditional independence testing, 62 | we implemented conditional mutual information, partial correlation statistical test, and d-separation (perfect oracle). 63 | Additionally, a Bayesian score (BDeu) can be used for evaluating the posterior probability of DAGs given data. 64 | 65 | ![Block Diagram](imgs/FrameworkBlockDiagram.png) 66 | 67 | 68 | ## Installation 69 | 70 | This code has been tested on Ubuntu 18.04 LTS and macOS Catalina, with Python 3.5. 71 | We recommend installing and running it in a virtualenv. 72 | 73 | ``` 74 | sudo -E pip3 install virtualenv 75 | virtualenv -p python3 causal_env 76 | . causal_env/bin/activate 77 | 78 | git clone https://github.com/IntelLabs/causality-lab.git 79 | cd causality-lab 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | ## Usage Example 84 | 85 | ### Learning a Casual Structure from Observed Data 86 | 87 | All causal structure learning algorithms are classes with a `learn_structure()` method that learns the causal graph. 88 | The learned causal graph is a public class member, simply called `graph`, which is an instance of a graph class. 89 | The structure learning algorithms does not have direct access to the data, instead they call a statistical test which accesses the data. 90 | 91 | Let's look at the following example: causal structure learning with ICD using a given dataset. 92 | 93 | ```angular2html 94 | par_corr_test = CondIndepParCorr(dataset, threshold=0.01) # CI test with the given significance level 95 | icd = LearnStructICD(nodes_set, par_corr_test) # instantiate an ICD learner 96 | icd.learn_structure() # learn the causal graph 97 | ``` 98 | 99 | For complete examples, see [causal discovery with latent confounders](notebooks/causal_discovery_with_latent_confounders.ipynb) and [causal discovery under causal sufficiency](notebooks/causal_discovery_under_causal_sufficiency.ipynb) notebooks. 100 | The learned structures can then be plotted - see a complete example for creating a PAG, calculating its properties, and plotting it in the [partial ancestral graphs](notebooks/partial_ancestral_graphs.ipynb) notebook. 101 | 102 | ![PAG plot example](imgs/ExamplePAG.png) 103 | 104 | 105 | 106 | ## References 107 | 108 | * Rohekar, Raanan, Yaniv Gurwicz, and Shami Nisimov. "Causal Interpretation of Self-Attention in Pre-Trained Transformers". Advances in Neural Information Processing Systems (NeurIPS) 36, 2023. 109 | * Rohekar, Raanan Y., Shami Nisimov, Yaniv Gurwicz, and Gal Novik. "From Temporal to Contemporaneous Iterative Causal Discovery in the Presence of Latent Confounders" International Conference on Machine Learning (ICML), 2023. 110 | * Nisimov, Shami, Raanan Y. Rohekar, Yaniv Gurwicz, Guy Koren, and Gal Novik. "CLEAR: Causal Explanations from Attention in Neural Recommenders". Causality, Counterfactuals and Sequential Decision-Making for Recommender Systems (CONSEQUENCES) workshop at RecSys, 2022. 111 | * Rohekar, Raanan Y., Shami Nisimov, Yaniv Gurwicz, and Gal Novik. "Iterative Causal Discovery in the Possible Presence of Latent Confounders and Selection Bias" Advances in Neural Information Processing Systems (NeurIPS) 34, 2021. 112 | * Rohekar, Raanan Y., Yaniv Gurwicz, Shami Nisimov, and Gal Novik. "Modeling Uncertainty by Learning a Hierarchy of Deep Neural Connections". Advances in Neural Information Processing Systems (NeurIPS) 32: 4244-4254, 2019. 113 | * Rohekar, Raanan Y., Yaniv Gurwicz, Shami Nisimov, Guy Koren, and Gal Novik. "Bayesian Structure Learning by Recursive Bootstrap." Advances in Neural Information Processing Systems (NeurIPS) 31: 10525-10535, 2018a. 114 | * Rohekar, Raanan Y., Shami Nisimov, Yaniv Gurwicz, Guy Koren, and Gal Novik. "Constructing Deep Neural Networks by Bayesian Network Structure Learning". Advances in Neural Information Processing Systems (NeurIPS) 31: 3047-3058, 2018b. 115 | * Yehezkel, Raanan, and Boaz Lerner. "Bayesian Network Structure Learning by Recursive Autonomy Identification". Journal of Machine Learning Research (JMLR) 10, no. 7, 2009 116 | * Richardson, Thomas, and Peter Spirtes. "Ancestral graph Markov models". The Annals of Statistics, 30 (4): 962–1030, 2002. 117 | * Spirtes Peter, Clark N. Glymour, Richard Scheines, and David Heckerman. "Causation, prediction, and search". MIT press, 2000. 118 | -------------------------------------------------------------------------------- /causal_discovery_algs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .fci import LearnStructFCI 3 | from .icd import LearnStructICD 4 | from .rai import LearnStructRAI 5 | from .pc import LearnStructPC 6 | from .brai import LearnStructBRAI 7 | -------------------------------------------------------------------------------- /causal_discovery_algs/brai.py: -------------------------------------------------------------------------------- 1 | from causal_discovery_algs.rai import LearnStructRAI 2 | from graphical_models import PDAG, DAG 3 | import numpy as np 4 | from itertools import combinations 5 | import graphical_models.arrow_head_types as Mark # incoming arrow head-types 6 | from causal_discovery_utils.performance_measures import score_bdeu 7 | from causal_discovery_utils.data_utils import get_var_size 8 | from causal_discovery_utils.constraint_based import SeparationSet 9 | 10 | 11 | class CookieNode: 12 | """ 13 | Cookie node: holds one possible graph learned with CI-test up to a specific order. 14 | Consists of a MultiHypNode (multiple-hypothesis-node) for each slice of the cookie 15 | (The children of a cookie-node are multi.-hyp. nodes and the children of a multi.-hyp. node are cookie nodes) 16 | """ 17 | def __init__(self, multi_hyp_ancestors=None, multi_hyp_descendant=None, cpdag=None, slices=None, extra_data=None): 18 | self.multi_hyp_ancestors = [] # a list of multi.-hyp nodes, each node corresponds to a sub-set of nodes 19 | self.multi_hyp_descendant = None # a multi.-hyp node corresponding to the descendant sub-set of nodes 20 | self.cpdag = None 21 | self.slices = None 22 | self.extra_data = extra_data 23 | 24 | if multi_hyp_ancestors is not None: 25 | for ancestor in multi_hyp_ancestors: 26 | self.add_multi_hyp_ancestor(ancestor) 27 | if multi_hyp_descendant is not None: 28 | self.add_multi_hyp_descendant(multi_hyp_descendant) 29 | if cpdag is not None: 30 | self.set_cookie_graph(cpdag) 31 | if slices is not None: 32 | self.set_cookie_slices(slices) 33 | 34 | def add_multi_hyp_ancestor(self, node): # add children corresponding to ancestor-groups 35 | assert isinstance(node, MultiHypNode) 36 | self.multi_hyp_ancestors.append(node) 37 | 38 | def add_multi_hyp_descendant(self, node): # add a child corresponding to the descendant-group 39 | assert isinstance(node, MultiHypNode) 40 | self.multi_hyp_descendant = node 41 | 42 | def set_cookie_slices(self, slices): 43 | assert isinstance(slices, dict) 44 | self.slices = slices 45 | 46 | def set_cookie_graph(self, cpdag): 47 | assert isinstance(cpdag, PDAG) 48 | self.cpdag = cpdag # a partially refined pdag 49 | 50 | 51 | class MultiHypNode: 52 | """ 53 | Multiple-hypothesis node: contains multiple cookie-nodes each describing a set of nodes 54 | The children of a multi.-hyp. node are cookie-nodes, and the children of a cookie node are multi.-hyp. nodes 55 | When sampling, only one of the children should be selected 56 | """ 57 | def __init__(self, cookie_children=None, multi_hyp_data=None, extra_data=None): 58 | 59 | self.children = [] # list of cookie nodes 60 | self.multi_hyp_data = None 61 | self.extra_data = extra_data 62 | self.selected_cookie_idx = None # will be set each call to sample graph 63 | 64 | if cookie_children is not None: 65 | self.is_leaf = False 66 | for cookie in cookie_children: 67 | self.addCookie(cookie) 68 | else: # a leaf node 69 | self.is_leaf = True 70 | 71 | if multi_hyp_data is not None: 72 | self.set_data(multi_hyp_data) 73 | 74 | def addCookie(self, node): 75 | assert isinstance(node, CookieNode) 76 | self.children.append(node) 77 | self.is_leaf = False 78 | 79 | def set_data(self, multi_hyp_data): 80 | assert isinstance(multi_hyp_data, dict) 81 | assert isinstance(multi_hyp_data['endogenous'], set) 82 | assert isinstance(multi_hyp_data['exogenous'], set) 83 | assert isinstance(multi_hyp_data['ci-order'], int) 84 | if multi_hyp_data['score'] is not None: 85 | assert isinstance(multi_hyp_data['score'], float) 86 | if self.is_leaf: 87 | assert isinstance(multi_hyp_data['graph'], PDAG) 88 | assert isinstance(multi_hyp_data['sepset'], SeparationSet) 89 | self.multi_hyp_data = multi_hyp_data 90 | 91 | @property 92 | def en_nodes(self): 93 | return self.multi_hyp_data['endogenous'] 94 | 95 | 96 | class LearnStructBRAI(LearnStructRAI): 97 | def __init__(self, nodes_set, ci_test, num_of_hyp, scoring_data=None, node_size=None, scoring_function=None): 98 | super().__init__(nodes_set, ci_test) 99 | self.graph.create_complete_graph(nodes_set) # Create a fully connected graph 100 | assert (num_of_hyp > 0) 101 | self.num_of_hyp = num_of_hyp 102 | 103 | # get data from the CI test 104 | self.data = ci_test.data.copy() 105 | self.num_records = ci_test.num_records 106 | self.num_nodes = ci_test.num_vars 107 | 108 | self.graph_generating_tree = MultiHypNode() 109 | 110 | self.scoring_data = None 111 | self.node_size = None 112 | 113 | if scoring_function is None: 114 | # scoring function with arguments: fun(dag, scoring_data, node_sizes, nodes), returns a log-probability 115 | self.scoring_function = score_bdeu # default scoring function is BDeu (assumed discrete variables) 116 | 117 | if scoring_data is None: 118 | self.is_scored = False 119 | else: 120 | self.is_scored = True # turn off if scoring the graph is not required 121 | self.scoring_data = scoring_data 122 | if node_size is not None: 123 | self.node_size = node_size 124 | elif scoring_data is not None: 125 | self.node_size = get_var_size(scoring_data) 126 | else: 127 | self.node_size = None 128 | 129 | def learn_structure(self): 130 | """ 131 | Main structure learning function. 132 | :return: The root of the learned GGT (graph generating tree with a MultiHypNode as root and leaves) 133 | """ 134 | 135 | # initialize for the 1st recursive call 136 | en_nodes = self.graph.nodes_set 137 | ex_nodes = set() 138 | 139 | ggt_root = self.learn_recursively(en_nodes=en_nodes, ex_nodes=ex_nodes, order=0) 140 | self.graph_generating_tree = ggt_root # root of GGT, a node of type MultiHypNode 141 | self.sample_cpdag(temperature=0) # get the MAP CPDAG 142 | 143 | def learn_recursively(self, en_nodes, ex_nodes, order): 144 | """ 145 | The folowing steps are preformed: 146 | 1. For num_of_hyp: 147 | a. Create a bootstrap sample of the training data 148 | b. Refine and orient using CI tests of a specific condition set size 149 | c. Identify ancestors and descendant groups 150 | d. Call recursively for the ancestor and descendant groups with CI order+1 151 | 2. Update GGT (tree structure) 152 | :param en_nodes: Endogenous nodes 153 | :param ex_nodes: Exogenous nodes 154 | :param order: CI test order, i.e., condition set size 155 | :return: A MultiHypNode that is ONE of the following: 156 | 1) a parent of "num_of_hyp" CookieNodes 157 | 2) a leaf node with values: CPDAG sub-graph and Bayesian score 158 | """ 159 | 160 | if self._exit_cond(en_nodes, order): 161 | # reached a leaf 162 | if self.is_scored: 163 | leaf_pdag = PDAG(self.graph.nodes_set) 164 | leaf_pdag.add_edges_from(self.graph, en_nodes, ex_nodes) 165 | leaf_dag = DAG(leaf_pdag.nodes_set) 166 | is_dag = leaf_pdag.convert_to_dag(leaf_dag) 167 | if is_dag: 168 | leaf_score = self.scoring_function(leaf_dag, self.scoring_data, self.node_size, en_nodes) 169 | else: 170 | leaf_score = -float('inf') # CPDAG does not admit any DAG extension 171 | else: 172 | leaf_score = 0 173 | 174 | multi_hyp_data = { 175 | 'endogenous': en_nodes, 176 | 'exogenous': ex_nodes, 177 | 'ci-order': order, 178 | 'graph': self.graph.copy(), 179 | 'sepset': self.sepset.copy(en_nodes | ex_nodes), 180 | 'score': leaf_score 181 | } 182 | leaf = MultiHypNode(cookie_children=None, multi_hyp_data=multi_hyp_data) 183 | return leaf 184 | 185 | cpdag_initial = self.graph # remember the initial graph because each hypothesis overwrites self.graph 186 | sepset_initial = self.sepset # remember the initial sepset because each hypothesis overwrites self.graph 187 | 188 | # generate multiple hypotheses for the further refinement of the graph 189 | # Recursive calls for each autonomous sub-graph 190 | cookie_nodes_list = [] 191 | for hyp_id in range(self.num_of_hyp): 192 | self.graph = cpdag_initial.copy() # initialize each hypothesis with the initial graph (erase previous hyp.) 193 | self.sepset = sepset_initial.copy() # initialize each hypothesis with the initial separation set 194 | cookie_node = CookieNode() 195 | 196 | # Create a bootstrap sample of the original training data 197 | idx_sampled = np.random.choice(self.num_records, self.num_records, replace=True) # sample data records 198 | self.ci_test.data = self.data[idx_sampled, :] # set a bootstrap sample to be used for CI testing 199 | 200 | # refine using self.ci_test with condition set size equal to "order", and orient 201 | self._refine_and_orient(en_nodes=en_nodes, ex_nodes=ex_nodes, order=order) 202 | 203 | # split into ancestors/descendant autonomous sub-graphs (cookie slices: descendant_set, and ancestors_sets) 204 | descendant_set, ancestors_sets, a_nodes = self._split_ancestors_descendant(en_nodes=en_nodes) 205 | 206 | # record cookie structure 207 | cookie_slices = { 208 | 'descendant-slice': descendant_set, 209 | 'ancestor-slices': ancestors_sets 210 | } 211 | cookie_node.set_cookie_slices(cookie_slices) 212 | cookie_node.set_cookie_graph(self.graph.copy()) 213 | 214 | # learn for ancestors sub-sets 215 | for ancestor in ancestors_sets: 216 | multi_hyp_ancestor = self.learn_recursively(en_nodes=ancestor, ex_nodes=ex_nodes, 217 | order=order + 1) # recursive call (ancestor) 218 | cookie_node.add_multi_hyp_ancestor(multi_hyp_ancestor) 219 | 220 | # learn for descendant sub-set 221 | multi_hyp_descendant = self.learn_recursively(en_nodes=descendant_set, ex_nodes=a_nodes | ex_nodes, 222 | order=order + 1) # recursive call (descendant) 223 | cookie_node.add_multi_hyp_descendant(multi_hyp_descendant) 224 | 225 | cookie_nodes_list.append(cookie_node) 226 | 227 | multi_hyp_data = {'endogenous': en_nodes, 'exogenous': ex_nodes, 'ci-order': order, 'score': None} 228 | multi_hyp_node = MultiHypNode(cookie_children=cookie_nodes_list, multi_hyp_data=multi_hyp_data) 229 | return multi_hyp_node 230 | 231 | def sample_cpdag(self, temperature=1): 232 | self.graph.create_empty_graph() 233 | self.sepset.erase() 234 | if self.is_scored: 235 | score = self._sample_cpdag_recursive(self.graph_generating_tree, self.graph, self.sepset, temperature) 236 | else: 237 | self._sample_cpdag_recursive_no_score(self.graph_generating_tree) 238 | score = None 239 | 240 | self._re_orient_skeleton() 241 | return score 242 | 243 | def _sample_cpdag_recursive(self, ggt_root, pdag: PDAG, sepset, temperature): 244 | """ 245 | Update input PDAG and return score 246 | :param ggt_root: root of GGT (root of the sub-tree) 247 | :param pdag: PDAG object to be updated 248 | :param sepset: SeparationSet object to be updated 249 | :return: (log-)score of the updated portion of the PDAG 250 | """ 251 | en_nodes = ggt_root.multi_hyp_data['endogenous'] 252 | ex_nodes = ggt_root.multi_hyp_data['exogenous'] 253 | 254 | if ggt_root.is_leaf: # update PDAG and return score 255 | pdag.add_edges_from(source_pdag=ggt_root.multi_hyp_data['graph'], 256 | en_nodes=en_nodes, ex_nodes=ex_nodes) # get sub-graph stored in the leaf 257 | sepset.copy_from(source_sepset=ggt_root.multi_hyp_data['sepset'], 258 | nodes=en_nodes | ex_nodes) # get separation sets stored in the leaf 259 | return ggt_root.multi_hyp_data['score'] 260 | 261 | # we have several cookies. We need to get their scores and then return one w.r.t. the scores 262 | 263 | cookie_pdags = [] 264 | cookie_sepsets = [] 265 | cookie_scores = [] 266 | for cookie_node in ggt_root.children: # loop through cookies 267 | pdag_i = PDAG(self.graph.nodes_set) # create a new PDAG to hold the sampled PDAG of the cookie 268 | sepset_i = SeparationSet(self.sepset.nodes_set) 269 | score_i = 0 270 | 271 | # sample a sub-graph for each of the ancestor sub-sets 272 | for ancestor_multi_hyp_node in cookie_node.multi_hyp_ancestors: 273 | score_i += self._sample_cpdag_recursive( # update pdag_i 274 | ggt_root=ancestor_multi_hyp_node, pdag=pdag_i, sepset=sepset_i, temperature=temperature) 275 | 276 | # sample a sub-graph for the descendant sub-set 277 | descendant_multi_hyp_node = cookie_node.multi_hyp_descendant # for consistent naming 278 | score_i += self._sample_cpdag_recursive(# updade sepset_i 279 | ggt_root=descendant_multi_hyp_node, pdag=pdag_i, sepset=sepset_i, temperature=temperature) 280 | 281 | cookie_pdags.append(pdag_i) 282 | cookie_sepsets.append(sepset_i) 283 | cookie_scores.append(score_i) 284 | 285 | # sample a cookie 286 | if temperature > 0: # temperature in Boltzmann distribution 287 | max_score = np.max(cookie_scores) 288 | 289 | if max_score != -float('inf'): 290 | log_scores = cookie_scores - max_score 291 | scores = np.exp(log_scores/temperature) 292 | scores /= scores.sum() 293 | sampled_cookie_idx = np.random.choice(len(scores), p=scores) # weighted sampling 294 | else: 295 | sampled_cookie_idx = np.random.choice(len(cookie_scores)) # uniform sampling 296 | elif temperature == 0: # get highest scoring CPDAG 297 | sampled_cookie_idx = np.argmax(cookie_scores) # arg-max: select the maximal score 298 | else: 299 | sampled_cookie_idx = None # this case should not happen in normal use 300 | assert (temperature >= 0) 301 | 302 | ggt_root.selected_cookie_idx = sampled_cookie_idx 303 | sampled_pdag = cookie_pdags[sampled_cookie_idx] 304 | sampled_sepset = cookie_sepsets[sampled_cookie_idx] 305 | sampled_score = cookie_scores[sampled_cookie_idx] 306 | ggt_root.multi_hyp_data['score'] = sampled_score # score of the sampled path starting from the sampled cookie 307 | 308 | # update PDAG and sepset, and return score (trash cookie PDAGs and sepsets) 309 | pdag.add_edges_from(source_pdag=sampled_pdag, en_nodes=en_nodes, ex_nodes=ex_nodes) 310 | sepset.copy_from(source_sepset=sampled_sepset, nodes=en_nodes | ex_nodes) 311 | return sampled_score 312 | 313 | def _sample_cpdag_recursive_no_score(self, ggt_root): 314 | if ggt_root.is_leaf: 315 | cpdag_leaf = ggt_root.multi_hyp_data['graph'] 316 | sepset_leaf = ggt_root.multi_hyp_data['sepset'] 317 | en_nodes_leaf = ggt_root.multi_hyp_data['endogenous'] 318 | ex_nodes_leaf = ggt_root.multi_hyp_data['exogenous'] 319 | self.graph.add_edges_from(source_pdag=cpdag_leaf, 320 | en_nodes=en_nodes_leaf, ex_nodes=ex_nodes_leaf) # get sub-graph stored in the leaf 321 | self.sepset.copy_from(source_sepset=sepset_leaf, nodes=en_nodes_leaf | ex_nodes_leaf) 322 | return 323 | 324 | cookie_idx = np.random.choice(len(ggt_root.children)) 325 | self.selected_cookie_idx = cookie_idx 326 | cookie_node = ggt_root.children[cookie_idx] 327 | 328 | # sample a sub-graph for each of the ancestor sub-sets 329 | for ancestor_multi_hyp_node in cookie_node.multi_hyp_ancestors: 330 | self._sample_cpdag_recursive_no_score(ancestor_multi_hyp_node) 331 | 332 | # sample a sub-graph for the descendant sub-set 333 | descendant_multi_hyp_node = cookie_node.multi_hyp_descendant # for consistent naming 334 | self._sample_cpdag_recursive_no_score(descendant_multi_hyp_node) 335 | 336 | def calc_graph_uncertainty(self, num_of_samples, threshold, temperature=1): 337 | graphs_list = [] 338 | 339 | # sample cpdags 340 | for _ in range(num_of_samples): 341 | self.sample_cpdag(temperature) 342 | graphs_list.append(self.graph) 343 | 344 | # calculate skeleton 345 | skeleton = PDAG(nodes_set=self.graph.nodes_set) 346 | for node_i, node_j in combinations(self.graph.nodes_set, 2): 347 | count = 0. 348 | for cpdag in graphs_list: 349 | if cpdag.is_connected(node_i, node_j): 350 | count += 1. 351 | 352 | if count > (threshold*num_of_samples): 353 | skeleton.add_edges(parents_set={node_i}, target_node=node_j, arrowhead_type=Mark.Undirected) 354 | 355 | return skeleton 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | -------------------------------------------------------------------------------- /causal_discovery_algs/fci.py: -------------------------------------------------------------------------------- 1 | from causal_discovery_utils.constraint_based import LearnStructBase 2 | from causal_discovery_algs.pc import LearnStructPC 3 | from graphical_models import PAG, arrow_head_types as Mark 4 | from itertools import combinations 5 | 6 | 7 | class LearnStructFCI(LearnStructBase): 8 | def __init__(self, nodes_set, ci_test, 9 | is_selection_bias=True, is_tail_completeness=True): 10 | super().__init__(PAG, nodes_set=nodes_set, ci_test=ci_test) 11 | 12 | assert isinstance(is_selection_bias, bool) 13 | self.is_selection_bias = is_selection_bias # if False, orientation rules R5, R6, R7 are not executed. 14 | assert isinstance(is_tail_completeness, bool) 15 | self.is_tail_completeness = is_tail_completeness # if False, orientation rules R8, R9, R10 are not executed 16 | 17 | self.graph.create_complete_graph(Mark.Circle, nodes_set) # Create a fully connected graph with edges: o--o 18 | self.pc_alg = LearnStructPC(nodes_set, ci_test) # initialize a PC object for learning the skeleton 19 | self.found_D_Sep_link = False # indicates if the learner removed an edges that the PC stage didn't remove 20 | 21 | def learn_structure(self): 22 | """ 23 | Learn a partial ancestral graph (PAG) using the fast causal inference (FCI) algorithm 24 | :return: 25 | """ 26 | # initial graph is a fully connected one with o--o edges between every pair of nodes 27 | # learn an initial skeleton using the same procedure as in the PC algorithm 28 | self._learn_pc_skeleton() 29 | 30 | # the resulting graph consists of only o--o edges 31 | # find and orient v-structures 32 | self.graph.orient_v_structures(self.sepset) 33 | 34 | # the resulting graph has only o--o, o-->, or <--> edges 35 | # find and remove edges between pairs of variables that are d-separated by some subset of Possible-D-SEP sets 36 | self.found_D_Sep_link = self._refine_pc_skeleton() 37 | 38 | # re-orient 39 | self.graph.reset_orientations(default_mark=Mark.Circle) 40 | self.graph.orient_v_structures(self.sepset) 41 | self.graph.maximally_orient_pattern(rules_set=[1, 2, 3, 4]) 42 | if self.is_selection_bias: 43 | self.graph.maximally_orient_pattern(rules_set=[5, 6, 7]) 44 | if self.is_tail_completeness: 45 | self.graph.maximally_orient_pattern(rules_set=[8, 9, 10]) 46 | 47 | def _learn_pc_skeleton(self): 48 | """ 49 | Learn an initial skeleton. This procedure is identical to the one of the PC algorithm 50 | :return: 51 | """ 52 | 53 | self.pc_alg.learn_skeleton() 54 | self.sepset.copy_from(self.pc_alg.sepset, self.graph.nodes_set) 55 | self.graph.create_empty_graph() 56 | self.graph.copy_skeleton_from_pdag(self.pc_alg.graph) # create edges with o-marks: X o--o Y 57 | self.graph.sepset = self.sepset 58 | 59 | def _refine_pc_skeleton(self): 60 | """ 61 | Refine the skeleton (v-structures are oriented) recovered by the PC algorithm 62 | using subset of possible-d-sep set. 63 | 64 | :return: True if the graph was modified by this method 65 | """ 66 | found_indep = False 67 | pds_list = dict() 68 | 69 | # Prepare the possible-d-sep set for each of the nodes 70 | for node_x in self.graph.nodes_set: 71 | pds_list[node_x] = possible_d_sep = self._create_pds_set(node_x) # self.get_pds(node_x) 72 | 73 | # Test CI for the graph edges 74 | for node_x in self.graph.nodes_set: 75 | possible_d_sep = pds_list[node_x] 76 | adjacent_nodes = self.graph.find_adjacent_nodes(node_x) 77 | for node_y in adjacent_nodes: 78 | found_indep |= self._test_ci_increasing(node_x, node_y, possible_d_sep - {node_y}) 79 | 80 | return found_indep 81 | 82 | def _test_ci_increasing(self, node_x, node_y, pds_super_set): 83 | """ 84 | Search for a minimal separating set by gradually increasing conditioning set size. 85 | :param node_x: a node on one side of the tested edge 86 | :param node_y: a node on the other side of the tested edge 87 | :param pds_super_set: a super-set of nodes from which to construct conditioning sets 88 | :return: True if an edge was deleted, False if no independence was found 89 | """ 90 | cond_indep = self.ci_test.cond_indep # for better readability 91 | for ci_size in range(len(pds_super_set)+1): # loop over condition set sizes; increasing set sizes 92 | for cond_set in combinations(pds_super_set, ci_size): # loop over condition sets of a fixed size 93 | if cond_indep(node_x, node_y, cond_set): 94 | self.graph.delete_edge(node_x, node_y) 95 | self.sepset.set_sepset(node_x, node_y, cond_set) 96 | return True 97 | 98 | return False 99 | 100 | def _create_pds_set(self, node_edge): 101 | """ 102 | Construct a possible-d-sep set for node_edge 103 | 104 | :param node_edge: node on the edge being CI tested 105 | :return: a possible-d-sep 106 | """ 107 | 108 | # Three lists are maintained: "first_nodes", "second_nodes", "neighbors". 109 | # Corresponding elements from the lists, 110 | # "node_1" in "first_nodes", 111 | # "node_2" in "second_nodes", and 112 | # "node_3" in "neighbors", 113 | # form a path "node_1" --- "node_2" --- "node_3" 114 | # If this path is "legal" then "node_2" is in the possible-d-sep set and added to the PDS-tree 115 | 116 | # create an adjacency matrix (ignore edge-marks) 117 | adj_graph = self.graph.get_skeleton_graph() 118 | 119 | # initialize "first nodes" and "second nodes" lists 120 | neighbors = adj_graph.get_neighbors(node_edge) 121 | second_nodes = neighbors.copy() 122 | first_nodes = [node_edge for _ in range(len(second_nodes))] 123 | 124 | # initialize possible-d-sep list of nodes 125 | pds_nodes = neighbors.copy() # initially: the neighbors of the node 126 | for node_nb in neighbors: 127 | adj_graph.remove_edge(node_edge, node_nb) # make sure the search doesn't loop back to the root 128 | 129 | while len(second_nodes) > 0: 130 | node_1 = first_nodes.pop(0) 131 | node_2 = second_nodes.pop(0) 132 | 133 | neighbors = adj_graph.get_neighbors(node_2) 134 | 135 | for node_3 in neighbors: 136 | if self.graph.is_possible_collider(node_x=node_1, node_middle=node_2, node_y=node_3): # test sub-path 137 | adj_graph.remove_edge(node_2, node_3) 138 | first_nodes.append(node_2) 139 | second_nodes.append(node_3) 140 | pds_nodes.append(node_3) 141 | 142 | possible_d_sep_set = set(pds_nodes) 143 | possible_d_sep_set.discard(node_edge) 144 | return possible_d_sep_set 145 | -------------------------------------------------------------------------------- /causal_discovery_algs/icd.py: -------------------------------------------------------------------------------- 1 | from causal_discovery_utils.constraint_based import LearnStructBase, unique_element_iterator 2 | from graphical_models import PAG, PDSTree, arrow_head_types as Mark 3 | from itertools import combinations, chain 4 | 5 | 6 | class LearnStructICD(LearnStructBase): 7 | def __init__(self, nodes_set, ci_test, is_pre_calc_cond_set=False, 8 | is_selection_bias=True, is_tail_completeness=True): 9 | super().__init__(PAG, nodes_set=nodes_set, ci_test=ci_test) 10 | 11 | # initialize ICD 12 | self.graph.create_complete_graph(Mark.Circle, nodes_set) # Create a fully connected graph with edges: o--o 13 | self.test_cond_ancestor = True # requires nodes in the conditioning set to be possible ancestors 14 | 15 | assert isinstance(is_pre_calc_cond_set, bool) 16 | self.is_pre_calc_pds = is_pre_calc_cond_set 17 | assert isinstance(is_selection_bias, bool) 18 | self.is_selection_bias = is_selection_bias # if False, orientation rules R5, R6, R7 are not executed. 19 | assert isinstance(is_tail_completeness, bool) 20 | self.is_tail_completeness = is_tail_completeness # if False, orientation rules R8, R9, R10 are not executed 21 | self.edge_key = lambda x, y: (x, y) if x < y else (y, x) 22 | self.conditioning_set = {self.edge_key(*edge): set() for edge in combinations(nodes_set, 2)} 23 | self._state = dict(done=False, cond_set_size=0) 24 | 25 | def reset_graph_orientations(self): 26 | """ 27 | Erase all edge marks replacing them with the circle mark. 28 | """ 29 | self.graph.reset_orientations(default_mark=Mark.Circle) 30 | 31 | def learn_structure(self) -> None: 32 | """ 33 | Learn a partial ancestral graph (PAG) using the iterative causal discovery (ICD) algorithm. 34 | 35 | :return: 36 | """ 37 | 38 | # reset state 39 | self._state = dict(done=False, # Latest iteration result 40 | cond_set_size=0) # next iteration r-value: desired search-radius & conditioning-set-size 41 | 42 | done = False 43 | while not done: 44 | # Perform ICD single iteration 45 | done, _ = self.learn_structure_iteration() 46 | 47 | def learn_structure_iteration(self): 48 | """ 49 | Execute a single ICD-iteration increasing the representation level of the PAG by 1: 50 | 1. Run a single ICD iteration with parameter r (internal) 51 | 2. Prepare for the next iteration: r := r + 1 52 | 53 | :return: a 2-tuple: done, current graph's r-value. 54 | done is True if ICD concluded and no more iterations are required/allowed. 55 | At this stage self.graph is an r-representing PAG. 56 | """ 57 | 58 | if self._state['done']: 59 | raise "ICD already concluded. Cannot run more iterations" 60 | 61 | # -1- Run a single ICD iteration ------------------------------------------------------------------------------- 62 | # for efficiency, handle special cases of ICD iterations (conditioning set is empty or contains a single node) 63 | if self._state['cond_set_size'] == 0: # empty conditioning set 64 | self._learn_struct_base_step_0() 65 | elif self._state['cond_set_size'] == 1: # a single node in the conditioning set 66 | self._state['done'] = self._learn_struct_base_step_1() 67 | else: # general ICD iteration 68 | if self.is_pre_calc_pds: 69 | self._pre_calc_conditioning(self._state['cond_set_size']) 70 | self._state['done'] = self._learn_struct_incremental_step(self._state['cond_set_size']) 71 | 72 | r_value = self._state['cond_set_size'] # the graphs representation level 73 | 74 | # -2- Prepare for the next iteration: r := r + 1 --------------------------------------------------------------- 75 | self._state['cond_set_size'] += 1 # for the next iteration: increase r (radius & conditioning-set-size) 76 | 77 | return self._state['done'], r_value # return r-value of latest iteration, i.e., self.graph is r-representing 78 | 79 | def _pre_calc_conditioning(self, cond_set_size): 80 | for node_i, node_j in combinations(self.graph.nodes_set, 2): 81 | if self.graph.is_connected(node_i, node_j): 82 | self.conditioning_set[self.edge_key(node_i, node_j)] = self._get_pdsep_range_sets( 83 | node_i, node_j, cond_set_size) 84 | 85 | def _learn_struct_incremental_step(self, cond_set_size=None): 86 | """ 87 | Learn a single increment, a single ICD step. This treats the generic case for conditioning set sizes >= 2. 88 | :param cond_set_size: create a list of possible conditioning sets of this size, taking into account the 89 | removal of previous edges during this step. Ignored if class-member 'pre_calc_pds' is True 90 | :return: True if the resulting PAG is completed (no more edges can be removed) 91 | """ 92 | if cond_set_size is None: 93 | assert self.is_pre_calc_pds is True 94 | cond_indep = self.ci_test.cond_indep 95 | source_pag = self.graph # Not a copy!!! thus, edge deletions affect consequent CI queries 96 | done = True 97 | for node_i, node_j in combinations(source_pag.nodes_set, 2): 98 | if not source_pag.is_connected(node_i, node_j): 99 | continue 100 | 101 | if self.is_pre_calc_pds: 102 | cond_sets = self.conditioning_set[self.edge_key(node_i, node_j)] 103 | else: 104 | cond_sets = self._get_pdsep_range_sets(node_i, node_j, cond_set_size) 105 | 106 | for cond in cond_sets: 107 | done = False # reset 'done' signaling to continue to the next ICD-iteration after the current one 108 | cond_set = cond[0] # get the set of nodes (in [1] there is the sum-of-minimal-distances) 109 | cond_tup = tuple(cond_set) 110 | if cond_indep(node_i, node_j, cond_tup): 111 | self.graph.delete_edge(node_i, node_j) # remove directed/undirected edge 112 | self.sepset.set_sepset(node_i, node_j, cond_tup) 113 | break # stop searching for independence as we found one and updated the graph accordingly 114 | 115 | # Orient edges 116 | # ------------ 117 | if not done: # re-orient the skeleton only if it was modified 118 | self.reset_graph_orientations() # self.graph.reset_orientations(default_mark=Mark.Circle) 119 | self.graph.orient_v_structures(self.sepset) # corresponds to rule R0 120 | self.graph.maximally_orient_pattern(rules_set=[1, 2, 3, 4]) 121 | else: # algorithm concluded, orient all edges for obtaining completeness 122 | if self.is_selection_bias: 123 | self.graph.maximally_orient_pattern(rules_set=[5, 6, 7]) # when selection-bias may be present 124 | if self.is_tail_completeness: 125 | self.graph.maximally_orient_pattern(rules_set=[8, 9, 10]) # for tail-completeness 126 | 127 | return done 128 | 129 | def _learn_struct_base_step_0(self): 130 | """ 131 | Execute ICD iteration with r = 0. That is, test unconditional independence between every pair of nodes and 132 | remove corresponding edges. Then, orient the graph. The result is a 0-representing PAG. 133 | 134 | :return: 135 | """ 136 | cond_indep = self.ci_test.cond_indep 137 | source_cpdag = self.graph # Not a copy!!! Thus, edge deletions affect consequent CI queries 138 | 139 | # r = 0: unconditional (marginal) independence tests 140 | for node_i, node_j in combinations(source_cpdag.nodes_set, 2): 141 | if cond_indep(node_i, node_j, ()): 142 | self.graph.delete_edge(node_i, node_j) # remove directed/undirected edge 143 | self.sepset.set_sepset(node_i, node_j, ()) 144 | 145 | self.graph.orient_v_structures(self.sepset) 146 | self.graph.maximally_orient_pattern(rules_set=[1, 2, 3, 4]) 147 | 148 | def _learn_struct_base_step_1(self): 149 | """ 150 | Execute ICD iteration with r = 1. That is, test independence between every pair of nodes conditioned on a single 151 | node, and remove corresponding edges. Then, orient the graph. The result is a 1-representing PAG. 152 | 153 | :return: True if done and no more iterations are required; otherwise False indicating the PAG is not completed. 154 | """ 155 | cond_indep = self.ci_test.cond_indep 156 | source_cpdag = self.graph # Not a copy!!! Thus, edge deletions affect consequent CI queries 157 | 158 | # r = 1: conditional independence tests order 1 159 | cond_set_size = 1 160 | done = True 161 | for node_i, node_j in combinations(source_cpdag.nodes_set, 2): 162 | if not source_cpdag.is_connected(node_i, node_j): 163 | continue 164 | 165 | pot_parents_i = self.graph.find_adjacent_nodes(node_i) - {node_j} 166 | pot_parents_j = self.graph.find_adjacent_nodes(node_j) - {node_i} 167 | 168 | cond_sets_i = combinations(pot_parents_i, cond_set_size) 169 | cond_sets_j = combinations(pot_parents_j, cond_set_size) 170 | cond_sets = unique_element_iterator( # unique of 171 | chain(cond_sets_i, cond_sets_j) # neighbors of node_i OR neighbors of node_j 172 | ) 173 | 174 | for cond_set in cond_sets: 175 | done = False 176 | if cond_indep(node_i, node_j, cond_set): 177 | self.graph.delete_edge(node_i, node_j) # remove directed/undirected edge 178 | self.sepset.set_sepset(node_i, node_j, cond_set) 179 | break # stop searching for independence as we found one and updated the graph accordingly 180 | 181 | self.reset_graph_orientations() # self.graph.reset_orientations(default_mark=Mark.Circle) 182 | self.graph.orient_v_structures(self.sepset) 183 | self.graph.maximally_orient_pattern(rules_set=[1, 2, 3, 4]) 184 | if self.is_selection_bias: 185 | self.graph.maximally_orient_pattern(rules_set=[5, 6, 7]) # when selection-bias may be present 186 | if self.is_tail_completeness: 187 | self.graph.maximally_orient_pattern(rules_set=[8, 9, 10]) # for tail-completeness 188 | 189 | return done 190 | 191 | def _get_pdsep_range_sets(self, node_i, node_j, cond_set_size): 192 | """ 193 | Create a list of conditioning sets that comply with the ICD-Sep conditions 194 | 195 | :param node_i: node on one side of the tested edge 196 | :param node_j: node on the other side of the tested edge 197 | :param cond_set_size: requested conditioning set size (ICD-Sep condition 1) 198 | :return: a list of conditioning sets to consider when testing CI between node_i and node_j 199 | """ 200 | # create PDS-trees for the tested nodes 201 | pds_tree_i, possible_d_sep_i = create_pds_tree(self.graph, node_i, max_depth=cond_set_size) 202 | pds_tree_j, possible_d_sep_j = create_pds_tree(self.graph, node_j, max_depth=cond_set_size) 203 | # pds_tree_i, possible_d_sep_i = self._create_pds_tree(node_i, max_depth=cond_set_size) 204 | # pds_tree_j, possible_d_sep_j = self._create_pds_tree(node_j, max_depth=cond_set_size) 205 | 206 | cond_sets_list_init = pds_tree_i.get_subsets_list(set_nodes=possible_d_sep_i, subset_size=cond_set_size) 207 | cond_sets_list_init += pds_tree_j.get_subsets_list(set_nodes=possible_d_sep_j, subset_size=cond_set_size) 208 | 209 | cond_sets_list = [] 210 | for cond in cond_sets_list_init: 211 | cond_set = cond[0] 212 | if (node_i in cond_set) or (node_j in cond_set): 213 | continue 214 | 215 | if not self._is_cond_set_possible_ancestor(cond_set, node_i, node_j): 216 | continue 217 | 218 | cond_sets_list.append(cond) 219 | 220 | # sort the list with respect to the sum-of-minimal-distances 221 | cond_sets_list.sort(key=lambda x: x[1]) 222 | return cond_sets_list 223 | 224 | def _is_cond_set_possible_ancestor(self, cond_set, node_i, node_j): 225 | """ 226 | Test ICD-Sep condition 3. That is, test if the all the nodes in the conditioning set are possible ancestors of 227 | node_i or node_j. 228 | 229 | :param cond_set: the conditioning set under examination 230 | :param node_i: node on one side of the tested edge 231 | :param node_j: node on the other side of the tested edge 232 | :return: True if the condition is satisfied, otherwise False 233 | """ 234 | for z in cond_set: 235 | if not ((self.graph.is_possible_ancestor(ancestor_node=z, descendant_node=node_i)) or 236 | (self.graph.is_possible_ancestor(ancestor_node=z, descendant_node=node_j))): 237 | return False 238 | return True 239 | 240 | 241 | def create_pds_tree(source_pag, node_root, en_nodes=None, max_depth=None): 242 | """ 243 | Create a PDS-tree rooted at node_root. 244 | 245 | :param source_pag: the partial ancestral graph from which to construct the PDS-tree 246 | :param node_root: root of the PDS tree 247 | :param en_nodes: nodes of interest 248 | :param max_depth: maximal depth of the tree (search radius around the root) 249 | :return: a PDS-tree 250 | """ 251 | 252 | # Three lists are maintained: "first_nodes", "second_nodes", "neighbors". 253 | # Corresponding elements from the lists, 254 | # "node_1" in "first_nodes", 255 | # "node_2" in "second_nodes", and 256 | # "node_3" in "neighbors", 257 | # form a path "node_1" --- "node_2" --- "node_3" 258 | # If this path is "legal" then "node_2" is in the possible-d-sep set and added to the PDS-tree 259 | 260 | if en_nodes is not None: 261 | assert node_root in en_nodes 262 | assert en_nodes.issubset(source_pag.nodes_set) 263 | 264 | pds_tree = PDSTree(node_root) # initialize 265 | 266 | # create an adjacency matrix (ignore edge-marks) 267 | adj_graph = source_pag.get_skeleton_graph(en_nodes=en_nodes) 268 | 269 | # initialize "first nodes" and "second nodes" lists 270 | neighbors = adj_graph.get_neighbors(node_root) 271 | second_nodes = neighbors.copy() 272 | first_nodes = [node_root for _ in range(len(second_nodes))] 273 | 274 | # initialize possible-d-sep list of nodes 275 | pds_nodes = neighbors.copy() # initially: the neighbors of the node 276 | for node_nb in neighbors: 277 | adj_graph.remove_edge(node_root, node_nb) # make sure the search doesn't loop back to the root 278 | 279 | # ----- for creating a PDS-tree -----\ 280 | if max_depth is None: # do not limit depth 281 | max_depth = len(adj_graph.nodes_set) - 1 282 | # create "first_nodes" and "second_nodes" trees 283 | first_nodes_trees = [pds_tree for _ in range(len(second_nodes))] 284 | for node in pds_nodes: 285 | pds_tree.add_branch(node) # add nodes to the PDS-tree 286 | second_nodes_trees = pds_tree.children.copy() # update "node_2 trees" list 287 | # now, both node_1_trees and node_2_trees have corresponding elements 288 | # -End: for creating a PDS-tree -----/ 289 | 290 | while len(second_nodes) > 0: 291 | node_1 = first_nodes.pop(0) 292 | node_2 = second_nodes.pop(0) 293 | 294 | # ----- for creating a PDS-tree ----- 295 | node_2_tree = second_nodes_trees.pop(0) 296 | if node_2_tree.depth_level >= max_depth: 297 | continue # skip the current pair: node_1 *--> node_2 (do not search <--* node_3 ) 298 | # -End: for creating a PDS-tree ----- 299 | 300 | neighbors = adj_graph.get_neighbors(node_2) 301 | 302 | for node_3 in neighbors: 303 | if source_pag.is_possible_collider(node_x=node_1, node_middle=node_2, node_y=node_3): # test sub-path 304 | adj_graph.remove_edge(node_2, node_3) 305 | first_nodes.append(node_2) 306 | second_nodes.append(node_3) 307 | pds_nodes.append(node_3) 308 | 309 | # ----- for creating a PDS-tree ----- 310 | node_2_tree.add_branch(node_3) 311 | added_branch = node_2_tree.get_child_branch(node_3) # get the added child branch 312 | second_nodes_trees.append(added_branch) 313 | first_nodes_trees.append(node_2_tree) 314 | # -End: for creating a PDS-tree ----- 315 | 316 | possible_d_sep_set = set(pds_nodes) 317 | possible_d_sep_set.discard(node_root) 318 | return pds_tree, possible_d_sep_set 319 | -------------------------------------------------------------------------------- /causal_discovery_algs/pc.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations, chain 2 | from causal_discovery_utils.constraint_based import LearnStructBase, unique_element_iterator 3 | from graphical_models import PDAG 4 | 5 | 6 | class LearnStructPC(LearnStructBase): 7 | def __init__(self, nodes_set, ci_test): 8 | super().__init__(PDAG, nodes_set=nodes_set, ci_test=ci_test) 9 | self.graph.create_complete_graph(nodes_set) # Create a fully connected graph 10 | self.overwrite_starting_graph = True # if True, the sequence at which the CIs are tested affects the result 11 | 12 | def learn_structure(self): 13 | """ 14 | Learn a CPDAG (completed partially directed graph) using the PC algorithm 15 | :return: 16 | """ 17 | self.learn_skeleton() 18 | 19 | self.orient_v_structures() 20 | self.graph.convert_bidirected_to_undirected() # treat bi-directed (spurious) as undirected 21 | 22 | self.graph.maximally_orient_pattern([1, 2, 3]) 23 | 24 | def _exit_cond(self, order): 25 | """ 26 | Check if the max fan-in is lower or equal to the order (exit-cond. is met) 27 | :param order: condition set size of the CI-test 28 | :return: True if exit condition is met 29 | """ 30 | for node in self.graph.nodes_set: 31 | if self.graph.fan_in(node) > order: # if a node have a large enough number of parents, exit cond. is false 32 | return False 33 | else: 34 | return True # didn't find a node with a large enough number of parents for CI test, so exit 35 | 36 | def learn_skeleton(self): 37 | cond_indep = self.ci_test.cond_indep 38 | 39 | if self.overwrite_starting_graph: 40 | source_cpdag = self.graph # Not a copy!!! thus, edge deletions affect consequent CI queries 41 | else: 42 | source_cpdag = self.graph.copy() # slower, but removes the dependence on the sequence of CI testing 43 | 44 | cond_set_size = 0 45 | while not self._exit_cond(cond_set_size): 46 | for node_i, node_j in combinations(source_cpdag.nodes_set, 2): 47 | if not source_cpdag.is_connected(node_i, node_j): 48 | continue 49 | 50 | pot_parents_i = source_cpdag.undirected_neighbors(node_i) - {node_j} 51 | pot_parents_j = source_cpdag.undirected_neighbors(node_j) - {node_i} 52 | cond_sets_i = combinations(pot_parents_i, cond_set_size) 53 | cond_sets_j = combinations(pot_parents_j, cond_set_size) 54 | cond_sets = unique_element_iterator( # unique of 55 | chain(cond_sets_i, cond_sets_j) # neighbors of node_i OR neighbors of node_j 56 | ) 57 | 58 | for cond_set in cond_sets: 59 | if cond_indep(node_i, node_j, cond_set): 60 | self.graph.delete_edge(node_i, node_j) # remove directed/undirected edge 61 | self.sepset.set_sepset(node_i, node_j, cond_set) 62 | break # stop searching for independence as we found one and updated the graph accordingly 63 | 64 | cond_set_size += 1 # now go again over all the edges and try to remove using a condition set size +1 65 | 66 | def orient_v_structures(self): 67 | # ToDo: Move this function to the PDAG class 68 | # create a copy of edges 69 | pre_neighbors = dict() 70 | for node in self.graph.nodes_set: 71 | pre_neighbors[node] = self.graph.undirected_neighbors(node).copy() # undirected neighbors pre graph changes 72 | 73 | # check each node if it can serve as new collider for a disjoint neighbors 74 | for node_z in self.graph.nodes_set: 75 | # check undirected neighbors 76 | xy_nodes = pre_neighbors[node_z] # undirected neighbors 77 | for node_x, node_y in combinations(xy_nodes, 2): 78 | if self.graph.is_connected(node_x, node_y): 79 | continue # skip this pair as they are connected 80 | if node_z not in self.sepset.get_sepset(node_x, node_y): 81 | self.graph.orient_edge(source_node=node_x, target_node=node_z) # orient X --> Z 82 | self.graph.orient_edge(source_node=node_y, target_node=node_z) # orient Y --> Z 83 | -------------------------------------------------------------------------------- /causal_discovery_algs/rai.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations, chain 2 | 3 | import graphical_models.arrow_head_types as Mark # incoming arrow head-types 4 | from graphical_models import PDAG 5 | from causal_discovery_utils.constraint_based import LearnStructBase, unique_element_iterator 6 | 7 | 8 | class LearnStructRAI(LearnStructBase): 9 | """ 10 | RAI structure learning algorithm 11 | 12 | Example: 13 | import pandas as pd 14 | pd_data = pd.read_csv('Alarm1_s5000_v1.txt', header=None, sep=' ') 15 | data_alarm = pd_data.values # a 2D numpy vector. len(data)=num. of cases, len(data[0])=num. of variables 16 | n_samples, n_vars = data_alarm.shape 17 | ci_test = CondIndepCMI(dataset=data_alarm, threshold=0.012) 18 | nodes = set(range(n_vars)) 19 | rai = LearnStructRAI(nodes_set=nodes, ci_test=ci_test) 20 | """ 21 | def __init__(self, nodes_set, ci_test): 22 | super().__init__(PDAG, nodes_set=nodes_set, ci_test=ci_test) 23 | self.graph.create_complete_graph(nodes_set) # Create a fully connected graph 24 | self.overwrite_starting_graph = True # if True, the sequence at which the CIs are tested affects the result 25 | 26 | def learn_structure(self): 27 | """ 28 | Learn a CPDAG (completed partially directed graph) using the recursive autonomy identification (RAI) algorithm 29 | :return: 30 | """ 31 | # initialize for the 1st recursive call 32 | en_nodes = self.graph.nodes_set 33 | ex_nodes = set() 34 | 35 | self._learn_recursively(en_nodes=en_nodes, ex_nodes=ex_nodes, order=0) 36 | self.graph.maximally_orient_pattern((1, 2, 3)) 37 | 38 | self._re_orient_skeleton() 39 | 40 | def _re_orient_skeleton(self): 41 | """ 42 | Remove all edge directions and re-orient using rules R1, R2, R3 43 | :return: 44 | """ 45 | cpdag_final = PDAG(self.graph.nodes_set) 46 | for node in self.graph.nodes_set: 47 | connected_nodes_set = self.graph.parents(node) | self.graph.undirected_neighbors(node) 48 | cpdag_final.add_edges(parents_set=connected_nodes_set, target_node=node, arrowhead_type=Mark.Undirected) 49 | 50 | for node in self.graph.nodes_set: 51 | parents_set = self.graph.parents(node) 52 | for (parent_i, parent_j) in combinations(parents_set, 2): 53 | if not self.graph.is_connected(parent_i, parent_j): 54 | cpdag_final.orient_edge(source_node=parent_i, target_node=node) # orient v-structure 55 | cpdag_final.orient_edge(source_node=parent_j, target_node=node) 56 | 57 | cpdag_final.maximally_orient_pattern((1, 2, 3)) # use orientation rules R1, R2, and R3 58 | self.graph = cpdag_final 59 | 60 | def _exit_cond(self, en_nodes, order): 61 | """ 62 | Check if the max fan-in is lower or equal to the order (exit-cond. is met) 63 | :param en_nodes: nodes of the sub-graph 64 | :param order: condition set size of the CI-test 65 | :return: True if exit condition is met 66 | """ 67 | for node in en_nodes: 68 | if self.graph.fan_in(node) > order: # if a node have a large enough number of parents, exit cond. is false 69 | return False 70 | else: 71 | return True # didn't find a node with a large enough number of parents for CI test, so exit 72 | 73 | def _learn_recursively(self, en_nodes, ex_nodes, order): 74 | """ 75 | The folowing steps are preformed: 76 | 1. Refine and orient using CI tests of a specific condition set size 77 | a. Test CI between exogenous nodes and endogenous nodes, remove edges, and orient edges 78 | b. Test CI among the endogenous nodes, remove edges, and orient 79 | 3. Identify ancestors and descendant groups 80 | 4. Call recursively for the ancestor and descendant groups with CI order+1 81 | :param en_nodes: Endogenous nodes 82 | :param ex_nodes: Exogenous nodes 83 | :param order: CI test order, i.e., condition set size 84 | :return: 85 | """ 86 | 87 | # test exit condition 88 | if self._exit_cond(en_nodes, order): 89 | return 90 | 91 | # refine and orient with condition set size equal to "order" 92 | self._refine_and_orient(en_nodes=en_nodes, ex_nodes=ex_nodes, order=order, cond_indep=self.ci_test.cond_indep) 93 | 94 | # split into ancestors/descendant autonomous sub-graphs 95 | d_nodes, list_of_ancestors_sets, a_nodes = self._split_ancestors_descendant(en_nodes=en_nodes) 96 | 97 | # Recursive calls for each autonomous sub-graph 98 | for ancestor_set in list_of_ancestors_sets: 99 | self._learn_recursively(en_nodes=ancestor_set, ex_nodes=ex_nodes, 100 | order=order+1) # recursive call (ancestor) 101 | self._learn_recursively(en_nodes=d_nodes, ex_nodes=a_nodes | ex_nodes, 102 | order=order+1) # recursive call (descendant) 103 | 104 | def _refine_and_orient(self, en_nodes, ex_nodes, order, cond_indep=None): 105 | """ 106 | Refine by removing edges between nodes that are conditionally independent given some set. 107 | Note: This is a the core element in the RAI and BRAI algorithms which is called recursively. 108 | :param en_nodes: Endogenous nodes 109 | :param ex_nodes: Exogenous nodes 110 | :param order: CI test order (condition set size) 111 | :param cond_indep: CI test 112 | :return: 113 | """ 114 | self._refine_exogenous_effect(en_nodes=en_nodes, ex_nodes=ex_nodes, order=order, cond_indep=cond_indep) 115 | self.maximally_orient_edges(en_nodes=en_nodes) 116 | self._refine_endogenous(en_nodes=en_nodes, order=order, cond_indep=cond_indep) 117 | self.maximally_orient_edges(en_nodes=en_nodes) 118 | 119 | def _split_ancestors_descendant(self, en_nodes): 120 | """ 121 | Split the nodes into a descendant nodes-set and a list of disconnected ancestor node-sub-sets 122 | :param en_nodes: set of nodes to split 123 | :return: descendant set, list of ancestors sub-sets, set of all ancestor nodes 124 | """ 125 | d_nodes = self._get_lowest_topological_set(en_nodes) 126 | a_nodes = en_nodes - d_nodes # sets of nodes that need to be separated 127 | # list_of_ancestors_sets = self._get_unconnected_subgraphs(all_nodes=a_nodes) # get unconnected ancestor sets 128 | list_of_ancestors_sets = self.graph.find_unconnected_subgraphs(a_nodes) # get unconnected ancestor sets 129 | return d_nodes, list_of_ancestors_sets, a_nodes 130 | 131 | def _refine_exogenous_effect(self, en_nodes, ex_nodes, order, cond_indep=None): 132 | """ 133 | Test each edge from an exogenous node to an endogenous node 134 | :param en_nodes: Endogenous nodes 135 | :param ex_nodes: Exogenous nodes 136 | :param order: CI test order, i.e., condition set size 137 | :param cond_indep: an oracle that answers conditional independence queries 138 | :return: 139 | """ 140 | if cond_indep is None: 141 | cond_indep = self.ci_test.cond_indep 142 | 143 | if self.overwrite_starting_graph: 144 | source_cpdag = self.graph # Not a copy!!! thus, edge deletions affect consequent CI queries 145 | else: 146 | source_cpdag = self.graph.copy() # slower, but removes the dependence on the sequence of CI testing 147 | 148 | for node in en_nodes: 149 | for ex in ex_nodes: 150 | if not source_cpdag.is_connected(ex, node): 151 | continue 152 | 153 | pot_parents_node = (source_cpdag.parents(node) | source_cpdag.undirected_neighbors(node)) - {ex} 154 | pot_parents_ex = (source_cpdag.parents(ex) | source_cpdag.undirected_neighbors(ex)) - {node} 155 | cond_sets_node = combinations(pot_parents_node, order) 156 | cond_sets_ex = combinations(pot_parents_ex, order) 157 | cond_sets = unique_element_iterator( 158 | chain(cond_sets_node, cond_sets_ex) 159 | ) 160 | for cset in cond_sets: # note that cond_sets is a generator of tuples (not sets) 161 | if cond_indep(ex, node, cset): # CI test: test for conditional independence 162 | self.graph.delete_edge(ex, node) # remove the edge ex --> node 163 | self.sepset.set_sepset(ex, node, cset) 164 | break # stop searching for independence as we found one and updated the graph accordingly 165 | 166 | def _refine_endogenous(self, en_nodes, order, cond_indep=None): 167 | """ 168 | Remove edges between pairs of conditionally independent endogenous variables. Condition set consists of nodes 169 | from endogenous and exogenous nodes and has a specific size. Test edges X --> Y and X --- Y 170 | :param en_nodes: Endogenous nodes 171 | :param order: Condition set size 172 | :param cond_indep: an oracle that answers conditional independence queries 173 | :return: 174 | """ 175 | if cond_indep is None: 176 | cond_indep = self.ci_test.cond_indep 177 | 178 | if self.overwrite_starting_graph: 179 | source_cpdag = self.graph # Not a copy!!! thus, edge deletions affect consequent CI queries 180 | else: 181 | source_cpdag = self.graph.copy() # slower, but removes the dependence on the sequence of CI testing 182 | 183 | for node_i, node_j in combinations(en_nodes, 2): 184 | if not source_cpdag.is_connected(node_i, node_j): 185 | continue 186 | 187 | pot_parents_i = (source_cpdag.parents(node_i) | source_cpdag.undirected_neighbors(node_i)) - {node_j} 188 | pot_parents_j = (source_cpdag.parents(node_j) | source_cpdag.undirected_neighbors(node_j)) - {node_i} 189 | cond_sets_i = combinations(pot_parents_i, order) 190 | cond_sets_j = combinations(pot_parents_j, order) 191 | cond_sets = unique_element_iterator( 192 | chain(cond_sets_i, cond_sets_j) 193 | ) 194 | 195 | for cset in cond_sets: 196 | if cond_indep(node_i, node_j, cset): 197 | self.graph.delete_edge(node_i, node_j) # remove directed/undirected edge 198 | self.sepset.set_sepset(node_i, node_j, cset) 199 | break # stop searching for independence as we found one and updated the graph accordingly 200 | 201 | def _get_lowest_topological_set(self, en_nodes): 202 | """ 203 | Return the set of nodes having the lowest topological order. 204 | In a directed edge, the parent has a higher topological order than the child. 205 | In an undirected edge, both nodes on the end points have equal topological order. 206 | :param en_nodes: 207 | :return: A set of nodes having the lowest topological order (equally) 208 | """ 209 | 210 | #return en_nodes # ToDo: Consider removing this section 211 | 212 | # ToDo: Consider removing this section 213 | if False: 214 | generic_parents = set() # parents of someone in the endogenous set 215 | for node in en_nodes: 216 | generic_parents.update(self.graph.parents(node) & en_nodes) 217 | 218 | leaves = en_nodes - generic_parents 219 | # "leaves" set may still contains nodes that are adjacent to "generic_parents" rendering them equal topological 220 | # order, thus not a descendant. Remove nodes that have a path to a generic-parent through only undirected edges 221 | lowest_set = set() 222 | for leaf in leaves: 223 | if self.graph.is_reachable_any_undirected(leaf, generic_parents, en_nodes): 224 | continue 225 | else: 226 | lowest_set.add(leaf) 227 | 228 | return lowest_set 229 | 230 | # ToDo: Consider removing this section 231 | ordered_sets = self.graph.find_partial_topological_order(en_nodes) 232 | # return ordered_sets[0] # ToDo: consider returning this 233 | if len(ordered_sets) > 1: # if there two or more sets 234 | return en_nodes - ordered_sets[-1] # return all the nodes not in the highest order 235 | else: 236 | return ordered_sets[0] # there is only one set (no separation was found) 237 | 238 | def maximally_orient_edges(self, en_nodes): 239 | """ 240 | Maximally orient edges starting anywhere but ending (including undirected) at the endogenous nodes. 241 | First, v-structures are identified and oriented R0. Then, four rules, R1-R4 are repeatedly applied. 242 | :param en_nodes: Endogenous nodes 243 | :return: 244 | """ 245 | # ToDo: Move this to the PAG class 246 | self.orient_v_structures(en_nodes) # [R0]: orient v-structures 247 | 248 | self.graph.maximally_orient_pattern([1, 2, 3, 4]) 249 | 250 | self.graph.convert_bidirected_to_undirected(en_nodes) # treat bi-directed (spurious) as undirected 251 | 252 | def orient_v_structures(self, en_nodes): 253 | """ 254 | Orient edges starting anywhere but ending (including undirected) at the endogenous nodes. 255 | note that in our case, the situation: X--Z<-Y can happen and node Z is tested if it can be a collider. 256 | :param en_nodes: The set of nodes from which to search for a collider 257 | :return: 258 | """ 259 | # ToDo: Move this to the PAG class 260 | # create a copy of edges 261 | pre_neighbors = dict() 262 | pre_parents = dict() 263 | for node in self.graph.nodes_set: 264 | pre_neighbors[node] = self.graph.undirected_neighbors(node).copy() # undirected neighbors pre graph changes 265 | 266 | # check each node if it can serve as new collider for a disjoint neighbors 267 | for node_z in en_nodes: 268 | # check undirected neighbors 269 | xy_nodes = pre_neighbors[node_z] # undirected neighbors 270 | for node_x, node_y in combinations(xy_nodes, 2): 271 | if self.graph.is_connected(node_x, node_y): 272 | continue # skip this pair as they are connected 273 | if node_z not in self.sepset.get_sepset(node_x, node_y): 274 | self.graph.orient_edge(source_node=node_x, target_node=node_z) # orient X --> Z 275 | self.graph.orient_edge(source_node=node_y, target_node=node_z) # orient Y --> Z 276 | 277 | for node in self.graph.nodes_set: 278 | pre_neighbors[node] = self.graph.undirected_neighbors(node).copy() # undirected neighbors pre graph changes 279 | pre_parents[node] = self.graph.parents(node).copy() # undirected parents pre graph changes 280 | 281 | for node_z in en_nodes: 282 | # check the case when one of the neighbors is already a parent 283 | parents_z = pre_parents[node_z] # parents before the graph was modified 284 | neighbors_z = pre_neighbors[node_z] # remaining undirected edges X --- Z 285 | for node_x in neighbors_z: 286 | for node_y in parents_z: 287 | if self.graph.is_connected(node_x, node_y): 288 | continue # skip this pair as they are connected 289 | if node_z not in self.sepset.get_sepset(node_x, node_y): 290 | self.graph.orient_edge(source_node=node_x, target_node=node_z) # orient (only) X --> Z 291 | -------------------------------------------------------------------------------- /causal_discovery_algs/ts_icd.py: -------------------------------------------------------------------------------- 1 | from causal_discovery_algs.icd import LearnStructICD, unique_element_iterator 2 | from graphical_models import PAG, arrow_head_types as Mark 3 | from itertools import product, combinations, chain 4 | 5 | 6 | class LearnStructTSICD(LearnStructICD): 7 | def __init__(self, nodes_sets_list, ci_test, is_homology=True, initial_pag: PAG = None, is_pre_calc_cond_set=False, 8 | is_selection_bias=False, is_tail_completeness=True): 9 | """ 10 | Initialization: create a complete PAG with o--> edges between past-future and o--o between contemporaneous nodes 11 | :param nodes_sets_list: list of list of nodes. 12 | first sub-list is the present, second sub-list is one time-step back and so on. 13 | That is: [ [node names at t=0], [node names at t=-1], ..., [node names at t=-k], where k is max-lag value. 14 | the order in the sub-lists matters. nodes at the same location in different sub-lists correspond to the 15 | same node at different time step. 16 | :param ci_test: a ci test that handles time-series data 17 | :param is_homology: if True then enforce homology, otherwise do not modify past-past edges (default: True) 18 | :param initial_pag: if provided, use this graph as the initial graph instead of a complete graph 19 | :param is_pre_calc_cond_set: if True, conditioning sets are precalculated at the beginning (default: False) 20 | :param is_selection_bias: is selection bias possibly present? (default: True) 21 | :param is_tail_completeness: should the algorithm apply rules R8, R9, R10 from Zhang 2008b (default: True) 22 | """ 23 | win_len = len(nodes_sets_list) # time steps: t=0, t=-1, ..., t=-(win_len-1) 24 | assert win_len > 1 # at least two time-steps are expected 25 | 26 | num_nodes_t0 = len(nodes_sets_list[0]) # number of nodes in a single time-step 27 | assert all((isinstance(nodes_t, list) and len(nodes_t) == num_nodes_t0) 28 | for nodes_t in nodes_sets_list) # check equal length of sub-lists, and that they are list types 29 | 30 | nodes_set = {node for nodes_t_i in nodes_sets_list for node in nodes_t_i} # unwrap all the nodes into one set 31 | assert len(nodes_set) == win_len * num_nodes_t0 # duplicate node names are not allowed 32 | 33 | super().__init__(nodes_set=nodes_set, ci_test=ci_test, is_pre_calc_cond_set=is_pre_calc_cond_set, 34 | is_selection_bias=is_selection_bias, is_tail_completeness=is_tail_completeness) 35 | 36 | # initialize the graph: o--> edges between nodes in different time-steps, and o--o between contemporaneous nodes 37 | if initial_pag is None: # create an initial graph 38 | initial_pag = PAG(nodes_set) 39 | initial_pag.create_complete_graph(Mark.Circle, nodes_set) 40 | past_future_idx = [(t_past, t_future) for t_future in range(win_len - 1) 41 | for t_past in range(t_future + 1, win_len)] 42 | for past_idx, future_idx in past_future_idx: 43 | for node_past, node_future in product(nodes_sets_list[past_idx], nodes_sets_list[future_idx]): 44 | # Note: past instance t-1 has larger indexes than future instance t 45 | initial_pag.replace_edge_mark(node_source=node_past, node_target=node_future, 46 | requested_edge_mark=Mark.Directed) # past o--> future (may be bidi) 47 | 48 | # copy initial graph and store fixed orientations 49 | assert isinstance(initial_pag, PAG) 50 | self.graph.create_empty_graph() 51 | for node_i, node_j in combinations(initial_pag.nodes_set, 2): # copy graph. ToDo: consider adding to PAG class 52 | if initial_pag.is_connected(node_i, node_j): 53 | self.graph.add_edge(node_i, node_j, 54 | initial_pag.get_edge_mark(node_j, node_i), 55 | initial_pag.get_edge_mark(node_i, node_j)) 56 | 57 | self.fixed_orientations = dict() 58 | for node_i, node_j in combinations(initial_pag.nodes_set, 2): # store fixed orientations 59 | if initial_pag.is_connected(node_i, node_j): 60 | self.fixed_orientations[(node_i, node_j)] = initial_pag.get_edge_mark(node_i, node_j) 61 | self.fixed_orientations[(node_j, node_i)] = initial_pag.get_edge_mark(node_j, node_i) 62 | 63 | self.is_homology = is_homology 64 | self.nodes_sets_list = nodes_sets_list 65 | self.time_window_len = win_len 66 | self.num_nodes_t = num_nodes_t0 67 | 68 | # create a dictionary for time-stamps of each node 69 | self.node2time = {node: time_idx for time_idx, nodes_t_i in enumerate(nodes_sets_list) for node in nodes_t_i} 70 | 71 | # create a dictionary for node identity that is encoded in the index in the sub-list 72 | self.node2idx = {node: node_idx for nodes_t_i in nodes_sets_list for node_idx, node in enumerate(nodes_t_i)} 73 | 74 | # add nodes' dictionaries to the PAG instance (too) 75 | self.graph.node2time = self.node2time.copy() 76 | self.graph.node2idx = self.node2idx.copy() 77 | 78 | # create iterators over nodes 79 | past_nodes_iter = (past_node for time_past in range(1, win_len) for past_node in nodes_sets_list[time_past]) 80 | present_nodes = nodes_sets_list[0] 81 | self.possible_cross_lag_edges = list(product(iter(present_nodes), past_nodes_iter)) # past-present node pairs 82 | self.possible_contemporaneous_edges = list(combinations(present_nodes, 2)) # present-time pairs-of-nodes 83 | 84 | def set_fixed_orientations(self): 85 | for node_i, node_j in combinations(self.graph.nodes_set, 2): 86 | if self.graph.is_connected(node_i, node_j): 87 | self.graph.replace_edge_mark(node_source=node_i, node_target=node_j, 88 | requested_edge_mark=self.fixed_orientations[(node_i, node_j)]) 89 | self.graph.replace_edge_mark(node_source=node_j, node_target=node_i, 90 | requested_edge_mark=self.fixed_orientations[(node_j, node_i)]) 91 | 92 | def get_edge_homology(self, node_0, node_1) -> list: 93 | """ 94 | A homology map finds all the node-pairs that are homologous to a given pair 95 | :param node_0: node at present time 96 | :param node_1: node in present time (contemporaneous) or in the past (cross-lag) 97 | :return: 98 | """ 99 | assert self.node2time[node_0] == 0 # node_0 must be from the current time-stamp 100 | time_1 = self.node2time[node_1] 101 | node_1_idx = self.node2idx[node_1] 102 | node_0_idx = self.node2idx[node_0] # index of the node in the present (t=0) 103 | homology = [] 104 | # loop through new time indexes for node_0 and node_i (iteratively move one step back in time) 105 | for time_0_new, time_1_new in zip(range(1, self.time_window_len - time_1), range(time_1 + 1, self.time_window_len)): 106 | node_0_new = self._get_node_name(time_0_new, node_0_idx) 107 | node_1_new = self._get_node_name(time_1_new, node_1_idx) 108 | homology.append((node_0_new, node_1_new)) 109 | return homology 110 | 111 | def get_ci_homology(self, node_0, node_1, conditioning_tuple) -> list: 112 | """ 113 | A homology map that finds all the node-pairs & conditioning sets that are homologous. 114 | Homology is search only for those cases that all node-pair and conditioning sets are inside the time window. 115 | This result in cases where nodes in the past are not returned as conditionally independent. 116 | :param node_0: node at present time 117 | :param node_1: node in present time (contemporaneous) or in the past (cross-lag) 118 | :param conditioning_tuple: conditioning set in present time or in the past 119 | :return: a list of tuples, where each tuple corresponds to one homologous triplet (node, node, conditioning set) 120 | """ 121 | assert self.node2time[node_0] == 0 122 | time_1 = self.node2time[node_1] 123 | time_conditioning = [self.node2time[conditioning_node] for conditioning_node in conditioning_tuple] 124 | node_0_idx = self.node2idx[node_0] # index of the node in the present (t=0) 125 | node_1_idx = self.node2idx[node_1] 126 | conditioning_idx = [self.node2idx[conditioning_node] for conditioning_node in conditioning_tuple] 127 | past_time = max(max(time_conditioning), time_1) # furthest back in time. TODO: should be only time_1 128 | homology = [] 129 | for idx, _ in enumerate(range(past_time+1, self.time_window_len)): 130 | node_0_new = self._get_node_name(1+idx, node_0_idx) 131 | node_1_new = self._get_node_name(time_1+1+idx, node_1_idx) 132 | conditioning_tuple_new = tuple(self._get_node_name(node_time+1+idx, node_idx) 133 | for node_time, node_idx in zip(time_conditioning, conditioning_idx)) 134 | homology.append((node_0_new, node_1_new, conditioning_tuple_new)) 135 | return homology 136 | 137 | def _get_node_name(self, time_idx, node_idx): 138 | return self.nodes_sets_list[time_idx][node_idx] 139 | 140 | def _pre_calc_conditioning(self, cond_set_size): 141 | self._pre_calc_conditioning_for_edges(cond_set_size, self.possible_contemporaneous_edges) 142 | self._pre_calc_conditioning_for_edges(cond_set_size, self.possible_cross_lag_edges) 143 | 144 | def _pre_calc_conditioning_for_edges(self, cond_set_size, possible_edges): # ToDo: consider moving to ICD class 145 | for node_i, node_j in possible_edges: 146 | if self.graph.is_connected(node_i, node_j): 147 | self.conditioning_set[self.edge_key(node_i, node_j)] = self._get_pdsep_range_sets( 148 | node_i, node_j, cond_set_size) 149 | 150 | def _learn_struct_incremental_step(self, cond_set_size=None): 151 | """ 152 | Learn a single increment, a single ICD step. This treats the generic case for conditioning set sizes >= 2. 153 | :param cond_set_size: create a list of possible conditioning sets of this size, taking into account the 154 | removal of previous edges during this step. Ignored if class-member 'pre_calc_pds' is True 155 | :return: True if the resulting PAG is completed (no more edges can be removed) 156 | """ 157 | if cond_set_size is None: 158 | assert self.is_pre_calc_pds is True 159 | 160 | # reduce exogenous 161 | done1 = self._learn_struct_incremental_for_edges(cond_set_size, self.possible_cross_lag_edges) 162 | # reduce contemporaneous 163 | done2 = self._learn_struct_incremental_for_edges(cond_set_size, self.possible_contemporaneous_edges) 164 | done = done1 and done2 165 | 166 | # Orient edges 167 | self.orient_graph(complete_orientation=done) 168 | return done 169 | 170 | def _learn_struct_incremental_for_edges(self, cond_set_size, possible_edges): 171 | done = True 172 | for node_i, node_j in possible_edges: 173 | if not self.graph.is_connected(node_i, node_j): 174 | continue 175 | 176 | if self.is_pre_calc_pds: 177 | cond_sets = self.conditioning_set[self.edge_key(node_i, node_j)] 178 | else: 179 | cond_sets = self._get_pdsep_range_sets(node_i, node_j, cond_set_size) 180 | 181 | for cond in cond_sets: 182 | done = False # reset 'done' signaling to continue to the next ICD-iteration after the current one 183 | cond_set = cond[0] # get the set of nodes (in [1] there is the sum-of-minimal-distances) 184 | cond_tup = tuple(cond_set) 185 | is_edge_removed = self._test_edge_and_remove_in_homology(node_i, node_j, cond_tup) 186 | if is_edge_removed: 187 | break # stop searching for independence as we found one and updated the graph accordingly 188 | return done 189 | 190 | def _test_edge_and_remove_in_homology(self, node_i, node_j, cond_tup) -> bool: 191 | if self.ci_test.cond_indep(node_i, node_j, cond_tup): 192 | self.graph.delete_edge(node_i, node_j) # remove directed/undirected edge 193 | self.sepset.set_sepset(node_i, node_j, cond_tup) 194 | if self.is_homology: 195 | if len(cond_tup) == 0: 196 | homology = self.get_edge_homology(node_0=node_i, node_1=node_j) 197 | for (node_0, node_1) in homology: 198 | self.graph.delete_edge(node_0, node_1) 199 | self.sepset.set_sepset(node_0, node_1, cond_tup) 200 | else: 201 | homology = self.get_ci_homology(node_0=node_i, node_1=node_j, conditioning_tuple=cond_tup) 202 | for (node_0, node_1, cond_tup_new) in homology: 203 | self.graph.delete_edge(node_0, node_1) 204 | self.sepset.set_sepset(node_0, node_1, cond_tup_new) 205 | return True # edge was removed! (independence was found!) 206 | else: 207 | return False # edge was not removed (independence was not found) 208 | 209 | def _learn_struct_base_step_0(self): 210 | """ 211 | Execute ICD iteration with r = 0. That is, test unconditional independence between every pair of nodes and 212 | remove corresponding edges. Then, orient the graph. The result is a 0-representing PAG. 213 | 214 | :return: 215 | """ 216 | # r = 0: unconditional (marginal) independence tests 217 | for node_i, node_j in self.possible_cross_lag_edges: 218 | self._test_edge_and_remove_in_homology(node_i, node_j, ()) 219 | 220 | for node_i, node_j in self.possible_contemporaneous_edges: 221 | self._test_edge_and_remove_in_homology(node_i, node_j, ()) 222 | 223 | self.orient_graph(complete_orientation=False) 224 | 225 | def _learn_struct_base_step_1(self): 226 | """ 227 | Execute ICD iteration with r = 1. That is, test independence between every pair of nodes conditioned on a single 228 | node, and remove corresponding edges. Then, orient the graph. The result is a 1-representing PAG. 229 | 230 | :return: True if done and no more iterations are required; otherwise False indicating the PAG is not completed. 231 | """ 232 | source_cpdag = self.graph # Not a copy!!! Thus, edge deletions affect consequent CI queries 233 | 234 | # r = 1: conditional independence tests order 1 235 | cond_set_size = 1 236 | done = True 237 | 238 | # Test Cross-Lag Edges 239 | for node_i, node_j in self.possible_cross_lag_edges: 240 | if not source_cpdag.is_connected(node_i, node_j): 241 | continue 242 | 243 | pot_parents_i = self.graph.find_adjacent_nodes(node_i) - {node_j} 244 | pot_parents_j = self.graph.find_adjacent_nodes(node_j) - {node_i} 245 | 246 | cond_sets_i = combinations(pot_parents_i, cond_set_size) 247 | cond_sets_j = combinations(pot_parents_j, cond_set_size) 248 | cond_sets = unique_element_iterator( # unique of 249 | chain(cond_sets_i, cond_sets_j) # neighbors of node_i OR neighbors of node_j 250 | ) 251 | 252 | for cond_set in cond_sets: 253 | done = False 254 | is_edge_removed = self._test_edge_and_remove_in_homology(node_i, node_j, cond_set) 255 | if is_edge_removed: 256 | break # stop searching for independence as we found one and updated the graph accordingly 257 | 258 | # Test Contemporaneous Edges 259 | for node_i, node_j in self.possible_contemporaneous_edges: 260 | if not source_cpdag.is_connected(node_i, node_j): 261 | continue 262 | 263 | pot_parents_i = self.graph.find_adjacent_nodes(node_i) - {node_j} 264 | pot_parents_j = self.graph.find_adjacent_nodes(node_j) - {node_i} 265 | 266 | cond_sets_i = combinations(pot_parents_i, cond_set_size) 267 | cond_sets_j = combinations(pot_parents_j, cond_set_size) 268 | cond_sets = unique_element_iterator( # unique of 269 | chain(cond_sets_i, cond_sets_j) # neighbors of node_i OR neighbors of node_j 270 | ) 271 | 272 | for cond_set in cond_sets: 273 | done = False 274 | is_edge_removed = self._test_edge_and_remove_in_homology(node_i, node_j, cond_set) 275 | if is_edge_removed: 276 | break # stop searching for independence as we found one and updated the graph accordingly 277 | 278 | self.orient_graph(complete_orientation=done) 279 | return done 280 | 281 | def orient_graph(self, complete_orientation=True): 282 | self.graph.reset_orientations(default_mark=Mark.Circle) 283 | self.set_fixed_orientations() 284 | self.graph.orient_v_structures(self.sepset) 285 | self._fill_orientation_homology() 286 | self.graph.maximally_orient_pattern(rules_set=[1, 2, 3, 4]) 287 | if complete_orientation: 288 | if self.is_selection_bias: 289 | self.graph.maximally_orient_pattern(rules_set=[5, 6, 7]) # when selection-bias may be present 290 | if self.is_tail_completeness: 291 | self.graph.maximally_orient_pattern(rules_set=[8, 9, 10]) # for tail-completeness 292 | self._fill_orientation_homology() 293 | 294 | def _fill_orientation_homology(self): 295 | """ 296 | Two consecutive stages: 297 | 1. Create a list of invariant edge-marks (head or tail) 298 | 2. For each edge-mark set edge-marks throughout the homology 299 | :return: 300 | """ 301 | # Step 1: Create a list of invariant edge-marks (head or tail) 302 | invariant_marks = [] # list of invariant edge marks in the graph 303 | for node1, node2 in combinations(self.graph.nodes_set, 2): 304 | if not self.graph.is_connected(node1, node2): 305 | continue 306 | edge_mark12 = self.graph.get_edge_mark(node1, node2) 307 | edge_mark21 = self.graph.get_edge_mark(node2, node1) 308 | if edge_mark12 != Mark.Circle: 309 | invariant_marks.append((node1, node2, edge_mark12)) 310 | if edge_mark21 != Mark.Circle: 311 | invariant_marks.append((node2, node1, edge_mark21)) 312 | 313 | # Step 2: For each edge-mark set edge-marks throughout the homology 314 | for parent, child, edge_mark in invariant_marks: 315 | parent_time = self.node2time[parent] 316 | child_time = self.node2time[child] 317 | min_time = min(parent_time, child_time) 318 | parent_base = self.nodes_sets_list[parent_time - min_time][self.node2idx[parent]] 319 | child_base = self.nodes_sets_list[child_time - min_time][self.node2idx[child]] 320 | if self.node2time[child_base] == 0: 321 | edge_homology = self.get_edge_homology(child_base, parent_base) 322 | elif self.node2time[parent_base] == 0: 323 | edge_homology = self.get_edge_homology(parent_base, child_base) 324 | edge_homology = [(b, a) for (a, b) in edge_homology] # reverse a,b 325 | else: 326 | raise RuntimeError("unexpected error") 327 | edge_homology.append((child_base, parent_base)) 328 | 329 | for (child_node, parent_node) in edge_homology: 330 | self.graph.replace_edge_mark(node_source=parent_node, node_target=child_node, 331 | requested_edge_mark=edge_mark) 332 | -------------------------------------------------------------------------------- /causal_discovery_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/causality-lab/8ad638a2057e3bdf35108b6e63f547dd7f6a95a1/causal_discovery_utils/__init__.py -------------------------------------------------------------------------------- /causal_discovery_utils/cond_indep_tests.py: -------------------------------------------------------------------------------- 1 | # This file contains conditional independence tests 2 | 3 | import math 4 | import numpy as np 5 | from causal_discovery_utils.data_utils import calc_stats 6 | from causal_discovery_utils.data_utils import get_var_size 7 | from graphical_models import DAG, UndirectedGraph, PAG 8 | from scipy import stats 9 | 10 | 11 | class CacheCI: 12 | """ 13 | A cache for CI tests. 14 | """ 15 | def __init__(self, num_vars=None): 16 | """ 17 | Initialize cache 18 | :param num_vars: Number of variables; if None, cache is not initialized 19 | """ 20 | if num_vars is None: 21 | self._cache = None 22 | else: 23 | self._cache = dict() 24 | # for each pair create a dictionary that holds the cached ci test. The sorted condition set is the hash key 25 | for i in range(num_vars - 1): 26 | for j in range(i + 1, num_vars): 27 | hkey, _ = self.get_hkeys(i, j, ()) # get a key for the (i, j) pair (simply order them) 28 | self._cache[hkey] = dict() 29 | 30 | def get_hkeys(self, x, y, zz): 31 | """ 32 | Return a keys for hashing variable-pair and for the condition set 33 | :param x: 1st variable 34 | :param y: 2nd variable 35 | :param zz: Set of variables that consist of the condition set 36 | :return: 37 | """ 38 | hkey = (x, y) if x < y else (y, x) 39 | hkey_cond_set = tuple(sorted(zz)) 40 | return hkey, hkey_cond_set 41 | 42 | def set_cache_result(self, x, y, zz, res): 43 | """ 44 | Set (override previous value) a result to be cached 45 | :param x: 1st variable 46 | :param y: 2nd variable 47 | :param zz: Variables that consists of the condition set 48 | :param res: Result to be cached 49 | :return: 50 | """ 51 | assert self._cache is not None 52 | 53 | hkey, hkey_cond_set = self.get_hkeys(x, y, zz) # get keys for hashing 54 | self._cache[hkey][hkey_cond_set] = res # cache, override previous result 55 | 56 | def get_cache_result(self, x, y, zz): 57 | """ 58 | Get previously cached result 59 | :param x: 1st variable 60 | :param y: 2nd variable 61 | :param zz: Variables that consists of the condition set 62 | :return: Cached result. None if nothing was cached previously 63 | """ 64 | if self._cache is None: # is cache data structure was initialized? 65 | return None 66 | 67 | hkey, hkey_cond_set = self.get_hkeys(x, y, zz) 68 | 69 | if hkey not in self._cache.keys(): # check if variable-pair cache data structure was created 70 | return None 71 | 72 | if hkey_cond_set not in self._cache[hkey].keys(): # check is result was ever cached 73 | return None 74 | 75 | return self._cache[hkey][hkey_cond_set] 76 | 77 | def del_cache(self, x, y, zz): 78 | """ 79 | Removed cached entry. 80 | :param x: 1st variable 81 | :param y: 2nd variable 82 | :param zz: Variables that consists of the condition set 83 | :return: Cached result that was deleted 84 | """ 85 | if self._cache is None: # is cache data structure was initialized? 86 | return None 87 | 88 | hkey, hkey_cond_set = self.get_hkeys(x, y, zz) 89 | 90 | if hkey not in self._cache.keys(): # check if variable-pair cache data structure was created 91 | return None 92 | 93 | if hkey_cond_set not in self._cache[hkey].keys(): # check is result was ever cached 94 | return None 95 | 96 | return self._cache[hkey].pop(hkey_cond_set) 97 | 98 | 99 | class DSep: 100 | """ 101 | An optimal CI oracle that uses the true DAG and returns d-separation result 102 | """ 103 | def __init__(self, true_dag: DAG, count_tests=False, use_cache=False, verbose=False): 104 | assert isinstance(true_dag, DAG) 105 | self.true_dag = true_dag 106 | 107 | self.verbose = verbose 108 | 109 | num_nodes = len(true_dag.nodes_set) 110 | 111 | self.count_tests = count_tests 112 | if count_tests: 113 | self.test_counter = [0 for _ in range(num_nodes-1)] 114 | else: 115 | self.test_counter = None 116 | 117 | self.is_cache = use_cache 118 | if use_cache: 119 | self.cache_ci = CacheCI(num_nodes) 120 | else: 121 | self.cache_ci = CacheCI(None) 122 | 123 | def cond_indep(self, x, y, zz): 124 | res = self.cache_ci.get_cache_result(x, y, zz) 125 | 126 | if res is None: 127 | res = self.true_dag.dsep(x, y, zz) 128 | if self.verbose: 129 | print('d-sep(', x, ',', y, '|', zz, ')', '=', res) 130 | if self.is_cache: 131 | self.cache_ci.set_cache_result(x, y, zz, res) 132 | if self.count_tests: 133 | self.test_counter[len(zz)] += 1 # update counter only if the test was not previously cached 134 | return res 135 | 136 | 137 | class GraphCondIndep: 138 | """ 139 | GraphCondIndep: a CI test that derive its result from a given graph. 140 | Depending on the graph type, an appropriate criterion is used: 141 | DAG type: d-separation criterion 142 | PAG type: m-separation criterion 143 | """ 144 | def __init__(self, reference_graph, static_conditioning=None, count_tests=False, use_cache=False, verbose=False): 145 | """ 146 | Initialize GraphCondIndep, a CI test that derive its result from a given graph. 147 | 148 | :param reference_graph: a graph from which independence relations are inferred. Only DAG and PAG are supported. 149 | :param static_conditioning: a set of nodes that will always be included in the conditioning set. 150 | :param count_tests: if True, count the number of CI test queries (default: False). Mainly for debug 151 | :param use_cache: if True, cache CI tests' results (default: False). Used for avoiding redundant CI tests. 152 | :param verbose: Verbose flag (default: False). Mainly for debug 153 | """ 154 | self.reference_graph = reference_graph 155 | self.verbose = verbose 156 | 157 | if type(reference_graph) == DAG: 158 | self.ci_criterion = reference_graph.dsep 159 | elif type(reference_graph) == PAG: 160 | self.ci_criterion = reference_graph.is_m_separated 161 | else: 162 | raise TypeError('Unsupported graph type.') 163 | 164 | if static_conditioning is None or type(static_conditioning) == tuple: 165 | self.static_conditioning = static_conditioning 166 | else: 167 | raise TypeError('Static conditioning, if defined, should be a tuple.') 168 | 169 | num_nodes = len(reference_graph.nodes_set) 170 | self.count_tests = count_tests 171 | if count_tests: 172 | self.test_counter = [0 for _ in range(num_nodes - 1)] 173 | else: 174 | self.test_counter = None 175 | 176 | self.is_cache = use_cache 177 | if use_cache: 178 | self.cache_ci = CacheCI(num_nodes) 179 | else: 180 | self.cache_ci = CacheCI(None) 181 | 182 | def cond_indep(self, x, y, zz_conditioning): 183 | if self.static_conditioning is None: 184 | zz = zz_conditioning 185 | else: 186 | zz = tuple(set(zz_conditioning + self.static_conditioning)) 187 | 188 | res = self.cache_ci.get_cache_result(x, y, zz) 189 | 190 | if res is None: 191 | res = self.ci_criterion(x, y, zz) 192 | if self.verbose: 193 | print(self.ci_criterion.__name__, '(', x, ',', y, '|', zz, ')', '=', res) 194 | if self.is_cache: 195 | self.cache_ci.set_cache_result(x, y, zz, res) 196 | if self.count_tests: 197 | self.test_counter[len(zz)] += 1 # update counter only if the test was not previously cached 198 | return res 199 | 200 | 201 | class StatCondIndep: 202 | def __init__(self, 203 | dataset, threshold, database_type, weights=None, 204 | retained_edges=None, count_tests=False, use_cache=False, verbose=False, 205 | num_records=None, num_vars=None): 206 | """ 207 | Base class for statistical conditional independence tests 208 | :param dataset: 209 | :param threshold: 210 | :param database_type: data type (e,g., int) 211 | :param weights: an array of values indicating weight of each individual data sample 212 | :param retained_edges: an undirected graph containing edges between nodes that are dependent (not to be tested) 213 | :param count_tests: if True, count the number of CI test queries (default: False). Mainly for debug 214 | """ 215 | self.verbose = verbose 216 | 217 | if dataset is not None: 218 | assert num_records is None and num_vars is None 219 | data = np.array(dataset, dtype=database_type) 220 | num_records, num_vars = data.shape 221 | else: 222 | data = None 223 | assert num_records is not None and num_records > 0 224 | assert num_vars is not None and num_vars > 0 225 | 226 | if retained_edges is None: 227 | self.retained_graph = UndirectedGraph(set(range(num_vars))) 228 | self.retained_graph.create_empty_graph() 229 | else: 230 | self.retained_graph = retained_edges 231 | 232 | node_size = None 233 | if data is not None: 234 | node_size = get_var_size(data) 235 | 236 | self.data = data 237 | self.num_records = num_records 238 | self.num_vars = num_vars 239 | self.node_size = node_size 240 | self.threshold = threshold 241 | self.weights = weights 242 | 243 | # Initialize counter of CI tests per conditioning set size 244 | self.count_tests = count_tests 245 | if count_tests: 246 | self.test_counter = [0 for _ in range(num_vars-1)] 247 | else: 248 | self.test_counter = None 249 | 250 | # Initialize cache 251 | self.is_cache = use_cache 252 | if use_cache: 253 | self.cache_ci = CacheCI(num_vars) 254 | else: 255 | self.cache_ci = CacheCI(None) 256 | 257 | def cond_indep(self, x, y, zz): 258 | if self.is_edge_retained(x, y): 259 | return False # do not test and return: "not independent" 260 | 261 | statistic = self.cache_ci.get_cache_result(x, y, zz) 262 | 263 | if statistic is None: 264 | statistic = self.calc_statistic(x, y, zz) # calculate correlation level 265 | self._debug_process(x, y, zz, statistic) 266 | self._cache_it(x, y, zz, statistic) 267 | 268 | res = statistic > self.threshold # test if p-value is greater than the threshold 269 | return res 270 | 271 | def calc_statistic(self, y, x, zz): 272 | return None # you must override this function in inherited classes 273 | 274 | def _debug_process(self, x, y, zz, res): 275 | """ 276 | Handles all tasks required for debug 277 | """ 278 | if self.verbose: 279 | print('Test: ', 'CI(', x, ',', y, '|', zz, ')', '=', res) 280 | if self.count_tests: 281 | self.test_counter[len(zz)] += 1 282 | 283 | def _cache_it(self, x, y, zz, res): 284 | """ 285 | Handles all task required after calculating the CI statistic 286 | """ 287 | if self.is_cache and (res is not None): 288 | self.cache_ci.set_cache_result(x, y, zz, res) 289 | 290 | def is_edge_retained(self, x, y): 291 | return self.retained_graph.is_connected(x, y) 292 | 293 | 294 | class CondIndepParCorr(StatCondIndep): 295 | def __init__(self, threshold, dataset, weights=None, retained_edges=None, count_tests=False, use_cache=False, 296 | num_records=None, num_vars=None, verbose=False): 297 | if weights is not None: 298 | raise Exception('weighted Partial-correlation is not supported. Please avoid using weights.') 299 | super().__init__(dataset, threshold, database_type=float, weights=weights, retained_edges=retained_edges, 300 | count_tests=count_tests, use_cache=use_cache, num_records=num_records, num_vars=num_vars, 301 | verbose=verbose) 302 | 303 | self.correlation_matrix = None 304 | if self.data is not None: 305 | self.correlation_matrix = np.corrcoef(self.data, rowvar=False) # np.corrcoef(self.data.T) 306 | self.data = None # no need to store the data, as we have the correlation matrix 307 | 308 | def calc_statistic(self, x, y, zz): 309 | corr_coef = self.correlation_matrix # for readability 310 | if len(zz) == 0: 311 | if corr_coef[x, y] >= 1.0: 312 | return 0 313 | 314 | par_corr = corr_coef[x, y] 315 | elif len(zz) == 1: 316 | z = zz[0] 317 | 318 | if corr_coef[x, z] >= 1.0 or corr_coef[y, z] >= 1.0: 319 | return 0 320 | 321 | par_corr = ( 322 | (corr_coef[x, y] - corr_coef[x, z] * corr_coef[y, z]) / 323 | np.sqrt((1 - np.power(corr_coef[x, z], 2)) * (1 - np.power(corr_coef[y, z], 2))) 324 | ) 325 | else: # zz contains 2 or more variables 326 | all_var_idx = (x, y) + zz 327 | corr_coef_subset = corr_coef[np.ix_(all_var_idx, all_var_idx)] 328 | inv_corr_coef = -np.linalg.inv(corr_coef_subset) # consider using pinv instead of inv 329 | par_corr = inv_corr_coef[0, 1] / np.sqrt(abs(inv_corr_coef[0, 0] * inv_corr_coef[1, 1])) 330 | 331 | if abs(par_corr) >= 1.0: # if outside the range [-1,+1] assume the variables are dependent 332 | return 0 333 | 334 | degrees_of_freedom = self.num_records - (len(zz) + 2) # degrees of freedom to be used to calculate p-value 335 | 336 | # # Calculate based on the t-distribution 337 | # t_statistic = par_corr * np.sqrt(degrees_of_freedom / (1.-par_corr*par_corr)) # approximately t-distributed 338 | # statistic = 2 * stats.t.sf(abs(t_statistic), degrees_of_freedom) # p-value 339 | 340 | # Estimation based on Fisher z-transform 341 | z = 0.5 * np.log1p(2 * par_corr / (1 - par_corr)) # Fisher Z-transform, 0.5*log( (1+par_corr)/(1-par_corr) ) 342 | val_for_cdf = abs(np.sqrt(degrees_of_freedom - 1) * z) # approximately normally distributed 343 | statistic = 2 * (1 - stats.norm.cdf(val_for_cdf)) # p-value 344 | 345 | return statistic 346 | 347 | 348 | class CondIndepCMI(StatCondIndep): 349 | def __init__(self, dataset, threshold, weights=None, retained_edges=None, count_tests=False, use_cache=False): 350 | self.weight_data_type = float 351 | if weights is not None: 352 | weights = np.array(weights, dtype=self.weight_data_type) 353 | # if np.min(weights) < 0: 354 | # raise Exception('Negative sample weights are not allowed') 355 | # if np.abs(np.sum(weights) - 1.0) > np.finfo(self.weight_data_type).eps: 356 | # raise Exception('Sample weights do not sum to 1.0') 357 | # weights *= dataset.shape[0] 358 | super().__init__(dataset, threshold, database_type=int, weights=weights, retained_edges=retained_edges, 359 | count_tests=count_tests, use_cache=use_cache) 360 | 361 | def cond_indep(self, x, y, zz): 362 | res = super().cond_indep(x, y, zz) 363 | return not res # invert the decision because the statistic is correlation level and not p-value 364 | 365 | def calc_statistic(self, x, y, zz): 366 | """ 367 | Calculate conditional mutual information for discrete variables 368 | :param x: 1st variable (index) 369 | :param y: 2nd variable (index) 370 | :param zz: condition set, a tuple. e.g., if zz contains a single value zz = (val,) 371 | :return: Empirical conditional mutual information 372 | """ 373 | all_var_idx = (x, y) + zz 374 | dd = self.data[:, all_var_idx] 375 | var_size = [self.node_size[node_i] for node_i in all_var_idx] 376 | 377 | hist_count = calc_stats(data=dd, var_size=var_size, weights=self.weights) 378 | if hist_count is None: # memory error 379 | return 0 380 | hist_count = np.reshape(hist_count, [var_size[0], var_size[1], -1], 381 | order='F') # 3rd axis is the states of condition set 382 | cmi = self._calc_cmi_from_counts(hist_count) 383 | # 384 | # xsize, ysize, csize = hist_count.shape 385 | # 386 | # # Calculate conditional mutual information 387 | # cmi = 0 388 | # for zi in range(csize): 389 | # cnt = hist_count[:, :, zi] 390 | # cnum = cnt.sum() 391 | # for node_i in range(self.node_size[x]): 392 | # for node_j in range(self.node_size[y]): 393 | # if cnt[node_i, node_j] > 0: 394 | # cnt_val = cnt[node_i, node_j] 395 | # cx = cnt[:, node_j].sum() # sum over y for specific x-state 396 | # cy = cnt[node_i, :].sum() # sum over x for specific y-state 397 | # 398 | # lg = math.log(cnt_val*cnum / (cx * cy)) 399 | # cmi_ = lg*cnt_val/self.num_records 400 | # cmi += cmi_ 401 | return cmi 402 | 403 | def _calc_cmi_from_counts(self, hist_count): 404 | xsize, ysize, csize = hist_count.shape 405 | 406 | # Calculate conditional mutual information 407 | cmi = 0 408 | for zi in range(csize): 409 | cnt = hist_count[:, :, zi] 410 | cnum = cnt.sum() 411 | for node_i in range(xsize): 412 | for node_j in range(ysize): 413 | if cnt[node_i, node_j] > 0: 414 | cnt_val = cnt[node_i, node_j] 415 | cx = cnt[:, node_j].sum() # sum over y for specific x-state 416 | cy = cnt[node_i, :].sum() # sum over x for specific y-state 417 | 418 | lg = math.log(cnt_val*cnum / (cx * cy)) 419 | cmi_ = lg*cnt_val/self.num_records 420 | cmi += cmi_ 421 | return cmi 422 | -------------------------------------------------------------------------------- /causal_discovery_utils/constraint_based.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | 3 | 4 | class LearnStructBase: 5 | """ 6 | Base class for constraint-based structure learning algorithms 7 | """ 8 | def __init__(self, graph_class, nodes_set, ci_test): 9 | if not isinstance(nodes_set, set): 10 | raise ValueError('nodes_set should be a set type') 11 | self.ci_test = ci_test 12 | self.sepset = SeparationSet(nodes_set) 13 | self.graph = graph_class(nodes_set) # e.g., graph_class=PDAG (for CPDAG under causal sufficiency) or PAG 14 | self.graph.sepset = self.sepset # link the algorithm's updated sepset to the graph (by reference) 15 | 16 | def get_graph(self): 17 | return self.graph.get_graph() 18 | 19 | 20 | class SeparationSet: 21 | def __init__(self, nodes_set): 22 | self._sepset = dict() 23 | self.nodes_set = nodes_set 24 | for (i, j) in combinations(nodes_set, 2): 25 | hkey = self.get_hash_key(i, j) 26 | self._sepset[hkey] = set() 27 | 28 | @staticmethod 29 | def get_hash_key(node_1, node_2): 30 | """ 31 | Get the hash key used to store separation sets 32 | :param node_1: 33 | :param node_2: 34 | :return: Hash key for sepset dictionary 35 | """ 36 | hkey = (node_1, node_2) if node_1 < node_2 else (node_2, node_1) 37 | return hkey 38 | 39 | def erase(self): 40 | for key in self._sepset: 41 | self._sepset[key] = set() 42 | 43 | def set_sepset(self, node_1, node_2, sepset): 44 | hkey = self.get_hash_key(node_1, node_2) 45 | self._sepset[hkey] = set(sepset) 46 | 47 | def get_sepset(self, node_1, node_2): 48 | hkey = self.get_hash_key(node_1, node_2) 49 | return self._sepset[hkey] 50 | 51 | def copy(self, nodes=None, target_sepset=None): 52 | if nodes is None: 53 | nodes = self.nodes_set 54 | 55 | if target_sepset is None: 56 | target_sepset = SeparationSet(nodes) 57 | 58 | for (i, j) in combinations(nodes, 2): 59 | hkey_source = self.get_hash_key(i, j) 60 | hkey_target = target_sepset.get_hash_key(i, j) 61 | target_sepset._sepset[hkey_target] = self._sepset[hkey_source].copy() # create a copy of the separation-set 62 | 63 | return target_sepset 64 | 65 | def copy_from(self, source_sepset, nodes): 66 | """ 67 | Selectively copy values from another SeparationSet instance 68 | :param source_sepset: source SeparationSet instance 69 | :param nodes: Nodes of interest (separation sets for pairs of these nodes will be copied) 70 | :return: 71 | """ 72 | for (i, j) in combinations(nodes, 2): 73 | hkey_target = self.get_hash_key(i, j) 74 | self._sepset[hkey_target] = source_sepset.get_sepset(i, j).copy() # copy separation-sets from external 75 | 76 | 77 | def unique_element_iterator(chained_iterators): 78 | """ 79 | return the unique instances of the chained iterators 80 | :param an iterator with possibly repeating elements, e.g., chained_iterators: chain(iter_a, iter_b) 81 | :return: an iterator with unique (unordered) elements 82 | """ 83 | seen = set() 84 | for e in chained_iterators: 85 | if e in seen: 86 | continue 87 | 88 | seen.add(e) 89 | yield e 90 | -------------------------------------------------------------------------------- /causal_discovery_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_var_size(data): 5 | num_records, num_vars = data.shape 6 | node_size = np.zeros((num_vars,), dtype='int') 7 | for var in range(num_vars): 8 | node_size[var] = data[:, var].max() + 1 9 | # node_size = data.max(axis=0)+1 # number of node states (smallest state 0) 10 | return node_size 11 | 12 | 13 | def calc_stats(data, var_size, weights=None): 14 | """ 15 | Calculate the counts of instances in the data 16 | :param data: a dataset of categorical features 17 | :param var_size: a vector defining the cardinalities of the features 18 | :param weights: a vector of non-negative weights for data samples. 19 | :return: 20 | """ 21 | sz_cum_prod = [1] 22 | for node_i in range(len(var_size) - 1): 23 | sz_cum_prod += [sz_cum_prod[-1] * var_size[node_i]] 24 | 25 | sz_prod = sz_cum_prod[-1] * var_size[-1] 26 | 27 | data_idx = np.dot(data, sz_cum_prod) 28 | try: 29 | # hist_count, _ = np.histogram(data_idx, np.arange(sz_prod + 1), weights=weights) 30 | hist_count = np.bincount(data_idx, minlength=sz_prod, weights=weights) 31 | except MemoryError as error: 32 | print('Out of memory') 33 | return None 34 | return hist_count 35 | 36 | 37 | def unroll_temporal_data(data_full, observed_nodes_list, window_len, t_step=1, reverse_temporal_order=False): 38 | """ 39 | Unroll temporally sorted data samples into the defined time window. For example, if the time window is 2, then 40 | each two consecutive data samples are concatenated into a single sample. The new samples are sorted temporally. 41 | :param data_full: Temporally sorted data sample. Each sample consists of jointly measured values. 42 | :param observed_nodes_list: indexes of columns in the original data that correspond to 'observed' variables. 43 | :param window_len: Length (number of time-stamps) of the unrolling window (current-time + window_len-1 past-steps). 44 | :param t_step: Nuber of time-step skips between unrolled samples (default is 1). 45 | :param reverse_temporal_order: if True, then indexes [0..n] represent the latest instances and 46 | higher indexes represent past instances. If False, [0..n] represent indexes of the earliest instances and 47 | higher indexes represent future instances. 48 | :return: An unrolled temporally sorted data samples. 49 | """ 50 | n_samples = data_full.shape[0] # number of data samples 51 | n_contemporaneous_nodes = data_full.shape[1] # number of contemporaneous (jointly measured) variables 52 | num_nodes_unrolled = window_len * n_contemporaneous_nodes # number of variables in a single unrolled sample 53 | 54 | # calculate the starting time index for each unrolled sample 55 | starts = [xid * n_contemporaneous_nodes for xid in np.arange(0, n_samples + 1 - window_len, t_step)] 56 | 57 | # create unrolled data 58 | data_full_unrolled = np.zeros((len(starts), num_nodes_unrolled)) # initialize unrolled data 59 | for i, st in enumerate(starts): 60 | data_full_unrolled[i] = data_full.flat[st:st + num_nodes_unrolled] 61 | 62 | # indexes of variables in the unrolled data 63 | _nodes_sets_list_full = np.reshape(range(num_nodes_unrolled), (window_len, n_contemporaneous_nodes)) 64 | _nodes_sets_list = [_nodes_sets_list_full[i, observed_nodes_list].tolist() for i in range(window_len)] # observed 65 | 66 | if reverse_temporal_order: 67 | # reverse time order such that indexes [0..n_contemporaneous_nodes-1] represent the latest instances, 68 | # whereas larger indexes represent earlier instances. 69 | data_full_unrolled_reversed = np.zeros_like(data_full_unrolled) 70 | for idxes, idxes_rev in zip(_nodes_sets_list, reversed(_nodes_sets_list)): 71 | data_full_unrolled_reversed[:, idxes] = data_full_unrolled[:, idxes_rev] 72 | data_full_unrolled = data_full_unrolled_reversed 73 | else: 74 | # increasing time order such that indexes [0..n_contemporaneous_nodes-1] represent the earliest instances, 75 | # whereas larger indexes represent future instances. 76 | pass 77 | 78 | return data_full_unrolled, _nodes_sets_list_full, _nodes_sets_list 79 | -------------------------------------------------------------------------------- /causal_discovery_utils/performance_measures.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import numpy as np 3 | from scipy.special import gammaln 4 | from graphical_models import PDAG, DAG 5 | from causal_discovery_utils.data_utils import calc_stats 6 | from causal_discovery_utils.cond_indep_tests import DSep # perfect CI oracle, used here to find the true PAG 7 | from causal_discovery_algs.icd import LearnStructICD # used here to find the true PAG 8 | 9 | 10 | def find_true_pag(true_dag, true_observed_set): 11 | perfect_ci_test = DSep(true_dag=true_dag) 12 | icd_true = LearnStructICD(nodes_set=true_observed_set, ci_test=perfect_ci_test) 13 | icd_true.learn_structure() # find the PAG using the ICD algorithm 14 | return icd_true.graph 15 | 16 | 17 | def calc_skeleton_accuracy(graph_tested, graph_correct): 18 | num_true_positive = 0 19 | num_false_negative = 0 # missing edges 20 | num_false_positive = 0 # extra edges 21 | num_true_negative = 0 22 | num_edges_true = 0 # count the number of edges in the true graph (== false_negative + true_positive) 23 | 24 | for (node_i, node_j) in combinations(graph_correct._graph, 2): 25 | 26 | # calculate edge errors 27 | if graph_correct.is_connected(node_i, node_j): 28 | num_edges_true += 1 # count the number of edges in the true PAG 29 | if graph_tested.is_connected(node_i, node_j): 30 | num_true_positive += 1 31 | else: 32 | num_false_negative += 1 33 | 34 | else: # there is no edge in the true graph 35 | if graph_tested.is_connected(node_i, node_j): 36 | num_false_positive += 1 37 | else: 38 | num_true_negative += 1 39 | 40 | edge_precision = num_true_positive / (num_false_positive + num_true_positive) 41 | edge_recall = num_true_positive / num_edges_true 42 | edge_f1 = 2 * edge_precision * edge_recall / (edge_precision + edge_recall) # 2 / (1/recision + 1/recall) 43 | FPR = num_false_positive / (num_false_positive+num_true_negative) # false positive rate (FPR) 44 | FNR = num_false_negative / (num_false_negative+num_true_positive) # false negative rate (FNR) 45 | skeleton_accuracy = { 46 | 'edge_precision': edge_precision, 47 | 'edge_recall': edge_recall, 48 | 'edge_F1': edge_f1, 49 | 'FPR': FPR, 50 | 'FNR': FNR 51 | } 52 | return skeleton_accuracy 53 | 54 | 55 | def structural_hamming_distance_cpdag(tested_graph: PDAG, true_graph: PDAG): 56 | """ 57 | Measure structural hamming distance between two CPDAGs. The following are calculated: 58 | * Edges (directed or undirected) 59 | ** Missing: an edge missing from the tested graph but existing in the true graph 60 | ** Extra: an edge in the tested graph but missing from the true graph 61 | * Arrowhead (for edges existing existing on both graphs) 62 | ** Missing: undirected in tested graph but directed in true graph 63 | ** Extra: directed in tested graph but undirected in true graph 64 | ** Reversed: directed in opposite direction 65 | 66 | The total SHD is the sum of all values in the returned dictionary 67 | 68 | :param tested_graph: 69 | :param true_graph: 70 | :return: A nested dictionary 71 | """ 72 | if (not isinstance(true_graph, PDAG)) or (not isinstance(tested_graph, PDAG)): 73 | raise ValueError 74 | 75 | shd_edge = {'missing': 0, 'extra': 0} 76 | shd_arrowhead = {'missing': 0, 'extra': 0, 'reversed': 0} 77 | for (node_i, node_j) in combinations(true_graph._graph, 2): 78 | if not tested_graph.is_connected(node_i, node_j): # if edge is missing from the tested graph 79 | if true_graph.is_connected(node_i, node_j): # if edge exists in true graph 80 | shd_edge['missing'] += 1 81 | elif not true_graph.is_connected(node_i, node_j): # edge exists in tested graph; check is missing from true 82 | shd_edge['extra'] += 1 83 | else: # edge exists in both true and tested graphs 84 | # now check direction error 85 | if node_i in true_graph.undirected_neighbors(node_j): # if the edge is undirected in the true graph 86 | if (node_i in tested_graph.parents(node_j)) or (node_j in tested_graph.parents(node_i)): 87 | shd_arrowhead['extra'] += 1 # the edge in the tested graph is directed 88 | elif node_i in tested_graph.undirected_neighbors(node_j): # directed in true; check if undirected in tested 89 | shd_arrowhead['missing'] += 1 90 | else: # both edges are directed 91 | (source, target) = (node_i, node_j) if node_i in true_graph.parents(node_j) else (node_j, node_i) 92 | if target in tested_graph.parents(source): 93 | shd_arrowhead['reversed'] += 1 # the edges are not directed in the same direction (i --> j) 94 | 95 | shd_total = sum(shd_edge.values()) + sum(shd_arrowhead.values()) 96 | 97 | return {'total': shd_total, 'edge': shd_edge, 'arrowhead': shd_arrowhead} 98 | 99 | 100 | def score_bdeu(dag: DAG, data, node_size, en_nodes=None): 101 | """ 102 | Calculate the BDeu score of a DAG 103 | :param dag: DAG to be scored 104 | :param data: dataset of discrete random variables from which to calculate the score 105 | :param node_size: sizes of the nodes: number of possible values for each variable in the dataset 106 | :param en_nodes: the score will be calculated for the sub-graph induced by these nodes 107 | :return: BDeu score 108 | """ 109 | if dag is None: 110 | return -float('inf') 111 | 112 | assert isinstance(dag, DAG) # graph must be a DAG 113 | 114 | if en_nodes is None: 115 | en_nodes = dag.nodes_set 116 | 117 | score = 0 118 | for node in en_nodes: 119 | parents = dag.parents(node) 120 | family = tuple(parents) + (node, ) # a tuple family nodes where the child ("node") is last 121 | family_sizes = [node_size[node_i] for node_i in family] 122 | family_data = data[:, family] 123 | counts = calc_stats(family_data, family_sizes) 124 | if counts is None: # memory error 125 | return -float('inf') 126 | counts = np.reshape(counts, [-1, family_sizes[-1]], order='F') # 2nd axis is the states of the child 127 | 128 | prior = np.ones_like(counts) 129 | prior = prior/prior.sum() 130 | 131 | score += score_family_dirichlet(counts=counts, prior=prior) 132 | 133 | return score 134 | 135 | 136 | def score_family_dirichlet(counts, prior): 137 | """ 138 | Score a family: a node and its parents 139 | 140 | :param counts: a matrix of counts, where the 2nd dimension belongs to the child in the family 141 | :param prior: prior 142 | :return: score of the family 143 | """ 144 | lu = (gammaln(prior + counts) - gammaln(prior)).sum(axis=1) 145 | alpha_ij = prior.sum(axis=1) 146 | n_ij = counts.sum(axis=1) 147 | lv = gammaln(alpha_ij) - gammaln(alpha_ij + n_ij) 148 | 149 | family_log_likelihood = (lu + lv).sum() 150 | return family_log_likelihood 151 | 152 | 153 | def calc_structural_accuracy_pag(pag_tested, pag_correct): 154 | """ 155 | Calculate structural accuracy: 156 | - Edge accuracy: precision and recall 157 | - Orientation accuracy: number of correctly identified edge-marks (variant:'o--', head:'<--', tail:'---') 158 | :param pag_tested: 159 | :param pag_correct: 160 | :return: a dictionary of the form: 161 | { 'edge_precision': value, 'edge_recall': value, 'orientation_correctness': value } 162 | """ 163 | 164 | num_orient_correct = 0 165 | num_orient_total = 0 # number edge-marks in the true PAG 166 | 167 | for (node_i, node_j) in combinations(pag_correct._graph, 2): 168 | if pag_correct.is_connected(node_i, node_j): 169 | num_orient_total += 2 170 | for edge_mark in pag_correct.edge_mark_types: # check which edge mark is present 171 | if node_i in pag_correct._graph[node_j][edge_mark] and node_i in pag_tested._graph[node_j][edge_mark]: 172 | num_orient_correct += 1 173 | if node_j in pag_correct._graph[node_i][edge_mark] and node_j in pag_tested._graph[node_i][edge_mark]: 174 | num_orient_correct += 1 175 | 176 | edge_accuracy = calc_skeleton_accuracy(pag_tested, pag_correct) 177 | edge_precision = edge_accuracy['edge_precision'] 178 | edge_recall = edge_accuracy['edge_recall'] 179 | edge_f1 = edge_accuracy['edge_F1'] 180 | causal_accuracy = num_orient_correct/num_orient_total # percentage of correct orientations 181 | 182 | result = { 183 | 'FPR': edge_accuracy['FPR'], 184 | 'FNR': edge_accuracy['FNR'], 185 | 'edge_precision': edge_precision, 186 | 'edge_recall': edge_recall, 187 | 'edge_F1': edge_f1, 188 | 'orientation_correctness': causal_accuracy 189 | } 190 | return result 191 | 192 | 193 | def calc_skeleton_fnr_fpr(graph_tested, graph_correct): 194 | # ToDo: implement 195 | skeleton_acc = calc_skeleton_accuracy(graph_tested=graph_tested, graph_correct=graph_correct) 196 | res = {'FPR': skeleton_acc['FPR'], 'FNR': skeleton_acc['FNR']} 197 | return res 198 | -------------------------------------------------------------------------------- /causal_discovery_utils/stat_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def cov_to_corr(cov_matrix: np.ndarray): 5 | """ 6 | Convert a square covariance matrix (COV) into a correlation matrix. An element (i, j) of the correlation matrix is 7 | equal to COV(i,j) / sqrt(COV(i,i)COV(j,j)). 8 | 9 | :param cov_matrix: input covariance matrix in numpy.ndarray format 10 | :return: correlation matrix. 11 | """ 12 | assert cov_matrix.ndim == 2 # matrix 13 | assert cov_matrix.shape[0] == cov_matrix.shape[1] # square 14 | min_variance = 1e-8 # minimal variance allowed 15 | diag = np.sqrt(np.diag(cov_matrix)) 16 | assert np.all(diag > min_variance) # variance is smaller than the supported minimal value 17 | inv_std = 1.0 / diag 18 | 19 | # Calculate the correlation matrix 20 | correlation_matrix = cov_matrix * inv_std * inv_std[:, np.newaxis] # Cor = (D^-1) @ Cov @ (D^-1), D=sqrt(diag(Cov)) 21 | 22 | return correlation_matrix 23 | -------------------------------------------------------------------------------- /causal_reasoning/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .cleann_explainer import CLEANN -------------------------------------------------------------------------------- /causal_reasoning/cleann_explainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from causal_discovery_algs import LearnStructICD 3 | from causal_discovery_algs.icd import create_pds_tree 4 | from causal_discovery_utils.cond_indep_tests import CondIndepParCorr 5 | from causal_discovery_utils.stat_utils import cov_to_corr 6 | 7 | 8 | class CLEANN: 9 | def __init__(self, attention_matrix: np.ndarray, num_samples, p_val_th: float, explanation_tester, 10 | nodes_set=None, search_minimal=True, structure_learning_class=LearnStructICD): 11 | """ 12 | Initialize a CLEANN explainer [https://arxiv.org/abs/2310.20307]. 13 | :param attention_matrix: self-attention, square matrix from which to infer inter-token relations. 14 | :param num_samples: number of samples from which the attention was calculated (e.g., config.hidden_size). 15 | :param p_val_th: p-value threshold for the partial-correlation-based independence test. 16 | :param explanation_tester: an externally-defined function that tests if a subset of tokens is an explanation. 17 | It takes as input: a list of all tokens, a list of indexes of tokens subset to be considered as explanation, 18 | and the index of the target token for which we are seeking an explanation. 19 | It outputs True if the explanation is confirmed, otherwise False. 20 | If None is set, then the largest potential explanation will be returned as an explanation by method explain. 21 | :param nodes_set: a list of tokens, where each token will be a node in the learned graph. 22 | :param search_minimal: If True, only the explanations with the minimal size will be returned. If False, all 23 | the explanations found from the graph will be returned. 24 | :param structure_learning_class: structure learning class to instantiate. Default: LearnStructICD. 25 | """ 26 | 27 | # calculate correlation matrix from attention matrix 28 | cov_matrix = np.matmul(attention_matrix, attention_matrix.transpose()) # COV = A @ transpose(A) 29 | corr_mat = cov_to_corr(cov_matrix) 30 | 31 | # prepare for learning a graph 32 | num_vars, _ = corr_mat.shape # number of graph-nodes 33 | if nodes_set is None: 34 | nodes_set = set(range(num_vars)) 35 | self.nodes_set = nodes_set 36 | self.ci_test = CondIndepParCorr( 37 | threshold=p_val_th, dataset=None, num_records=num_samples, num_vars=num_vars, 38 | count_tests=True, use_cache=True) 39 | self.ci_test.correlation_matrix = corr_mat 40 | self.StructureLearning = structure_learning_class 41 | self.graph = None 42 | 43 | # initialize for evaluating explanations 44 | self.results = dict() # explanations found by the 'explain' method will be stored in this dictionary 45 | self.is_explanation = explanation_tester 46 | self._search_minimal = search_minimal 47 | 48 | def learn_graph(self): 49 | icd_alg = self.StructureLearning(nodes_set=self.nodes_set, ci_test=self.ci_test) # init structure learner 50 | icd_alg.learn_structure() 51 | return icd_alg.graph 52 | 53 | def explain(self, target_node_idx, max_set_size=None, max_range=None): 54 | """ 55 | Identify an explanation for the given target node. The result is stored in self.results['explanations']. 56 | :param target_node_idx: index of the node to be explained. 57 | :param max_set_size: setting a value limits the search to look for explanations 58 | having at most max_set_size nodes. 59 | :param max_range: setting a value limits the search to look for explanations 60 | such that the distance on the graph between the explanation nodes and the target node is at most max_range. 61 | containing nodes having at most max_range distance on the graph from the target node. 62 | :return: a list of minimal explanations (all explanations have the same size). 63 | """ 64 | 65 | # learn a Graph if one haven't been learned already 66 | if self.graph is None: 67 | self.graph = self.learn_graph() 68 | 69 | # create a PDS-tree rooted at the target node 70 | pds_tree, full_explain_set = create_pds_tree(self.graph, target_node_idx, max_depth=max_range) 71 | max_pds_tree_depth = pds_tree.get_max_depth() 72 | results = dict() 73 | results['pds_tree'] = pds_tree 74 | results['full_explanation_set'] = full_explain_set 75 | results['max_pds_tree_depth'] = max_pds_tree_depth 76 | 77 | if max_set_size is None: 78 | max_size = len(full_explain_set) 79 | else: 80 | max_size = max_set_size 81 | 82 | explanations_list = [] 83 | if self.is_explanation is None: 84 | if len(full_explain_set) <= max_size: 85 | explanations_list.append([full_explain_set, max_size]) 86 | else: 87 | found_explanation = False 88 | for set_size in range(1, max_size+1): 89 | sets_list = pds_tree.get_subsets_list(set_nodes=full_explain_set, subset_size=set_size) 90 | sets_list.sort(key=lambda x: x[1]) # sort with respect to the sum of minimal distances 91 | for possible_explanation_set in sets_list: 92 | if self.is_explanation(list(possible_explanation_set[0]), target_node_idx): 93 | explanations_list.append(possible_explanation_set) 94 | found_explanation = True 95 | if found_explanation and self._search_minimal: 96 | break 97 | 98 | results['explanations'] = explanations_list 99 | self.results[target_node_idx] = results 100 | return explanations_list 101 | -------------------------------------------------------------------------------- /example_data/Alarm1_data/Alarm1_graph.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 5 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 | 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 | 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 10 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 11 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 12 | 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 13 | 0 0 0 0 1 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 | 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 | 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 16 | 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 17 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 18 | 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 19 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 20 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 22 | 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 23 | 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 24 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 25 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 26 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 27 | 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 28 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 29 | 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 30 | 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 31 | 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 32 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 33 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 34 | 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 35 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 36 | 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 37 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 38 | -------------------------------------------------------------------------------- /experiment_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/causality-lab/8ad638a2057e3bdf35108b6e63f547dd7f6a95a1/experiment_utils/__init__.py -------------------------------------------------------------------------------- /experiment_utils/explanation.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | 3 | 4 | def exhaustive_search_explanation(relevant_tokens_pos, target_token, is_explanation, search_minimal=True): 5 | """ 6 | Exhaustive search for the minimal subset of tokens that complies with the explanation-definition 7 | for the target node. 8 | 9 | :param relevant_tokens_pos: indexes of tokens from which to search for a minimal subset of explanation. 10 | :param target_token: the index of the target token needed to be explained. 11 | :param is_explanation: a function that tests if a subset is an explanation for a target token. It contains the 12 | definition of explanation for the target token. It takes as input: 13 | tokens_list: a list of tokens 14 | explanation_token_pos: indexes of those tokens to be considered as explanation 15 | target_pos: index of the target node to be explained 16 | :param search_minimal: 17 | :return: 18 | """ 19 | found_flag = False 20 | minimal_explanations = [] 21 | n_nodes = len(relevant_tokens_pos) 22 | for set_size in range(1, n_nodes): 23 | for explanation_subset in combinations(relevant_tokens_pos, set_size): 24 | if target_token in explanation_subset: 25 | continue 26 | if is_explanation(list(explanation_subset), target_token): 27 | minimal_explanations.append(explanation_subset) 28 | found_flag = True 29 | if search_minimal and found_flag: 30 | break # do not search for larger explaining sets 31 | return minimal_explanations -------------------------------------------------------------------------------- /experiment_utils/synthetic_graphs.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from graphical_models import DAG 4 | 5 | 6 | def create_random_dag(num_nodes, expected_neighborhood_size): 7 | nodes_set = set(range(num_nodes)) 8 | dag = DAG(nodes_set) 9 | 10 | # The probability of the presence of a directed edge is Bernoulli(expected_neighborhood_size/(num_nodes-1)) 11 | for node_parent in range(num_nodes-1): 12 | for node_child in range(node_parent+1, num_nodes): 13 | is_edge = random.random() < (expected_neighborhood_size / (num_nodes-1)) 14 | if is_edge: 15 | dag.add_edges({node_parent}, node_child) 16 | 17 | return dag 18 | 19 | 20 | def create_random_dag_max_fan(num_nodes, max_fan_in): 21 | """ 22 | Create a random DAG with bounded fan-in (number of parents per node) 23 | :param num_nodes: Number of nodes in the graph 24 | :param max_fan_in: Maximal number of parents per child 25 | :return: A random DAG with bounded fan-in 26 | """ 27 | nodes_set = set(range(num_nodes)) 28 | dag = DAG(nodes_set) # create an empty graph 29 | 30 | for current_node_id in range(1, num_nodes): 31 | num_parents = random.randint(0, min(current_node_id, max_fan_in)) # sample number of parents 32 | if num_parents > 0: 33 | parents = set(random.sample(range(current_node_id), num_parents)) # sample parents 34 | dag.add_edges(parents_set=parents, target_node=current_node_id) 35 | 36 | return dag 37 | 38 | 39 | def create_random_connected_dag(num_nodes, expected_neighborhood_size, num_dags_timeout=10000): 40 | is_conn = False 41 | acc = 0 42 | while not is_conn: 43 | if num_dags_timeout < acc: 44 | return None # timed out: did not find a connected DAG 45 | 46 | dag_rand = create_random_dag(num_nodes, expected_neighborhood_size) 47 | is_conn = dag_rand.is_graph_connected() 48 | acc += 1 49 | 50 | return dag_rand 51 | 52 | 53 | def select_latent_variables(graph: DAG): 54 | """ 55 | Find nodes in a DAG that can serve as latent confounders. They comply with: 56 | 1. They don't have incoming edges (parentless) 57 | 2. Each of the is a parent of at least two nodes 58 | :param graph: A DAG for which to find the possible latents 59 | :return: A set of varaibles that can serve as latent confounders 60 | """ 61 | # find parentless nodes 62 | parentless_set = set() 63 | for node in graph.nodes_set: 64 | if len(graph.parents(node)) == 0: 65 | parentless_set.add(node) 66 | 67 | # find parentless that have at least 2 children 68 | parents_set = set() 69 | for parent in parentless_set: 70 | if len(graph.find_children(parent)) >= 2: 71 | parents_set.add(parent) 72 | 73 | return parents_set 74 | 75 | 76 | def create_random_dag_with_latents(n_nodes, conn_coeff): 77 | # sample a connected DAG 78 | dag_samp = create_random_connected_dag(n_nodes, conn_coeff, num_dags_timeout=1000000) 79 | 80 | # find nodes that can serve as latents (parentless, and parents of at least two observed nodes) 81 | potential_latents = select_latent_variables(dag_samp) 82 | 83 | # sample 50% of the potential latents 84 | lat_set = set( 85 | random.sample(potential_latents, len(potential_latents) // 2) 86 | ) 87 | obs_set = dag_samp.nodes_set - lat_set 88 | return dag_samp, obs_set, lat_set 89 | 90 | 91 | def sample_data_from_dag(in_dag, num_samples, min_edge_weight, max_edge_weight): 92 | """ 93 | Sample data from a linear SEM. A linear SEM is created from a DAG. 94 | A node is the sum of a normally distributed noise term and the weighted sum of the values of its parents. 95 | :param in_dag: The DAG structure of the linear SCM 96 | :param num_samples: number of samples (dataset records) 97 | :param min_edge_weight: lowest absolute value of edge weight (linear coefficient) 98 | :param max_edge_weight: highest absolute value of the weight (linear coefficient) 99 | :return: Sampled dataset in the form of a 2D NumPy array 100 | """ 101 | # ToDo: create a dedicated module for probabilistic graphical models. These should take graph structure as input 102 | data = np.random.randn(num_samples, len(in_dag.nodes_set)) # sample noise: N(0,1) 103 | topological_order = in_dag.find_topological_order() 104 | for node in topological_order: 105 | parents_set = in_dag.parents(node) 106 | for node_parent in parents_set: 107 | weight_sign = 2 * random.randint(0, 1) - 1 # select positive or negative range for the edge weight 108 | weight = weight_sign * np.random.uniform(min_edge_weight, 109 | max_edge_weight) # considers negative weights as well 110 | data[:, node] += weight * data[:, node_parent] # add the linear effect of the parents 111 | 112 | return data 113 | -------------------------------------------------------------------------------- /experiment_utils/threshold_select_ci_test.py: -------------------------------------------------------------------------------- 1 | from causal_discovery_utils.performance_measures import score_bdeu 2 | from graphical_models import DAG 3 | 4 | def search_threshold_bdeu(alg_class, train_data, ci_test_class, th_range_list, use_cache=True): 5 | """ 6 | A grid-search for the threshold that maximizes the BDeu score when learning a DAG structure 7 | :param alg_class: Class of the algorithm to be used to learn the graph 8 | :param train_data: Data that is used for calculating the BDeu score 9 | :param ci_test_class: Class of CI test for which we are searching a threshold 10 | :param th_range_list: A list of candidate threshold 11 | :param use_cache: CI test statistic is cached and re-used when evaluating different thresholds (default=TRUE). 12 | do not use caching for B-RAI algorithm as it changes the data during its operation. 13 | :return: The threshold that returned the structure having the highest score, and a list of all candidates scores 14 | """ 15 | _n_samples, _n_vars = train_data.shape 16 | _nodes = set(range(_n_vars)) 17 | best_th = float("inf") 18 | best_score = -float("inf") 19 | score_list = [] 20 | 21 | _ci_test = ci_test_class(dataset=train_data, threshold=None, use_cache=use_cache) # conditional independence oracle 22 | for i, _th in enumerate(th_range_list): 23 | _ci_test.threshold = _th 24 | _alg = alg_class(nodes_set=_nodes, ci_test=_ci_test) # algorithm instance 25 | _alg.learn_structure() # learn structure 26 | _dag = DAG(_alg.graph.nodes_set) 27 | is_dag = _alg.graph.convert_to_dag(_dag) 28 | if is_dag == True: 29 | _score = score_bdeu(_dag, train_data, _ci_test.node_size) 30 | else: 31 | _score = -float("inf") 32 | if _score >= best_score: # if several threshold result in equal scores, get the highest threshold 33 | best_score = _score 34 | best_th = _th 35 | score_list.append(_score) 36 | 37 | return best_th, score_list 38 | -------------------------------------------------------------------------------- /graphical_models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .undirected_graph import UndirectedGraph 3 | from .dag import DAG 4 | from .partially_dag import PDAG 5 | from .partial_ancestral_graph import PAG 6 | from .possible_dsep_tree import PDSTree -------------------------------------------------------------------------------- /graphical_models/arrow_head_types.py: -------------------------------------------------------------------------------- 1 | # Arrow head types used in MixedGraph and classes that inherent from it, e.g., partially directed graphs, CPDAG and PAG 2 | Undirected = '---' # X---Y (for CPDAG and PAG) 3 | Directed = '<--' # X-->Y (for CPDAG and PAG) 4 | Circle = 'o--' # X--*Y (for PAGs) 5 | Tail = Undirected 6 | 7 | # In PAGs there are 6 edge types: o--o, o---, o-->, --->, <-->, ----. In MAGs: --->, <-->, ---- 8 | -------------------------------------------------------------------------------- /graphical_models/basic_equivalance_class_graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import combinations 3 | from .undirected_graph import UndirectedGraph 4 | 5 | 6 | class MixedGraph: 7 | """ 8 | A graph for representing equivalence classes such as CPDAG and PAG 9 | """ 10 | def __init__(self, nodes_set, edge_mark_types): 11 | assert isinstance(nodes_set, set) 12 | 13 | self.edge_mark_types = set(edge_mark_types) 14 | 15 | self._graph = dict() 16 | self.nodes_set = nodes_set 17 | self.create_empty_graph(self.nodes_set) 18 | 19 | # graph initialization functions ---------------------------------------------------------------------------------- 20 | def create_empty_graph(self, nodes_set=None): 21 | if nodes_set is None: 22 | nodes_set = self.nodes_set 23 | else: 24 | assert isinstance(nodes_set, set) 25 | 26 | for node in nodes_set: 27 | self._graph[node] = dict() 28 | for head_type in self.edge_mark_types: # loop over arrow head_types 29 | self._graph[node][head_type] = set() 30 | 31 | def create_complete_graph(self, edge_mark, nodes_set=None): 32 | if nodes_set is None: 33 | nodes_set = self.nodes_set 34 | else: 35 | assert isinstance(nodes_set, set) 36 | 37 | self.create_empty_graph(nodes_set) # first, clear all arrow-heads 38 | 39 | for node in nodes_set: 40 | self._graph[node][edge_mark] = nodes_set - {node} # connect all nodes into the current node 41 | 42 | # --- graph query functions --------------------------------------------------------------------------------------- 43 | def is_empty(self, nodes_set=None): 44 | """ 45 | Test if the graph is empty 46 | :return: True if the graph is empty; Flase if there exist at least one edge 47 | """ 48 | if nodes_set is None: 49 | nodes_set = self.nodes_set 50 | 51 | for node in nodes_set: 52 | for edge_mark in self.edge_mark_types: 53 | if self._graph[node][edge_mark]: 54 | return False # an edge was found, graph is not empty 55 | else: 56 | return True # completed looping over all the nodes and didn't find an edge 57 | 58 | def number_of_edges(self, return_missing=False): 59 | num_edges = 0 60 | missing_edges = 0 61 | for node_i, node_j in combinations(self.nodes_set, 2): 62 | if self.is_connected(node_i, node_j): 63 | num_edges += 1 64 | else: 65 | missing_edges += 1 66 | 67 | if return_missing: 68 | return num_edges, missing_edges 69 | else: 70 | return num_edges 71 | 72 | def is_any_edge_mark(self, node_source, node_target): 73 | """ 74 | Test if there is any edge-mark at "node_target" on the edge between node_source and node_target 75 | :param node_source: 76 | :param node_target: 77 | :return: True if the is some edge-mark, False otherwise (no edge-mark; not to be confused with undirected-mark) 78 | """ 79 | for edge_mark in self.edge_mark_types: # test all edge marks 80 | if node_source in self._graph[node_target][edge_mark]: 81 | return True 82 | else: 83 | return False 84 | 85 | def get_edge_mark(self, node_parent, node_child): 86 | for edge_mark in self.edge_mark_types: # test all edge marks 87 | if node_parent in self._graph[node_child][edge_mark]: 88 | return edge_mark 89 | else: 90 | return None 91 | 92 | def is_connected(self, node_i, node_j): 93 | """ 94 | Test if two nodes are adjacent in the graph. That is, if they are connected by any edge type. 95 | :param node_i: 96 | :param node_j: 97 | :return: True if the nodes are adjacent; otherwise, False 98 | """ 99 | assert node_i != node_j 100 | 101 | for (node_p, node_c) in [(node_i, node_j), (node_j, node_i)]: # switch roles "parent"-"child" 102 | for edge_mark in self.edge_mark_types: # test all edge marks 103 | if node_p in self._graph[node_c][edge_mark]: 104 | return True 105 | 106 | return False 107 | 108 | def is_edge(self, node_i, node_j, edge_mark_at_i, edge_mark_at_j): 109 | """ 110 | Test the esistance of an edge with the given edge-marks. 111 | :param node_i: 112 | :param node_j: 113 | :param edge_mark_at_i: 114 | :param edge_mark_at_j: 115 | :return: True if the specific edge exists; otherwise, False. 116 | """ 117 | assert (edge_mark_at_i in self.edge_mark_types) and (edge_mark_at_j in self.edge_mark_types) 118 | 119 | if node_j in self._graph[node_i][edge_mark_at_i] and node_i in self._graph[node_j][edge_mark_at_j]: 120 | return True 121 | else: 122 | return False 123 | 124 | def is_graph_connected(self, nodes_set=None): 125 | # ToDo: Check correctness 126 | if nodes_set is None: 127 | nodes_set = self.nodes_set 128 | 129 | assert len(nodes_set) > 1 130 | 131 | nodes_to_reach = nodes_set.copy() # create a copy (passed by reference) 132 | starting_nodes = {nodes_to_reach.pop()} # start from an arbitrary node 133 | 134 | while len(starting_nodes) > 0: 135 | node_start = starting_nodes.pop() 136 | adjacent_nodes = self.find_adjacent_nodes(node_start, nodes_to_reach) 137 | nodes_to_reach = nodes_to_reach - adjacent_nodes 138 | if len(nodes_to_reach) == 0: 139 | return True # reach all the nodes in the graph 140 | starting_nodes.update(adjacent_nodes) 141 | 142 | return False 143 | 144 | def find_adjacent_nodes(self, node_i, pool_nodes=None, edge_type=None): 145 | """ 146 | Find all the nodes that are connected in/out of node_i. 147 | :param node_i: 148 | :param pool_nodes: a set of nodes from which to find the adjacent ones (default: all graph nodes) 149 | :param edge_type: a tuples: (alpha, beta) defining the allowed connecting edge, 150 | where alpha is the edge-mark at node_i and beta is the edge-mark at the neighbors. 151 | Default is None indicating that any edge-mark is allowed. 152 | :return: 153 | """ 154 | if edge_type is None: 155 | connected_nodes = set() 156 | for edge_mark in self.edge_mark_types: 157 | connected_nodes.update(self._graph[node_i][edge_mark]) 158 | else: 159 | mark_origin = edge_type[0] 160 | mark_neighbor = edge_type[1] 161 | connected_nodes = set(filter( 162 | lambda neighbor: node_i in self._graph[neighbor][mark_neighbor], 163 | self._graph[node_i][mark_origin] 164 | )) 165 | 166 | if pool_nodes is not None: 167 | connected_nodes = connected_nodes & pool_nodes 168 | return connected_nodes 169 | 170 | def find_reachable_set(self, anchor_node, nodes_pool, edge_type_list): 171 | """ 172 | Find the set of nodes that are reachable from a node via specific edge-types 173 | :param anchor_node: A node from which to start reaching 174 | :param nodes_pool: a set of nodes tested to be reachable 175 | :param edge_type_list: a list of edge types, e.g., [('<--', '---'), ('<--', '-->')] 176 | :return: a set of nodes that are reachable from the anchor node 177 | """ 178 | neighbors_set = set() 179 | 180 | if len(nodes_pool) == 0: 181 | return neighbors_set 182 | 183 | # find immediate reachable neighbors 184 | if edge_type_list is None: # any edge type 185 | neighbors_set = self.find_adjacent_nodes(anchor_node, nodes_pool, None) 186 | else: 187 | for edge_type in edge_type_list: # specific edge types 188 | neighbors_set.update(self.find_adjacent_nodes(anchor_node, nodes_pool, edge_type)) 189 | 190 | if len(neighbors_set) == 0: 191 | return neighbors_set 192 | 193 | reachable_set = neighbors_set.copy() 194 | updated_nodes_pool = nodes_pool - neighbors_set 195 | 196 | for neighbor in neighbors_set: 197 | neighbor_reach = self.find_reachable_set(neighbor, updated_nodes_pool, edge_type_list) 198 | reachable_set.update(neighbor_reach) 199 | updated_nodes_pool.difference_update(neighbor_reach) # remove neighbor_reach from the pool 200 | 201 | return reachable_set 202 | 203 | def find_unconnected_subgraphs(self, en_nodes=None, sym_edge_mark=None) -> list: 204 | """ 205 | Find groups of nodes that belong to unconnected sub-graphs (connected component) 206 | :param en_nodes: Nodes that belong to the (unconnected) graph that need to be clustered 207 | :param sym_edge_mark: the type of symmetric edges that defines connectivity has the provided edges-mark, 208 | e.g., Mark.Directed guides the search to consider only bi-directed edges as connectivity. 209 | Note that if you provide an edge-mark, only symmetric edges are considered, in contrast to the None default 210 | Default: None, means that any edge qualifies as connection (not just symmetric ones). 211 | :return: disjoint subsets of en_nodes that belong to distinct sub-graphs (connected components) 212 | """ 213 | if en_nodes is None: 214 | en_nodes = self.nodes_set 215 | 216 | connected_sets = [] 217 | nodes = en_nodes.copy() 218 | 219 | edge_type_list = None 220 | if sym_edge_mark in self.edge_mark_types: 221 | edge_type_list = [(sym_edge_mark, sym_edge_mark)] 222 | 223 | while len(nodes) > 0: 224 | node_i = nodes.pop() 225 | reachable_set = self.find_reachable_set(node_i, nodes, edge_type_list) 226 | nodes.difference_update(reachable_set) 227 | reachable_set.add(node_i) 228 | connected_sets.append(reachable_set) 229 | 230 | return connected_sets 231 | 232 | def get_skeleton_graph(self, en_nodes=None) -> UndirectedGraph: 233 | if en_nodes is None: 234 | en_nodes = self.nodes_set 235 | 236 | adj_graph = UndirectedGraph(en_nodes.copy()) 237 | for node_i, node_j in combinations(en_nodes, 2): 238 | if self.is_connected(node_i, node_j): 239 | adj_graph.add_edge(node_i, node_j) 240 | return adj_graph 241 | 242 | # --- graph modification functions -------------------------------------------------------------------------------- 243 | def delete_edge(self, node_i, node_j): 244 | for edge_mark in self.edge_mark_types: # loop through all edge marks 245 | self._graph[node_i][edge_mark].discard(node_j) 246 | self._graph[node_j][edge_mark].discard(node_i) 247 | 248 | def replace_edge_mark(self, node_source, node_target, requested_edge_mark): 249 | assert requested_edge_mark in self.edge_mark_types 250 | 251 | # remove any edge-mark 252 | for edge_mark in self.edge_mark_types: 253 | self._graph[node_target][edge_mark].discard(node_source) 254 | 255 | # set requested edge-mark 256 | self._graph[node_target][requested_edge_mark].add(node_source) 257 | 258 | def reset_orientations(self, default_mark, nodes_set=None): 259 | """ 260 | Reset all orientations, e.g., convert all edges into o--o edges, where "o" is the default edge-mark 261 | :param default_mark: an edge-mark to place the instead of the existing edge_marks 262 | :param nodes_set: Only edges between pairs of nodes from this set will be converted (default: all edges) 263 | :return: 264 | """ 265 | assert default_mark in self.edge_mark_types 266 | if nodes_set is None: 267 | nodes_set = self.nodes_set 268 | 269 | for (node_x, node_y) in combinations(nodes_set, 2): 270 | if self.is_connected(node_x, node_y): 271 | self.replace_edge_mark(node_x, node_y, default_mark) 272 | self.replace_edge_mark(node_y, node_x, default_mark) 273 | 274 | def add_edge(self, node_i, node_j, edge_mark_at_i, edge_mark_at_j): 275 | """ 276 | Add an edge with the requested edge-marks. 277 | :param node_i: 278 | :param node_j: 279 | :param edge_mark_at_i: 280 | :param edge_mark_at_j: 281 | :return: 282 | """ 283 | 284 | assert not self.is_connected(node_i, node_j) # edge already exists 285 | assert (edge_mark_at_i in self.edge_mark_types) and (edge_mark_at_j in self.edge_mark_types) 286 | 287 | self._graph[node_i][edge_mark_at_i].add(node_j) 288 | self._graph[node_j][edge_mark_at_j].add(node_i) 289 | 290 | def get_skeleton_mat(self): 291 | """ 292 | Return the adjacency matrix of the graph skeleton, a square numpy matrix format. 293 | :return: 294 | """ 295 | num_nodes = len(self.nodes_set) 296 | adj_mat = np.zeros((num_nodes, num_nodes), dtype=int) 297 | node_index_map = {node: i for i, node in enumerate(sorted(list(self.nodes_set)))} 298 | 299 | for node in self._graph: 300 | for edge_mark in self.edge_mark_types: # test all edge marks 301 | for node_p in self._graph[node][edge_mark]: 302 | adj_mat[node_index_map[node_p]][node_index_map[node]] = 1 303 | 304 | return adj_mat 305 | 306 | # --- plotting tools ---------------------------------------------------------------------------------------------- 307 | def __str__(self): 308 | text_print = 'Edge-marks on the graph edges:\n' 309 | for node in self.nodes_set: 310 | for edge_mark in self.edge_mark_types: 311 | if len(self._graph[node][edge_mark]) > 0: 312 | text_print += ('Edges: ' + str(node) + ' ' + edge_mark + '*' + 313 | ' ' + str(self._graph[node][edge_mark]) + '\n') 314 | return text_print 315 | -------------------------------------------------------------------------------- /graphical_models/basic_graph.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | 3 | 4 | class Graph: 5 | """ 6 | A Graph with a single edge-head style. Used by DAG and UndirectedGraph as a base class 7 | """ 8 | def __init__(self, nodes_set): 9 | assert isinstance(nodes_set, set) 10 | 11 | self._graph = dict() 12 | self.nodes_set = nodes_set 13 | self.create_empty_graph(self.nodes_set) 14 | 15 | # --- graph initialization functions ------------------------------------------------------------------------------ 16 | def create_empty_graph(self, nodes_set=None): 17 | if nodes_set is None: 18 | nodes_set = self.nodes_set 19 | else: 20 | assert isinstance(nodes_set, set) 21 | 22 | for node in nodes_set: 23 | self._graph[node] = set() 24 | 25 | # --- graph query functions --------------------------------------------------------------------------------------- 26 | def is_connected(self, node_i, node_j): 27 | if (node_i in self._graph[node_j]) or (node_j in self._graph[node_i]): 28 | return True 29 | else: 30 | return False 31 | 32 | def number_of_edges(self, nodes_subset=None, return_missing=False): 33 | if nodes_subset is None: 34 | nodes_subset = self.nodes_set 35 | num_edges = 0 36 | missing_edges = 0 37 | for node_i, node_j in combinations(nodes_subset, 2): 38 | if self.is_connected(node_i, node_j): 39 | num_edges += 1 40 | else: 41 | missing_edges += 1 42 | 43 | if return_missing: 44 | return num_edges, missing_edges 45 | else: 46 | return num_edges 47 | 48 | def get_neighbors(self, node_i): 49 | neighbors = [] 50 | for node_j in (self.nodes_set - {node_i}): 51 | if self.is_connected(node_i, node_j): 52 | neighbors.append(node_j) 53 | return neighbors 54 | -------------------------------------------------------------------------------- /graphical_models/dag.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import combinations 3 | from .basic_graph import Graph 4 | from .undirected_graph import UndirectedGraph 5 | from . import arrow_head_types as Mark 6 | 7 | _ErrorCyclicGraph = 'Graph is cyclic' 8 | 9 | 10 | class DAG(Graph): 11 | """ 12 | A directed acyclic graph 13 | Example: 14 | nodes_set1 = set(range(5)) 15 | dag = DAG(nodes_set1) 16 | dag.add_edges({0}, 1) 17 | dag.add_edges({0}, 2) 18 | dag.add_edges({1, 2}, 3) 19 | dag.add_edges({3}, 4) 20 | 21 | print('Is acyclic?', dag.is_acyclic()) 22 | print('(0, 4) d-separated by {1, 2}?', dag.dsep(0, 4, {1, 2})) 23 | """ 24 | def max_parents(self): 25 | max_parents = 0 26 | for node in self.nodes_set: 27 | num_parents = len(self.parents(node)) 28 | if num_parents > max_parents: 29 | max_parents = num_parents 30 | 31 | return max_parents 32 | 33 | def parents(self, node): 34 | return self._graph[node] 35 | 36 | def find_children(self, node_parent, nodes_set=None): 37 | # ToDo: inefficient, should re-impelement 38 | if nodes_set is None: 39 | nodes_set = self.nodes_set 40 | 41 | children_set = set() 42 | for node in nodes_set: 43 | if node_parent in self._graph[node]: 44 | children_set.add(node) 45 | 46 | return children_set 47 | 48 | def find_adjacent_nodes(self, node): 49 | return self.parents(node) | self.find_children(node) 50 | 51 | def is_connected(self, node_i, node_j): 52 | if (node_i in self.parents(node_j)) or (node_j in self.parents(node_i)): 53 | return True 54 | else: 55 | return False 56 | 57 | def is_ancestor(self, descendant_node, tested_node): 58 | if descendant_node == tested_node: 59 | return True # a node is defined to be its own ancestor 60 | 61 | parents_set = self.parents(descendant_node) 62 | if len(parents_set) == 0: 63 | return False # no parents, descendant_node is a root 64 | 65 | if tested_node in parents_set: 66 | return True # found the tested_node 67 | 68 | for parent in parents_set: 69 | if self.is_ancestor(descendant_node=parent, tested_node=tested_node): 70 | return True # found tested_node to be an ancestor of one of the parents 71 | else: 72 | return False # tested_node is not an ancestor of one of the parents 73 | 74 | def is_acyclic(self): 75 | for node in self._graph: 76 | parents_set = self.parents(node) 77 | for parent in parents_set: 78 | # test if a node is an ancestor of its parents 79 | if self.is_ancestor(descendant_node=parent, tested_node=node): 80 | return False 81 | else: 82 | return True 83 | 84 | def is_graph_connected(self, nodes_set=None): 85 | # ToDo: Check correctness and improve efficiency 86 | if nodes_set is None: 87 | nodes_set = self.nodes_set 88 | 89 | assert len(nodes_set) > 1 90 | 91 | nodes_to_reach = nodes_set.copy() # create a copy (passed by reference) 92 | starting_nodes = {nodes_to_reach.pop()} # start from an arbitrary node 93 | 94 | while len(starting_nodes) > 0: 95 | node_start = starting_nodes.pop() 96 | parent_nodes = self.parents(node_start) & nodes_to_reach 97 | nodes_to_reach = nodes_to_reach - parent_nodes 98 | 99 | children_nodes = self.find_children(node_start, nodes_to_reach) 100 | nodes_to_reach = nodes_to_reach - children_nodes 101 | 102 | if len(nodes_to_reach) == 0: 103 | return True # reach all the nodes in the graph 104 | 105 | starting_nodes.update(parent_nodes) 106 | starting_nodes.update(children_nodes) 107 | 108 | return False 109 | 110 | def get_ancestors(self, node, candidate_nodes=None): 111 | if candidate_nodes is None: 112 | candidate_nodes = self.nodes_set 113 | parents_set = self.parents(node) & candidate_nodes 114 | if len(parents_set) == 0: 115 | return {node} 116 | 117 | ancestors = set() 118 | for parent in parents_set: 119 | ancestors.update(self.get_ancestors(parent, candidate_nodes - parents_set)) 120 | 121 | ancestors.add(node) # a node is considered its own ancestor 122 | return ancestors 123 | 124 | def dsep(self, node_i, node_j, condition_set): 125 | """ 126 | Test d-separation by following these steps: 127 | 1. Find the ancestors of node_i, node_j, and the nodes in the condition_set 128 | 2. moralize the sub-graph consisting of the ancestors, resulting in an undirected sub-graph 129 | 3. test separation by blocking all the undirected paths through the condition_set 130 | :param node_i: 131 | :param node_j: 132 | :param condition_set: 133 | :return: True if the node_i and node_j are d-separated by condition_set 134 | """ 135 | 136 | # 1. Find the nodes of the ancestors of node_i, node_j, and the nodes in the condition_set 137 | # a node is defined to be its own ancestor, thus node_i, node_j, and condition_set will be included 138 | ancestors = set() 139 | ancestors.update(self.get_ancestors(node_i)) 140 | ancestors.update(self.get_ancestors(node_j)) 141 | for cond_node in condition_set: 142 | ancestors.update(self.get_ancestors(cond_node)) 143 | 144 | # 2. Moralize the sub-graph consisting of the ancestors, resulting in an undirected sub-graph 145 | moral_graph = UndirectedGraph(ancestors) # undirected graph 146 | for node in ancestors: 147 | parents_set = self.parents(node) & ancestors 148 | for parent in parents_set: 149 | moral_graph.add_edge(parent, node) # create undirected edges between node and its parents 150 | for (parent_k, parent_l) in combinations(parents_set, 2): 151 | if not self.is_connected(parent_k, parent_l): 152 | moral_graph.add_edge(parent_k, parent_l) # "marry" unconnected parents by an undirected graph 153 | 154 | # 3. Test separation by blocking all the undirected paths through the condition_set 155 | return not moral_graph.is_reachable(node_i, node_j, condition_set) 156 | 157 | def convert_to_cpdag(self, cpdag): 158 | """ 159 | Convert the DAG to a CPDAG by copying the skeleton and v-structures. Then, the remaining undirected edges are 160 | oriented by rules R1, R2, R3. 161 | :param cpdag: an externally instantiated PDAG that will be filled with the result 162 | """ 163 | if not self.is_acyclic(): 164 | raise ValueError(_ErrorCyclicGraph) 165 | 166 | # copy skeleton 167 | for node in self.nodes_set: 168 | parents_set = self.parents(node) 169 | cpdag.add_edges(parents_set=parents_set, target_node=node, arrowhead_type=Mark.Undirected) 170 | 171 | for node in self.nodes_set: 172 | parents_set = self.parents(node) 173 | for (parent_i, parent_j) in combinations(parents_set, 2): 174 | if not self.is_connected(parent_i, parent_j): 175 | cpdag.orient_edge(source_node=parent_i, target_node=node) # orient v-structure 176 | cpdag.orient_edge(source_node=parent_j, target_node=node) 177 | 178 | cpdag.maximally_orient_pattern({1, 2, 3}) # use orientation rules R1, R2, and R3 179 | 180 | def get_adj_mat(self, en_nodes_list=None): 181 | """ 182 | Return the adjacency matrix, a numpy matrix format 183 | :param en_nodes_list: (optional) an ordered list of edges to which the matrix indexes will correspond. 184 | A partial list of graph nodes can be provided. The size of the output matrix will be num.nodes X num.nodes. 185 | :return: 1) Adjacency matrix, and 186 | 2) if no list was provided as input, also returns an ordered list of node identifiers 187 | """ 188 | if en_nodes_list is None: 189 | nodes_sorted_list = sorted(self.nodes_set) 190 | else: 191 | assert isinstance(en_nodes_list, list) 192 | for node in en_nodes_list: 193 | assert node in self.nodes_set 194 | nodes_sorted_list = en_nodes_list 195 | 196 | num_nodes = len(nodes_sorted_list) 197 | adj_mat = np.zeros((num_nodes, num_nodes), dtype=int) 198 | node_index_map = {node: i for i, node in enumerate(nodes_sorted_list)} 199 | 200 | for node in nodes_sorted_list: 201 | parents_set = [node_index_map[n] for n in self.parents(node)] 202 | adj_mat[parents_set, node_index_map[node]] = 1 203 | 204 | # return the proper values 205 | if en_nodes_list is None: 206 | return adj_mat, nodes_sorted_list # return both the adjacency matrix and the ordered list of nodes 207 | else: 208 | return adj_mat # return only the adjacency matrix since the ordered list of nodes was input 209 | 210 | def find_topological_order(self, en_nodes=None) -> list: 211 | topological_groups = self.find_topological_order_groups(en_nodes) 212 | return [node for group in topological_groups for node in group] 213 | 214 | def find_topological_order_groups(self, en_nodes=None) -> list: 215 | if en_nodes is None: 216 | en_nodes = self.nodes_set 217 | 218 | if len(en_nodes) == 0: 219 | return [] 220 | 221 | parents_set = set() 222 | for node in en_nodes: 223 | parents_set.update(self.parents(node) & en_nodes) # update the set of nodes that are parents of someone 224 | 225 | leaves_set = en_nodes - parents_set # nodes that are not parents of any endogenous node 226 | assert len(leaves_set) > 0 # there should be at least one leaf in an acyclic graph 227 | 228 | high_topological_order = self.find_topological_order_groups(parents_set) # recursive call 229 | return high_topological_order + [leaves_set] 230 | 231 | # --- functions that modify the graph ----------------------------------------------------------------------------- 232 | def init_from_adj_mat(self, adj_mat: np.ndarray, nodes_order: list = None): 233 | num_vars = adj_mat.shape[0] 234 | if nodes_order is not None: 235 | assert isinstance(nodes_order, list) 236 | assert num_vars == len(nodes_order) 237 | else: 238 | nodes_order = list(range(num_vars)) 239 | 240 | self.create_empty_graph() # delete all pre-existing edges 241 | 242 | parents_list, children_list = adj_mat.nonzero() 243 | 244 | for (parent, child) in zip(parents_list, children_list): # convert adjacency matrix to DAG 245 | self.add_edges( 246 | parents_set={nodes_order[parent]}, 247 | target_node=nodes_order[child] 248 | ) 249 | 250 | def add_edges(self, parents_set, target_node): 251 | assert isinstance(parents_set, set) 252 | 253 | if len(parents_set - self._graph.keys()) != 0: 254 | raise ValueError('Parents set includes nodes that are not in the graph') 255 | 256 | if target_node not in self._graph: 257 | raise ValueError('Target node is not in the graph') 258 | 259 | self._graph[target_node].update(parents_set) 260 | -------------------------------------------------------------------------------- /graphical_models/possible_dsep_tree.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import math 3 | 4 | 5 | _ErrorChildNotExist = 'The child branch does not exist' 6 | _ErrorAddExistBranch = "The child already exists in the PDS-tree" 7 | 8 | 9 | class PDSTree: 10 | """ 11 | A tree structure for storage and retrieval of Possible-D-Sep nodes. 12 | """ 13 | def __init__(self, node_root): 14 | self.origin = node_root 15 | self.children = [] 16 | self.dict_child = {} # dictionary that maps a node to the index of the child in the children list 17 | self.depth_level = 0 18 | 19 | def get_child_branch(self, child_origin): 20 | if child_origin not in self.dict_child: 21 | raise ValueError(_ErrorChildNotExist) 22 | 23 | child_idx = self.dict_child[child_origin] 24 | return self.children[child_idx] 25 | 26 | def add_branch(self, branch_root): 27 | """ 28 | Add a child node (it will serve as the root of a tree originating from it) 29 | :param branch_root: a node identifier 30 | :return: 31 | """ 32 | if branch_root in self.dict_child: 33 | raise ValueError(_ErrorAddExistBranch) 34 | 35 | self.dict_child[branch_root] = len(self.children) # create an index value for this child (not a key) 36 | pds_tree_child = PDSTree(branch_root) 37 | pds_tree_child.depth_level = self.depth_level + 1 38 | self.children.append(pds_tree_child) # add the child to the list of children 39 | 40 | def get_max_depth(self): # TODO: correct this function to retern a 0-based depth (root is 0) 41 | """ 42 | Get the maximal depth (number of nodes) 43 | :return: maximal depth: number of nodes from the root to the deepest leaf (inclusive) 44 | """ 45 | if len(self.children) == 0: # a leaf node 46 | return 1 47 | 48 | max_child_depth = 1 49 | for child in self.children: 50 | current_child_depth = child.get_max_depth() # max depth of the tree originating from the current child 51 | if current_child_depth > max_child_depth: 52 | max_child_depth = current_child_depth 53 | 54 | return max_child_depth + 1 55 | 56 | def get_minimal_distance(self, node): 57 | minimal_dist = math.inf 58 | for child_branch in self.children: 59 | if child_branch.origin == node: 60 | return child_branch.depth_level # a child is found, subsequent iteration will return greater or equal 61 | else: 62 | dist = child_branch.get_minimal_distance(node) 63 | minimal_dist = min(minimal_dist, dist) 64 | 65 | return minimal_dist # in case no children or node not in the immediate children 66 | 67 | def is_pds_path(self, subset_nodes): 68 | if len(subset_nodes) == 0: 69 | return True 70 | for branch_x in self.children: 71 | if branch_x.origin in subset_nodes: 72 | path_found = branch_x.is_pds_path(subset_nodes - {branch_x.origin}) 73 | if path_found: 74 | return True 75 | else: 76 | return False 77 | 78 | def is_legal_cond_set(self, subset_nodes): 79 | """ 80 | Test ICD-Sep condition 2-b: for every node in the conditioning set there exists a pds path such that 81 | all the nodes on the path are also members of the same conditioning set. 82 | :param subset_nodes: conditioning set to be inspected 83 | :return: True is the conditioning set complies with ICD-Sep condition 2-b. 84 | """ 85 | # check if evey node in the subset_nodes is reachable from the root using paths composed of only subset_nodes 86 | for node in subset_nodes: 87 | if not self.is_reachable(node, possible_path_nodes=subset_nodes): # ICD-Sep Condition 2-b 88 | return False 89 | else: 90 | return True 91 | 92 | def is_reachable(self, target_node, possible_path_nodes): 93 | if len(possible_path_nodes) == 0: 94 | return False 95 | for branch_x in self.children: 96 | if branch_x.origin == target_node: 97 | return True 98 | if branch_x.origin in possible_path_nodes: 99 | is_found = branch_x.is_reachable(target_node, possible_path_nodes) 100 | if is_found: 101 | return True 102 | 103 | return False 104 | 105 | def get_subsets_list(self, set_nodes, subset_size): 106 | min_dist = {node: self.get_minimal_distance(node) for node in set_nodes} # minimal distances given set_nodes 107 | 108 | subsets_list = [] # each element in this list is a 2-element list [ {subsets}, distance ] 109 | 110 | # create a list of all legal subsets 111 | for subset_nodes in combinations(set_nodes, subset_size): 112 | if self.is_legal_cond_set(subset_nodes): 113 | # sum minimal distances 114 | dist_sum = 0 115 | for node in subset_nodes: 116 | dist_sum += min_dist[node] 117 | 118 | subsets_list.append([set(subset_nodes), dist_sum]) 119 | 120 | return subsets_list 121 | -------------------------------------------------------------------------------- /graphical_models/undirected_graph.py: -------------------------------------------------------------------------------- 1 | from .basic_graph import Graph 2 | 3 | _ErrorUnknownNode = 'Node is not in the graph' 4 | 5 | 6 | class UndirectedGraph(Graph): 7 | """ 8 | An undirected graphical model. 9 | """ 10 | 11 | # --- graph modification functions -------------------------------------------------------------------------------- 12 | def add_edge(self, node_i, node_j): 13 | if (node_i not in self.nodes_set) or (node_j not in self.nodes_set): 14 | raise ValueError(_ErrorUnknownNode) 15 | 16 | self._graph[node_i].add(node_j) 17 | self._graph[node_j].add(node_i) 18 | 19 | def remove_edge(self, node_i, node_j): 20 | if (node_i not in self.nodes_set) or (node_j not in self.nodes_set): 21 | raise ValueError(_ErrorUnknownNode) 22 | 23 | self._graph[node_i].discard(node_j) 24 | self._graph[node_j].discard(node_i) 25 | 26 | def disconnect_node(self, node): 27 | neighbors_set = self._graph[node] 28 | for neighbor in neighbors_set: 29 | self.remove_edge(neighbor, node) 30 | 31 | # --- graph query functions --------------------------------------------------------------------------------------- 32 | def is_reachable(self, node_start, node_end, visited_nodes_in=None): 33 | """ 34 | Test if there is a path between two nodes (node_start, node_end) not passing through the given set of nodes. 35 | :param node_start: one end-point of the tested path 36 | :param node_end: second end-point of the tested path 37 | :param visited_nodes_in: (forbidden nodes) set of nodes that block the tested paths 38 | (e.g., nodes that were already visited) 39 | :return: True if a path is found 40 | """ 41 | if node_start == node_end: 42 | return True # reach the destination 43 | 44 | if visited_nodes_in is None: 45 | visited_nodes = {node_start} 46 | else: 47 | visited_nodes = set(visited_nodes_in) # create a copy and ensure it's of set type 48 | visited_nodes.add(node_start) 49 | 50 | unvisited_neighbors = self._graph[node_start] - visited_nodes 51 | for neighbor in unvisited_neighbors: 52 | if self.is_reachable(neighbor, node_end, visited_nodes): 53 | return True # found a path from an (unvisited) neighbor to the target node 54 | visited_nodes.add(neighbor) 55 | else: 56 | return False # went through all the neighbors and didn't find a path to the target 57 | -------------------------------------------------------------------------------- /imgs/ExampleAnimationICD.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/causality-lab/8ad638a2057e3bdf35108b6e63f547dd7f6a95a1/imgs/ExampleAnimationICD.gif -------------------------------------------------------------------------------- /imgs/ExamplePAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/causality-lab/8ad638a2057e3bdf35108b6e63f547dd7f6a95a1/imgs/ExamplePAG.png -------------------------------------------------------------------------------- /imgs/FrameworkBlockDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/causality-lab/8ad638a2057e3bdf35108b6e63f547dd7f6a95a1/imgs/FrameworkBlockDiagram.png -------------------------------------------------------------------------------- /notebooks/imgs/TimeSeriesMeasurmentSites.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/causality-lab/8ad638a2057e3bdf35108b6e63f547dd7f6a95a1/notebooks/imgs/TimeSeriesMeasurmentSites.png -------------------------------------------------------------------------------- /plot_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .draw_graph import draw_graph, draw_temporal_graph, draw_pds_tree 2 | -------------------------------------------------------------------------------- /plot_utils/graph_layout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import sqrt, cos, sin 3 | 4 | 5 | class BaseLayout: 6 | """ 7 | Base class for layouts. 8 | """ 9 | def __init__(self, graph, win_left_right, win_top_bottom, group_sort=None): 10 | self.num_nodes = len(graph.nodes_set) 11 | self.left = win_left_right[0] 12 | self.right = win_left_right[1] 13 | self.top = win_top_bottom[0] 14 | self.bottom = win_top_bottom[1] 15 | self.group_sort = group_sort 16 | self.graph = graph 17 | 18 | # set initial positions 19 | if group_sort is not None: 20 | nodes_order = [node for group in group_sort for node in group] 21 | xpos = np.linspace(self.left, self.right, self.num_nodes) 22 | ypos = np.linspace(self.top, self.bottom, self.num_nodes) 23 | else: 24 | nodes_order = list(graph.nodes_set) 25 | xpos = np.random.uniform(self.left, self.right, self.num_nodes)*0.2 26 | ypos = np.random.uniform(self.top, self.bottom, self.num_nodes)*0.2 27 | 28 | # dictionary with nodes as keys and [x,y] position numpy-vector as value 29 | self.pos = dict() 30 | for idx in range(self.num_nodes): 31 | self.pos[nodes_order[idx]] = np.array([xpos[idx], ypos[idx]]) # np.array([xpos[idx], ypos[idx]]) 32 | 33 | 34 | class ForceDirectedLayout(BaseLayout): 35 | def __init__(self, graph, win_left_right, win_top_bottom, group_sort=None, k_const=None, num_iterations=100): 36 | """ 37 | Node positions are calculated using the force-directed algorithm by Fruchterman and Reingold (1991) 38 | :param graph: a graph to draw 39 | :param win_left_right: a tuple (left, right) with border coordinates 40 | :param win_top_bottom: a tuple (top, bottom) with border coordinates 41 | :param group_sort: a topologically sorted list of groups of nodes. 42 | Nodes at the beginning of the list will tend to be higher in the layout. 43 | :param k_const: constant of repulsion and attraction forces. 44 | (default is None indicating to automatically determine the value) 45 | :param num_iterations: number of iterations (default: 100; can be reduced to improve runtime) 46 | """ 47 | super().__init__(graph, win_left_right, win_top_bottom, group_sort) 48 | 49 | np.random.seed(123) 50 | self.num_iter = num_iterations 51 | 52 | if k_const is None: 53 | k_const = sqrt((self.right-self.left)*(self.bottom-self.top) / self.num_nodes) # area / number of nodes 54 | self.k_const = k_const 55 | 56 | self.attraction = lambda dist: dist*dist / self.k_const 57 | self.repulsion = lambda dist: self.k_const*self.k_const / dist 58 | 59 | init_layout = CircleLayout(graph, tuple(m for m in win_left_right), tuple(m for m in win_top_bottom), group_sort) 60 | init_pos = init_layout.calc_layout() 61 | 62 | def _calc_repulsive_forces(self): 63 | repulse = {node: np.zeros(2, dtype=float) for node in self.graph.nodes_set} # dictionary 64 | for node_i in self.graph.nodes_set: 65 | repulse[node_i] = np.zeros(2, dtype=float) 66 | for node_j in self.graph.nodes_set-{node_i}: 67 | dv = self.pos[node_i] - self.pos[node_j] 68 | dist = np.sqrt(dv[0]*dv[0] + dv[1]*dv[1]) # Euclidean distance 69 | dist = max(dist, 1) 70 | repulse[node_i] += (dv / dist) * self.repulsion(dist) 71 | return repulse 72 | 73 | def _calc_attracting_forces(self): 74 | attract = {node: np.zeros(2, dtype=float) for node in self.graph.nodes_set} # dictionary 75 | for node_i in self.graph.nodes_set: 76 | for node_j in self.graph.find_adjacent_nodes(node_i): # loop through neighbors 77 | dv = self.pos[node_i] - self.pos[node_j] 78 | dist = np.sqrt(dv[0] * dv[0] + dv[1] * dv[1]) # Euclidean distance 79 | dist = max(dist, 1) 80 | attract[node_i] -= (dv / dist) * self.attraction(dist) 81 | return attract 82 | 83 | def calc_layout(self, num_iterations=None): 84 | """ 85 | Main function for calculating the layout. 86 | :param num_iterations: default is None, indicating to use the class attribute value. 87 | :return: return the final position values (also set the class attribute). 88 | """ 89 | if num_iterations is None: 90 | num_iterations = self.num_iter 91 | first_max_disp = self.k_const 92 | for i in range(num_iterations): 93 | max_disp = first_max_disp / (i+1) 94 | repulsing = self._calc_repulsive_forces() 95 | attracting = self._calc_attracting_forces() 96 | for node in self.graph.nodes_set: 97 | disp = (repulsing[node] + attracting[node]) 98 | disp_norm = np.sqrt(disp[0]*disp[0] + disp[1]*disp[1]) 99 | self.pos[node] += (disp/disp_norm) * min(disp_norm, max_disp) 100 | self.pos[node][0] = min(self.right, max(self.left, self.pos[node][0])) 101 | self.pos[node][1] = min(self.bottom, max(self.top, self.pos[node][1])) 102 | return self.pos 103 | 104 | 105 | class CircleLayout(BaseLayout): 106 | def __init__(self, graph, win_left_right, win_top_bottom, group_sort=None): 107 | super().__init__(graph, win_left_right, win_top_bottom, group_sort) 108 | 109 | def calc_layout(self): 110 | if self.group_sort is not None: 111 | nodes_order = [node for group in self.group_sort for node in group] 112 | else: 113 | nodes_order = list(self.graph.nodes_set) 114 | 115 | xrad = (self.right - self.left)/2 116 | yrad = (self.bottom - self.top)/2 117 | center = ((self.right + self.left)/2, (self.bottom + self.top)/2) # 2-tuple 118 | 119 | angle_list = np.linspace(0, 2*np.pi, self.num_nodes+1) 120 | for idx in range(self.num_nodes): 121 | node = nodes_order[idx] 122 | angle = angle_list[idx] 123 | self.pos[node][0] = center[0] + xrad * cos(angle) 124 | self.pos[node][1] = center[1] + yrad * sin(angle) 125 | 126 | return self.pos 127 | 128 | 129 | class ColumnLayout(BaseLayout): 130 | def __init__(self, graph, win_left_right, win_top_bottom, group_sort=None): 131 | super().__init__(graph, win_left_right, win_top_bottom, group_sort) 132 | 133 | def calc_layout(self): 134 | if self.group_sort is None: 135 | nodes_order = list(self.graph.nodes_set) 136 | group_sort = [nodes_order] 137 | else: # self.group_sort is not None 138 | group_sort = self.group_sort 139 | 140 | # find the size of the largest group of nodes 141 | max_group_len = 0 142 | for group in group_sort: 143 | if len(group) > max_group_len: 144 | max_group_len = len(group) 145 | 146 | # calculate positions. First group is left-most 147 | n_groups = len(group_sort) 148 | y_offset = self.bottom 149 | if n_groups > 1: 150 | x_step = (self.right - self.left) / (n_groups-1) 151 | x_offset = self.left 152 | else: 153 | x_step = 0 154 | x_offset = (self.right + self.left) / 2. 155 | y_step = (self.top - self.bottom) / (max_group_len - 1) 156 | for i_group, group in enumerate(group_sort): 157 | for i_node, node in enumerate(group): 158 | self.pos[node][0] = x_step * i_group + x_offset 159 | self.pos[node][1] = y_step * i_node + y_offset 160 | 161 | return self.pos 162 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = causality_lab 3 | version = 0.0.0 4 | author = Intel Labs 5 | description = Research code of novel causal discovery algorithms developed at Intel Labs. 6 | long_description = file: README.md 7 | long_description_content_type = text/markdown 8 | url = https://github.com/IntelLabs/causality-lab 9 | keywords = AI, ML, causal discovery, causality 10 | classifiers = 11 | 'License :: OSI Approved :: Apache License 2.0' 12 | 'Programming Language :: Python' 13 | 'Programming Language :: Python :: 3' 14 | 'Programming Language :: Python :: 3.7' 15 | 'Programming Language :: Python :: 3.8' 16 | 'Programming Language :: Python :: 3.9' 17 | 'Programming Language :: Python :: 3.10' 18 | 'Programming Language :: Python :: Implementation :: CPython' 19 | 'Programming Language :: Python :: Implementation :: PyPy' 20 | 21 | [options] 22 | packages = find: 23 | python_requires = >=3.7 24 | install_requires = file: requirements.txt 25 | 26 | [options.packages.find] 27 | exclude = 28 | unit_tests 29 | imgs 30 | notebooks 31 | example_data 32 | -------------------------------------------------------------------------------- /unit_tests/graphical_models/test_basic_equivalance_class_graph.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from graphical_models import PAG, DAG, arrow_head_types as Mark 3 | from itertools import combinations 4 | 5 | 6 | class TestECGraph(TestCase): 7 | def test_find_adjacent_nodes(self): 8 | # define a PAG: 0 is a common member of all other disconnected nodes 9 | nodes = set(range(10)) 10 | pag = PAG(nodes) 11 | pag.add_edge(0, 1, Mark.Tail, Mark.Directed) 12 | pag.add_edge(0, 6, Mark.Tail, Mark.Directed) 13 | pag.add_edge(0, 4, Mark.Tail, Mark.Directed) 14 | pag.add_edge(0, 3, Mark.Directed, Mark.Directed) 15 | pag.add_edge(0, 5, Mark.Directed, Mark.Directed) 16 | pag.add_edge(0, 2, Mark.Circle, Mark.Directed) 17 | pag.add_edge(0, 7, Mark.Directed, Mark.Tail) 18 | pag.add_edge(0, 8, Mark.Circle, Mark.Circle) 19 | pag.add_edge(0, 9, Mark.Tail, Mark.Tail) 20 | 21 | # test common use (consider any edge-mark/edge-type) 22 | neighbors = pag.find_adjacent_nodes(0, nodes) 23 | self.assertSetEqual(nodes - {0}, neighbors) 24 | neighbors = pag.find_adjacent_nodes(0, nodes - {1, 2}) 25 | self.assertSetEqual(nodes - {0, 1, 2}, neighbors) 26 | 27 | # find neighbors connected via specific edge-type 28 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Tail, Mark.Tail)) 29 | self.assertSetEqual({9}, neighbors) 30 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Tail, Mark.Directed)) 31 | self.assertSetEqual({1, 6, 4}, neighbors) 32 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Directed, Mark.Directed)) 33 | self.assertSetEqual({3, 5}, neighbors) 34 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Directed, Mark.Tail)) 35 | self.assertSetEqual({7}, neighbors) 36 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Directed, Mark.Circle)) 37 | self.assertSetEqual(set(), neighbors) 38 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Circle, Mark.Circle)) 39 | self.assertSetEqual({8}, neighbors) 40 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Circle, Mark.Directed)) 41 | self.assertSetEqual({2}, neighbors) 42 | neighbors = pag.find_adjacent_nodes(0, None, (Mark.Circle, Mark.Tail)) 43 | self.assertSetEqual(set(), neighbors) 44 | 45 | def test_find_reachable_set(self): 46 | nodes = set(range(10)) 47 | pag = PAG(nodes) 48 | pag.add_edge(5, 1, Mark.Tail, Mark.Directed) 49 | pag.add_edge(5, 2, Mark.Directed, Mark.Circle) 50 | pag.add_edge(1, 4, Mark.Tail, Mark.Directed) 51 | pag.add_edge(1, 9, Mark.Tail, Mark.Directed) 52 | pag.add_edge(1, 3, Mark.Circle, Mark.Circle) 53 | pag.add_edge(4, 2, Mark.Circle, Mark.Tail) 54 | pag.add_edge(1, 2, Mark.Circle, Mark.Circle) 55 | pag.add_edge(2, 6, Mark.Circle, Mark.Directed) 56 | pag.add_edge(6, 7, Mark.Circle, Mark.Circle) 57 | pag.add_edge(7, 8, Mark.Tail, Mark.Directed) 58 | pag.add_edge(9, 0, Mark.Directed, Mark.Directed) 59 | 60 | reachable = pag.find_reachable_set(anchor_node=5, nodes_pool=nodes, 61 | edge_type_list=[(Mark.Tail, Mark.Directed), (Mark.Circle, Mark.Directed)]) 62 | self.assertSetEqual({1, 4, 9}, reachable) 63 | 64 | def test_find_unconnected_subgraphs(self): 65 | print('Test finding unconnected sub-graph') 66 | nodes = set(range(10)) 67 | pag = PAG(nodes) 68 | pag.add_edge(0, 4, Mark.Directed, Mark.Circle) 69 | pag.add_edge(4, 2, Mark.Directed, Mark.Directed) 70 | pag.add_edge(2, 8, Mark.Directed, Mark.Directed) 71 | pag.add_edge(3, 5, Mark.Directed, Mark.Directed) 72 | pag.add_edge(6, 5, Mark.Circle, Mark.Tail) 73 | pag.add_edge(7, 9, Mark.Circle, Mark.Circle) 74 | 75 | sub_graphs = pag.find_unconnected_subgraphs() 76 | self.assertEqual(4, len(sub_graphs)) 77 | self.assertIn({0, 2, 4, 8}, sub_graphs) 78 | self.assertIn({3, 5, 6}, sub_graphs) 79 | self.assertIn({7, 9}, sub_graphs) 80 | self.assertIn({1}, sub_graphs) 81 | 82 | # test on a subset of nodes. Nodes outside this subset are considered "blocking" nodes 83 | sub_graphs = pag.find_unconnected_subgraphs(en_nodes=nodes-{2,1}) 84 | self.assertEqual(4, len(sub_graphs)) 85 | self.assertIn({8}, sub_graphs) 86 | self.assertIn({0, 4}, sub_graphs) 87 | self.assertIn({3, 5, 6}, sub_graphs) 88 | self.assertIn({7, 9}, sub_graphs) 89 | 90 | dc_components = pag.find_unconnected_subgraphs(en_nodes=nodes, sym_edge_mark=Mark.Directed) 91 | self.assertEqual(7, len(dc_components)) 92 | self.assertIn({0}, dc_components) 93 | self.assertIn({1}, dc_components) 94 | self.assertIn({2, 4, 8}, dc_components) 95 | self.assertIn({3, 5}, dc_components) 96 | self.assertIn({6}, dc_components) 97 | self.assertIn({7}, dc_components) 98 | self.assertIn({9}, dc_components) 99 | 100 | dc_components = pag.find_unconnected_subgraphs(en_nodes=nodes-{2}, sym_edge_mark=Mark.Directed) 101 | self.assertEqual(8, len(dc_components)) 102 | self.assertIn({0}, dc_components) 103 | self.assertIn({1}, dc_components) 104 | self.assertIn({4}, dc_components) 105 | self.assertIn({8}, dc_components) 106 | self.assertIn({3, 5}, dc_components) 107 | self.assertIn({6}, dc_components) 108 | self.assertIn({7}, dc_components) 109 | self.assertIn({9}, dc_components) 110 | -------------------------------------------------------------------------------- /unit_tests/graphical_models/test_dag.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from graphical_models import DAG 3 | import numpy as np 4 | 5 | class TestDAG(TestCase): 6 | def test_find_topological_order_groups(self): 7 | latent = {'A', 'B', 'C', 'D', 'E'} 8 | observed = {'Sun', 'W', 'T', 'Z', 'V', 'X', 'U', 'Y'} 9 | dag = DAG(latent | observed) 10 | dag.add_edges(parents_set={'A', 'B', 'T', 'V'}, target_node='X') 11 | dag.add_edges(parents_set={'E', 'U'}, target_node='Y') 12 | dag.add_edges(parents_set={'B', 'Z'}, target_node='U') 13 | dag.add_edges(parents_set={'E', 'D'}, target_node='V') 14 | dag.add_edges(parents_set={'Sun', 'C', 'D'}, target_node='Z') 15 | dag.add_edges(parents_set={'W', 'C'}, target_node='T') 16 | dag.add_edges(parents_set={'W', 'A'}, target_node='Sun') 17 | 18 | # correct topological order (found by recursively eliminating leaves) 19 | correct_groups_order = [{'W', 'A'}, {'D', 'C', 'Sun'}, {'B', 'E', 'Z'}, {'V', 'U', 'T'}, {'X', 'Y'}] 20 | 21 | group_list = dag.find_topological_order_groups(latent | observed) 22 | for group_test, group_correct in zip(group_list, correct_groups_order): 23 | self.assertSetEqual(group_test, group_correct) 24 | 25 | def test_get_adj_mat(self): 26 | nodes_list = [2, 4, 6, 1] 27 | dag = DAG(nodes_set=set(nodes_list)) 28 | dag.add_edges(parents_set={2, 4}, target_node=6) # a v-structure 29 | dag.add_edges(parents_set={6}, target_node=1) 30 | 31 | adj_mat1, nodes_list1 = dag.get_adj_mat() 32 | self.assertListEqual(nodes_list1, sorted(nodes_list)) 33 | adj_true = np.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0]]) 34 | self.assertFalse(np.any(adj_true != adj_mat1)) 35 | 36 | adj_mat2 = dag.get_adj_mat(en_nodes_list=nodes_list) 37 | adj_true = np.array([[0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]]) 38 | self.assertFalse(np.any(adj_true != adj_mat2)) 39 | 40 | adj_mat3 = dag.get_adj_mat(en_nodes_list=[6, 2, 4, 1]) 41 | adj_true = np.array([[0, 0, 0, 1], [1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]]) 42 | self.assertFalse(np.any(adj_true != adj_mat3)) 43 | 44 | latent = {'A', 'B', 'C', 'D', 'E'} 45 | observed = {'Sun', 'W', 'T', 'Z', 'V', 'X', 'U', 'Y'} 46 | dag = DAG(latent | observed) 47 | dag.add_edges(parents_set={'A', 'B', 'T', 'V'}, target_node='X') 48 | dag.add_edges(parents_set={'E', 'U'}, target_node='Y') 49 | dag.add_edges(parents_set={'B', 'Z'}, target_node='U') 50 | dag.add_edges(parents_set={'E', 'D'}, target_node='V') 51 | dag.add_edges(parents_set={'Sun', 'C', 'D'}, target_node='Z') 52 | dag.add_edges(parents_set={'W', 'C'}, target_node='T') 53 | dag.add_edges(parents_set={'W', 'A'}, target_node='Sun') 54 | adj = dag.get_adj_mat(en_nodes_list=['A', 'B', 'C', 'D', 'E', 'Sun', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']) 55 | adj_true = np.array([ 56 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0], 57 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], 58 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], 59 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], 60 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], 61 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 62 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 63 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 64 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 65 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], 66 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 67 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 68 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]]) 69 | self.assertFalse(np.any(adj_true != adj)) 70 | 71 | 72 | 73 | 74 | --------------------------------------------------------------------------------