├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── basic.py ├── distributions.py ├── erdos_renyi.py └── generator.py ├── main.py ├── models ├── autoreg_base.py ├── bge_model.py ├── factorised_base.py └── vcn.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .vscode/ 3 | __pycache__/ 4 | .DS_Store 5 | *.ipynb 6 | ======= 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | #Apple 137 | .DS_Store 138 | 139 | out/ 140 | results/ 141 | create_job.py 142 | run_eval_all.py 143 | jobs/ 144 | run_eval_all_.py 145 | del.py 146 | results_new/ 147 | results_tp/ 148 | results_iclr/ 149 | weights_tp/ 150 | plots/ 151 | data/MNIST/ 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Causal Networks 2 | Pytorch implementation of [Variational Causal Networks: Approximate Bayesian Inference over Causal Structures](https://arxiv.org/abs/2106.07635) (Annadani et al. 2021). 3 | 4 | [Yashas Annadani](https://yashasannadani.com), [Jonas Rothfuss](https://las.inf.ethz.ch/people/jonas-rothfuss), [Alexandre Lacoste](https://ca.linkedin.com/in/alexandre-lacoste-4032465), [Nino Scherrer](https://ch.linkedin.com/in/ninoscherrer), [Anirudh Goyal](https://anirudh9119.github.io/), [Yoshua Bengio](https://mila.quebec/en/yoshua-bengio/), [Stefan Bauer](https://www.is.mpg.de/~sbauer) 5 | 6 | 7 | ## Installation 8 | You can install the dependencies using 9 | `pip install -r requirements.txt 10 | ` 11 | 12 | Create Directory structure which looks as follows: `[save_path]/er_1/` 13 | 14 | ## Examples 15 | 16 | Run 17 | 18 | `python main.py --num_nodes [num_nodes] --data_seed [data_seed] --anneal --save_path [save_path]` 19 | 20 | In the paper we run the model on 20 different data seeds to obtain confidence intervals. If you would like to compare with factorised distribution, run: 21 | 22 | `python main.py --num_nodes [num_nodes] --data_seed [data_seed] --anneal --save_path [save_path] --no_autoreg_base` 23 | 24 | ## Contact 25 | 26 | If you have any questions, please address them to: Yashas Annadani `yashas.annadani@gmail.com` 27 | 28 | 29 | 30 | If you use this work, please cite: 31 | 32 | @article{annadani2021variational, 33 | title={Variational Causal Networks: Approximate Bayesian Inference over Causal Structures}, 34 | author={Annadani, Yashas and Rothfuss, Jonas and Lacoste, Alexandre and Scherrer, Nino and Goyal, Anirudh and Bengio, Yoshua and Bauer, Stefan}, 35 | journal={arXiv preprint arXiv:2106.07635}, 36 | year={2021} 37 | } 38 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /data/basic.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import time 4 | import numpy as np 5 | import scipy 6 | import tqdm 7 | from scipy.stats import multivariate_normal 8 | import igraph as ig 9 | import itertools 10 | import torch 11 | from utils import all_combinations, mat_to_graph 12 | 13 | class BasicModel: 14 | """ 15 | Basic observational model 16 | Given 17 | p(G) 18 | 19 | Implements 20 | 21 | p(theta | G) 22 | p(x | theta, G) 23 | 24 | """ 25 | 26 | def __init__(self, *, g_dist, verbose=False, seed = None): 27 | super(BasicModel, self).__init__() 28 | 29 | self.verbose = verbose 30 | self.g_dist = g_dist 31 | self.reseed(seed = seed) 32 | 33 | def reseed(self, seed = None): 34 | if seed is None: 35 | return 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | 39 | def sample_parameters(self, g): 40 | """Samples parameters given igraph.Graph g 41 | For each variable i, sample parameters for every possible state of parents 42 | Returns: 43 | theta 44 | """ 45 | raise NotImplementedError 46 | 47 | def sample_obs(self, n_samples, g, theta = None, toporder=None): 48 | """Samples `n_samples` observations given index i of graph and theta 49 | n_samples : int 50 | g : graph 51 | theta : [n_edges] 52 | Returns: 53 | x : [n_samples, n_vars] 54 | """ 55 | if theta is None: 56 | theta = self.sample_parameters(g) 57 | if toporder is None: 58 | toporder = g.topological_sorting() 59 | 60 | x = np.zeros((n_samples, len(g.vs))) 61 | z = scipy.stats.norm.rvs(loc=0.0, scale=self.sig_obs, size=(n_samples, len(g.vs))) 62 | 63 | # ancestral sampling 64 | for j in toporder: 65 | parent_edges = g.incident(j, mode='in') 66 | parents = list(g.es[e].source for e in parent_edges) 67 | if parents: 68 | mean = x[:, parents] @ theta[parents, j] 69 | x[:, j] = mean + z[:, j] 70 | else: 71 | x[:, j] = z[:, j] 72 | return x.astype(np.float32) 73 | 74 | 75 | def log_prob_parameters(self, theta, g): 76 | """Computes p(theta | G)""" 77 | 78 | raise NotImplementedError 79 | 80 | def log_likelihood(self, x, theta, g): 81 | """Computes p(x | theta, G)""" 82 | 83 | raise NotImplementedError 84 | 85 | 86 | def log_marginal_likelihood_given_g(self, g, x): 87 | """Computes log p(x | G) 88 | x : [n_samples, n_vars] 89 | g: graph 90 | """ 91 | 92 | raise NotImplementedError 93 | 94 | 95 | def log_marginal_likelihood(self, x, all_g, z_g=None, numpy = False): 96 | """Computes log p(x) in closed form using conjugacy properties of Dirichlet-Categorical 97 | x : [n_samples, n_vars] 98 | all_g : list of all possible igraph.Graph objects in domain; is exhaustively summed over 99 | """ 100 | 101 | # log p(x, G) 102 | log_prob_obs_g = np.zeros(len(all_g)) 103 | 104 | # normalizing constant for log p(G) using exhaustive normalization 105 | if z_g is None: 106 | z_g = self.g_dist.log_normalization_constant(all_g=all_g) 107 | 108 | # log p(x, G) 109 | for i, g in enumerate(tqdm.tqdm(all_g, desc='p(X) log_marginal_likelihood', disable=not self.verbose)): 110 | if numpy: 111 | g = mat_to_graph(g) 112 | # log p(x, G) = log (p(G)/Z) + log p(x | G) 113 | log_prob_obs_g[i] = self.g_dist.unnormalized_log_prob(g=g) - z_g \ 114 | + self.log_marginal_likelihood_given_g(g=g, x=x) 115 | 116 | # log p(x) = log(sum_G exp(log p(x, G))) 117 | return scipy.special.logsumexp(log_prob_obs_g) 118 | 119 | def log_posterior_graph_given_obs(self, g, x, z_g, _log_marginal_likelihood = None): 120 | """Computes p(G | D) given the previously computed normalization constant 121 | x : [..., n_vars] 122 | i : int (graph) 123 | """ 124 | 125 | log_prob_g = self.g_dist.unnormalized_log_prob(g=g) - z_g 126 | log_marginal_likelihood_given_g = self.log_marginal_likelihood_given_g( 127 | g=g, x=x) 128 | if _log_marginal_likelihood is None: 129 | return log_prob_g + log_marginal_likelihood_given_g 130 | return log_prob_g + log_marginal_likelihood_given_g - _log_marginal_likelihood 131 | 132 | def sample_posterior_weights_given_obs(self, g, x): 133 | """Computes p(theta | G, D) 134 | x : [..., n_vars] 135 | i : int (graph) 136 | """ 137 | 138 | raise NotImplementedError 139 | #### 140 | # Monte Carlo Integration to validate (closed-form computation) of marginal likelihood 141 | #### 142 | 143 | def log_prob_parameters_mc(self, theta, n_samples=3e4): 144 | """Approximates p(theta) using Monte Carlo integration 145 | theta : parameters 146 | """ 147 | 148 | logliks = [] 149 | for tt in range(int(n_samples)): 150 | 151 | # sample from p(G) 152 | g = self.g_dist.sample_G() 153 | 154 | # evaluate log prob p(theta | G) 155 | logliks.append(self.log_prob_parameters(theta=theta, g=g)) 156 | 157 | # print 158 | if not tt % int(n_samples / 1000) and tt > 0: 159 | curr = scipy.special.logsumexp( 160 | np.array(logliks[:tt + 1]) - np.log(tt + 1)) 161 | print(f'iter = {tt}: log p(theta | G) [MC] = {curr}', end='\r') 162 | 163 | log_prob_obs = scipy.special.logsumexp( 164 | np.array(logliks) - np.log(n_samples)) 165 | return log_prob_obs 166 | 167 | def log_marginal_likelihood_given_g_mc(self, x, g, n_samples=3e4): 168 | """Approximates p(x | G) using Monte Carlo integration 169 | x : [n_samples, n_vars] 170 | g : graph 171 | """ 172 | 173 | logliks = [] 174 | for tt in range(int(n_samples)): 175 | 176 | # sample from p(theta | G) 177 | theta = self.sample_parameters(g=g) 178 | 179 | # evaluate likelihood log p(X | theta, G) 180 | logliks.append(self.log_likelihood(x=x, theta=theta, g=g)) 181 | 182 | # print 183 | if not tt % int(n_samples / 1000) and tt > 0: 184 | curr = scipy.special.logsumexp( 185 | np.array(logliks[:tt + 1]) - np.log(tt + 1)) 186 | print(f'iter = {tt}: log p(X | G) [MC] = {curr}', end='\r') 187 | 188 | log_prob_obs = scipy.special.logsumexp( 189 | np.array(logliks) - np.log(n_samples)) 190 | return log_prob_obs 191 | 192 | def log_marginal_likelihood_mc(self, x, n_samples=3e4): 193 | """Approximates normalization constant p(x) using Monte Carlo integration 194 | x : [n_samples, n_vars] 195 | """ 196 | 197 | logliks = [] 198 | for tt in range(int(n_samples)): 199 | 200 | # sample from p(G, theta) = p(G) p(theta | G) 201 | g = self.g_dist.sample_G() 202 | theta = self.sample_parameters(g=g) 203 | 204 | # evaluate likelihood log p(X | theta, G) 205 | logliks.append(self.log_likelihood(x=x, theta=theta, g=g)) 206 | 207 | # print 208 | if not tt % int(n_samples / 1000) and tt > 0: 209 | curr = scipy.special.logsumexp( 210 | (logliks[:tt + 1] - np.log(tt + 1))) 211 | print(f'iter = {tt}: log p(X) [MC] = {curr}', end='\r') 212 | print() 213 | log_prob_obs = scipy.special.logsumexp((logliks - np.log(n_samples))) 214 | return log_prob_obs 215 | -------------------------------------------------------------------------------- /data/distributions.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tqdm 4 | import scipy 5 | from scipy.stats import multivariate_normal 6 | import itertools 7 | import networkx as nx 8 | 9 | from utils import expm_np, all_combinations 10 | import torch 11 | 12 | class GraphDistribution: 13 | """ 14 | Class to represent distributions over graphs. 15 | """ 16 | 17 | def __init__(self, n_vars, verbose=False): 18 | self.n_vars = n_vars 19 | self.verbose = verbose 20 | 21 | def sample_G(self, return_mat=False): 22 | """ 23 | Samples graph according to distribution 24 | 25 | n: number of vertices 26 | Returns: 27 | g: igraph.Graph 28 | """ 29 | raise NotImplementedError 30 | 31 | def unnormalized_log_prob(self, g): 32 | """ 33 | g: igraph.Graph object 34 | Returns: 35 | float log p(G) + const, i.e. unnormalized 36 | """ 37 | raise NotImplementedError 38 | 39 | def log_normalization_constant(self, all_g): 40 | """ 41 | Computes normalization constant for log p(G), i.e. `Z = log(sum_G p(g))` 42 | all_g: list of igraph.Graph objects 43 | Returns: 44 | float 45 | """ 46 | log_prob_g_unn = np.zeros(len(all_g)) 47 | for i, g in enumerate(tqdm.tqdm(all_g, desc='p(G) log_normalization_constant', disable=not self.verbose)): 48 | log_prob_g_unn[i] = self.unnormalized_log_prob(g=g) 49 | log_prob_sum_g = scipy.special.logsumexp(log_prob_g_unn) 50 | return log_prob_sum_g 51 | 52 | 53 | 54 | 55 | 56 | 57 | class UniformDAGDistributionRejection(GraphDistribution): 58 | """ 59 | Uniform distribution over DAGs 60 | """ 61 | 62 | def __init__(self, n_vars, verbose=False): 63 | super(UniformDAGDistributionRejection, self).__init__(n_vars=n_vars, verbose=verbose) 64 | self.n_vars = n_vars 65 | self.verbose = verbose 66 | 67 | def sample_G(self, return_mat=False): 68 | """Samples uniformly random DAG""" 69 | while True: 70 | mat = np.random.choice(2, size=self.n_vars * self.n_vars).reshape(self.n_vars, self.n_vars) 71 | if expm_np(mat) == 0: 72 | if return_mat: 73 | return mat 74 | else: 75 | return nx.DiGraph(mat) 76 | 77 | def unnormalized_log_prob(self, g): 78 | """ 79 | p(G) ~ 1 80 | """ 81 | 82 | return 0.0 83 | 84 | class GibbsUniformDAGDistribution(GraphDistribution): 85 | """ 86 | Almost Uniform distribution over DAGs based on the DAG constraint 87 | """ 88 | 89 | def __init__(self, n_vars, gibbs_temp=10., sparsity_factor = 0.0, verbose=False): 90 | super(GibbsUniformDAGDistribution, self).__init__(n_vars=n_vars, verbose=verbose) 91 | self.n_vars = n_vars 92 | self.verbose = verbose 93 | self.gibbs_temp = gibbs_temp 94 | self.sparsity_factor = sparsity_factor 95 | self.z_g = None 96 | 97 | def sample_G(self, return_mat=False): 98 | """Samples almost uniformly random DAG""" 99 | raise NotImplementedError 100 | 101 | def unnormalized_log_prob(self, g): 102 | """ 103 | p(G) ~ 1 104 | """ 105 | mat = g 106 | dagness = expm_np(mat, self.n_vars) 107 | return -self.gibbs_temp*dagness - self.sparsity_factor*np.sum(mat) 108 | 109 | class GibbsDAGDistributionFull(GraphDistribution): 110 | """ 111 | Almost Uniform distribution over DAGs based on the DAG constraint 112 | """ 113 | 114 | def __init__(self, n_vars, gibbs_temp=10., sparsity_factor = 0.0, verbose=False): 115 | super(GibbsDAGDistributionFull, self).__init__(n_vars=n_vars, verbose=verbose) 116 | assert n_vars<=4, 'Cannot use this for higher dimensional variables, Try UniformDAGDistributionRejection instead' 117 | self.n_vars = n_vars 118 | self.verbose = verbose 119 | self.gibbs_temp = gibbs_temp 120 | self.sparsity_factor = sparsity_factor 121 | all_g = all_combinations(n_vars, return_adj = True) #Do not store this in interest of memory 122 | dagness = np.zeros(len(all_g)) 123 | for i, j in enumerate(all_g): 124 | dagness[i] = expm_np(j, self.n_vars) 125 | self.logits = -gibbs_temp*dagness - sparsity_factor*np.sum(all_g, axis = (-1, -2)) 126 | self.z_g = scipy.special.logsumexp(self.logits) 127 | 128 | def sample_G(self, return_mat=False): 129 | """Samples almost uniformly random DAG""" 130 | all_g = all_combinations(self.n_vars, return_adj = True) 131 | mat_id = torch.distributions.Categorical(logits = torch.tensor(self.logits)).sample() 132 | mat = all_g[mat_id] 133 | if return_mat: 134 | return mat 135 | else: 136 | return nx.DiGraph(mat) 137 | 138 | def unnormalized_log_prob(self, g): 139 | """ 140 | p(G) ~ 1 141 | """ 142 | mat = g 143 | dagness = expm_np(mat, self.n_vars) 144 | return -self.gibbs_temp*dagness - self.sparsity_factor*np.sum(mat) -------------------------------------------------------------------------------- /data/erdos_renyi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .generator import Generator 4 | import networkx as nx 5 | import graphical_models 6 | 7 | class ER(Generator): 8 | """Generate erdos renyi random graphs using networkx's native random graph builder 9 | Args: 10 | num_nodes - Number of Nodes in the graph 11 | exp_edges - Expected Number of edges in Erdos Renyi graph 12 | noise_type - Type of exogenous variables 13 | noise_sigma - Std of the noise type 14 | num_sampels - number of observations 15 | mu_prior - prior of weights mean(gaussian) 16 | sigma_prior - prior of weights sigma (gaussian) 17 | seed - random seed for data 18 | """ 19 | 20 | def __init__(self, num_nodes, exp_edges = 1, noise_type='isotropic-gaussian', noise_sigma = 1.0, num_samples=1000, mu_prior = 2.0, sigma_prior = 1.0, seed = 10): 21 | self.noise_sigma = noise_sigma 22 | p = float(exp_edges)/ (num_nodes-1) 23 | acyclic = 0 24 | mmec = 0 25 | count = 1 26 | while not (acyclic and mmec): 27 | if exp_edges <= 2: 28 | self.graph = nx.generators.random_graphs.fast_gnp_random_graph(num_nodes, p, directed = True, seed = seed*count) 29 | else: 30 | self.graph = nx.generators.random_graphs.gnp_random_graph(num_nodes, p, directed = True, seed = seed*count) 31 | acyclic = expm_np(nx.to_numpy_matrix(self.graph), num_nodes) == 0 32 | if acyclic: 33 | mmec = num_mec(self.graph) >=2 34 | count += 1 35 | super().__init__(num_nodes, len(self.graph.edges), noise_type, num_samples, mu_prior = mu_prior , sigma_prior = sigma_prior, seed = seed) 36 | self.init_sampler() 37 | self.samples = self.sample(self.num_samples) 38 | 39 | def __getitem__(self, index): 40 | return self.samples[index] 41 | 42 | def matrix_poly_np(matrix, d): 43 | x = np.eye(d) + matrix/d 44 | return np.linalg.matrix_power(x, d) 45 | 46 | def expm_np(A, m): 47 | expm_A = matrix_poly_np(A, m) 48 | h_A = np.trace(expm_A) - m 49 | return h_A 50 | 51 | def num_mec(m): 52 | a = graphical_models.DAG.from_nx(m) 53 | skeleton = a.cpdag() ##Find the skeleton 54 | all_dags = skeleton.all_dags() #Find all DAGs in MEC 55 | return len(all_dags) 56 | -------------------------------------------------------------------------------- /data/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import networkx as nx 4 | PRESETS = ['chain', 'collider','fork', 'random'] 5 | NOISE_TYPES = ['gaussian', 'isotropic-gaussian', 'exponential', 'gumbel'] 6 | VARIABLE_TYPES = ['gaussian', 'non-gaussian', 'categorical'] 7 | 8 | class Generator(torch.utils.data.Dataset): 9 | 10 | """ Base class for generating different graphs and performing ancestral sampling""" 11 | 12 | def __init__(self, num_nodes, num_edges, noise_type, num_samples, mu_prior = None, sigma_prior = None, seed = None): 13 | self.num_nodes = num_nodes 14 | self.num_edges = num_edges 15 | assert noise_type in NOISE_TYPES, 'Noise types must correspond to {} but got {}'.format(NOISE_TYPES, noise_type) 16 | self.noise_type = noise_type 17 | self.num_samples = num_samples 18 | self.mu_prior = mu_prior 19 | self.sigma_prior = sigma_prior 20 | if seed is not None: 21 | self.reseed(seed) 22 | if not "self.weighted_adjacency_matrix" in locals(): 23 | self.sample_weights() 24 | self.build_graph() 25 | 26 | def reseed(self, seed = None): 27 | torch.manual_seed(seed) 28 | np.random.seed(seed) 29 | 30 | def __getitem__(self, index): 31 | raise NotImplementedError 32 | 33 | def build_graph(self): 34 | """ Initilises the adjacency matrix and the weighted adjacency matrix""" 35 | 36 | self.adjacency_matrix = nx.to_numpy_matrix(self.graph) 37 | self.weighted_adjacency_matrix = self.adjacency_matrix.copy() 38 | edge_pointer = 0 39 | for i in nx.topological_sort(self.graph): 40 | parents = list(self.graph.predecessors(i)) 41 | if len(parents) == 0: 42 | continue 43 | else: 44 | for j in parents: 45 | self.weighted_adjacency_matrix[j, i] = self.weights[edge_pointer] 46 | edge_pointer += 1 47 | 48 | def init_sampler(self): 49 | if self.noise_type.endswith('gaussian'): 50 | #Identifiable 51 | if self.noise_type == 'isotropic-gaussian': 52 | noise_std= [self.noise_sigma]*self.num_nodes 53 | elif self.noise_type == 'gaussian': 54 | noise_std = np.linspace(0.1, 3., self.num_nodes) 55 | for i in range(self.num_nodes): 56 | self.graph.nodes[i]['sampler'] = torch.distributions.normal.Normal(0., noise_std[i]) 57 | 58 | elif self.noise_type == 'exponential': 59 | noise_std= [self.noise_sigma]*self.num_nodes 60 | for i in range(self.num_nodes): 61 | self.graph.nodes[i]['sampler'] = torch.distributions.exponential.Exponential(noise_std[i]) 62 | 63 | def sample_weights(self): 64 | """Sample the edge weights""" 65 | 66 | if self.mu_prior is not None: 67 | 68 | self.weights = torch.distributions.normal.Normal(self.mu_prior, self.sigma_prior).sample([self.num_edges]) 69 | else: 70 | dist = torch.distributions.uniform.Uniform(-5, 5) 71 | self.weights = torch.zeros(self.num_edges) 72 | for k in range(self.num_edges): 73 | sample = 0. 74 | while sample > -0.5 and sample < 0.5: 75 | sample = dist.sample() 76 | self.weights[k] = sample 77 | 78 | print(self.weights) 79 | 80 | def sample(self, num_samples, graph = None, node = None, value = None): 81 | """Sample observations given a graph 82 | num_samples: Scalar 83 | graph: networkx DiGraph 84 | node: If intervention is performed, specify which node 85 | value: value set to node after intervention 86 | 87 | Outputs: Observations [num_samples x num_nodes] 88 | """ 89 | 90 | if graph is None: 91 | graph = self.graph 92 | 93 | samples = torch.zeros(num_samples, self.num_nodes) 94 | edge_pointer = 0 95 | for i in nx.topological_sort(graph): 96 | if i == node: 97 | noise = torch.tensor([value]*num_samples) 98 | else: 99 | noise = self.graph.nodes[i]['sampler'].sample([num_samples]) 100 | parents = list(self.graph.predecessors(i)) 101 | if len(parents) == 0: 102 | samples[:,i] = noise 103 | else: 104 | curr = 0. 105 | for j in parents: 106 | curr += self.weighted_adjacency_matrix[j, i]*samples[:,j] 107 | edge_pointer += 1 108 | curr += noise 109 | samples[:, i] = curr 110 | return samples 111 | 112 | def intervene(self, num_samples, node = None, value = None): 113 | 114 | """Perform intervention to obtain a mutilated graph""" 115 | 116 | if node is None: 117 | node = torch.randint(self.num_nodes, (1,)) 118 | if value is None: 119 | #value = torch.distributions.uniform.Uniform(-5,5).sample() 120 | value = torch.tensor(2.0) 121 | 122 | mutated_graph = self.adjacency_matrix.copy() 123 | mutated_graph[:, node] = 0. #Cut off all the parents 124 | 125 | return self.sample(num_samples, nx.DiGraph(mutated_graph), node.item(), value.item()), node, value 126 | 127 | def __len__(self): 128 | return self.num_samples 129 | 130 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import argparse 6 | from datetime import datetime 7 | import pickle as pkl 8 | import shutil 9 | import networkx as nx 10 | import time 11 | 12 | import utils 13 | import matplotlib.pyplot as plt 14 | from models import vcn, autoreg_base, factorised_base, bge_model 15 | from data import erdos_renyi, distributions 16 | import graphical_models 17 | from sklearn import metrics 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Variational Causal Networks') 21 | parser.add_argument('--save_path', type=str, default = 'results_anneal/', 22 | help='Path to save result files') 23 | parser.add_argument('--no_autoreg_base', action='store_true', default=False, 24 | help='Use factorisable disrtibution') 25 | parser.add_argument('--seed', type=int, default=10, 26 | help='random seed (default: 10)') 27 | parser.add_argument('--data_seed', type=int, default=20, 28 | help='random seed for generating data(default: 20)') 29 | parser.add_argument('--batch_size', type=int, default=1000, 30 | help='Batch Size for training') 31 | parser.add_argument('--lr', type=float, default=1e-2, 32 | help='Learning rate') 33 | parser.add_argument('--gibbs_temp', type=float, default=1000.0, 34 | help='Temperature for the Graph Gibbs Distribution') 35 | parser.add_argument('--sparsity_factor', type=float, default=0.001, 36 | help='Hyperparameter for sparsity regularizer') 37 | parser.add_argument('--epochs', type=int, default=30000, 38 | help='Number of iterations to train') 39 | parser.add_argument('--num_nodes', type=int, default=2, 40 | help='Number of nodes in the causal model') 41 | parser.add_argument('--num_samples', type=int, default=100, 42 | help='Total number of samples in the synthetic data') 43 | parser.add_argument('--noise_type', type=str, default='isotropic-gaussian', 44 | help='Type of noise of causal model') 45 | parser.add_argument('--noise_sigma', type=float, default=1.0, 46 | help='Std of Noise Variables') 47 | parser.add_argument('--theta_mu', type=float, default=2.0, 48 | help='Mean of Parameter Variables') 49 | parser.add_argument('--theta_sigma', type=float, default=1.0, 50 | help='Std of Parameter Variables') 51 | parser.add_argument('--data_type', type=str, default='er', 52 | help='Type of data') 53 | parser.add_argument('--exp_edges', type=float, default=1.0, 54 | help='Expected number of edges in the random graph') 55 | parser.add_argument('--eval_only', action='store_true', default=False, 56 | help='Perform Just Evaluation') 57 | parser.add_argument('--anneal', action='store_true', default=False, 58 | help='Perform gibbs temp annealing') 59 | 60 | args = parser.parse_args() 61 | args.data_size = args.num_nodes * (args.num_nodes-1) 62 | root = args.save_path 63 | list_dir = os.listdir(args.save_path) 64 | args.save_path = os.path.join(args.save_path, args.data_type + '_' + str(int(args.exp_edges)), str(args.num_nodes) + '_' + str(args.seed) + '_' + str(args.data_seed) + '_' + str(args.num_samples) + '_' + \ 65 | str(args.sparsity_factor) +'_' + str(args.gibbs_temp) + '_' + str(args.no_autoreg_base)) 66 | if not os.path.exists(args.save_path): 67 | os.makedirs(args.save_path) 68 | if args.num_nodes == 2: 69 | args.exp_edges = 0.8 70 | 71 | args.gibbs_temp_init = 10. 72 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | 74 | torch.manual_seed(args.seed) 75 | np.random.seed(args.seed) 76 | 77 | return args 78 | 79 | def auroc(model, ground_truth, num_samples = 1000): 80 | """Compute the AUROC of the model as given in 81 | https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0009202""" 82 | 83 | gt = utils.adj_mat_to_vec(torch.from_numpy(ground_truth).unsqueeze(0), model.num_nodes).numpy().squeeze() 84 | num_nodes = model.num_nodes 85 | bs = 10000 86 | i = 0 87 | samples = [] 88 | with torch.no_grad(): 89 | while i