├── src ├── __init__.py ├── utils │ ├── metrics.py │ └── plotting.py ├── environments │ ├── experiment.py │ ├── generic_environments.py │ └── environment.py ├── experimental_design │ ├── optimization.py │ ├── exp_designer_base.py │ ├── exp_designer_abci_categorical_gp.py │ └── exp_designer_abci_dibs_gp.py ├── scripts │ └── run_single_env.py ├── abci_base.py ├── abci_categorical_gp.py ├── models │ ├── mechanisms.py │ ├── gp_model.py │ └── graph_models.py └── abci_dibs_gp.py ├── .gitignore ├── environment.yaml ├── LICENCE ├── README.md └── notebooks ├── generate_benchmark_envs.ipynb ├── example_abci_dibs_gp.ipynb ├── plot_benchmark_results.ipynb └── example_abci_categorical_gp.ipynb /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **idea/ 2 | **__pycache__/ 3 | /results 4 | /data 5 | /figures 6 | /notebooks/.ipynb_checkpoints/ 7 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: abci 2 | channels: 3 | - conda-forge 4 | - gpytorch 5 | - pytorch 6 | - defaults 7 | dependencies: 8 | - botorch=0.5.1 9 | - cudatoolkit=10.2.89 10 | - jupyter=1.0.0 11 | - matplotlib=3.3.4 12 | - networkx=2.5 13 | - numpy=1.19.2 14 | - pip=21.2.2 15 | - python=3.7.10 16 | - pytorch=1.9.1 17 | - scikit-learn=1.0.2 18 | - scipy=1.6.1 19 | - torchaudio=0.9.1 20 | - torchvision=0.10.1 21 | - pip: 22 | - cdt==0.5.23 23 | - gpytorch==1.5.1 24 | - pyro-ppl==1.7.0 25 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | from cdt.metrics import SHD 4 | from sklearn.metrics import precision_recall_curve, roc_curve, auc 5 | 6 | 7 | def auroc(posterior_edge_probs: torch.Tensor, true_adj_mat: torch.Tensor): 8 | assert posterior_edge_probs.squeeze().shape == true_adj_mat.squeeze().shape 9 | edge_probs = posterior_edge_probs.detach().view(-1).numpy() 10 | targets = true_adj_mat.int().view(-1).numpy() 11 | 12 | fpr, tpr, _ = roc_curve(targets, edge_probs) 13 | return auc(fpr, tpr) 14 | 15 | 16 | def auprc(posterior_edge_probs: torch.Tensor, true_adj_mat: torch.Tensor): 17 | assert posterior_edge_probs.squeeze().shape == true_adj_mat.squeeze().shape 18 | edge_probs = posterior_edge_probs.detach().view(-1).numpy() 19 | targets = true_adj_mat.int().view(-1).numpy() 20 | 21 | precision, recall, _ = precision_recall_curve(targets, edge_probs) 22 | return auc(recall, precision) 23 | 24 | 25 | def shd(target_graph: nx.DiGraph, predicted_graph: nx.DiGraph): 26 | return torch.tensor(SHD(target_graph, predicted_graph)).float() 27 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Christian Toth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/environments/experiment.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import networkx as nx 4 | 5 | from src.models.graph_models import get_parents 6 | from src.models.mechanisms import * 7 | 8 | 9 | class Experiment: 10 | def __init__(self, interventions: dict, data: Dict[str, torch.tensor]): 11 | num_batches, batch_size = list(data.values())[0].shape[0:2] 12 | assert all([node_data.shape == (num_batches, batch_size, 1) for node_data in data.values()]) 13 | self.interventions = interventions 14 | self.num_batches = num_batches 15 | self.batch_size = batch_size 16 | self.data = data 17 | 18 | 19 | def gather_data(experiments: List[Experiment], node: str, graph: nx.DiGraph = None, parents: List[str] = None, 20 | mode: str = 'joint'): 21 | assert graph is not None or parents is not None 22 | assert mode in {'joint', 'independent_batches', 'independent_samples'}, print('Invalid gather mode: ', mode) 23 | 24 | # gather targets 25 | if mode == 'independent_batches': 26 | batch_size = experiments[0].batch_size 27 | assert all([experiment.batch_size == batch_size for experiment in experiments]), print('Batch size mismatch!') 28 | targets = [exp.data[node].squeeze(-1) for exp in experiments if node not in exp.interventions] 29 | elif mode == 'independent_samples': 30 | targets = [exp.data[node].view(-1, 1) for exp in experiments if node not in exp.interventions] 31 | else: # mode == 'joint' 32 | targets = [exp.data[node].reshape(-1) for exp in experiments if node not in exp.interventions] 33 | 34 | # check if we have data for this node 35 | if not targets: 36 | return None, None 37 | targets = torch.cat(targets, dim=0) 38 | 39 | # check if we have parents 40 | parents = sorted(parents) if graph is None else get_parents(node, graph) 41 | if not parents: 42 | return None, targets 43 | 44 | # gather parent data 45 | num_parents = len(parents) 46 | if mode == 'independent_batches': 47 | inputs = [torch.cat([experiment.data[parent] for parent in parents], dim=-1) for experiment in experiments if 48 | node not in experiment.interventions] 49 | elif mode == 'independent_samples': 50 | inputs = [torch.cat([experiment.data[parent] for parent in parents], dim=-1).view(-1, 1, num_parents) for 51 | experiment in experiments if node not in experiment.interventions] 52 | else: # mode == 'joint' 53 | inputs = [torch.cat([experiment.data[parent] for parent in parents], dim=-1).view(-1, num_parents) for 54 | experiment in experiments if node not in experiment.interventions] 55 | 56 | inputs = torch.cat(inputs, dim=0) 57 | 58 | return inputs, targets 59 | -------------------------------------------------------------------------------- /src/experimental_design/optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from botorch.acquisition import UpperConfidenceBound 3 | from botorch.models import SingleTaskGP 4 | from botorch.optim.fit import fit_gpytorch_torch 5 | from gpytorch.mlls import ExactMarginalLogLikelihood 6 | 7 | 8 | def gp_ucb(utility: callable, bounds: torch.Tensor, num_total_candidates=8, num_initial_candidates=1): 9 | assert bounds.shape == (2, 1), print(bounds.shape) 10 | 11 | # generate initial candidates 12 | candidate_list = [(torch.rand(1) * (bounds[1] - bounds[0]) + bounds[0]).unsqueeze(1) for _ in 13 | range(num_initial_candidates)] 14 | utility_list = [utility(candidate.squeeze().item()).squeeze() for candidate in candidate_list] 15 | 16 | # search for better candidates 17 | try: 18 | xrange = torch.linspace(bounds[0].item(), bounds[1].item(), 50).view(-1, 1, 1) 19 | for i in range(num_total_candidates - num_initial_candidates): 20 | # fit acquisition function GP 21 | gp = SingleTaskGP(torch.cat(candidate_list).view(-1, 1), torch.stack(utility_list).view(-1, 1)) 22 | mll = ExactMarginalLogLikelihood(gp.likelihood, gp) 23 | mll.train() 24 | fit_gpytorch_torch(mll, options={'disp': False, 'maxiter': 50}) 25 | # fit_gpytorch_scipy(mll) 26 | 27 | # get new candidate 28 | ucb = UpperConfidenceBound(gp, beta=1.)(xrange) 29 | shuffle_idc = torch.randperm(ucb.numel()) 30 | argmax = ucb[shuffle_idc].argmax() 31 | candidate = xrange[shuffle_idc[argmax]] 32 | # candidate, acq_value = optimize_acqf(ucb, bounds=bounds, q=1, num_restarts=1, raw_samples=100) 33 | 34 | # record candidate and objective 35 | candidate_list.append(candidate) 36 | utility_list.append(utility(candidate.squeeze().item()).squeeze()) 37 | except Exception as e: 38 | print('Exception occured when running GP UCB:') 39 | print(e) 40 | print('Continuing with the best candidate from ', candidate_list) 41 | print('with utilities ', utility_list) 42 | 43 | # return best candidate and objective values 44 | best_candidate = candidate_list[torch.stack(utility_list).argmax().item()].squeeze() 45 | best_utility = torch.stack(utility_list).max().item() 46 | return best_candidate, best_utility 47 | 48 | 49 | def random_search(utility: callable, bounds: torch.Tensor, num_candidates=10): 50 | assert bounds.shape == (2, 1), print(bounds.shape) 51 | 52 | # generate initial candidates 53 | candidate_list = [torch.rand(1) * (bounds[1] - bounds[0]) + bounds[0] for _ in range(num_candidates)] 54 | utility_list = [utility(candidate.item()).squeeze() for candidate in candidate_list] 55 | 56 | # return best candidate and objective values 57 | best_candidate = candidate_list[torch.stack(utility_list).argmax().item()] 58 | best_utility = torch.stack(utility_list).max().item() 59 | return best_candidate, best_utility 60 | 61 | 62 | def grid_search(utility: callable, bounds: torch.Tensor, num_candidates=10): 63 | assert bounds.shape == (2, 1), print(bounds.shape) 64 | 65 | # generate initial candidates 66 | candidate_list = torch.linspace(bounds[0].item(), bounds[1].item(), num_candidates) 67 | utility_list = [utility(candidate.item()).squeeze() for candidate in candidate_list] 68 | 69 | # return best candidate and objective values 70 | best_candidate = candidate_list[torch.stack(utility_list).argmax().item()] 71 | best_utility = torch.stack(utility_list).max().item() 72 | return best_candidate, best_utility 73 | -------------------------------------------------------------------------------- /src/experimental_design/exp_designer_base.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import namedtuple 3 | from typing import Dict, Tuple, Set 4 | 5 | import torch.distributed.rpc as rpc 6 | import torch.optim 7 | 8 | from src.experimental_design.optimization import random_search, grid_search, gp_ucb 9 | 10 | Design = namedtuple('Design', ['interventions', 'info_gain']) 11 | 12 | 13 | class ExpDesignerBase: 14 | def __init__(self, intervention_bounds: Dict[str, Tuple[float, float]], opt_strategy: str = 'gp-ucb', 15 | distributed=False): 16 | self.worker_id = rpc.get_worker_info().id if distributed else 0 17 | self.intervention_bounds = intervention_bounds 18 | if opt_strategy not in {'gp-ucb', 'random', 'grid'}: 19 | print('Invalid optimization strategy ' + opt_strategy + '. Doing Bayesian optimization instead.') 20 | opt_strategy = 'gp-ucb' 21 | self.opt_strategy = opt_strategy 22 | self.utility = None 23 | 24 | def init_design_process(self, args: dict): 25 | raise NotImplementedError 26 | 27 | def run_distributed(self, experimenter_rref, args: dict): 28 | experimenter_rref.rpc_sync().report_status(self.worker_id, 'Initializing design process...') 29 | self.init_design_process(args) 30 | experimenter_rref.rpc_sync().report_status(self.worker_id, 'Finished initializing design process...') 31 | 32 | target_node = experimenter_rref.rpc_sync().get_target(self.worker_id) 33 | while target_node: 34 | design = self.design_experiment(target_node) 35 | experimenter_rref.rpc_sync().report_design(self.worker_id, target_node, design) 36 | target_node = experimenter_rref.rpc_sync().get_target(self.worker_id) 37 | 38 | def design_experiment(self, target_node: str): 39 | 40 | # if no target is given report info gain of observational sample 41 | if target_node == 'OBSERVATIONAL': 42 | try: 43 | score = self.utility({}) 44 | except Exception as e: 45 | print(f'Exception occured in ExperimentDesigner.design_experiment() when the score for the ' 46 | f'observational target:') 47 | print(e) 48 | score = torch.tensor(0.) 49 | return Design({}, score) 50 | 51 | # otherwise, design experiment for target node 52 | bounds = torch.Tensor(self.intervention_bounds[target_node]).view(2, 1) 53 | if self.opt_strategy == 'random': 54 | target_value, score = random_search(lambda x: self.utility({target_node: x}), bounds) 55 | elif self.opt_strategy == 'grid': 56 | target_value, score = grid_search(lambda x: self.utility({target_node: x}), bounds) 57 | else: 58 | target_value, score = gp_ucb(lambda x: self.utility({target_node: x}), bounds) 59 | 60 | return Design({target_node: target_value}, score) 61 | 62 | def get_best_experiment(self, target_nodes: Set[str]): 63 | best_intervention = {} 64 | best_score = self.utility({}) 65 | print(f'Expected information gain for observational sample is {best_score}.') 66 | for target_node in target_nodes: 67 | print(f'Start experiment design for node {target_node} at {time.strftime("%H:%M:%S")}') 68 | design = self.design_experiment(target_node) 69 | print(f'Expected information gain for {design.interventions} is {design.info_gain}.') 70 | if design.info_gain > best_score: 71 | best_score = design.info_gain 72 | best_intervention = design.interventions 73 | 74 | return best_intervention, best_score 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Active Bayesian Causal Inference 2 | 3 | This repository contains the implementation of the Active Bayesian Causal Inference framework for non-linear additive Gaussian noise models as described in our [NeurIPS'22 ABCI paper](https://arxiv.org/abs/2206.02063). In summary, it provides functionality for generating groundtruth environments, running ABCI of course, and generating plots as in the paper. We also provide example notebooks to illustrate the basic usage of the code base and get you started quickly. Feel free to reach out if you have questions about the paper or code! 4 | 5 | > [!NOTE] 6 | > In case you are interested in *running our framework on static datasets*, we provide an improved and extended implementation at [https://github.com/chritoth/bci-arco-gp](https://github.com/chritoth/bci-arco-gp). The main differences are parameter handling via a global config file, more efficient training and inference, additional functionality for storing and loading datasets, and a novel graph inference model based on orders as an alternative to DiBS. Check it out! 7 | 8 | 9 | ## Getting Started 10 | 11 | ##### Python Environment Setup 12 | 13 | These instructions should help you set up a suitable Python environment. We recommend to use [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for easily recreating the Python environment. You can install the latest version like so: 14 | 15 | ``` 16 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh 17 | chmod +x Miniconda3-latest-Linux-x86_64.sh 18 | ./Miniconda3-latest-Linux-x86_64.sh 19 | ``` 20 | Once you have Miniconda installed, create a virtual environment from the included environment description in `environment.yaml` like so: 21 | 22 | ``` 23 | conda env create -f environment.yaml 24 | ``` 25 | Finally, activate the conda environment via 26 | ``` 27 | conda activate abci 28 | ``` 29 | and set your python path to the project root 30 | ``` 31 | export PYTHONPATH="${PYTHONPATH}:/path/to/abci" 32 | ``` 33 | 34 | #### Running the Code 35 | 36 | You can get started and play around with the example notebooks `example_abci_categorical_gp.ipynb` or `example_abci_dibs_gp.ipynb` to get the gist of how to use the code base. For running larger examples you first need to generate benchmark environments (e.g. with `generate_benchmark_envs.ipynb`), run ABCI by starting either one of the scripts in `./src/scripts/`, and then plotting the results in `plot_benchmark_results.ipynb`. If you prefer to run ABCI from the command line, you can use the script `./src/scripts/run_single_env.py` (see e.g. `python run_single_env.py -h ` for usage instructions). 37 | 38 | You can implement your own ground truth models by building upon the `Environment` base class in `./src/environments/environments.py`. In principle it is also possible to run the Bayesian causal inference part of this implementation on a static dataset without the active learning part. 39 | 40 | ### Project Layout 41 | 42 | The following gives you a brief overview on the organization and contents of this project. Note: in general it should be clear where to change the default paths in the scripts and notebooks, but if you don't want to waste any time just use the default project structure. 43 | 44 | ``` 45 | │ 46 | ├── README.md <- This readme file. 47 | │ 48 | ├── environment.yml <- The Python environment spec for running the code in this project. 49 | │ 50 | ├── data <- Directory for generated ground truth models. 51 | │ 52 | ├── figures <- Output directory for generated figures. 53 | │ 54 | ├── notebooks <- Jupyter notebooks for running interactive experiments and analysis. 55 | | 56 | ├── results <- Simulation results. 57 | | 58 | ├── src <- Contains the Python source code of this project. 59 | │   ├── __init__.py <- Makes src a Python module 60 | │   ├── abci_base.py <- ABCI base class. 61 | │   ├── abci_categorical_gp.py <- ABCI with categorical distribution over graphs & GP models. 62 | │   ├── abci_dibs_gp.py <- ABCI with DiBS approximate graph inference & GP models. 63 | │   ├── environments <- Everything pertaining to ground truth environments. 64 | │   ├── experimental_design <- Everything pertaining to experimental design (utility functions, optimization,...). 65 | │   ├── models <- Everything pertaining to models (DiBS, GPs, ...) 66 | │   ├── scripts <- Scripts for running experiments. 67 | │   ├── utils <- Utils for plotting, metrics,... 68 | │ 69 | ``` 70 | -------------------------------------------------------------------------------- /src/scripts/run_single_env.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import socketserver 5 | import string 6 | import time 7 | 8 | import torch.distributed.rpc as rpc 9 | import torch.multiprocessing as mp 10 | 11 | from src.abci_categorical_gp import ABCICategoricalGP 12 | from src.abci_dibs_gp import ABCIDiBSGP 13 | from src.environments.generic_environments import * 14 | 15 | MODELS = {'abci-dibs-gp', 'abci-dibs-gp-linear', 'abci-categorical-gp', 'abci-categorical-gp-linear'} 16 | 17 | 18 | def spawn_abci_model(abci_model, env, policy, num_workers): 19 | assert abci_model in MODELS, print(f'Invalid ABCI model {abci_model}') 20 | 21 | if abci_model == 'abci-dibs-gp': 22 | return ABCIDiBSGP(env, policy, num_workers=num_workers) 23 | 24 | if abci_model == 'abci-dibs-gp-linear': 25 | return ABCIDiBSGP(env, policy, num_workers=num_workers, linear=True) 26 | 27 | if abci_model == 'abci-categorical-gp': 28 | return ABCICategoricalGP(env, policy, num_workers=num_workers) 29 | 30 | if abci_model == 'abci-categorical-gp-linear': 31 | return ABCICategoricalGP(env, policy, num_workers=num_workers, linear=True) 32 | 33 | 34 | def get_free_port(): 35 | with socketserver.TCPServer(("localhost", 0), None) as s: 36 | free_port = s.server_address[1] 37 | return free_port 38 | 39 | 40 | def generate_job_id(): 41 | return ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(6)]) 42 | 43 | 44 | def run_worker(rank: int, env: Environment, master_port: str, output_dir: str, policy: str, num_experiments: int, 45 | batch_size: int, num_initial_obs_samples: int, num_workers: int, job_id: str, abci_model: str): 46 | os.environ['MASTER_ADDR'] = 'localhost' 47 | os.environ['MASTER_PORT'] = master_port 48 | torch.set_num_interop_threads(1) 49 | torch.set_num_threads(1) 50 | 51 | if rank == 0: 52 | rpc.init_rpc('Experimenter', 53 | rank=rank, 54 | world_size=num_workers, 55 | rpc_backend_options=rpc.TensorPipeRpcBackendOptions(num_worker_threads=num_workers, rpc_timeout=0)) 56 | try: 57 | abci = spawn_abci_model(abci_model, env, policy, num_workers) 58 | abci.run(num_experiments, batch_size, num_initial_obs_samples=num_initial_obs_samples, outdir=output_dir, 59 | job_id=job_id) 60 | except Exception as e: 61 | print(e) 62 | else: 63 | outpath = f'{output_dir}{abci_model}-{policy}-{env.name}-{job_id}-exp-{num_experiments}.pth' 64 | abci.save(outpath) 65 | else: 66 | rpc.init_rpc(f'ExperimentDesigner{rank}', 67 | rank=rank, 68 | world_size=num_workers, 69 | rpc_backend_options=rpc.TensorPipeRpcBackendOptions(num_worker_threads=num_workers, rpc_timeout=0)) 70 | 71 | rpc.shutdown() 72 | 73 | 74 | def run_single_env(env_file: str, output_dir: str, policy: str, num_experiments: int, batch_size: int, 75 | num_initial_obs_samples: int, num_workers: int, abci_model: str): 76 | print(torch.__config__.parallel_info()) 77 | mp.set_sharing_strategy('file_system') 78 | 79 | # load env 80 | env = Environment.load(env_file) 81 | 82 | assert abci_model in MODELS, print(f'Invalid ABCI model {abci_model}!') 83 | job_id = generate_job_id() 84 | 85 | # run benchmark env 86 | print('\n-------------------------------------------------------------------------------------------') 87 | print(f'--------- Running {abci_model.upper()} ({policy}) on Environment {env.name} w/ job ID {job_id} ----------') 88 | print(f'--------- Number of Experiments: {num_experiments} ----------') 89 | print(f'--------- Batch Size: {batch_size} ----------') 90 | print(f'--------- Number of Initial Observational Samples: {num_initial_obs_samples} ----------') 91 | print(f'--------- Starting time: {time.strftime("%H:%M:%S")} ----------') 92 | print('-------------------------------------------------------------------------------------------\n') 93 | 94 | if num_workers > 1: 95 | master_port = str(get_free_port()) 96 | print(f'Starting {num_workers} workers on port ' + master_port) 97 | try: 98 | mp.spawn( 99 | run_worker, 100 | args=(env, master_port, output_dir, policy, num_experiments, batch_size, num_initial_obs_samples, 101 | num_workers, job_id, abci_model), 102 | nprocs=num_workers, 103 | join=True 104 | ) 105 | except Exception as e: 106 | print(e) 107 | else: 108 | try: 109 | abci = spawn_abci_model(abci_model, env, policy, num_workers=1) 110 | abci.run(num_experiments, batch_size, num_initial_obs_samples=num_initial_obs_samples, outdir=output_dir, 111 | job_id=job_id) 112 | except Exception as e: 113 | print(e) 114 | else: 115 | outpath = f'{output_dir}{abci_model}-{policy}-{env.name}-{job_id}-exp-{num_experiments}.pth' 116 | abci.save(outpath) 117 | 118 | 119 | # parse arguments when run from shell 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser('ABCI usage on single environment:') 122 | parser.add_argument('env_file', type=str, help=f'Path to environment file.') 123 | parser.add_argument('output_dir', type=str, help='Output directory.') 124 | parser.add_argument('policy', type=str, help=f'ABCI policy.') 125 | parser.add_argument('--num_experiments', type=int, default=50, help='Number of experiments per environment.') 126 | parser.add_argument('--batch_size', type=int, default=1, help='Number of samples drawn in each experiment.') 127 | parser.add_argument('--num_initial_obs_samples', type=int, default=0, 128 | help='Number of initial observational samples drawn before policy becomes active.') 129 | parser.add_argument('--num_workers', type=int, default=1, help='Number of worker threads per environment.') 130 | parser.add_argument('--model', default='abci-dibs-gp', type=str, choices=MODELS, help=f'Available models: {MODELS}') 131 | 132 | args = vars(parser.parse_args()) 133 | run_single_env(args['env_file'], args['output_dir'], args['policy'], args['num_experiments'], args['batch_size'], 134 | args['num_initial_obs_samples'], args['num_workers'], args['model']) 135 | -------------------------------------------------------------------------------- /src/abci_base.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch.distributed.rpc as rpc 4 | from torch.distributed.rpc import RRef, remote 5 | 6 | from src.environments.environment import * 7 | from src.experimental_design.exp_designer_base import Design 8 | 9 | 10 | class ABCIBase: 11 | def __init__(self, env: Environment, policy='observational', num_workers: int = 1): 12 | self.env = env 13 | self.policy = policy 14 | 15 | # init distributed experiment design 16 | self.num_workers = num_workers 17 | self.designed_experiments = {} 18 | self.open_targets = set() 19 | if num_workers > 1: 20 | self.worker_id = rpc.get_worker_info().id 21 | if self.worker_id == 0: 22 | self.experimenter_rref = RRef(self) 23 | self.designer_rrefs = [] 24 | for worker_id in range(1, num_workers): 25 | info = rpc.get_worker_info(f'ExperimentDesigner{worker_id}') 26 | self.designer_rrefs.append(remote(info, self.experiment_designer_factory)) 27 | 28 | # init lists 29 | self.experiments = [] 30 | self.loss_list = [] 31 | self.eshd_list = [] 32 | self.graph_ll_list = [] 33 | self.info_gain_list = [] 34 | self.graph_entropy_list = [] 35 | self.auroc_list = [] 36 | self.auprc_list = [] 37 | self.observational_test_ll_list = [] 38 | self.interventional_test_ll_lists = {node: [] for node in self.env.node_labels} 39 | self.observational_kld_list = [] 40 | self.interventional_kld_lists = {node: [] for node in self.env.node_labels} 41 | self.query_kld_list = [] 42 | 43 | def experiment_designer_factory(self): 44 | raise NotImplementedError 45 | 46 | def run(self, num_experiments=10, batch_size=1, update_interval=5, log_interval=5, num_initial_obs_samples=1): 47 | raise NotImplementedError 48 | 49 | def get_random_intervention(self, fixed_value: float = None): 50 | target_node = random.choice(list(self.env.intervenable_nodes) + ['OBSERVATIONAL']) 51 | if target_node == 'OBSERVATIONAL': 52 | return {} 53 | if fixed_value is None: 54 | bounds = self.env.intervention_bounds[target_node] 55 | target_value = torch.rand(1) * (bounds[1] - bounds[0]) + bounds[0] 56 | else: 57 | target_value = torch.tensor(fixed_value) 58 | return {target_node: target_value} 59 | 60 | def report_design(self, worker_id: int, design_key: str, design: Design): 61 | print(f'Worker {worker_id} designed {design.interventions} with info gain {design.info_gain}', flush=True) 62 | self.designed_experiments[design_key] = design 63 | 64 | def report_status(self, worker_id: int, message: str): 65 | print(f'Worker {worker_id} reports at {time.strftime("%H:%M:%S")}: {message}', flush=True) 66 | 67 | def get_target(self, worker_id: int): 68 | target = self.open_targets.pop() if self.open_targets else None 69 | print(f'Worker {worker_id} asks for new target at {time.strftime("%H:%M:%S")}. Assigning new target {target}.', 70 | flush=True) 71 | return target 72 | 73 | def design_experiment_distributed(self, args): 74 | self.open_targets = self.env.intervenable_nodes | {'OBSERVATIONAL'} 75 | self.designed_experiments.clear() 76 | 77 | # start workers 78 | futs = [] 79 | for designer_rref in self.designer_rrefs: 80 | futs.append(designer_rref.rpc_async().run_distributed(self.experimenter_rref, args)) 81 | 82 | # init and run designer of master process 83 | designer = self.experiment_designer_factory() 84 | designer.init_design_process(args) 85 | 86 | target_node = self.get_target(self.worker_id) 87 | while target_node: 88 | design = designer.design_experiment(target_node) 89 | self.report_design(self.worker_id, target_node, design) 90 | target_node = self.get_target(self.worker_id) 91 | 92 | # wait until all workers have finished 93 | for fut in futs: 94 | fut.wait() 95 | 96 | # pick most promising experiment 97 | print('Experiment design process has finished:') 98 | best_intervention = {} 99 | best_info_gain = self.designed_experiments['OBSERVATIONAL'].info_gain 100 | for _, design in self.designed_experiments.items(): 101 | print(f'Interventions {design.interventions} expect info gain {design.info_gain}', flush=True) 102 | if design.info_gain > best_info_gain: 103 | best_info_gain = design.info_gain 104 | best_intervention = design.interventions 105 | 106 | return best_intervention, best_info_gain 107 | 108 | def param_dict(self): 109 | env_param_dict = self.env.param_dict() 110 | params = {'env_param_dict': env_param_dict, 111 | 'policy': self.policy, 112 | 'num_workers': self.num_workers, 113 | 'experiments': self.experiments, 114 | 'loss_list': self.loss_list, 115 | 'eshd_list': self.eshd_list, 116 | 'graph_ll_list': self.graph_ll_list, 117 | 'info_gain_list': self.info_gain_list, 118 | 'graph_entropy_list': self.graph_entropy_list, 119 | 'auroc_list': self.auroc_list, 120 | 'auprc_list': self.auprc_list, 121 | 'observational_test_ll_list': self.observational_test_ll_list, 122 | 'interventional_test_ll_lists': self.interventional_test_ll_lists, 123 | 'observational_kld_list': self.observational_kld_list, 124 | 'interventional_kld_lists': self.interventional_kld_lists, 125 | 'query_kld_list': self.query_kld_list} 126 | return params 127 | 128 | def load_param_dict(self, param_dict): 129 | self.policy = param_dict['policy'] 130 | self.num_workers = param_dict['num_workers'] 131 | self.experiments = param_dict['experiments'] 132 | self.loss_list = param_dict['loss_list'] 133 | self.eshd_list = param_dict['eshd_list'] 134 | self.graph_ll_list = param_dict['graph_ll_list'] 135 | self.info_gain_list = param_dict['info_gain_list'] 136 | self.graph_entropy_list = param_dict['graph_entropy_list'] 137 | self.auroc_list = param_dict['auroc_list'] 138 | self.auprc_list = param_dict['auprc_list'] 139 | self.observational_test_ll_list = param_dict['observational_test_ll_list'] 140 | self.interventional_test_ll_lists = param_dict['interventional_test_ll_lists'] 141 | self.observational_kld_list = param_dict['observational_kld_list'] 142 | self.interventional_kld_lists = param_dict['interventional_kld_lists'] 143 | self.query_kld_list = param_dict['query_kld_list'] 144 | -------------------------------------------------------------------------------- /src/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import List, Optional 4 | 5 | import matplotlib.pyplot as plt 6 | import scipy.stats as sst 7 | import torch 8 | 9 | 10 | def init_plot_style(): 11 | """Initialize the plot style for pyplot. 12 | """ 13 | plt.rcParams.update({'figure.figsize': (12, 9)}) 14 | plt.rcParams.update({'lines.linewidth': 2}) 15 | plt.rcParams.update({'lines.markersize': 25}) 16 | plt.rcParams.update({'lines.markeredgewidth': 2}) 17 | plt.rcParams.update({'axes.labelpad': 10}) 18 | plt.rcParams.update({'xtick.major.width': 2.5}) 19 | plt.rcParams.update({'xtick.major.size': 15}) 20 | plt.rcParams.update({'xtick.minor.size': 10}) 21 | plt.rcParams.update({'ytick.major.width': 2.5}) 22 | plt.rcParams.update({'ytick.minor.width': 2.5}) 23 | plt.rcParams.update({'ytick.major.size': 15}) 24 | plt.rcParams.update({'ytick.minor.size': 15}) 25 | 26 | # for font settings see also https://stackoverflow.com/questions/2537868/sans-serif-math-with-latex-in-matplotlib 27 | plt.rcParams.update({'font.size': 50}) 28 | plt.rcParams.update({'font.family': 'sans-serif'}) 29 | plt.rcParams.update({'text.usetex': True}) 30 | plt.rcParams['text.latex.preamble'] = '\n'.join([ 31 | r'\usepackage{amsmath,amssymb,amsfonts,amsthm}', 32 | r'\usepackage{siunitx}', # i need upright \micro symbols, but you need... 33 | r'\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts 34 | r'\usepackage{helvet}', # set the normal font here 35 | r'\usepackage{sansmath}', # load up the sansmath so that math -> helvet 36 | r'\sansmath' # <- tricky! -- gotta actually tell tex to use! 37 | ]) 38 | 39 | 40 | def parse_file_name(filename: str): 41 | tokens = filename.split('-') 42 | if tokens[-2] == 'exp': 43 | job_id = tokens[-3] 44 | env_id = tokens[-4] 45 | exp_num = int(tokens[-1][:-4]) 46 | else: 47 | job_id = tokens[-1][:-4] 48 | env_id = tokens[-2] 49 | exp_num = 0 50 | return env_id, job_id, exp_num 51 | 52 | 53 | class Simulation: 54 | def __init__(self, env_name: str, num_nodes: int, timestamp: str, abci_model: str, policy: str, 55 | num_experiments: int, 56 | plot_kwargs: Optional[dict] = None): 57 | self.env_name = env_name 58 | self.num_nodes = num_nodes 59 | self.timestamp = timestamp 60 | self.abci_model = abci_model 61 | self.policy = policy 62 | self.num_experiments = num_experiments 63 | self.stats = None 64 | self.plot_kwargs = dict() if plot_kwargs is None else plot_kwargs 65 | 66 | def get_result_files(self, base_dir: str = '../results/'): 67 | results_dir = base_dir + f'{self.env_name}/{self.num_nodes}_nodes/' \ 68 | f'{self.timestamp}_{self.abci_model}_{self.policy}/' 69 | abci_files = [entry for entry in os.scandir(results_dir) if 70 | entry.is_file() and os.path.basename(entry)[-4:] == '.pth'] 71 | 72 | common_env_files = dict() 73 | for f in abci_files: 74 | env_id, job_id, exp_num = parse_file_name(os.path.basename(f)) 75 | if exp_num != self.num_experiments: 76 | continue 77 | 78 | if env_id in common_env_files: 79 | common_env_files[env_id].append(os.path.abspath(f)) 80 | else: 81 | common_env_files[env_id] = [os.path.abspath(f)] 82 | 83 | return common_env_files 84 | 85 | def load_results(self, stats_names: List[str], base_dir: str = '../results/'): 86 | common_env_files = self.get_result_files(base_dir) 87 | num_environments = len(common_env_files) 88 | print(f'Loading results from {num_environments} environments.') 89 | 90 | stats_lists = dict() 91 | for env_id, env_files in common_env_files.items(): 92 | print(f'Got {len(env_files)} runs for environment {env_id}.') 93 | env_wise_stat_lists = dict() 94 | for abci_file in env_files: 95 | param_dict = torch.load(abci_file) 96 | 97 | for stat_name in stats_names: 98 | if stat_name in {'interventional_test_ll', 'interventional_kld'}: 99 | stat_token = stat_name + '_lists' 100 | else: 101 | stat_token = stat_name + '_list' 102 | 103 | if stat_token not in param_dict: 104 | print(f'No results for {stat_token} available in {abci_file}!') 105 | continue 106 | 107 | if stat_name == 'interventional_test_ll': 108 | data = -torch.tensor(list(param_dict[stat_token].values())).sum(dim=0) 109 | elif stat_name == 'interventional_kld': 110 | data = torch.tensor(list(param_dict[stat_token].values())).mean(dim=0) 111 | else: 112 | data = torch.tensor(param_dict[stat_token]) 113 | if stat_name in {'graph_ll', 'observational_test_ll'}: 114 | data = -data 115 | 116 | if stat_name in env_wise_stat_lists: 117 | env_wise_stat_lists[stat_name].append(data) 118 | else: 119 | env_wise_stat_lists[stat_name] = [data] 120 | 121 | # aggregate env-wise results 122 | for stat_name in env_wise_stat_lists: 123 | with torch.no_grad(): 124 | data = torch.stack(env_wise_stat_lists[stat_name], dim=0) 125 | 126 | reduce = lambda x: x 127 | # reduce = lambda x: x.mean(dim=0, keepdims=True) 128 | 129 | if stat_name in stats_lists: 130 | stats_lists[stat_name].append(reduce(data)) 131 | else: 132 | stats_lists[stat_name] = [reduce(data)] 133 | 134 | self.stats = {stat_name: torch.cat(stat_list, dim=0) for stat_name, stat_list in stats_lists.items()} 135 | return self.stats 136 | 137 | def plot_simulation_data(self, ax, stat_name: str): 138 | if self.stats is None: 139 | print('Nothing to plot...') 140 | return 141 | 142 | data = self.stats[stat_name] 143 | num_envs, num_exps = data.shape 144 | exp_numbers = torch.arange(1, num_exps + 1) 145 | 146 | # compute 95% CIs 147 | mean = data.mean(dim=0) 148 | std_err = data.std(unbiased=True, dim=0) / math.sqrt(num_envs) + 1e-8 149 | lower, upper = sst.t.interval(.95, df=num_envs - 1, loc=mean, scale=std_err) 150 | 151 | ax.plot(exp_numbers, mean.detach(), **self.plot_kwargs) 152 | # if self.policy not in {'observational', 'random-fixed-value'}: 153 | ax.fill_between(exp_numbers, upper, lower, alpha=0.2, color=self.plot_kwargs['c']) 154 | -------------------------------------------------------------------------------- /notebooks/generate_benchmark_envs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generate and store benchmark environments\n", 8 | "\n", 9 | "In this notebook can be used to generate benchmark environments." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "outputs": [], 16 | "source": [ 17 | "%load_ext autoreload\n", 18 | "%autoreload 2\n", 19 | "import os\n", 20 | "import shutil\n", 21 | "\n", 22 | "import torch.distributions as dist\n", 23 | "\n", 24 | "from src.environments.generic_environments import *" 25 | ], 26 | "metadata": { 27 | "collapsed": false, 28 | "pycharm": { 29 | "name": "#%% imports\n" 30 | } 31 | } 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "source": [ 36 | "Generate the environments." 37 | ], 38 | "metadata": { 39 | "collapsed": false, 40 | "pycharm": { 41 | "name": "#%% md\n" 42 | } 43 | } 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "outputs": [], 49 | "source": [ 50 | "# benchmark parameters\n", 51 | "env_class = CRGraph\n", 52 | "num_envs = 5\n", 53 | "n_list = [5]\n", 54 | "frac_non_intervenable_nodes= None\n", 55 | "non_intervenable_nodes= set()\n", 56 | "# non_intervenable_nodes = set([\"X2\"])\n", 57 | "num_test_samples_per_intervention = 50\n", 58 | "num_test_queries = 50\n", 59 | "# interventional_queries = None\n", 60 | "interventional_queries = [InterventionalDistributionsQuery(['X4'], {'X2':dist.Uniform(2., 5.)})]\n", 61 | "\n", 62 | "descriptor = ''\n", 63 | "# descriptor = '_X2'\n", 64 | "\n", 65 | "# generation setup\n", 66 | "env_dir = '../data/' + env_class.__name__ + descriptor + '/' # dir where to store the generated envs\n", 67 | "delete_existing = False # delete existing benchmarks\n", 68 | "\n", 69 | "# generating the benchmark envs from here on\n", 70 | "i = 0\n", 71 | "total_graphs = num_envs * len(n_list)\n", 72 | "for num_nodes in n_list:\n", 73 | " # generate/empty folder for envs of same type\n", 74 | " n_dir = env_dir + f'{num_nodes}_nodes/'\n", 75 | " if os.path.isdir(n_dir):\n", 76 | " if not delete_existing:\n", 77 | " print('\\nDirectory \\'' + n_dir + '\\' already exists, not generating benchmarks...')\n", 78 | " continue\n", 79 | "\n", 80 | " print('\\nDirectory \\'' + n_dir + '\\' already exists, delete existing benchmarks...')\n", 81 | " for root, dirs, files in os.walk(n_dir):\n", 82 | " for file in files:\n", 83 | " os.remove(os.path.join(root, file))\n", 84 | " for folder in dirs:\n", 85 | " shutil.rmtree(os.path.join(root, folder))\n", 86 | "\n", 87 | " os.makedirs(n_dir, exist_ok=True)\n", 88 | "\n", 89 | " # generate benchmark envs\n", 90 | " for _ in range(num_envs):\n", 91 | " i = i + 1\n", 92 | " env = env_class(num_nodes=num_nodes,\n", 93 | " frac_non_intervenable_nodes=frac_non_intervenable_nodes,\n", 94 | " non_intervenable_nodes=non_intervenable_nodes,\n", 95 | " num_test_samples_per_intervention=num_test_samples_per_intervention,\n", 96 | " num_test_queries=num_test_queries,\n", 97 | " interventional_queries=interventional_queries)\n", 98 | " env_path = n_dir + env.name + '.pth'\n", 99 | " env.save(env_path)\n", 100 | " print(f'\\rGenerated {i}/{total_graphs} environments.', end='')\n" 101 | ], 102 | "metadata": { 103 | "collapsed": false, 104 | "pycharm": { 105 | "name": "#%%\n" 106 | } 107 | } 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "source": [ 112 | "Take existing environments, restrict their set of intervenable nodes and store them seperately." 113 | ], 114 | "metadata": { 115 | "collapsed": false, 116 | "pycharm": { 117 | "name": "#%% md\n" 118 | } 119 | } 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "outputs": [], 125 | "source": [ 126 | "# benchmark parameters\n", 127 | "env_class = CRGraph\n", 128 | "n_list = [5]\n", 129 | "# non_intervenable_nodes= set()\n", 130 | "non_intervenable_nodes = set([\"X2\"])\n", 131 | "\n", 132 | "# descriptor = ''\n", 133 | "descriptor = '_X2'\n", 134 | "\n", 135 | "# generation setup\n", 136 | "source_env_dir = '../data/' + env_class.__name__ + '/' # dir where origianl envs are stored\n", 137 | "target_env_dir = '../data/' + env_class.__name__ + descriptor + '/' # dir where to store the generated envs\n", 138 | "delete_existing = False # delete existing benchmarks\n", 139 | "\n", 140 | "i = 0\n", 141 | "for num_nodes in n_list:\n", 142 | " # check if source envs available\n", 143 | " source_n_dir = source_env_dir + f'{num_nodes}_nodes/'\n", 144 | " if not os.path.isdir(source_n_dir):\n", 145 | " print(f'Source directory {source_n_dir} does not exist!')\n", 146 | " continue\n", 147 | "\n", 148 | " # generate/empty folder for target envs\n", 149 | " target_n_dir = target_env_dir + f'{num_nodes}_nodes/'\n", 150 | " if os.path.isdir(target_n_dir):\n", 151 | " if not delete_existing:\n", 152 | " print('\\nTarget directory \\'' + target_n_dir + '\\' already exists, not generating benchmarks...')\n", 153 | " continue\n", 154 | "\n", 155 | " print('\\nTarget directory \\'' + target_n_dir + '\\' already exists, delete existing benchmarks...')\n", 156 | " for root, dirs, files in os.walk(target_n_dir):\n", 157 | " for file in files:\n", 158 | " os.remove(os.path.join(root, file))\n", 159 | " for folder in dirs:\n", 160 | " shutil.rmtree(os.path.join(root, folder))\n", 161 | "\n", 162 | " os.makedirs(target_n_dir, exist_ok=True)\n", 163 | "\n", 164 | " # load source envs\n", 165 | " env_files = [entry for entry in os.scandir(source_n_dir) if entry.is_file() and os.path.basename(entry)[-4:] == '.pth']\n", 166 | " for i, f in enumerate(env_files):\n", 167 | " env = env_class.load(os.path.abspath(f))\n", 168 | " env.non_intervenable_nodes = non_intervenable_nodes\n", 169 | " env.intervenable_nodes = set(env.node_labels) - env.non_intervenable_nodes\n", 170 | " env.save(target_n_dir + os.path.basename(f))\n", 171 | "\n", 172 | " print(f'\\rProcessed {i+1}/{len(env_files)} environments in {source_n_dir}.', end='')\n", 173 | " print('')\n" 174 | ], 175 | "metadata": { 176 | "collapsed": false, 177 | "pycharm": { 178 | "name": "#%%\n" 179 | } 180 | } 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.8" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 1 204 | } -------------------------------------------------------------------------------- /src/experimental_design/exp_designer_abci_categorical_gp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, List, Optional 2 | 3 | import networkx as nx 4 | import torch.optim 5 | 6 | from src.environments.environment import Experiment, InterventionalDistributionsQuery 7 | from src.experimental_design.exp_designer_base import ExpDesignerBase 8 | from src.models.gp_model import GaussianProcessModel 9 | from src.models.graph_models import CategoricalModel 10 | 11 | 12 | class ExpDesignerABCICategoricalGP(ExpDesignerBase): 13 | model: Optional[GaussianProcessModel] 14 | graph_posterior: Optional[CategoricalModel] 15 | 16 | def __init__(self, intervention_bounds: Dict[str, Tuple[float, float]], opt_strategy: str = 'gp-ucb', 17 | distributed=False): 18 | super().__init__(intervention_bounds, opt_strategy, distributed) 19 | self.model = None 20 | self.graph_posterior = None 21 | 22 | def init_design_process(self, args: dict): 23 | self.model = args['mechanism_model'] 24 | self.model.entropy_cache.clear() 25 | self.graph_posterior = args['graph_posterior'] 26 | 27 | if args['policy'] == 'scm-info-gain': 28 | def utility(interventions: dict): 29 | return self.scm_info_gain(interventions, args['batch_size'], args['num_exp_per_graph'], args['mode']) 30 | elif args['policy'] == 'graph-info-gain': 31 | def utility(interventions: dict): 32 | return self.graph_info_gain(interventions, args['batch_size'], args['num_exp_per_graph'], args['mode']) 33 | elif args['policy'] == 'intervention-info-gain': 34 | def utility(interventions: dict): 35 | return self.intervention_info_gain(interventions, 36 | args['experiments'], 37 | args['interventional_queries'], 38 | args['outer_mc_graphs'], 39 | args['outer_log_weights'], 40 | args['num_mc_queries'], 41 | args['num_batches_per_query'], 42 | args['batch_size'], 43 | args['num_exp_per_graph'], 44 | args['mode']) 45 | else: 46 | assert False, 'Invalid policy ' + args['policy'] + '!' 47 | self.utility = utility 48 | 49 | def compute_graph_posterior_mlls(self, experiments: List[Experiment], graphs: List[nx.DiGraph], 50 | mode='independent_batches', reduce=True): 51 | graph_mlls = [self.model.mll(experiments, graph, prior_mode=False, use_cache=True, mode=mode, reduce=reduce) for 52 | graph in graphs] 53 | graph_mlls = torch.stack(graph_mlls) 54 | return graph_mlls 55 | 56 | def graph_info_gain(self, interventions: dict, batch_size: int = 1, num_exp_per_graph: int = 1, mode='n-best', 57 | num_mc_graphs=20): 58 | with torch.no_grad(): 59 | graphs, log_weights = self.graph_posterior.get_mc_graphs(mode, num_mc_graphs) 60 | 61 | graph_info_gains = torch.zeros(len(graphs)) 62 | for graph_idx, graph in enumerate(graphs): 63 | simulated_experiments = [self.model.sample(interventions, batch_size, num_exp_per_graph, graph)] 64 | 65 | self.model.clear_posterior_mll_cache() 66 | posterior_mlls = self.compute_graph_posterior_mlls(simulated_experiments, graphs, reduce=False) 67 | data_mll = (posterior_mlls + log_weights.unsqueeze(-1)).logsumexp(dim=0) 68 | graph_info_gains[graph_idx] += (posterior_mlls[graph_idx] - data_mll).mean() 69 | 70 | return log_weights.exp() @ graph_info_gains 71 | 72 | def scm_info_gain(self, interventions: dict, batch_size: int = 1, num_exp_per_graph: int = 1, mode='n-best', 73 | num_mc_graphs=20): 74 | with torch.no_grad(): 75 | graphs, log_weights = self.graph_posterior.get_mc_graphs(mode, num_mc_graphs) 76 | 77 | graph_info_gains = torch.zeros(len(graphs)) 78 | for graph_idx, graph in enumerate(graphs): 79 | simulated_experiments = [self.model.sample(interventions, batch_size, num_exp_per_graph, graph)] 80 | 81 | self.model.clear_posterior_mll_cache() 82 | posterior_mlls = self.compute_graph_posterior_mlls(simulated_experiments, graphs, reduce=False) 83 | data_mll = (posterior_mlls + log_weights.unsqueeze(-1)).logsumexp(dim=0) 84 | entropy = self.model.expected_noise_entropy(interventions, graph, use_cache=True) 85 | graph_info_gains[graph_idx] += -entropy * batch_size - data_mll.mean() 86 | 87 | return log_weights.exp() @ graph_info_gains 88 | 89 | def intervention_info_gain(self, 90 | interventions: dict, 91 | experiments: List[Experiment], 92 | interventional_queries: List[InterventionalDistributionsQuery], 93 | outer_mc_graphs: List[nx.DiGraph], 94 | outer_log_weights: torch.Tensor, 95 | num_mc_queries: int = 1, 96 | num_batches_per_query: int = 1, 97 | batch_size: int = 1, 98 | num_exp_per_graph: int = 1, 99 | mode='n-best', num_inner_mc_graphs=20): 100 | with torch.no_grad(): 101 | # get mc graphs and compute log weights 102 | inner_mc_graphs, inner_log_weights = self.graph_posterior.get_mc_graphs(mode, num_inner_mc_graphs) 103 | num_outer_graphs = len(outer_mc_graphs) 104 | num_inner_graphs = len(inner_mc_graphs) 105 | 106 | graph_info_gains = torch.zeros(num_outer_graphs) 107 | for graph_idx, graph in enumerate(outer_mc_graphs): 108 | # simulate experiments and compute data mll 109 | self.model.set_data(experiments) 110 | self.model.clear_posterior_mll_cache() 111 | simulated_experiments = [self.model.sample(interventions, batch_size, num_exp_per_graph, graph)] 112 | posterior_mlls = self.compute_graph_posterior_mlls(simulated_experiments, inner_mc_graphs, reduce=False) 113 | assert posterior_mlls.shape == (num_inner_graphs, num_exp_per_graph) 114 | data_mll = (posterior_mlls + inner_log_weights.unsqueeze(-1)).logsumexp(dim=0).mean() 115 | 116 | # simulate queries 117 | self.model.set_data(experiments + simulated_experiments) 118 | self.model.clear_posterior_mll_cache() 119 | 120 | # simulate queries & compute query lls 121 | mc_queries = self.model.sample_queries(interventional_queries, num_mc_queries, 122 | num_batches_per_query, graph) 123 | 124 | query_lls = torch.stack([self.model.query_log_probs(mc_queries, g) for g in inner_mc_graphs]) 125 | num_mc_queries = len(mc_queries[0].sample_queries) 126 | num_batches_per_query = mc_queries[0].sample_queries[0].num_batches 127 | assert query_lls.shape == (num_inner_graphs, num_mc_queries, num_batches_per_query) 128 | 129 | # compute and update per graph info gain 130 | tmp = posterior_mlls.view(num_inner_graphs, 1, 1, num_exp_per_graph) + query_lls.unsqueeze(-1) + \ 131 | inner_log_weights.view(num_inner_graphs, 1, 1, 1) 132 | graph_info_gains[graph_idx] += tmp.logsumexp(dim=0).mean() 133 | graph_info_gains[graph_idx] = graph_info_gains[graph_idx] - data_mll 134 | 135 | return outer_log_weights.exp() @ graph_info_gains 136 | -------------------------------------------------------------------------------- /notebooks/example_abci_dibs_gp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# Example usage of ABCI-DiBS-GP\n", 13 | "\n", 14 | "This notebook illustrates the example usage of ABCI using DiBS for approximate graph posterior inference\n", 15 | "and a GP mechanism model." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "outputs": [], 22 | "source": [ 23 | "# imports\n", 24 | "%reload_ext autoreload\n", 25 | "%autoreload 2\n", 26 | "\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import torch.distributions as dist\n", 29 | "from matplotlib.ticker import MaxNLocator\n", 30 | "\n", 31 | "from src.abci_dibs_gp import ABCIDiBSGP as ABCI\n", 32 | "from src.environments.generic_environments import *\n", 33 | "from src.models.gp_model import *" 34 | ], 35 | "metadata": { 36 | "collapsed": false, 37 | "pycharm": { 38 | "name": "#%%\n" 39 | } 40 | } 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "source": [ 45 | "First, we generate a ground truth environment/SCM.\n" 46 | ], 47 | "metadata": { 48 | "collapsed": false 49 | } 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "outputs": [], 55 | "source": [ 56 | "# specify the number of nodes and (optionally) a query of interventional variables\n", 57 | "num_nodes = 5\n", 58 | "interventional_queries = None\n", 59 | "# interventional_queries = [InterventionalDistributionsQuery(['X2'], {'X1': dist.Uniform(2., 5.)})]\n", 60 | "\n", 61 | "# generate the ground truth environment\n", 62 | "env = BarabasiAlbert(num_nodes,\n", 63 | " num_test_queries=50,\n", 64 | " interventional_queries=interventional_queries)\n", 65 | "\n", 66 | "# plot true graph\n", 67 | "nx.draw(env.graph, nx.circular_layout(env.graph), labels=dict(zip(env.graph.nodes, env.graph.nodes)))" 68 | ], 69 | "metadata": { 70 | "collapsed": false, 71 | "pycharm": { 72 | "name": "#%% init environment\n" 73 | } 74 | } 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "source": [ 79 | "Here, we create an ABCI instance with the desired experimental design policy." 80 | ], 81 | "metadata": { 82 | "collapsed": false 83 | } 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "outputs": [], 89 | "source": [ 90 | "policy = 'graph-info-gain'\n", 91 | "abci = ABCI(env, policy, num_particles=5, num_mc_graphs=40, num_workers=1, dibs_plus=True, linear=False)" 92 | ], 93 | "metadata": { 94 | "collapsed": false, 95 | "pycharm": { 96 | "name": "#%%\n" 97 | } 98 | } 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "source": [ 103 | "We can now run a number of ABCI loops." 104 | ], 105 | "metadata": { 106 | "collapsed": false 107 | } 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "outputs": [], 113 | "source": [ 114 | "num_experiments = 2\n", 115 | "batch_size = 3\n", 116 | "\n", 117 | "abci.run(num_experiments, batch_size, num_initial_obs_samples=3)" 118 | ], 119 | "metadata": { 120 | "collapsed": false, 121 | "pycharm": { 122 | "name": "#%%\n" 123 | } 124 | } 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "source": [ 129 | "Here, we plot the training stats and results." 130 | ], 131 | "metadata": { 132 | "collapsed": false, 133 | "pycharm": { 134 | "name": "#%% md\n" 135 | } 136 | } 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "outputs": [], 142 | "source": [ 143 | "print(f'Number of observational batches: {len([e for e in abci.experiments if e.interventions == {}])}')\n", 144 | "for node in env.node_labels:\n", 145 | " print(\n", 146 | " f'Number of interventional batches on {node}: {len([e for e in abci.experiments if node in e.interventions])}')\n", 147 | "\n", 148 | "# plot expected SHD over experiments\n", 149 | "ax = plt.figure().gca()\n", 150 | "plt.plot(abci.eshd_list)\n", 151 | "plt.xlabel('Number of Experiments')\n", 152 | "plt.ylabel('Expected SHD')\n", 153 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 154 | "\n", 155 | "# plot auroc over experiments\n", 156 | "ax = plt.figure().gca()\n", 157 | "plt.plot(abci.auroc_list)\n", 158 | "plt.xlabel('Number of Experiments')\n", 159 | "plt.ylabel('AUROC')\n", 160 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 161 | "\n", 162 | "# plot auprc over experiments\n", 163 | "ax = plt.figure().gca()\n", 164 | "plt.plot(abci.auprc_list)\n", 165 | "plt.xlabel('Number of Experiments')\n", 166 | "plt.ylabel('AUPRC')\n", 167 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 168 | "\n", 169 | "# plot Query KLD over experiments\n", 170 | "ax = plt.figure().gca()\n", 171 | "plt.plot(abci.query_kld_list)\n", 172 | "plt.xlabel('Number of Experiments')\n", 173 | "plt.ylabel('Query KLD')\n", 174 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 175 | "plt.tight_layout()" 176 | ], 177 | "metadata": { 178 | "collapsed": false, 179 | "pycharm": { 180 | "name": "#%% print training stats\n" 181 | } 182 | } 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "source": [ 187 | "Finally, we can have a look at the learned vs. true mechanisms." 188 | ], 189 | "metadata": { 190 | "collapsed": false 191 | } 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "outputs": [], 197 | "source": [ 198 | "# plot X_i -> X_j true vs. predicted\n", 199 | "i = 0\n", 200 | "j = 1\n", 201 | "xdata, ydata = gather_data(abci.experiments, f'X{j}', parents=[f'X{i}'])\n", 202 | "xrange = torch.linspace(-7., 7., 100).unsqueeze(-1)\n", 203 | "ytrue = env.mechanisms[f'X{j}'](xrange).detach()\n", 204 | "mech = abci.mechanism_model.get_mechanism(f'X{j}', parents=[f'X{i}'])\n", 205 | "mech.set_data(xdata, ydata)\n", 206 | "ypred = mech(xrange).detach()\n", 207 | "\n", 208 | "plt.figure()\n", 209 | "plt.plot(xdata, ydata, 'rx', label='Experimental Data')\n", 210 | "plt.plot(xrange, ytrue, label='X->Y true')\n", 211 | "plt.plot(xrange, ypred, label='X->Y prediction')\n", 212 | "plt.xlabel(f'X{i}')\n", 213 | "plt.ylabel(f'X{j}')\n", 214 | "plt.legend()\n", 215 | "plt.tight_layout()" 216 | ], 217 | "metadata": { 218 | "collapsed": false, 219 | "pycharm": { 220 | "name": "#%%\n" 221 | } 222 | } 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "outputs": [], 228 | "source": [ 229 | "# plot bivariate mechanisms\n", 230 | "node = 'X2'\n", 231 | "num_points = 100\n", 232 | "xrange = torch.linspace(-7., 7., num_points)\n", 233 | "yrange = torch.linspace(-7., 7., num_points)\n", 234 | "xgrid, ygrid = torch.meshgrid(xrange, yrange)\n", 235 | "inputs = torch.stack((xgrid, ygrid), dim=2).view(-1, 2)\n", 236 | "ztrue = env.mechanisms[node](inputs).detach().view(num_points, num_points).numpy()\n", 237 | "\n", 238 | "parents = ['X0', 'X1']\n", 239 | "mech = abci.mechanism_model.get_mechanism(node, parents=parents)\n", 240 | "sample_inputs, sample_targets = gather_data(abci.experiments, node, parents=parents)\n", 241 | "mech.set_data(sample_inputs, sample_targets)\n", 242 | "zpred = mech(inputs)\n", 243 | "zpred = zpred.detach().view(num_points, num_points).numpy()\n", 244 | "\n", 245 | "zmin = ztrue.min().item()\n", 246 | "zmax = ztrue.max().item()\n", 247 | "print(f'Function values for {node} in range [{zmin, zmax}].')\n", 248 | "\n", 249 | "levels = torch.linspace(zmin, zmax, 30).numpy()\n", 250 | "fig, axes = plt.subplots(1, 2)\n", 251 | "cp1 = axes[0].contourf(xgrid, ygrid, ztrue, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,\n", 252 | " antialiased=False)\n", 253 | "cp2 = axes[1].contourf(xgrid, ygrid, zpred, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,\n", 254 | " antialiased=False)\n", 255 | "\n", 256 | "axes[0].plot(sample_inputs[:, 0], sample_inputs[:, 1], 'kx')\n", 257 | "axes[0].set_xlabel(parents[0])\n", 258 | "axes[1].set_xlabel(parents[0])\n", 259 | "axes[0].set_ylabel(parents[1])\n", 260 | "_ = fig.colorbar(cp2)" 261 | ], 262 | "metadata": { 263 | "collapsed": false, 264 | "pycharm": { 265 | "name": "#%%\n" 266 | } 267 | } 268 | } 269 | ], 270 | "metadata": { 271 | "kernelspec": { 272 | "display_name": "Python 3", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 2 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython2", 286 | "version": "2.7.6" 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 0 291 | } -------------------------------------------------------------------------------- /src/environments/generic_environments.py: -------------------------------------------------------------------------------- 1 | from typing import Set, List 2 | 3 | import networkx as nx 4 | import torch 5 | 6 | from src.environments.environment import Environment, InterventionalDistributionsQuery 7 | 8 | 9 | class ErdosRenyi(Environment): 10 | def __init__(self, num_nodes: int, 11 | p: float = None, 12 | mechanism_model='gp-model', 13 | frac_non_intervenable_nodes: float = None, 14 | non_intervenable_nodes: Set = None, 15 | num_test_samples_per_intervention: int = 50, 16 | num_test_queries: int = 0, 17 | interventional_queries: List[InterventionalDistributionsQuery] = None): 18 | assert num_nodes >= 2 19 | self.edge_prob = p if p is not None else 4. / (num_nodes - 1) if num_nodes > 5 else 0.5 20 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 21 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 22 | 23 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 24 | adj_mat = torch.bernoulli(torch.tensor(self.edge_prob).expand(num_nodes, num_nodes)) 25 | adj_mat = torch.triu(adj_mat, diagonal=1).bool() 26 | order = torch.randperm(num_nodes) 27 | 28 | graph = nx.DiGraph() 29 | graph.add_nodes_from([f'X{i}' for i in range(num_nodes)]) 30 | for i in range(num_nodes): 31 | for j in range(i + 1, num_nodes): 32 | if adj_mat[i, j]: 33 | graph.add_edge(f'X{order[i]}', f'X{order[j]}') 34 | 35 | return graph 36 | 37 | 38 | class BarabasiAlbert(Environment): 39 | def __init__(self, num_nodes: int, 40 | num_parents_per_node: int = 2, 41 | mechanism_model='gp-model', 42 | frac_non_intervenable_nodes: float = None, 43 | non_intervenable_nodes: Set = None, 44 | num_test_samples_per_intervention: int = 50, 45 | num_test_queries: int = 0, 46 | interventional_queries: List[InterventionalDistributionsQuery] = None): 47 | assert num_nodes >= 2 48 | self.num_parents_per_node = num_parents_per_node 49 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 50 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 51 | 52 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 53 | graph = nx.generators.barabasi_albert_graph(num_nodes, self.num_parents_per_node) 54 | adj_mat = torch.tensor(nx.to_numpy_array(graph)) 55 | adj_mat = torch.triu(adj_mat, diagonal=1).bool() 56 | 57 | graph = nx.DiGraph() 58 | for i in range(num_nodes): 59 | for j in range(i + 1, num_nodes): 60 | if adj_mat[i, j]: 61 | graph.add_edge(f'X{i}', f'X{j}') 62 | 63 | return graph 64 | 65 | 66 | class CRGraph(Environment): 67 | def __init__(self, num_nodes: int = None, 68 | mechanism_model='gp-model', 69 | frac_non_intervenable_nodes: float = None, 70 | non_intervenable_nodes: Set = None, 71 | num_test_samples_per_intervention: int = 50, 72 | num_test_queries: int = 0, 73 | interventional_queries: List[InterventionalDistributionsQuery] = None): 74 | super().__init__(5, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 75 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 76 | 77 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 78 | graph = nx.DiGraph() 79 | graph.add_edge('X0', 'X1') 80 | graph.add_edge('X0', 'X2') 81 | graph.add_edge('X0', 'X3') 82 | graph.add_edge('X1', 'X2') 83 | graph.add_edge('X1', 'X3') 84 | graph.add_edge('X2', 'X4') 85 | graph.add_edge('X3', 'X4') 86 | return graph 87 | 88 | 89 | class BiDiag(Environment): 90 | def __init__(self, num_nodes: int, 91 | mechanism_model='gp-model', 92 | frac_non_intervenable_nodes: float = None, 93 | non_intervenable_nodes: Set = None, 94 | num_test_samples_per_intervention: int = 50, 95 | num_test_queries: int = 0, 96 | interventional_queries: List[InterventionalDistributionsQuery] = None): 97 | assert num_nodes >= 2 98 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 99 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 100 | 101 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 102 | graph = nx.DiGraph() 103 | graph.add_edge('X0', 'X1') 104 | for i in range(2, num_nodes): 105 | graph.add_edge(f'X{i - 2}', f'X{i}') 106 | graph.add_edge(f'X{i - 1}', f'X{i}') 107 | return graph 108 | 109 | 110 | class Chain(Environment): 111 | def __init__(self, num_nodes: int, 112 | mechanism_model='gp-model', 113 | frac_non_intervenable_nodes: float = None, 114 | non_intervenable_nodes: Set = None, 115 | num_test_samples_per_intervention: int = 50, 116 | num_test_queries: int = 0, 117 | interventional_queries: List[InterventionalDistributionsQuery] = None): 118 | assert num_nodes >= 2 119 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 120 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 121 | 122 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 123 | graph = nx.DiGraph() 124 | for i in range(1, num_nodes): 125 | graph.add_edge(f'X{i - 1}', f'X{i}') 126 | return graph 127 | 128 | 129 | class Collider(Environment): 130 | def __init__(self, num_nodes: int, 131 | mechanism_model='gp-model', 132 | frac_non_intervenable_nodes: float = None, 133 | non_intervenable_nodes: Set = None, 134 | num_test_samples_per_intervention: int = 50, 135 | num_test_queries: int = 0, 136 | interventional_queries: List[InterventionalDistributionsQuery] = None): 137 | assert num_nodes >= 2 138 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 139 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 140 | 141 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 142 | graph = nx.DiGraph() 143 | for i in range(1, num_nodes): 144 | graph.add_edge(f'X{i}', 'X0') 145 | return graph 146 | 147 | 148 | class Full(Environment): 149 | def __init__(self, num_nodes: int, 150 | mechanism_model='gp-model', 151 | frac_non_intervenable_nodes: float = None, 152 | non_intervenable_nodes: Set = None, 153 | num_test_samples_per_intervention: int = 50, 154 | num_test_queries: int = 0, 155 | interventional_queries: List[InterventionalDistributionsQuery] = None): 156 | assert num_nodes >= 2 157 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 158 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 159 | 160 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 161 | graph = nx.DiGraph() 162 | for i in range(num_nodes): 163 | root_node = [f'X{i}'] * (num_nodes - i) 164 | target_nodes = [f'X{j}' for j in range(i + 1, num_nodes)] 165 | graph.add_edges_from(zip(root_node, target_nodes)) 166 | 167 | return graph 168 | 169 | 170 | class Independent(Environment): 171 | def __init__(self, num_nodes: int, 172 | mechanism_model='gp-model', 173 | frac_non_intervenable_nodes: float = None, 174 | non_intervenable_nodes: Set = None, 175 | num_test_samples_per_intervention: int = 50, 176 | num_test_queries: int = 0, 177 | interventional_queries: List[InterventionalDistributionsQuery] = None): 178 | assert num_nodes >= 2 179 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 180 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 181 | 182 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 183 | graph = nx.DiGraph() 184 | graph.add_nodes_from([f'X{i}' for i in range(num_nodes)]) 185 | return graph 186 | 187 | 188 | class Jungle(Environment): 189 | def __init__(self, num_nodes: int, 190 | mechanism_model='gp-model', 191 | frac_non_intervenable_nodes: float = None, 192 | non_intervenable_nodes: Set = None, 193 | num_test_samples_per_intervention: int = 50, 194 | num_test_queries: int = 0, 195 | interventional_queries: List[InterventionalDistributionsQuery] = None): 196 | assert num_nodes >= 2 197 | super().__init__(num_nodes, mechanism_model, frac_non_intervenable_nodes, non_intervenable_nodes, 198 | num_test_samples_per_intervention, num_test_queries, interventional_queries) 199 | 200 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 201 | graph = nx.DiGraph() 202 | graph.add_edge('X0', 'X1') 203 | graph.add_edge('X0', 'X2') 204 | for i in range(3, num_nodes): 205 | parent = int((i + 1) / 2) - 1 206 | grandparent = int((parent + 1) / 2) - 1 207 | graph.add_edge(f'X{parent}', f'X{i}') 208 | graph.add_edge(f'X{grandparent}', f'X{i}') 209 | 210 | return graph 211 | -------------------------------------------------------------------------------- /notebooks/plot_benchmark_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Plot benchmark results\n", 7 | "\n", 8 | "In this notebook can be used to generate benchmark environments as found in the ABCI paper." 9 | ], 10 | "metadata": { 11 | "collapsed": false 12 | } 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "collapsed": true, 19 | "pycharm": { 20 | "name": "#%% imports\n" 21 | } 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "%reload_ext autoreload\n", 26 | "%autoreload 2\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "from matplotlib.ticker import MaxNLocator\n", 30 | "\n", 31 | "from src.utils.plotting import init_plot_style, Simulation" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "outputs": [], 38 | "source": [ 39 | "# the simulations you want to plot/compare\n", 40 | "simulations = []\n", 41 | "simulations.append(Simulation('CRGraph', 5, '20220728_0022', 'abci-dibs-gp', 'observational',\n", 42 | " num_experiments=50, plot_kwargs={'label':'OBS', 'marker':'s', 'c':'Plum'}))\n", 43 | "simulations.append(Simulation('CRGraph', 5, '20221005_1523', 'abci-dibs-gp', 'scm-info-gain',\n", 44 | " num_experiments=50, plot_kwargs={'label':'RAND-FIXED', 'marker':'+', 'c':'DarkTurquoise'}))\n", 45 | "simulations.append(Simulation('CRGraph', 5, '20220728_0022', 'abci-dibs-gp', 'random',\n", 46 | " num_experiments=50, plot_kwargs={'label':'RAND', 'marker':'^', 'c':'Goldenrod'}))\n", 47 | "simulations.append(Simulation('CRGraph', 5, '20220728_0025', 'abci-dibs-gp', 'graph-info-gain',\n", 48 | " num_experiments=50, plot_kwargs={'label':r'$\\text{U}_\\text{CD}$', 'marker':'o', 'c':'MediumSeaGreen'}))\n", 49 | "simulations.append(Simulation('CRGraph', 5, '20220728_0025', 'abci-dibs-gp', 'scm-info-gain',\n", 50 | " num_experiments=50, plot_kwargs={'label':r'$\\text{U}_\\text{CML}$', 'marker':'x', 'c':'Tomato'}))\n", 51 | "simulations.append(Simulation('CRGraph', 5, '20220728_0022', 'abci-dibs-gp', 'intervention-info-gain',\n", 52 | " num_experiments=50, plot_kwargs={'label':r'$\\text{U}_\\text{CR}$', 'marker':'*', 'c':'CornflowerBlue'}))\n", 53 | "\n", 54 | "# the stats to extract from the simulation results\n", 55 | "stats_names = ['eshd', 'graph_ll', 'graph_entropy', 'auroc', 'auprc', 'observational_test_ll', 'observational_kld',\n", 56 | " 'interventional_test_ll', 'interventional_kld', 'query_kld']\n", 57 | "\n", 58 | "for sim in simulations:\n", 59 | " sim.load_results(stats_names)" 60 | ], 61 | "metadata": { 62 | "collapsed": false, 63 | "pycharm": { 64 | "name": "#%% load data\n" 65 | } 66 | } 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "outputs": [], 72 | "source": [ 73 | "init_plot_style()\n", 74 | "\n", 75 | "save_plots = False\n", 76 | "dpi = 600\n", 77 | "fig_format = 'png'\n", 78 | "fig_dir = '../figures/'\n", 79 | "figdate = '20230101'\n", 80 | "fig_name = 'CRGraph-5'\n", 81 | "\n", 82 | "# axis labels for the given stats\n", 83 | "stat_labels = {'eshd': 'Expected SHD', 'graph_ll': 'Graph KLD', 'graph_entropy': 'Graph Entropy', 'auroc': 'AUROC',\n", 84 | " 'auprc':'AUPRC', 'observational_test_ll': 'NLL of Observational Test Data',\n", 85 | " 'interventional_test_ll': 'NLL of Interventional Test Data', 'observational_kld': 'Observational KLD',\n", 86 | " 'interventional_kld': 'Avg. Interventional KLD', 'query_kld': 'Query KLD'}\n", 87 | "# file identifier token for the given stats\n", 88 | "stat_tokens = {'eshd': 'ESHD', 'graph_ll': 'GRAPH-KLD', 'graph_entropy': 'Graph Entropy', 'auroc': 'AUROC',\n", 89 | " 'auprc':'AUPRC', 'observational_test_ll': 'OBS-NLL','interventional_test_ll':'INTR-NLL',\n", 90 | " 'observational_kld': 'OBS-KLD', 'interventional_kld': 'AVG-INTR-KLD', 'query_kld':'QUERY-KLD'}\n" 91 | ], 92 | "metadata": { 93 | "collapsed": false, 94 | "pycharm": { 95 | "name": "#%% init plot params\n" 96 | } 97 | } 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "outputs": [], 103 | "source": [ 104 | "stat_name = 'eshd'\n", 105 | "\n", 106 | "# plot stats over experiments\n", 107 | "ax = plt.figure(figsize=(36,12)).gca()\n", 108 | "for sim in simulations:\n", 109 | " sim.plot_simulation_data(ax, stat_name)\n", 110 | "plt.xlabel('Number of Experiments')\n", 111 | "plt.ylabel(stat_labels[stat_name])\n", 112 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 113 | "plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.4), ncol=len(simulations))\n", 114 | "plt.xlim([0.8, simulations[0].stats[stat_name].shape[-1] + 0.2])\n", 115 | "plt.tight_layout()\n", 116 | "\n", 117 | "if save_plots:\n", 118 | " plt.savefig(fig_dir + f'{figdate}-{fig_name}-{stat_tokens[stat_name]}.{fig_format}', dpi = dpi, bbox_inches='tight')\n" 119 | ], 120 | "metadata": { 121 | "collapsed": false, 122 | "pycharm": { 123 | "name": "#%% plot single stat\n" 124 | } 125 | } 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "outputs": [], 131 | "source": [ 132 | "\n", 133 | "# CRGraph\n", 134 | "stats_names = ['eshd', 'interventional_kld', 'query_kld']\n", 135 | "yranges = [[-0.2, 4.], [0.5, 6.], [-0.05, 1.1]]\n", 136 | "\n", 137 | "# plot entropy of graph posterior\n", 138 | "fig, axs = plt.subplots(1, 3, figsize=(36,9))\n", 139 | "for i, stat_name in enumerate(stats_names):\n", 140 | " for sim in simulations:\n", 141 | " sim.plot_simulation_data(axs[i], stat_name)\n", 142 | "\n", 143 | " axs[i].set_title(stat_labels[stat_name], loc='center', y=1.01)\n", 144 | " # format x axis\n", 145 | " axs[i].set_xlim([0.8, simulations[0].stats[stat_name].shape[-1] + 0.2])\n", 146 | " axs[i].xaxis.set_major_locator(MaxNLocator(10,integer=True))\n", 147 | "\n", 148 | " # format y axis\n", 149 | " axs[i].set_ylim(yranges[i])\n", 150 | "\n", 151 | "\n", 152 | "axs[1].set_xlabel('Number of Experiments')\n", 153 | "axs[0].legend(loc='lower left', bbox_to_anchor=(0.1, -0.4), ncol=len(simulations), frameon=False)\n", 154 | "\n", 155 | "if save_plots:\n", 156 | " plt.savefig(fig_dir + f'{figdate}-{fig_name}-MIXED-STATS.{fig_format}', dpi = dpi, bbox_inches='tight')\n" 157 | ], 158 | "metadata": { 159 | "collapsed": false, 160 | "pycharm": { 161 | "name": "#%% plot multiple stats\n" 162 | } 163 | } 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "outputs": [], 169 | "source": [ 170 | "# the simulations you want to plot/compare\n", 171 | "simulations2 = []\n", 172 | "simulations2.append(Simulation('CRGraph', 5, '20220728_0022', 'abci-dibs-gp', 'observational',\n", 173 | " num_experiments=50, plot_kwargs={'label':'OBS', 'marker':'s', 'c':'Plum'}))\n", 174 | "simulations2.append(Simulation('CRGraph', 5, '20221005_1523', 'abci-dibs-gp', 'scm-info-gain',\n", 175 | " num_experiments=50, plot_kwargs={'label':'RAND-FIXED', 'marker':'+', 'c':'DarkTurquoise'}))\n", 176 | "simulations2.append(Simulation('CRGraph', 5, '20220728_0022', 'abci-dibs-gp', 'random',\n", 177 | " num_experiments=50, plot_kwargs={'label':'RAND', 'marker':'^', 'c':'Goldenrod'}))\n", 178 | "simulations2.append(Simulation('CRGraph', 5, '20220728_0025', 'abci-dibs-gp', 'graph-info-gain',\n", 179 | " num_experiments=50, plot_kwargs={'label':r'$\\text{U}_\\text{CD}$', 'marker':'o', 'c':'MediumSeaGreen'}))\n", 180 | "simulations2.append(Simulation('CRGraph', 5, '20220728_0025', 'abci-dibs-gp', 'scm-info-gain',\n", 181 | " num_experiments=50, plot_kwargs={'label':r'$\\text{U}_\\text{CML}$', 'marker':'x', 'c':'Tomato'}))\n", 182 | "simulations2.append(Simulation('CRGraph', 5, '20220728_0022', 'abci-dibs-gp', 'intervention-info-gain',\n", 183 | " num_experiments=50, plot_kwargs={'label':r'$\\text{U}_\\text{CR}$', 'marker':'*', 'c':'CornflowerBlue'}))\n", 184 | "\n", 185 | "for sim in simulations2:\n", 186 | " sim.load_results(stats_names)" 187 | ], 188 | "metadata": { 189 | "collapsed": false, 190 | "pycharm": { 191 | "name": "#%% load a second set of simulations to be compared with the first set of simulations (e.g., comparing different envs)\n" 192 | } 193 | } 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "outputs": [], 199 | "source": [ 200 | "stat_name = 'query_kld'\n", 201 | "yranges = [[-0.05, 1.0], [-0.05, 1.0]]\n", 202 | "\n", 203 | "# plot entropy of graph posterior\n", 204 | "fig, axs = plt.subplots(1, 2, figsize=(36,13), sharey=True)\n", 205 | "for i, sims in enumerate((simulations, simulations2)):\n", 206 | " for sim in sims:\n", 207 | " sim.plot_simulation_data(axs[i], stat_name)\n", 208 | "\n", 209 | " # format x axis\n", 210 | " axs[i].set_xlim([0.8, simulations[0].stats[stat_name].shape[-1] + 0.2])\n", 211 | " axs[i].xaxis.set_major_locator(MaxNLocator(10,integer=True))\n", 212 | "\n", 213 | " # format y axis\n", 214 | " axs[i].set_ylim(yranges[i])\n", 215 | "\n", 216 | "\n", 217 | "axs[0].set_xlabel('Number of Experiments')\n", 218 | "axs[0].legend(loc='lower left', bbox_to_anchor=(0.75, -0.25), ncol=len(simulations), frameon=False)\n", 219 | "plt.tight_layout()\n", 220 | "plt.subplots_adjust(wspace=0.05)\n", 221 | "\n", 222 | "if save_plots:\n", 223 | " plt.savefig(fig_dir + f'{figdate}-{fig_name}-QUERY_KLD.{fig_format}', dpi = dpi, bbox_inches='tight')\n" 224 | ], 225 | "metadata": { 226 | "collapsed": false, 227 | "pycharm": { 228 | "name": "#%% plot single stat for two different sets of simulations\n" 229 | } 230 | } 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 2 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython2", 249 | "version": "2.7.6" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 0 254 | } -------------------------------------------------------------------------------- /notebooks/example_abci_categorical_gp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example usage of ABCI-Categorical-GP\n", 8 | "\n", 9 | "This notebook illustrates the example usage of ABCI with a categorical\n", 10 | "distribution over graphs and a GP mechanism model. This setup scales up to systems with\n", 11 | "four variables." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "pycharm": { 19 | "name": "#%%\n" 20 | } 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "# imports\n", 25 | "%reload_ext autoreload\n", 26 | "%autoreload 2\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import torch.distributions as dist\n", 30 | "from matplotlib.ticker import MaxNLocator\n", 31 | "\n", 32 | "from src.abci_categorical_gp import ABCICategoricalGP as ABCI\n", 33 | "from src.environments.generic_environments import *\n", 34 | "from src.models.gp_model import get_graph_key, gather_data\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "First, we generate a ground truth environment/SCM." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "pycharm": { 49 | "name": "#%%\n" 50 | } 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "# specify the number of nodes and (optionally) a query of interventional variables\n", 55 | "num_nodes = 3\n", 56 | "# interventional_queries = None\n", 57 | "interventional_queries = [InterventionalDistributionsQuery(['X2'], {'X1': dist.Uniform(3, 4.)})]\n", 58 | "\n", 59 | "# generate the ground truth environment\n", 60 | "env = BiDiag(num_nodes,\n", 61 | " num_test_queries=10,\n", 62 | " interventional_queries=interventional_queries)\n", 63 | "\n", 64 | "# plot true graph\n", 65 | "nx.draw(env.graph, nx.planar_layout(env.graph), labels=dict(zip(env.graph.nodes, env.graph.nodes)))\n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "Next, we can examine the ground truth mechanisms." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "pycharm": { 80 | "name": "#%%\n" 81 | } 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "# plotting a univariate mechanism\n", 86 | "node = 'X1' # target node\n", 87 | "num_points = 100\n", 88 | "xrange = torch.linspace(-10., 10., num_points).unsqueeze(1)\n", 89 | "ytrue = env.mechanisms[node](xrange, prior_mode=False).detach()\n", 90 | "\n", 91 | "plt.figure()\n", 92 | "plt.plot(xrange, ytrue)\n", 93 | "plt.xlabel('X0')\n", 94 | "plt.ylabel(node)\n", 95 | "plt.tight_layout()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "pycharm": { 103 | "name": "#%%\n" 104 | } 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "# plotting a bivariate mechanism\n", 109 | "node = 'X2' # target node\n", 110 | "num_points = 100\n", 111 | "xrange = torch.linspace(-7., 7., num_points)\n", 112 | "yrange = torch.linspace(-7., 7., num_points)\n", 113 | "xgrid, ygrid = torch.meshgrid(xrange, yrange)\n", 114 | "ztrue = env.mechanisms[node](torch.stack((xgrid, ygrid), dim=2).view(-1, 2))\n", 115 | "zmin = ztrue.min().item()\n", 116 | "zmax = ztrue.max().item()\n", 117 | "print(f'Function values for {node} in range [{zmin, zmax}].')\n", 118 | "ztrue = ztrue.detach().view(num_points, num_points).numpy()\n", 119 | "\n", 120 | "levels = torch.linspace(zmin, zmax, 30).numpy()\n", 121 | "fig, ax = plt.subplots()\n", 122 | "cp1 = ax.contourf(xgrid, ygrid, ztrue, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax,\n", 123 | " antialiased=False)\n", 124 | "ax.set_xlabel('X0')\n", 125 | "ax.set_ylabel('X1')\n", 126 | "_ = fig.colorbar(cp1)\n", 127 | "\n", 128 | "# _, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n", 129 | "# surf = ax.plot_surface(xgrid, ygrid, ztrue, cmap=plt.get_cmap('jet'), linewidth=0, antialiased=False)\n", 130 | "# ax.set_xlabel('X0')\n", 131 | "# ax.set_ylabel('X1')\n", 132 | "# ax.set_zlabel(node)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "Here, we create an ABCI instance with the desired experimental design policy." 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "pycharm": { 147 | "name": "#%%\n" 148 | } 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "policy = 'observational'\n", 153 | "abci = ABCI(env, policy)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "We can now run a number of ABCI loops." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "pycharm": { 168 | "name": "#%%\n" 169 | } 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "num_experiments = 3\n", 174 | "batch_size = 3\n", 175 | "\n", 176 | "abci.run(num_experiments, batch_size, num_initial_obs_samples=3)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "Here, we plot the training stats and results." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "pycharm": { 191 | "name": "#%%\n" 192 | } 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "print(f'Number of observational batches: {len([e for e in abci.experiments if e.interventions == {}])}')\n", 197 | "for node in env.node_labels:\n", 198 | " print(\n", 199 | " f'Number of interventional batches on {node}: {len([e for e in abci.experiments if node in e.interventions])}')\n", 200 | "\n", 201 | "# plot expected SHD over experiments\n", 202 | "ax = plt.figure().gca()\n", 203 | "plt.plot(abci.eshd_list)\n", 204 | "plt.xlabel('Number of Experiments')\n", 205 | "plt.ylabel('Expected SHD')\n", 206 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 207 | "\n", 208 | "# plot true graph NLL over experiments\n", 209 | "ax = plt.figure().gca()\n", 210 | "plt.plot(-torch.tensor(abci.graph_ll_list))\n", 211 | "plt.xlabel('Number of Experiments')\n", 212 | "plt.ylabel('True Graph NLL')\n", 213 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 214 | "\n", 215 | "# plot graph posterior\n", 216 | "graphs = abci.graph_posterior.sort_by_prob()[0:10]\n", 217 | "probs = [abci.graph_posterior.log_prob(g).exp().detach() for g in graphs]\n", 218 | "graph_keys = [get_graph_key(g) for g in graphs]\n", 219 | "\n", 220 | "plt.figure()\n", 221 | "plt.xticks(rotation=90)\n", 222 | "plt.bar(graph_keys, probs)\n", 223 | "plt.ylabel(r'Graph Posterior, $p(G|D)$')\n", 224 | "\n", 225 | "# plot graph posterior entropy over experiments\n", 226 | "ax = plt.figure().gca()\n", 227 | "plt.plot(abci.graph_entropy_list, label='entropy estimate')\n", 228 | "plt.xlabel('Number of Experiments')\n", 229 | "plt.ylabel('Entropy of Graph Posterior')\n", 230 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 231 | "plt.legend()\n", 232 | "\n", 233 | "# plot Query KLD over experiments\n", 234 | "ax = plt.figure().gca()\n", 235 | "plt.plot(abci.query_kld_list)\n", 236 | "plt.xlabel('Number of Experiments')\n", 237 | "plt.ylabel('Query KLD')\n", 238 | "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 239 | "plt.tight_layout()" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "Finally, we can have a look at the learned vs. true mechanisms." 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": { 253 | "pycharm": { 254 | "name": "#%%\n" 255 | } 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "# plot X_i -> X_j true vs. predicted\n", 260 | "i = 0\n", 261 | "j = 1\n", 262 | "xdata, ydata = gather_data(abci.experiments, f'X{j}', parents=[f'X{i}'])\n", 263 | "xrange = torch.linspace(xdata.min(), xdata.max(), 100).unsqueeze(1)\n", 264 | "ytrue = env.mechanisms[f'X{j}'](xrange).detach()\n", 265 | "mech = abci.mechanism_model.get_mechanism(f'X{j}', parents=[f'X{i}'])\n", 266 | "mech.set_data(xdata, ydata)\n", 267 | "ypred = mech(xrange).detach()\n", 268 | "\n", 269 | "plt.figure()\n", 270 | "plt.plot(xdata, ydata, 'rx', label='Experimental Data')\n", 271 | "plt.plot(xrange, ytrue, label=f'X{i}->X{j} true')\n", 272 | "plt.plot(xrange, ypred, label=f'X{i}->X{j} prediction')\n", 273 | "plt.xlabel(f'X{i}')\n", 274 | "plt.ylabel(f'X{j}')\n", 275 | "plt.legend()\n", 276 | "plt.tight_layout()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "pycharm": { 284 | "name": "#%%\n" 285 | } 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "# plot bivariate mechanisms\n", 290 | "node = 'X2'\n", 291 | "num_points = 100\n", 292 | "xrange = torch.linspace(-10., 10., num_points)\n", 293 | "yrange = torch.linspace(-10., 10., num_points)\n", 294 | "xgrid, ygrid = torch.meshgrid(xrange, yrange)\n", 295 | "inputs = torch.stack((xgrid, ygrid), dim=2).view(-1, 2)\n", 296 | "ztrue = env.mechanisms[node](inputs).detach().view(num_points, num_points).numpy()\n", 297 | "\n", 298 | "parents = ['X0', 'X1']\n", 299 | "mech = abci.mechanism_model.get_mechanism(node, parents=parents)\n", 300 | "zpred = mech(inputs)\n", 301 | "zpred = zpred.detach().view(num_points, num_points).numpy()\n", 302 | "\n", 303 | "zmin = ztrue.min().item()\n", 304 | "zmax = ztrue.max().item()\n", 305 | "print(f'Function values for {node} in range [{zmin, zmax}].')\n", 306 | "\n", 307 | "levels = torch.linspace(zmin, zmax, 30).numpy()\n", 308 | "fig, axes = plt.subplots(1, 2)\n", 309 | "cp1 = axes[0].contourf(xgrid, ygrid, ztrue, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax)\n", 310 | "cp2 = axes[1].contourf(xgrid, ygrid, zpred, cmap=plt.get_cmap('jet'), levels=levels, vmin=zmin, vmax=zmax)\n", 311 | "\n", 312 | "inputs, targets = gather_data(abci.experiments, node, parents=parents)\n", 313 | "axes[0].plot(inputs[:, 0], inputs[:, 1], 'kx')\n", 314 | "\n", 315 | "axes[0].set_xlabel(parents[0])\n", 316 | "axes[1].set_xlabel(parents[0])\n", 317 | "axes[0].set_ylabel(parents[1])\n", 318 | "_ = fig.colorbar(cp2)" 319 | ] 320 | } 321 | ], 322 | "metadata": { 323 | "kernelspec": { 324 | "display_name": "Python 3 (ipykernel)", 325 | "language": "python", 326 | "name": "python3" 327 | }, 328 | "language_info": { 329 | "codemirror_mode": { 330 | "name": "ipython", 331 | "version": 3 332 | }, 333 | "file_extension": ".py", 334 | "mimetype": "text/x-python", 335 | "name": "python", 336 | "nbconvert_exporter": "python", 337 | "pygments_lexer": "ipython3", 338 | "version": "3.7.10" 339 | } 340 | }, 341 | "nbformat": 4, 342 | "nbformat_minor": 1 343 | } -------------------------------------------------------------------------------- /src/environments/environment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from typing import Dict, List, Set, Optional 4 | 5 | import networkx as nx 6 | 7 | from src.environments.experiment import Experiment, gather_data 8 | from src.models.graph_models import get_parents 9 | from src.models.mechanisms import * 10 | 11 | 12 | class InterventionalDistributionsQuery: 13 | def __init__(self, query_nodes: List[str], intervention_targets: Dict[str, dist.Distribution], 14 | sample_queries: List[Experiment] = None): 15 | self.query_nodes = query_nodes 16 | self.intervention_targets = intervention_targets 17 | self.sample_queries = sample_queries 18 | 19 | def sample_intervention(self): 20 | return {target: distr.sample() for target, distr in self.intervention_targets.items()} 21 | 22 | def set_sample_queries(self, sample_queries: List[Experiment]): 23 | self.sample_queries = sample_queries 24 | 25 | def clone(self): 26 | return InterventionalDistributionsQuery(self.query_nodes, self.intervention_targets) 27 | 28 | def param_dict(self): 29 | params = {'query_nodes': self.query_nodes, 30 | 'intervention_targets': self.intervention_targets, 31 | 'sample_queries': self.sample_queries} 32 | return params 33 | 34 | @classmethod 35 | def load_param_dict(cls, param_dict): 36 | return InterventionalDistributionsQuery(param_dict['query_nodes'], param_dict['intervention_targets'], 37 | param_dict['sample_queries']) 38 | 39 | 40 | class Environment: 41 | def __init__(self, num_nodes: int, 42 | mechanism_model: Optional[str] = 'gp-model', 43 | frac_non_intervenable_nodes: float = None, 44 | non_intervenable_nodes: Set = None, 45 | num_test_samples_per_intervention: int = 50, 46 | num_test_queries: int = 30, 47 | interventional_queries: List[InterventionalDistributionsQuery] = None, 48 | graph: nx.DiGraph = None): 49 | 50 | # generate unique env name 51 | seed = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(8)]) 52 | self.name = self.__class__.__name__ + f'-{num_nodes}-{seed}' 53 | 54 | # construct graph 55 | self.num_nodes = num_nodes 56 | self.graph = self.construct_graph(num_nodes) if graph is None else graph 57 | self.topological_order = list(nx.topological_sort(self.graph)) 58 | self.node_labels = sorted(list(set(self.graph.nodes))) 59 | 60 | # generate mechanisms 61 | self.mechanism_model = mechanism_model 62 | if mechanism_model is not None: 63 | mechanisms = [] 64 | for node in self.node_labels: 65 | parents = get_parents(node, self.graph) 66 | mechanisms.append(self.create_mechanism(len(parents))) 67 | self.mechanisms = dict(zip(self.node_labels, mechanisms)) 68 | else: 69 | self.mechanisms = None 70 | 71 | # optional: restrict intervenable nodes 72 | self.non_intervenable_nodes = set() 73 | if frac_non_intervenable_nodes is not None: 74 | num_non_intervenable_nodes = int(num_nodes * frac_non_intervenable_nodes) 75 | node_idc = torch.randperm(num_nodes)[:num_non_intervenable_nodes] 76 | self.non_intervenable_nodes = set(self.node_labels[i] for i in node_idc) 77 | if non_intervenable_nodes is not None: 78 | self.non_intervenable_nodes |= non_intervenable_nodes 79 | 80 | self.intervenable_nodes = set(self.node_labels) - self.non_intervenable_nodes 81 | 82 | # set intervention bounds for experiment design 83 | self.intervention_bounds = dict(zip(self.node_labels, [(-7., 7.) for _ in range(self.num_nodes)])) 84 | 85 | # generate observational/interventional test data 86 | self.observational_test_data = self.interventional_test_data = None 87 | self.num_test_samples_per_intervention = num_test_samples_per_intervention 88 | if num_test_samples_per_intervention > 0: 89 | self.observational_test_data = [self.sample({}, 1, num_test_samples_per_intervention)] 90 | self.interventional_test_data = dict() 91 | for node in self.node_labels: 92 | bounds = self.intervention_bounds[node] 93 | intr_values = torch.rand(num_test_samples_per_intervention) * (bounds[1] - bounds[0]) + bounds[0] 94 | experiments = [self.sample({node: intr_values[i]}, 1) for i in range(num_test_samples_per_intervention)] 95 | self.interventional_test_data.update({node: experiments}) 96 | 97 | # generate query test data 98 | self.num_test_queries = num_test_queries 99 | self.interventional_queries = interventional_queries 100 | self.query_ll = torch.tensor(0.) 101 | 102 | if num_test_queries > 0 and self.interventional_queries is not None: 103 | with torch.no_grad(): 104 | for query in self.interventional_queries: 105 | experiments = [] 106 | for i in range(num_test_queries): 107 | interventions = query.sample_intervention() 108 | experiments.append(self.sample(interventions, 1)) 109 | 110 | query.set_sample_queries(experiments) 111 | 112 | query_lls = torch.zeros(num_test_queries, len(self.interventional_queries)) 113 | for i in range(num_test_queries): 114 | for query_idx, query in enumerate(self.interventional_queries): 115 | query_node = query.query_nodes[0] # ToDo: supports only single query node!!! 116 | targets = query.sample_queries[i].data[query_node].squeeze(-1) 117 | imll = self.interventional_mll(targets, query_node, query.sample_queries[i].interventions) 118 | query_lls[i, query_idx] = imll 119 | 120 | self.query_ll = query_lls.sum(dim=1).mean() 121 | 122 | def create_mechanism(self, num_parents: int): 123 | if self.mechanism_model == 'gp-model': 124 | return GaussianProcess(num_parents, static=True) if num_parents > 0 else GaussianRootNode(static=True) 125 | 126 | assert False, print(f'Invalid mechanism model {self.mechanism_model}!') 127 | 128 | def sample(self, interventions: dict, batch_size: int, num_batches: int = 1) -> Experiment: 129 | data = dict() 130 | for node in self.topological_order: 131 | # check if node is intervened upon 132 | if node in interventions: 133 | samples = torch.ones(num_batches, batch_size, 1) * interventions[node] 134 | else: 135 | mech = self.mechanisms[node] 136 | 137 | # sample from mechanism 138 | parents = get_parents(node, self.graph) 139 | if not parents: 140 | samples = mech.sample(torch.empty(num_batches, batch_size, 1)) 141 | else: 142 | x = torch.cat([data[parent] for parent in parents], dim=-1) 143 | assert x.shape == (num_batches, batch_size, mech.in_size), print(f'Invalid shape {x.shape}!') 144 | samples = mech.sample(x) 145 | 146 | # store samples 147 | data[node] = samples 148 | 149 | return Experiment(interventions, data) 150 | 151 | def log_likelihood(self, experiments: List[Experiment]) -> torch.Tensor: 152 | ll = torch.tensor(0.) 153 | for node in self.node_labels: 154 | # gather data from the experiments 155 | parents = get_parents(node, self.graph) 156 | inputs, targets = gather_data(experiments, node, parents=parents, mode='independent_samples') 157 | 158 | # check if we have any data for this node and compute log-likelihood 159 | mechanism_ll = torch.tensor(0.) 160 | if targets is not None: 161 | try: 162 | mechanism_ll = self.mechanisms[node].mll(inputs, targets, prior_mode=False) 163 | except Exception as e: 164 | print(f'Exception occured in Environment.log_likelihood() when computing LL for mechanism {node}:') 165 | print(e) 166 | 167 | ll += mechanism_ll 168 | return ll 169 | 170 | def interventional_mll(self, targets, node: str, interventions: dict, num_mc_samples=200, reduce=True): 171 | assert targets.dim() == 2, print(f'Invalid shape {targets.shape}') 172 | num_batches, batch_size = targets.shape 173 | 174 | parents = get_parents(node, self.graph) 175 | mechanism = self.mechanisms[node] 176 | 177 | if len(parents) == 0: 178 | # if we have a root note imll is simple 179 | ll = mechanism.mll(None, targets, prior_mode=False, reduce=False) 180 | assert ll.shape == (num_batches,), print(f'Invalid shape {ll.shape}!') 181 | return ll.sum() if reduce else ll 182 | 183 | # otherwise, do MC estimate via ancestral sampling 184 | samples = self.sample(interventions, batch_size, num_mc_samples) 185 | # assemble inputs and targets 186 | inputs, _ = gather_data([samples], node, parents=parents, mode='independent_batches') 187 | inputs = inputs.unsqueeze(0).expand(num_batches, -1, -1, -1) 188 | assert inputs.shape == (num_batches, num_mc_samples, batch_size, len(parents)) 189 | targets = targets.unsqueeze(1).expand(-1, num_mc_samples, batch_size) 190 | assert targets.shape == (num_batches, num_mc_samples, batch_size) 191 | # compute interventional ll 192 | ll = mechanism.mll(inputs, targets, prior_mode=False, reduce=False).squeeze(-1) 193 | assert ll.shape == (num_batches, num_mc_samples), print(f'Invalid shape: {ll.shape}') 194 | ll = ll.logsumexp(dim=1) - math.log(num_mc_samples) 195 | return ll.sum() if reduce else ll 196 | 197 | def construct_graph(self, num_nodes: int) -> nx.DiGraph: 198 | raise NotImplementedError 199 | 200 | def param_dict(self): 201 | mechanism_param_dict = {key: m.param_dict() for key, m in self.mechanisms.items()} 202 | if self.interventional_queries is not None: 203 | intr_query_param_dicts = [query.param_dict() for query in self.interventional_queries] 204 | else: 205 | intr_query_param_dicts = None 206 | params = {'num_nodes': self.num_nodes, 207 | 'mechanism_model': self.mechanism_model, 208 | 'graph': self.graph, 209 | 'name': self.name, 210 | 'mechanism_param_dict': mechanism_param_dict, 211 | 'non_intervenable_nodes': self.non_intervenable_nodes, 212 | 'intervention_bounds': self.intervention_bounds, 213 | 'num_test_samples_per_intervention': self.num_test_samples_per_intervention, 214 | 'observational_test_data': self.observational_test_data, 215 | 'interventional_test_data': self.interventional_test_data, 216 | 'num_test_queries': self.num_test_queries, 217 | 'intr_query_param_dicts': intr_query_param_dicts, 218 | 'query_ll': self.query_ll} 219 | return params 220 | 221 | def load_param_dict(self, param_dict): 222 | self.num_nodes = param_dict['num_nodes'] 223 | self.mechanism_model = param_dict['mechanism_model'] 224 | self.graph = param_dict['graph'] 225 | self.topological_order = list(nx.topological_sort(self.graph)) 226 | self.node_labels = sorted(list(set(self.graph.nodes))) 227 | self.name = param_dict['name'] 228 | self.non_intervenable_nodes = param_dict['non_intervenable_nodes'] 229 | self.intervenable_nodes = set(self.node_labels) - self.non_intervenable_nodes 230 | self.intervention_bounds = param_dict['intervention_bounds'] 231 | self.num_test_samples_per_intervention = param_dict['num_test_samples_per_intervention'] 232 | self.observational_test_data = param_dict['observational_test_data'] 233 | self.interventional_test_data = param_dict['interventional_test_data'] 234 | self.num_test_queries = param_dict['num_test_queries'] 235 | self.query_ll = param_dict['query_ll'] 236 | self.mechanisms = dict() 237 | for key, d in param_dict['mechanism_param_dict'].items(): 238 | self.mechanisms[key] = self.create_mechanism(d['in_size']) 239 | self.mechanisms[key].load_param_dict(d) 240 | 241 | if param_dict['intr_query_param_dicts'] is not None: 242 | self.interventional_queries = [] 243 | for query_param_dict in param_dict['intr_query_param_dicts']: 244 | self.interventional_queries.append(InterventionalDistributionsQuery.load_param_dict(query_param_dict)) 245 | 246 | def save(self, path): 247 | torch.save(self.param_dict(), path) 248 | 249 | @classmethod 250 | def load(cls, path): 251 | param_dict = torch.load(path) 252 | env = Environment(param_dict['num_nodes'], mechanism_model=None, num_test_samples_per_intervention=0, 253 | num_test_queries=0, graph=param_dict['graph']) 254 | env.load_param_dict(param_dict) 255 | return env 256 | -------------------------------------------------------------------------------- /src/experimental_design/exp_designer_abci_dibs_gp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, List, Optional 2 | 3 | import networkx as nx 4 | import torch.optim 5 | 6 | from src.environments.environment import Experiment, InterventionalDistributionsQuery 7 | from src.experimental_design.exp_designer_base import ExpDesignerBase 8 | from src.models.gp_model import GaussianProcessModel 9 | from src.models.graph_models import get_graph_key 10 | 11 | 12 | class ExpDesignerABCIDiBSGP(ExpDesignerBase): 13 | model: Optional[GaussianProcessModel] 14 | 15 | def __init__(self, intervention_bounds: Dict[str, Tuple[float, float]], opt_strategy: str = 'gp-ucb', 16 | distributed=False): 17 | super().__init__(intervention_bounds, opt_strategy, distributed) 18 | self.model = None 19 | 20 | def init_design_process(self, args: dict): 21 | self.model = args['mechanism_model'] 22 | 23 | if args['policy'] == 'scm-info-gain': 24 | def utility(interventions: dict): 25 | self.model.entropy_cache.clear() 26 | return self.scm_info_gain(interventions, 27 | args['inner_mc_graphs'], 28 | args['log_inner_graph_weights'], 29 | args['log_inner_particle_weights'], 30 | args['outer_mc_graphs'], 31 | args['outer_graph_weights'], 32 | args['outer_particle_weights'], 33 | args['batch_size'], 34 | args['num_exp_per_graph']) 35 | elif args['policy'] == 'graph-info-gain': 36 | def utility(interventions: dict): 37 | return self.graph_info_gain(interventions, 38 | args['inner_mc_graphs'], 39 | args['log_inner_graph_weights'], 40 | args['log_inner_particle_weights'], 41 | args['outer_mc_graphs'], 42 | args['outer_graph_weights'], 43 | args['outer_particle_weights'], 44 | args['batch_size'], 45 | args['num_exp_per_graph']) 46 | elif args['policy'] == 'intervention-info-gain': 47 | def utility(interventions: dict): 48 | return self.intervention_info_gain(interventions, 49 | args['experiments'], 50 | args['interventional_queries'], 51 | args['inner_mc_graphs'], 52 | args['log_inner_graph_weights'], 53 | args['log_inner_particle_weights'], 54 | args['outer_mc_graphs'], 55 | args['outer_graph_weights'], 56 | args['outer_particle_weights'], 57 | args['num_mc_queries'], 58 | args['num_batches_per_query'], 59 | args['batch_size'], 60 | args['num_exp_per_graph']) 61 | else: 62 | assert False, 'Invalid policy ' + args['policy'] + '!' 63 | 64 | self.utility = utility 65 | 66 | def compute_graph_posterior_mlls(self, experiments: List[Experiment], graphs: List[List[nx.DiGraph]], 67 | mode='independent_batches', reduce=True) -> torch.Tensor: 68 | num_particles = len(graphs) 69 | num_mc_graphs = len(graphs[0]) 70 | graph_mlls = [self.model.mll(experiments, graph, prior_mode=False, use_cache=True, mode=mode, reduce=reduce) 71 | for i in range(num_particles) for graph in graphs[i]] 72 | graph_mlls = torch.stack(graph_mlls).view(num_particles, num_mc_graphs, *graph_mlls[0].shape) 73 | return graph_mlls 74 | 75 | def graph_info_gain(self, interventions: dict, 76 | inner_mc_graphs: List[List[nx.DiGraph]], 77 | log_inner_graph_weights: torch.Tensor, 78 | log_inner_particle_weights: torch.Tensor, 79 | outer_mc_graphs: List[List[nx.DiGraph]], 80 | outer_graph_weights: torch.Tensor, 81 | outer_particle_weights: torch.Tensor, 82 | batch_size: int = 1, 83 | num_exp_per_graph: int = 1): 84 | num_particles = len(outer_mc_graphs) 85 | num_outer_mc_graphs = len(outer_mc_graphs[0]) 86 | 87 | with torch.no_grad(): 88 | graph_info_gains = torch.zeros(num_particles, num_outer_mc_graphs) 89 | for particle_idx, graphs in enumerate(outer_mc_graphs): 90 | for graph_idx, graph in enumerate(graphs): 91 | # check if graph is acyclic 92 | if get_graph_key(graph) not in self.model.topological_orders: 93 | print(f'Worker {self.worker_id}: could not sample from cyclic graph.') 94 | continue 95 | 96 | simulated_experiments = [self.model.sample(interventions, batch_size, num_exp_per_graph, graph)] 97 | 98 | self.model.clear_posterior_mll_cache() 99 | outer_posterior_mll = self.model.mll(simulated_experiments, graph, prior_mode=False, 100 | use_cache=True, mode='independent_batches', reduce=False) 101 | outer_posterior_mll = outer_posterior_mll.sum() 102 | 103 | # compute log p(D|D_old) 104 | inner_posterior_mlls = self.compute_graph_posterior_mlls(simulated_experiments, 105 | inner_mc_graphs, 106 | reduce=False) 107 | particle_posterior_mlls = (inner_posterior_mlls + 108 | log_inner_graph_weights.unsqueeze(-1)).logsumexp(dim=1) 109 | data_mll = (log_inner_particle_weights.unsqueeze(-1) + particle_posterior_mlls).logsumexp(dim=0) 110 | data_mll = data_mll.sum() 111 | 112 | graph_info_gains[particle_idx, graph_idx] += outer_posterior_mll - data_mll 113 | 114 | # compute info gain 115 | graph_info_gains /= num_exp_per_graph 116 | expected_info_gain = outer_particle_weights @ (outer_graph_weights * graph_info_gains).sum(dim=1) 117 | return expected_info_gain 118 | 119 | def scm_info_gain(self, interventions: dict, 120 | inner_mc_graphs: List[List[nx.DiGraph]], 121 | log_inner_graph_weights: torch.Tensor, 122 | log_inner_particle_weights: torch.Tensor, 123 | outer_mc_graphs: List[List[nx.DiGraph]], 124 | outer_graph_weights: torch.Tensor, 125 | outer_particle_weights: torch.Tensor, 126 | batch_size: int = 1, 127 | num_exp_per_graph: int = 1): 128 | num_particles = len(outer_mc_graphs) 129 | num_outer_mc_graphs = len(outer_mc_graphs[0]) 130 | 131 | with torch.no_grad(): 132 | graph_info_gains = torch.zeros(num_particles, num_outer_mc_graphs) 133 | for particle_idx, graphs in enumerate(outer_mc_graphs): 134 | for graph_idx, graph in enumerate(graphs): 135 | # check if graph is acyclic 136 | if get_graph_key(graph) not in self.model.topological_orders: 137 | print(f'Worker {self.worker_id}: could not sample from cyclic graph.') 138 | continue 139 | 140 | simulated_experiments = [self.model.sample(interventions, batch_size, num_exp_per_graph, graph)] 141 | 142 | self.model.clear_posterior_mll_cache() 143 | inner_posterior_mlls = self.compute_graph_posterior_mlls(simulated_experiments, 144 | inner_mc_graphs, 145 | reduce=False) 146 | 147 | # compute log p(D|D_old) 148 | particle_posterior_mlls = (inner_posterior_mlls + 149 | log_inner_graph_weights.unsqueeze(-1)).logsumexp(dim=1) 150 | data_mll = (log_inner_particle_weights.unsqueeze(-1) + particle_posterior_mlls).logsumexp(dim=0) 151 | data_mll = data_mll.sum() 152 | 153 | entropy = self.model.expected_noise_entropy(interventions, outer_mc_graphs[particle_idx][graph_idx], 154 | use_cache=True) 155 | graph_info_gains[particle_idx, graph_idx] += -entropy * batch_size - data_mll / num_exp_per_graph 156 | 157 | # compute info gain 158 | expected_info_gain = outer_particle_weights @ (outer_graph_weights * graph_info_gains).sum(dim=1) 159 | return expected_info_gain 160 | 161 | def intervention_info_gain(self, interventions: dict, 162 | experiments: List[Experiment], 163 | interventional_queries: List[InterventionalDistributionsQuery], 164 | inner_mc_graphs: List[List[nx.DiGraph]], 165 | log_inner_graph_weights: torch.Tensor, 166 | log_inner_particle_weights: torch.Tensor, 167 | outer_mc_graphs: List[List[nx.DiGraph]], 168 | outer_graph_weights: torch.Tensor, 169 | outer_particle_weights: torch.Tensor, 170 | num_mc_queries: int = 1, 171 | num_batches_per_query: int = 1, 172 | batch_size: int = 1, 173 | num_exp_per_graph: int = 1): 174 | # get mc graphs and compute log weights 175 | num_particles = len(outer_mc_graphs) 176 | num_outer_graphs = len(outer_mc_graphs[0]) 177 | num_inner_graphs = len(inner_mc_graphs[0]) 178 | 179 | # check if all inner graphs are acyclic and zero the weights of the cyclic ones 180 | for particle_idx in range(num_particles): 181 | for graph_idx, graph in enumerate(inner_mc_graphs[particle_idx]): 182 | if get_graph_key(graph) not in self.model.topological_orders: 183 | print(f'Worker {self.worker_id}: cannot evaluate query ll for cyclic graph.') 184 | log_inner_graph_weights[particle_idx, graph_idx] = torch.tensor(-1e8) 185 | 186 | # compute the info gain 187 | with torch.no_grad(): 188 | graph_info_gains = torch.zeros(num_particles, num_outer_graphs) 189 | for particle_idx in range(num_particles): 190 | for graph_idx, graph in enumerate(outer_mc_graphs[particle_idx]): 191 | # check if graph is acyclic 192 | if get_graph_key(graph) not in self.model.topological_orders: 193 | print(f'Worker {self.worker_id}: could not sample from cyclic graph.') 194 | continue 195 | 196 | # simulate experiments and compute data mll 197 | self.model.set_data(experiments) 198 | self.model.clear_posterior_mll_cache() 199 | simulated_experiments = [self.model.sample(interventions, batch_size, num_exp_per_graph, graph)] 200 | posterior_mlls = self.compute_graph_posterior_mlls(simulated_experiments, inner_mc_graphs, 201 | reduce=False) 202 | assert posterior_mlls.shape == (num_particles, num_inner_graphs, num_exp_per_graph) 203 | 204 | particle_posterior_mlls = (posterior_mlls + log_inner_graph_weights.unsqueeze(-1)).logsumexp(dim=1) 205 | data_mll = (log_inner_particle_weights.unsqueeze(-1) + particle_posterior_mlls).logsumexp(dim=0) 206 | data_mll = data_mll.mean() 207 | 208 | # simulate queries 209 | self.model.set_data(experiments + simulated_experiments) 210 | self.model.clear_posterior_mll_cache() 211 | 212 | # simulate queries & compute query lls 213 | mc_queries = self.model.sample_queries(interventional_queries, num_mc_queries, 214 | num_batches_per_query, graph) 215 | 216 | query_lls = [] 217 | for i in range(num_particles): 218 | for g in inner_mc_graphs[i]: 219 | # check if graph is acyclic 220 | if get_graph_key(g) in self.model.topological_orders: 221 | query_lls.append(self.model.query_log_probs(mc_queries, g)) 222 | else: 223 | query_lls.append(torch.tensor(0.)) 224 | 225 | query_lls = torch.stack(query_lls).view(num_particles, num_inner_graphs, num_mc_queries, 226 | num_batches_per_query) 227 | 228 | # compute and update per graph info gain 229 | tmp = posterior_mlls.view(num_particles, num_inner_graphs, 1, 1, num_exp_per_graph) + \ 230 | query_lls.unsqueeze(-1) + \ 231 | log_inner_graph_weights.view(num_particles, num_inner_graphs, 1, 1, 1) 232 | tmp = tmp.logsumexp(dim=1) 233 | tmp = (tmp + log_inner_particle_weights.view(num_particles, 1, 1, 1)).logsumexp(dim=0) 234 | graph_info_gains[graph_idx] += tmp.mean() 235 | 236 | graph_info_gains[graph_idx] = graph_info_gains[graph_idx] - data_mll 237 | 238 | expected_info_gain = outer_particle_weights @ (outer_graph_weights * graph_info_gains).sum(dim=1) 239 | return expected_info_gain 240 | -------------------------------------------------------------------------------- /src/abci_categorical_gp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch.optim 4 | 5 | from src.abci_base import ABCIBase 6 | from src.environments.environment import * 7 | from src.experimental_design.exp_designer_abci_categorical_gp import ExpDesignerABCICategoricalGP 8 | from src.models.gp_model import GaussianProcessModel 9 | from src.models.graph_models import CategoricalModel, graph_to_adj_mat 10 | from src.utils.metrics import shd, auroc, auprc 11 | 12 | 13 | class ABCICategoricalGP(ABCIBase): 14 | policies = {'observational', 'random', 'random-fixed-value', 'graph-info-gain', 'scm-info-gain', 15 | 'intervention-info-gain', 'oracle'} 16 | 17 | def __init__(self, env: Environment, policy, num_workers: int = 1, linear: bool = False): 18 | assert policy in self.policies, print(f'Invalid policy {policy}!') 19 | super().__init__(env, policy, num_workers) 20 | 21 | # store params 22 | self.linear = linear 23 | 24 | # init models 25 | self.graph_prior = CategoricalModel(self.env.node_labels) 26 | self.graph_posterior = CategoricalModel(self.env.node_labels) 27 | self.mechanism_model = GaussianProcessModel(env.node_labels, linear=linear) 28 | 29 | # init mechanisms for all graphs 30 | for graph in self.graph_prior.graphs: 31 | self.mechanism_model.init_mechanisms(graph) 32 | 33 | def experiment_designer_factory(self): 34 | distributed = self.num_workers > 1 35 | return ExpDesignerABCICategoricalGP(self.env.intervention_bounds, opt_strategy='gp-ucb', 36 | distributed=distributed) 37 | 38 | def run(self, num_experiments=10, batch_size=3, update_interval=5, log_interval=1, num_initial_obs_samples=1, 39 | checkpoint_interval: int = 10, outdir: str = None, job_id: str = ''): 40 | 41 | # pre-compute env test data stats 42 | with torch.no_grad(): 43 | num_test_samples = self.env.num_test_samples_per_intervention 44 | env_obs_test_ll = self.env.log_likelihood(self.env.observational_test_data) / num_test_samples 45 | env_intr_test_lls = {} 46 | for node, experiments in self.env.interventional_test_data.items(): 47 | env_intr_test_lls[node] = self.env.log_likelihood(experiments) / num_test_samples 48 | 49 | # run experiments 50 | for epoch in range(num_experiments): 51 | print(f'Starting experiment cycle {epoch + 1}/{num_experiments}...') 52 | 53 | # pick intervention according to policy 54 | print(f'Design and perform experiment...', flush=True) 55 | info_gain = None 56 | if self.policy == 'observational' or len(self.experiments) == 0 and num_initial_obs_samples > 0: 57 | interventions = {} 58 | elif self.policy == 'random': 59 | interventions = self.get_random_intervention() 60 | elif self.policy == 'random-fixed-value': 61 | interventions = self.get_random_intervention(0.) 62 | elif self.policy == 'oracle': 63 | interventions, info_gain = self.get_oracle_intervention(batch_size) 64 | else: 65 | if self.policy == 'graph-info-gain' or self.policy == 'scm-info-gain': 66 | args = {'mechanism_model': self.mechanism_model, 67 | 'graph_posterior': self.graph_posterior, 68 | 'batch_size': batch_size, 69 | 'num_exp_per_graph': 500, 70 | 'policy': self.policy, 71 | 'mode': 'full'} 72 | elif self.policy == 'intervention-info-gain': 73 | outer_mc_graphs, outer_log_weights = self.graph_posterior.get_mc_graphs('n-best', 5) 74 | args = {'mechanism_model': self.mechanism_model, 75 | 'graph_posterior': self.graph_posterior, 76 | 'experiments': self.experiments, 77 | 'interventional_queries': self.env.interventional_queries, 78 | 'outer_mc_graphs': outer_mc_graphs, 79 | 'outer_log_weights': outer_log_weights, 80 | 'num_mc_queries': 5, 81 | 'num_batches_per_query': 3, 82 | 'batch_size': batch_size, 83 | 'num_exp_per_graph': 50, 84 | 'policy': self.policy, 85 | 'mode': 'full'} 86 | else: 87 | assert False, print(f'Invalid policy {self.policy}!') 88 | 89 | if self.num_workers > 1: 90 | interventions, info_gain = self.design_experiment_distributed(args) 91 | else: 92 | designer = self.experiment_designer_factory() 93 | designer.init_design_process(args) 94 | interventions, info_gain = designer.get_best_experiment(self.env.intervenable_nodes) 95 | 96 | # record expected information gain of chosen intervention 97 | if info_gain is None: 98 | info_gain = torch.tensor(-1.) 99 | self.info_gain_list.append(info_gain) 100 | 101 | # perform experiment 102 | num_experiments_conducted = len(self.experiments) 103 | num_samples = batch_size 104 | if num_experiments_conducted == 0 and num_initial_obs_samples > 0: 105 | num_samples = num_initial_obs_samples 106 | self.experiments.append(self.env.sample(interventions, num_samples)) 107 | 108 | # set training data for mechanisms 109 | self.mechanism_model.set_data(self.experiments) 110 | 111 | # update mechanism hyperparameters 112 | hyperparam_update_interval = 1 113 | if hyperparam_update_interval > 0: 114 | if num_experiments_conducted <= 1: 115 | update_time = 1 116 | else: 117 | update_time = (num_experiments_conducted // hyperparam_update_interval) * hyperparam_update_interval 118 | 119 | self.mechanism_model.update_gp_hyperparameters(update_time, self.experiments, set_data=True) 120 | 121 | # update graph posterior 122 | if not self.policy == 'oracle': 123 | print(f'Updating graph posterior...', flush=True) 124 | self.graph_posterior = self.compute_graph_posterior(self.experiments, use_cache=True) 125 | 126 | print(f'Logging evaluation stats...', flush=True) 127 | 128 | # record graph posterior entropy 129 | self.graph_entropy_list.append(self.graph_posterior.entropy().detach()) 130 | 131 | # record expected SHD 132 | with torch.no_grad(): 133 | eshd = self.graph_posterior_expectation(lambda g: shd(self.env.graph, g)) 134 | self.eshd_list.append(eshd) 135 | 136 | # record env graph LL 137 | self.graph_ll_list.append(self.graph_posterior.log_prob(self.env.graph)) 138 | 139 | # record observational test LLs 140 | def test_ll(graph): 141 | return self.mechanism_model.mll(self.env.observational_test_data, graph, prior_mode=False, 142 | use_cache=True, mode='joint') 143 | 144 | self.mechanism_model.clear_posterior_mll_cache() 145 | with torch.no_grad(): 146 | ll = self.graph_posterior_expectation(test_ll) 147 | self.observational_test_ll_list.append(ll) 148 | 149 | # record interventional test LLs 150 | for node, experiments in self.env.interventional_test_data.items(): 151 | def test_ll(graph): 152 | return self.mechanism_model.mll(experiments, graph, prior_mode=False, use_cache=True, mode='joint') 153 | 154 | self.mechanism_model.clear_posterior_mll_cache() 155 | with torch.no_grad(): 156 | ll = self.graph_posterior_expectation(test_ll) 157 | self.interventional_test_ll_lists[node].append(ll) 158 | 159 | # record observational KLD 160 | def test_ll(graph): 161 | return self.mechanism_model.mll(self.env.observational_test_data, graph, prior_mode=False, 162 | use_cache=True, mode='independent_samples', reduce=False) 163 | 164 | self.mechanism_model.clear_posterior_mll_cache() 165 | with torch.no_grad(): 166 | ll = self.graph_posterior_expectation(test_ll, logspace=True).mean() 167 | self.observational_kld_list.append(env_obs_test_ll - ll) 168 | 169 | # record interventional test KLDs 170 | for node, experiments in self.env.interventional_test_data.items(): 171 | def test_ll(graph): 172 | return self.mechanism_model.mll(experiments, graph, prior_mode=False, use_cache=True, 173 | mode='independent_samples', reduce=False) 174 | 175 | self.mechanism_model.clear_posterior_mll_cache() 176 | with torch.no_grad(): 177 | ll = self.graph_posterior_expectation(test_ll, logspace=True).mean() 178 | self.interventional_kld_lists[node].append(env_intr_test_lls[node] - ll) 179 | 180 | # record AUROC/AUPRC scores 181 | with torch.no_grad(): 182 | posterior_edge_probs = self.graph_posterior.edge_probs() 183 | true_adj_mat = graph_to_adj_mat(self.env.graph, self.env.node_labels) 184 | self.auroc_list.append(auroc(posterior_edge_probs, true_adj_mat)) 185 | self.auprc_list.append(auprc(posterior_edge_probs, true_adj_mat)) 186 | 187 | # record query KLD 188 | if self.env.interventional_queries is not None: 189 | def test_query_ll(graph): 190 | query_lls = self.mechanism_model.query_log_probs(self.env.interventional_queries, graph, 200) 191 | return query_lls 192 | 193 | self.mechanism_model.clear_posterior_mll_cache() 194 | with torch.no_grad(): 195 | ll = self.graph_posterior_expectation(test_query_ll, logspace=True).mean() 196 | self.query_kld_list.append(self.env.query_ll - ll) 197 | 198 | if outdir is not None and 0 < epoch < num_experiments - 1 and (epoch + 1) % checkpoint_interval == 0: 199 | model = 'abci-categorical-gp-linear' if self.linear else 'abci-categorical-gp' 200 | outpath = outdir + model + '-' + self.policy + '-' + self.env.name + f'-{job_id}-exp-{epoch + 1}.pth' 201 | self.save(outpath) 202 | 203 | if log_interval > 0 and epoch % log_interval == 0: 204 | print(f'Experiment {epoch + 1}/{num_experiments}, ESHD is {eshd.item()}', flush=True) 205 | 206 | def get_oracle_intervention(self, num_samples: int, num_candidates_per_node: int = 10): 207 | current_entropy = self.graph_posterior.entropy() 208 | 209 | self.experiments.append(self.env.sample({}, num_samples)) 210 | posterior = self.compute_graph_posterior(self.experiments) 211 | best_intervention = {} 212 | best_info_gain = current_entropy - posterior.entropy() 213 | best_posterior = posterior 214 | 215 | for node in self.env.node_labels: 216 | bounds = self.env.intervention_bounds[node] 217 | candidates = torch.linspace(bounds[0], bounds[1], num_candidates_per_node) 218 | for i in range(num_candidates_per_node): 219 | self.experiments[-1] = self.env.sample({node: candidates[i]}, num_samples) 220 | posterior = self.compute_graph_posterior(self.experiments) 221 | info_gain = current_entropy - posterior.entropy() 222 | if info_gain > best_info_gain: 223 | best_info_gain = info_gain 224 | best_intervention = {node: candidates[i]} 225 | best_posterior = posterior 226 | 227 | self.graph_posterior = best_posterior 228 | return best_intervention, best_info_gain 229 | 230 | def compute_graph_posterior(self, experiments: List[Experiment], use_cache: bool = False) -> CategoricalModel: 231 | posterior = CategoricalModel(self.env.node_labels) 232 | self.mechanism_model.clear_prior_mll_cache() 233 | with torch.no_grad(): 234 | for graph in posterior.graphs: 235 | mll = self.mechanism_model.mll(experiments, graph, prior_mode=True, use_cache=use_cache).squeeze() 236 | posterior.set_log_prob(mll + self.graph_prior.log_prob(graph), graph) 237 | 238 | posterior.normalize() 239 | return posterior 240 | 241 | def graph_posterior_expectation(self, func: Callable[[nx.DiGraph], torch.Tensor], logspace=False): 242 | with torch.no_grad(): 243 | # compute function values 244 | func_values = [func(graph) for graph in self.graph_posterior.graphs] 245 | func_output_shape = func_values[0].shape 246 | func_output_dim = len(func_output_shape) 247 | func_values = torch.stack(func_values).view(self.graph_posterior.num_graphs, *func_output_shape) 248 | 249 | # compute expectation 250 | if logspace: 251 | log_graph_weights = torch.tensor([self.graph_posterior.log_prob(graph) for graph in 252 | self.graph_posterior.graphs]) 253 | log_graph_weights = log_graph_weights.view(self.graph_posterior.num_graphs, *([1] * func_output_dim)) 254 | 255 | expected_value = (log_graph_weights + func_values).logsumexp(dim=0) 256 | return expected_value 257 | 258 | graph_weights = torch.tensor([self.graph_posterior.log_prob(graph).exp() for graph in 259 | self.graph_posterior.graphs]) 260 | graph_weights = graph_weights.view(self.graph_posterior.num_graphs, *([1] * func_output_dim)) 261 | 262 | expected_value = (graph_weights * func_values).sum(dim=0) 263 | return expected_value 264 | 265 | def param_dict(self): 266 | params = super().param_dict() 267 | params.update({'linear': self.linear, 268 | 'mechanism_model_params': self.mechanism_model.param_dict(), 269 | 'graph_prior_params': self.graph_prior.param_dict(), 270 | 'graph_posterior_params': self.graph_posterior.param_dict()}) 271 | return params 272 | 273 | def load_param_dict(self, param_dict): 274 | super().load_param_dict(param_dict) 275 | self.linear = param_dict['linear'] 276 | self.mechanism_model.load_param_dict(param_dict['mechanism_model_params']) 277 | self.graph_prior.load_param_dict(param_dict['graph_prior_params']) 278 | self.graph_posterior.load_param_dict(param_dict['graph_posterior_params']) 279 | 280 | def save(self, path): 281 | torch.save(self.param_dict(), path) 282 | 283 | @classmethod 284 | def load(cls, path, num_workers: int = 1): 285 | param_dict = torch.load(path) 286 | 287 | env_param_dict = param_dict['env_param_dict'] 288 | env = Environment(env_param_dict['num_nodes'], mechanism_model=None, num_test_samples_per_intervention=0, 289 | num_test_queries=0, graph=env_param_dict['graph']) 290 | env.load_param_dict(env_param_dict) 291 | 292 | abci = ABCICategoricalGP(env, param_dict['policy'], num_workers, param_dict['linear']) 293 | abci.load_param_dict(param_dict) 294 | return abci 295 | -------------------------------------------------------------------------------- /src/models/mechanisms.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import gpytorch 4 | import torch 5 | import torch.distributions as dist 6 | from torch.nn import Module 7 | 8 | 9 | class Mechanism(Module): 10 | """ 11 | Class that represents a generic mechanism including a likelihood/noise model in an SCM. 12 | 13 | Attributes 14 | ---------- 15 | in_size : int 16 | Number of mechanism inputs. 17 | """ 18 | 19 | def __init__(self, in_size: int): 20 | """ 21 | Parameters 22 | ---------- 23 | in_size : int 24 | Number of mechanism inputs. 25 | """ 26 | super().__init__() 27 | self.in_size = in_size 28 | 29 | def _check_args(self, inputs: torch.Tensor = None, targets: torch.Tensor = None): 30 | """ 31 | Checks the generic argument shapes (inputs and targets) and their compatibility. 32 | 33 | Parameters 34 | ---------- 35 | inputs : torch.Tensor 36 | Mechanism inputs. 37 | targets : torch.Tensor 38 | Mechanism targets, e.g., for evaluating the marginal log-likelihood. 39 | """ 40 | if inputs is not None: 41 | assert inputs.dim() >= 2 and inputs.shape[-1] == self.in_size, print( 42 | f'Ill-shaped inputs: {inputs.shape}') 43 | if targets is not None: 44 | assert targets.dim() >= 1, print(f'Ill-shaped targets: {targets.shape}') 45 | if targets is not None and inputs is not None: 46 | assert inputs.shape[:-1] == targets.shape, print(f'Batch size mismatch: {inputs.shape} vs.' 47 | f' {targets.shape}') 48 | 49 | def forward(self, inputs: torch.Tensor, prior_mode: bool = False): 50 | """ 51 | Computes the mechanism output for a given input tensor. Must be implemented by all child classes. 52 | 53 | Parameters 54 | ---------- 55 | inputs : torch.Tensor 56 | Mechanism inputs. 57 | prior_mode : bool 58 | Whether to evaluate the mechanism with prior or posterior parameters. 59 | """ 60 | raise NotImplementedError 61 | 62 | def sample(self, inputs: torch.Tensor, prior_mode: bool = False): 63 | """ 64 | Generates samples a given input tensor according to the implemented likelihood model. Must be implemented by 65 | all child classes. 66 | 67 | Parameters 68 | ---------- 69 | inputs : torch.Tensor 70 | Mechanism inputs. 71 | prior_mode : bool 72 | Whether to evaluate the mechanism with prior or posterior parameters. 73 | """ 74 | raise NotImplementedError 75 | 76 | 77 | class GaussianRootNode(Mechanism): 78 | def __init__(self, mu_0: float = 0., kappa_0: float = 0.1, alpha_0: float = 50., beta_0: float = 25., 79 | static=False): 80 | super().__init__(in_size=0) 81 | 82 | # init prior and posterior hyper-parameters 83 | self.mu_0 = self.mu_n = torch.tensor(mu_0) 84 | self.kappa_0 = self.kappa_n = torch.tensor(kappa_0) 85 | self.alpha_0 = self.alpha_n = torch.tensor(alpha_0) 86 | self.beta_0 = self.beta_n = torch.tensor(beta_0) 87 | self.lam_0 = None 88 | self.train_targets = None 89 | 90 | self.static = static 91 | if static: 92 | self.init_as_static() 93 | 94 | def compute_posterior_params(self, targets: torch.Tensor, prior_mode=False): 95 | self._check_args(targets=targets) 96 | 97 | full_targets = targets 98 | if not prior_mode and self.train_targets is not None: 99 | full_targets = torch.cat((targets, self.train_targets.expand(*targets.shape[:-1], -1)), dim=-1) 100 | 101 | n = full_targets.shape[-1] 102 | empirical_means = full_targets.mean(dim=-1) 103 | 104 | kappa_n = self.kappa_0 + n 105 | mu_n = (self.kappa_0 * self.mu_0 + n * empirical_means) / kappa_n 106 | alpha_n = self.alpha_0 + 0.5 * n 107 | beta_n = self.beta_0 + 0.5 * (full_targets - empirical_means.unsqueeze(-1)).pow_(2).sum(dim=-1) + \ 108 | 0.5 * self.kappa_0 * n * (empirical_means - self.mu_0).pow_(2) / kappa_n 109 | 110 | return mu_n, kappa_n.expand(mu_n.shape), alpha_n.expand(mu_n.shape), beta_n 111 | 112 | def init_as_static(self): 113 | self.lam_0 = dist.Gamma(self.alpha_0, self.beta_0).sample() 114 | self.mu_0 = dist.Normal(0., (self.kappa_0 * self.lam_0).pow(-0.5)).sample() 115 | 116 | def set_data(self, inputs: torch.Tensor, targets: torch.Tensor): 117 | self._check_args(targets=targets) 118 | assert targets.dim() == 1, print('Can only work with one set of posterior params!') 119 | self.train_targets = targets 120 | self.mu_n, self.kappa_n, self.alpha_n, self.beta_n = self.compute_posterior_params(targets, prior_mode=True) 121 | 122 | def forward(self, inputs: torch.Tensor, prior_mode=False): 123 | assert inputs.dim() >= 2 124 | output_shape = (*inputs.shape[:-1], 1) 125 | 126 | if self.static or prior_mode: 127 | return self.mu_0 * torch.ones(output_shape) 128 | 129 | return self.mu_n * torch.ones(output_shape) 130 | 131 | def sample(self, inputs: torch.Tensor, prior_mode=False): 132 | assert inputs.dim() >= 2 133 | output_shape = (*inputs.shape[:-1], 1) 134 | 135 | if self.static: 136 | # sample from true distribution 137 | y_dist = dist.Normal(self.mu_0, self.lam_0.pow(-0.5)) 138 | return y_dist.sample(torch.Size(output_shape)) 139 | 140 | # sample from marginal likelihood 141 | if prior_mode: 142 | mu_n, kappa_n, alpha_n, beta_n = (self.mu_0, self.kappa_0, self.alpha_0, self.beta_0) 143 | else: 144 | mu_n, kappa_n, alpha_n, beta_n = (self.mu_n, self.kappa_n, self.alpha_n, self.beta_n) 145 | 146 | lambdas = dist.Gamma(alpha_n, beta_n).sample(output_shape[:-2]) 147 | mus = dist.Normal(mu_n.expand_as(lambdas), (kappa_n * lambdas).pow(-0.5)).sample() 148 | 149 | y_dist = dist.Normal(mus, lambdas.pow(-0.5)) 150 | samples = y_dist.sample(output_shape[-2:-1]).unsqueeze(-1).transpose(0, -1).view(output_shape) 151 | return samples 152 | 153 | def mll(self, inputs: torch.Tensor, targets: torch.Tensor, prior_mode=False, reduce=True): 154 | self._check_args(targets=targets) 155 | output_shape = targets.shape[:-1] 156 | if self.static: 157 | # evaluate true log-likelihood 158 | y_dist = dist.Normal(self.mu_0, self.lam_0.pow(-0.5)) 159 | lls = y_dist.log_prob(targets).squeeze(-1) 160 | else: 161 | if prior_mode: 162 | kappa_n, alpha_n, beta_n = (self.kappa_0, self.alpha_0, self.beta_0) 163 | else: 164 | kappa_n, alpha_n, beta_n = (self.kappa_n, self.alpha_n, self.beta_n) 165 | 166 | _, kappa_m, alpha_m, beta_m = self.compute_posterior_params(targets, prior_mode) 167 | lls = torch.lgamma(alpha_m) - torch.lgamma(alpha_n) + alpha_n * beta_n.log() - alpha_m * beta_m.log() + \ 168 | 0.5 * (kappa_n.log() - kappa_m.log()) - 0.5 * targets.shape[-1] * math.log(2. * math.pi) 169 | 170 | assert lls.shape == output_shape, print(lls.shape) 171 | if reduce: 172 | return lls.sum() 173 | return lls 174 | 175 | def expected_noise_entropy(self, prior_mode: bool = False) -> torch.Tensor: 176 | if self.static: 177 | return dist.Normal(self.mu_0, self.lam_0.pow(-0.5)).entropy() 178 | 179 | # expected noise entropy exact 180 | alpha, beta = (self.alpha_0, self.beta_0) if prior_mode else (self.alpha_n, self.beta_n) 181 | return 0.5 * (math.log(2. * math.pi * math.e) - torch.digamma(alpha) + beta.log()) 182 | 183 | # expected noise entropy point estimate (mean variance of inverse gamma posterior) 184 | # return 0.5 * (2. * math.pi * beta/(alpha + 1.) * math.e).log().squeeze() 185 | 186 | def param_dict(self): 187 | params = {'in_size': 0, 188 | 'mu_0': self.mu_0, 189 | 'kappa_0': self.kappa_0, 190 | 'alpha_0': self.alpha_0, 191 | 'beta_0': self.beta_0, 192 | 'lam_0': self.lam_0, 193 | 'mu_n': self.mu_n, 194 | 'kappa_n': self.kappa_n, 195 | 'alpha_n': self.alpha_n, 196 | 'beta_n': self.beta_n, 197 | 'static': self.static} 198 | 199 | return params 200 | 201 | def load_param_dict(self, param_dict): 202 | self.mu_0 = param_dict['mu_0'] 203 | self.kappa_0 = param_dict['kappa_0'] 204 | self.alpha_0 = param_dict['alpha_0'] 205 | self.beta_0 = param_dict['beta_0'] 206 | self.lam_0 = param_dict['lam_0'] 207 | self.mu_n = param_dict['mu_n'] 208 | self.kappa_n = param_dict['kappa_n'] 209 | self.alpha_n = param_dict['alpha_n'] 210 | self.beta_n = param_dict['beta_n'] 211 | self.static = param_dict['static'] 212 | 213 | 214 | class GaussianProcess(Mechanism): 215 | class ExactGPModelRQKernel(gpytorch.models.ExactGP): 216 | # ATTENTION: do not name the HP priors "noise_prior", "outputscale_prior" or "lengthscale_prior" 217 | noise_var_prior = dist.Gamma(50., 500.) 218 | outscale_prior = dist.Gamma(100., 10.) 219 | 220 | def __init__(self, train_x, train_y, likelihood, in_size: int): 221 | super().__init__(train_x, train_y, likelihood) 222 | self.mean_module = gpytorch.means.ZeroMean() 223 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RQKernel()) 224 | 225 | # draw kernel parameters 226 | self.lscale_prior = dist.Gamma(30. * in_size, 30.) 227 | self.posterior_noise, self.posterior_outputscale, self.posterior_lengthscale = self.draw_prior_hyperparams() 228 | 229 | def forward(self, x): 230 | mean = self.mean_module(x) 231 | covar = self.covar_module(x) 232 | return gpytorch.distributions.MultivariateNormal(mean, covar) 233 | 234 | def hyperparam_log_prior(self, prior_mode: bool = False): 235 | self.select_hyperparameters(prior_mode) 236 | return self.noise_var_prior.log_prob(self.likelihood.noise) + \ 237 | self.outscale_prior.log_prob(self.covar_module.outputscale) + \ 238 | self.lscale_prior.log_prob(self.covar_module.base_kernel.lengthscale) 239 | 240 | def draw_prior_hyperparams(self): 241 | return self.noise_var_prior.sample(), self.outscale_prior.sample(), self.lscale_prior.sample() 242 | 243 | def select_hyperparameters(self, prior_mode: bool = False): 244 | if prior_mode: 245 | # self.likelihood.noise, self.covar_module.outputscale, self.covar_module.base_kernel.lengthscale = \ 246 | # self.draw_prior_hyperparams() 247 | self.likelihood.noise = self.noise_var_prior.mean 248 | self.covar_module.outputscale = self.outscale_prior.mean 249 | self.covar_module.base_kernel.lengthscale = self.lscale_prior.mean 250 | else: 251 | self.likelihood.noise = self.posterior_noise 252 | self.covar_module.outputscale = self.posterior_outputscale 253 | self.covar_module.base_kernel.lengthscale = self.posterior_lengthscale 254 | 255 | def param_dict(self): 256 | params = {'posterior_noise': self.posterior_noise, 257 | 'posterior_outputscale': self.posterior_outputscale, 258 | 'posterior_lengthscale': self.posterior_lengthscale} 259 | return params 260 | 261 | def load_param_dict(self, param_dict): 262 | self.posterior_noise = param_dict['posterior_noise'] 263 | self.posterior_outputscale = param_dict['posterior_outputscale'] 264 | self.posterior_lengthscale = param_dict['posterior_lengthscale'] 265 | 266 | class ExactGPModelLinearKernel(gpytorch.models.ExactGP): 267 | # ATTENTION: do not name the HP priors "noise_prior", "outputscale_prior" or "lengthscale_prior" 268 | noise_var_prior = dist.Gamma(50., 500.) 269 | outscale_prior = dist.Gamma(100., 10.) 270 | 271 | def __init__(self, train_x, train_y, likelihood): 272 | super().__init__(train_x, train_y, likelihood) 273 | self.mean_module = gpytorch.means.ZeroMean() 274 | self.covar_module = gpytorch.kernels.LinearKernel() 275 | 276 | # draw kernel parameters 277 | self.posterior_noise, self.posterior_outputscale = self.draw_prior_hyperparams() 278 | 279 | def forward(self, x): 280 | mean = self.mean_module(x) 281 | covar = self.covar_module(x) 282 | return gpytorch.distributions.MultivariateNormal(mean, covar) 283 | 284 | def hyperparam_log_prior(self, prior_mode: bool = False): 285 | self.select_hyperparameters(prior_mode) 286 | return self.noise_var_prior.log_prob(self.likelihood.noise) + \ 287 | self.outscale_prior.log_prob(self.covar_module.variance) 288 | 289 | def draw_prior_hyperparams(self): 290 | return self.noise_var_prior.sample(), self.outscale_prior.sample() 291 | 292 | def select_hyperparameters(self, prior_mode: bool = False): 293 | if prior_mode: 294 | # self.likelihood.noise, self.gp.covar_module.outputscale = \ 295 | # self.draw_prior_hyperparams() 296 | self.likelihood.noise = self.noise_var_prior.mean 297 | self.covar_module.outputscale = self.outscale_prior.mean 298 | else: 299 | self.likelihood.noise = self.posterior_noise 300 | self.covar_module.outputscale = self.posterior_outputscale 301 | 302 | def param_dict(self): 303 | params = {'posterior_noise': self.posterior_noise, 304 | 'posterior_outputscale': self.posterior_outputscale} 305 | return params 306 | 307 | def load_param_dict(self, param_dict): 308 | self.posterior_noise = param_dict['posterior_noise'] 309 | self.posterior_outputscale = param_dict['posterior_outputscale'] 310 | 311 | def __init__(self, in_size: int, static=False, linear=False): 312 | super().__init__(in_size) 313 | 314 | # initialize likelihood and gp model 315 | likelihood = gpytorch.likelihoods.GaussianLikelihood() 316 | if linear: 317 | self.gp = GaussianProcess.ExactGPModelLinearKernel(None, None, likelihood) 318 | else: 319 | self.gp = GaussianProcess.ExactGPModelRQKernel(None, None, likelihood, in_size) 320 | self.static = static 321 | self.linear = linear 322 | 323 | # can set true kernel parameters for testing 324 | # self.gp.likelihood.noise = 0.1 325 | # self.gp.covar_module.outputscale = 10. 326 | # self.gp.covar_module.base_kernel.lengthscale = self.in_size 327 | 328 | if static: 329 | self.init_as_static() 330 | 331 | def init_as_static(self): 332 | # generate support points and sample training targets from the GP prior 333 | num_train = 50 * self.in_size 334 | train_x = 20. * (torch.rand((num_train, self.in_size)) - 0.5) 335 | self.eval() 336 | with gpytorch.settings.prior_mode(True): 337 | self.gp.select_hyperparameters(prior_mode=False) 338 | f_dist = self.gp(train_x) 339 | y_dist = self.gp.likelihood(f_dist) 340 | train_y = y_dist.sample().detach() 341 | 342 | # update GP data 343 | self.set_data(train_x, train_y) 344 | 345 | def set_data(self, inputs: torch.Tensor, targets: torch.Tensor): 346 | self._check_args(inputs, targets) 347 | self.gp.set_train_data(inputs, targets, strict=False) 348 | 349 | def forward(self, inputs: torch.Tensor, prior_mode=False): 350 | self._check_args(inputs) 351 | output_shape = (*inputs.shape[:-1], 1) 352 | 353 | self.eval() 354 | with gpytorch.settings.prior_mode(prior_mode): 355 | self.gp.select_hyperparameters(prior_mode) 356 | f_dist = self.gp(inputs) 357 | return f_dist.mean.view(output_shape) 358 | 359 | def sample(self, inputs: torch.Tensor, prior_mode=False): 360 | self._check_args(inputs) 361 | output_shape = (*inputs.shape[:-1], 1) 362 | 363 | self.eval() 364 | with gpytorch.settings.prior_mode(prior_mode): 365 | self.gp.select_hyperparameters(prior_mode) 366 | f_dist = self.gp(inputs) 367 | y_dist = self.gp.likelihood(f_dist.mean) if self.static else self.gp.likelihood(f_dist) 368 | return y_dist.sample().view(output_shape) 369 | 370 | def mll(self, inputs: torch.Tensor, targets: torch.Tensor, prior_mode=False, reduce=True): 371 | self._check_args(inputs, targets) 372 | output_shape = targets.shape[:-1] 373 | with gpytorch.settings.prior_mode(prior_mode): 374 | self.gp.select_hyperparameters(prior_mode) 375 | f_dist = self.gp(inputs) 376 | 377 | if self.static: 378 | y_dist = self.gp.likelihood(f_dist.mean) 379 | mlls = y_dist.log_prob(targets).squeeze(-1) 380 | else: 381 | y_dist = self.gp.likelihood(f_dist) 382 | mlls = y_dist.log_prob(targets) 383 | assert mlls.shape == output_shape, print(f'Invalid shape {mlls.shape}!') 384 | 385 | if reduce: 386 | return mlls.sum() 387 | return mlls 388 | 389 | def expected_noise_entropy(self, prior_mode: bool = False) -> torch.Tensor: 390 | # use point estimate with the MAP variance 391 | self.gp.select_hyperparameters(prior_mode) 392 | return 0.5 * (2. * math.pi * self.gp.likelihood.noise * math.e).log().squeeze() 393 | 394 | def param_dict(self): 395 | gp_param_dict = self.gp.param_dict() 396 | params = {'in_size': self.in_size, 397 | 'static': self.static, 398 | 'linear': self.linear, 399 | 'gp_param_dict': gp_param_dict} 400 | 401 | if self.static: 402 | params['train_inputs'] = self.gp.train_inputs 403 | params['train_targets'] = self.gp.train_targets 404 | return params 405 | 406 | def load_param_dict(self, param_dict): 407 | self.in_size = param_dict['in_size'] 408 | self.static = param_dict['static'] 409 | self.linear = param_dict['linear'] 410 | self.gp.load_param_dict(param_dict['gp_param_dict']) 411 | 412 | if self.static: 413 | self.gp.train_inputs = param_dict['train_inputs'] 414 | self.gp.train_targets = param_dict['train_targets'] 415 | -------------------------------------------------------------------------------- /src/models/gp_model.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | from typing import Tuple, Optional, List 4 | 5 | import networkx as nx 6 | import torch 7 | 8 | from src.environments.environment import Experiment, gather_data, InterventionalDistributionsQuery 9 | from src.models.graph_models import get_graph_key, get_parents 10 | from src.models.mechanisms import Mechanism, GaussianProcess, GaussianRootNode 11 | 12 | 13 | def get_mechanism_key(node, parents: List) -> str: 14 | parents = sorted(parents) 15 | key = str(node) + '<-' + ','.join([str(parent) for parent in parents]) 16 | return key 17 | 18 | 19 | def resolve_mechanism_key(key: str) -> Tuple[str, List[str]]: 20 | idx = key.find('<-') 21 | assert idx > 0, print('Invalid key: ' + key) 22 | node = key[:idx] 23 | parents = key[idx + 2:].split(',') if len(key) > idx + 2 else [] 24 | return node, parents 25 | 26 | 27 | class GaussianProcessModel: 28 | def __init__(self, node_labels: List[str], linear: bool = False): 29 | self.node_labels = sorted(list(set(node_labels))) 30 | self.linear = linear 31 | self.mechanisms = dict() 32 | self.mechanism_init_times = dict() 33 | self.mechanism_update_times = dict() 34 | self.topological_orders = dict() 35 | self.topological_orders_init_times = dict() 36 | self.prior_mll_cache = dict() 37 | self.posterior_mll_cache = dict() 38 | self.entropy_cache = dict() 39 | 40 | def get_parameters(self, keys: List[str] = None): 41 | keys = self.mechanisms.keys() if keys is None else keys 42 | param_lists = [list(self.mechanisms[key].parameters()) for key in keys] 43 | return list(itertools.chain(*param_lists)) 44 | 45 | def init_mechanisms(self, graph: nx.DiGraph, init_time: int = 0): 46 | graph_key = get_graph_key(graph) 47 | if graph_key not in self.topological_orders and nx.is_directed_acyclic_graph(graph): 48 | self.topological_orders[graph_key] = list(nx.topological_sort(graph)) 49 | if graph_key in self.topological_orders: 50 | self.topological_orders_init_times[graph_key] = init_time 51 | 52 | initialized_mechanisms = [] 53 | for node in graph: 54 | parents = get_parents(node, graph) 55 | key = get_mechanism_key(node, parents) 56 | initialized_mechanisms.append(key) 57 | self.mechanism_init_times[key] = init_time 58 | if key not in self.mechanisms: 59 | self.mechanism_update_times[key] = 0 60 | self.mechanisms[key] = self.create_mechanism(len(parents)) 61 | 62 | return initialized_mechanisms 63 | 64 | def discard_mechanisms(self, current_time: int, max_age: int): 65 | keys = [key for key, time in self.mechanism_init_times.items() if current_time - time > max_age] 66 | print(f'Discarding {len(keys)} old mechanisms...') 67 | for key in keys: 68 | del self.mechanisms[key] 69 | del self.mechanism_init_times[key] 70 | del self.mechanism_update_times[key] 71 | 72 | keys = [key for key, time in self.topological_orders_init_times.items() if current_time - time > max_age] 73 | print(f'Discarding {len(keys)} old topological orders...') 74 | for key in keys: 75 | del self.topological_orders[key] 76 | del self.topological_orders_init_times[key] 77 | 78 | def eval(self, keys: List[str] = None): 79 | keys = self.mechanisms.keys() if keys is None else keys 80 | for key in keys: 81 | self.mechanisms[key].eval() 82 | 83 | def train(self, keys: List[str] = None): 84 | keys = self.mechanisms.keys() if keys is None else keys 85 | for key in keys: 86 | self.mechanisms[key].train() 87 | 88 | def clear_prior_mll_cache(self, keys: List[str] = None): 89 | if keys is None: 90 | self.prior_mll_cache.clear() 91 | else: 92 | for key in keys: 93 | if key in self.prior_mll_cache: 94 | del self.prior_mll_cache[key] 95 | 96 | def clear_posterior_mll_cache(self, keys: List[str] = None): 97 | if keys is None: 98 | self.posterior_mll_cache.clear() 99 | else: 100 | for key in keys: 101 | if key in self.posterior_mll_cache: 102 | del self.posterior_mll_cache[key] 103 | 104 | def get_mechanism(self, node, graph: nx.DiGraph = None, parents: List[str] = None) -> Mechanism: 105 | assert graph is not None or parents is not None 106 | 107 | # get unique mechanism key 108 | parents = get_parents(node, graph) if parents is None else list(set(parents)) 109 | key = get_mechanism_key(node, parents) 110 | 111 | # return mechanism if it already exists 112 | if key in self.mechanisms: 113 | return self.mechanisms[key] 114 | 115 | # if mechanism does not yet exists in the model, create a new mechanism 116 | num_parents = len(parents) 117 | self.mechanisms[key] = self.create_mechanism(num_parents) 118 | return self.mechanisms[key] 119 | 120 | def node_mll(self, experiments: List[Experiment], node: str, graph: nx.DiGraph, prior_mode=False, 121 | use_cache=False, mode='joint', reduce=True) -> torch.Tensor: 122 | cache = self.prior_mll_cache if prior_mode else self.posterior_mll_cache 123 | parents = get_parents(node, graph) 124 | key = get_mechanism_key(node, parents) 125 | mll = torch.tensor(0.) 126 | if not use_cache or key not in cache: 127 | # gather data from the experiments 128 | inputs, targets = gather_data(experiments, node, parents=parents, mode=mode) 129 | # check if we have any data for this node 130 | if targets is not None: 131 | # compute log-likelihood 132 | mechanism = self.get_mechanism(node, parents=parents) 133 | try: 134 | mll = mechanism.mll(inputs, targets, prior_mode, reduce=reduce) 135 | except Exception as e: 136 | print( 137 | f'Exception occured in GaussianProcessModel.mll() when computing MLL for mechanism {key} ' 138 | f'with prior mode {prior_mode} and use cache {use_cache}:') 139 | print(e) 140 | # cache mll 141 | cache[key] = mll 142 | else: 143 | mll = cache[key] 144 | 145 | return mll 146 | 147 | def mll(self, experiments: List[Experiment], graph: nx.DiGraph, prior_mode=False, use_cache=False, 148 | mode='joint', reduce=True) -> torch.Tensor: 149 | mll = torch.tensor(0.) 150 | for node in self.node_labels: 151 | mll = mll + self.node_mll(experiments, node, graph, prior_mode, use_cache, mode=mode, reduce=reduce) 152 | return mll 153 | 154 | def expected_noise_entropy(self, interventions, graph: nx.DiGraph, use_cache=False) -> torch.Tensor: 155 | entropy = torch.tensor(0.) 156 | for node in self.node_labels: 157 | if node in interventions: 158 | continue 159 | 160 | parents = get_parents(node, graph) 161 | key = get_mechanism_key(node, parents) 162 | if not use_cache or key not in self.entropy_cache: 163 | # compute and cache entropy 164 | mechanism = self.get_mechanism(node, parents=parents) 165 | mechanism_entropy = mechanism.expected_noise_entropy() 166 | self.entropy_cache[key] = mechanism_entropy 167 | else: 168 | # take entropy from cache 169 | mechanism_entropy = self.entropy_cache[key] 170 | 171 | entropy += mechanism_entropy 172 | return entropy 173 | 174 | def get_num_mechanisms(self): 175 | return len(self.mechanisms) 176 | 177 | def mechanism_mlls(self, experiments: List[Experiment], keys: List[str] = None, prior_mode=False) -> torch.Tensor: 178 | # if no keys are given compute mlls for all mechanisms 179 | keys = self.mechanisms.keys() if keys is None else keys 180 | 181 | mlls = torch.tensor(0.) 182 | for key in keys: 183 | node, parents = resolve_mechanism_key(key) 184 | 185 | # gather data from the experiments 186 | inputs, targets = gather_data(experiments, node, parents=parents, mode='joint') 187 | 188 | # check if we have any data for this node 189 | if targets is None: 190 | continue 191 | 192 | # compute log-likelihood 193 | mechanism = self.mechanisms[key] 194 | try: 195 | mlls += mechanism.mll(inputs, targets, prior_mode) / targets.numel() 196 | except Exception as e: 197 | print( 198 | f'Exception occured in GaussianProcessModel.mechanism_mlls() when computing MLL for mechanism ' 199 | f'{key} with prior mode {prior_mode}:') 200 | print(e) 201 | if isinstance(mechanism, GaussianProcess): 202 | print('Resampling GP hyperparameters...') 203 | mechanism.init_hyperparams() 204 | return mlls 205 | 206 | def mechanism_log_hp_priors(self, keys: List[str] = None) -> torch.Tensor: 207 | # if no keys are given compute mlls for all mechanisms 208 | keys = self.mechanisms.keys() if keys is None else keys 209 | 210 | gps = [self.mechanisms[key] for key in keys if isinstance(self.mechanisms[key], GaussianProcess)] 211 | if not gps: 212 | return torch.tensor(0.) 213 | 214 | log_priors = torch.stack([gp.gp.hyperparam_log_prior() for gp in gps]) 215 | return log_priors.sum() 216 | 217 | def create_mechanism(self, num_parents: int) -> Mechanism: 218 | return GaussianProcess(num_parents, linear=self.linear) if num_parents > 0 else GaussianRootNode() 219 | 220 | def set_data(self, experiments: List[Experiment], keys: List[str] = None): 221 | # if no keys are given set data for all mechanisms 222 | keys = self.mechanisms.keys() if keys is None else keys 223 | 224 | for key in keys: 225 | node, parents = resolve_mechanism_key(key) 226 | 227 | # gather data from the experiments 228 | inputs, targets = gather_data(experiments, node, parents=parents, mode='joint') 229 | 230 | # check if we have any data for this node 231 | if targets is None: 232 | continue 233 | 234 | # set GP data 235 | self.mechanisms[key].set_data(inputs, targets) 236 | 237 | def sample(self, interventions: dict, batch_size: int, num_batches: int, graph: nx.DiGraph) -> Experiment: 238 | data = dict() 239 | for node in self.topological_orders[get_graph_key(graph)]: 240 | # check if node is intervened upon 241 | if node in interventions: 242 | node_samples = torch.ones(num_batches, batch_size, 1) * interventions[node] 243 | else: 244 | # sample from mechanism 245 | parents = get_parents(node, graph) 246 | mechanism = self.get_mechanism(node, parents=parents) 247 | if not parents: 248 | node_samples = mechanism.sample(torch.empty(num_batches, batch_size, 1)) 249 | else: 250 | x = torch.cat([data[parent] for parent in parents], dim=-1) 251 | assert x.shape == (num_batches, batch_size, mechanism.in_size) 252 | node_samples = mechanism.sample(x) 253 | 254 | # store samples 255 | data[node] = node_samples 256 | 257 | return Experiment(interventions, data) 258 | 259 | def interventional_mll(self, targets, node: str, interventions: dict, graph: nx.DiGraph, num_mc_samples=50, 260 | reduce=True): 261 | assert targets.dim() == 2, print(f'Invalid shape {targets.shape}') 262 | num_batches, batch_size = targets.shape 263 | 264 | parents = get_parents(node, graph) 265 | mechanism = self.get_mechanism(node, parents=parents) 266 | 267 | if len(parents) == 0: 268 | # if we have a root note imll is simple 269 | mll = mechanism.mll(None, targets, prior_mode=False, reduce=False) 270 | assert mll.shape == (num_batches,), print(f'Invalid shape {mll.shape}!') 271 | return mll.sum() if reduce else mll 272 | 273 | # otherwise, do MC estimate via ancestral sampling 274 | samples = self.sample(interventions, batch_size, num_mc_samples, graph) 275 | # assemble inputs and targets 276 | inputs, _ = gather_data([samples], node, parents=parents, mode='independent_batches') 277 | inputs = inputs.unsqueeze(0).expand(num_batches, -1, -1, -1) 278 | assert inputs.shape == (num_batches, num_mc_samples, batch_size, len(parents)) 279 | targets = targets.unsqueeze(1).expand(-1, num_mc_samples, batch_size) 280 | assert targets.shape == (num_batches, num_mc_samples, batch_size) 281 | # compute interventional mll 282 | mll = mechanism.mll(inputs, targets, prior_mode=False, reduce=False) 283 | assert mll.shape == (num_batches, num_mc_samples), print(f'Invalid shape {mll.shape}!') 284 | mll = mll.logsumexp(dim=1) - math.log(num_mc_samples) 285 | return mll.sum() if reduce else mll 286 | 287 | def sample_queries(self, queries: List[InterventionalDistributionsQuery], num_mc_queries: int, 288 | num_batches_per_query: int, graph: nx.DiGraph): 289 | 290 | interventional_queries = [query.clone() for query in queries] 291 | with torch.no_grad(): 292 | for query in interventional_queries: 293 | experiments = [] 294 | for i in range(num_mc_queries): 295 | interventions = query.sample_intervention() 296 | experiments.append(self.sample(interventions, 1, num_batches_per_query, graph)) 297 | 298 | query.set_sample_queries(experiments) 299 | 300 | return interventional_queries 301 | 302 | def query_log_probs(self, queries: List[InterventionalDistributionsQuery], graph: nx.DiGraph, 303 | num_imll_mc_samples: int = 50): 304 | num_queries = len(queries) 305 | num_mc_queries = len(queries[0].sample_queries) 306 | num_batches_per_query = queries[0].sample_queries[0].num_batches 307 | 308 | query_lls = torch.zeros(num_mc_queries, num_queries, num_batches_per_query) 309 | for i in range(num_mc_queries): 310 | for query_idx, query in enumerate(queries): 311 | query_node = query.query_nodes[0] # ToDo: supports only single query node!!! 312 | targets = query.sample_queries[i].data[query_node].squeeze(-1) 313 | imll = self.interventional_mll(targets, query_node, query.sample_queries[i].interventions, graph, 314 | num_mc_samples=num_imll_mc_samples, reduce=False) 315 | query_lls[i, query_idx] = imll 316 | 317 | return query_lls.sum(dim=1) 318 | 319 | def update_gp_hyperparameters(self, update_time: int, experiments: List[Experiment], set_data=False, 320 | mechanisms: Optional[List[str]] = None): 321 | assert 0 < update_time <= len(experiments), print(f'Cannot update on {update_time}/{len(experiments)} ' 322 | f'experiments!') 323 | if mechanisms is None: 324 | mechanisms = list(self.mechanisms.keys()) 325 | 326 | keys = [key for key in mechanisms if self.mechanism_update_times[key] != update_time] 327 | 328 | if keys: 329 | print(f'Updating {len(keys)} GP\'s hyperparams with first {update_time} experiments.') 330 | max_update_size = 500 331 | num_full_batches = len(keys) // max_update_size 332 | key_batches = [keys[i * max_update_size:(i + 1) * max_update_size] for i in range(num_full_batches)] 333 | if len(keys) % max_update_size > 0: 334 | tmp = keys[max_update_size * num_full_batches:] 335 | key_batches.append(tmp) 336 | 337 | for batch in key_batches: 338 | self.set_data(experiments[:update_time], batch) 339 | self.optimize_gp_hyperparams(experiments[:update_time], batch) 340 | self.clear_prior_mll_cache(batch) 341 | self.clear_posterior_mll_cache(batch) 342 | for key in batch: 343 | self.mechanism_update_times[key] = update_time 344 | 345 | # set training data for mechanisms 346 | if set_data: 347 | self.set_data(experiments, mechanisms) 348 | 349 | # put mechanisms into eval mode 350 | self.eval(mechanisms) 351 | 352 | def optimize_gp_hyperparams(self, experiments: List[Experiment], keys: List[str] = None, num_steps: int = 70, 353 | log_interval: int = 0): 354 | keys = self.mechanisms.keys() if keys is None else keys 355 | params = self.get_parameters(keys) 356 | if not params: 357 | return 358 | 359 | optimizer = torch.optim.RMSprop(params, lr=2e-2) 360 | 361 | losses = [] 362 | self.train(keys) 363 | for i in range(num_steps): 364 | optimizer.zero_grad() 365 | loss = -self.mechanism_mlls(experiments, keys) 366 | loss -= self.mechanism_log_hp_priors(keys) 367 | 368 | loss.backward() 369 | optimizer.step() 370 | losses.append(loss.item()) 371 | 372 | if i > 5 and torch.tensor(losses[-4:-1]).mean() - torch.tensor(losses[-5:-2]).mean() < 1e-3: 373 | # print(f'\nStopping GP parameter update early after improvement stagnates...') 374 | break 375 | 376 | if log_interval < 1: 377 | continue 378 | 379 | print(f'Step {i + 1} of {num_steps}, negative MLL is {loss.item()}...', 380 | end='' if i % log_interval == 0 else '\r', flush=True) 381 | 382 | return losses 383 | 384 | def submodel(self, graphs): 385 | mechanisms_keys = {get_mechanism_key(node, get_parents(node, graph)) for graph in graphs for node in graph} 386 | submodel = self.__class__(self.node_labels) 387 | submodel.mechanisms = {key: self.mechanisms[key] for key in mechanisms_keys} 388 | submodel.topological_orders = self.topological_orders 389 | return submodel 390 | 391 | def param_dict(self): 392 | mechanism_param_dict = {key: m.param_dict() for key, m in self.mechanisms.items()} 393 | params = {'node_labels': self.node_labels, 394 | 'linear': self.linear, 395 | 'mechanism_init_times': self.mechanism_init_times, 396 | 'mechanism_update_times': self.mechanism_update_times, 397 | 'topological_orders': self.topological_orders, 398 | 'mechanism_param_dict': mechanism_param_dict} 399 | return params 400 | 401 | def load_param_dict(self, param_dict): 402 | self.node_labels = param_dict['node_labels'] 403 | self.linear = param_dict['linear'] 404 | self.mechanism_init_times = param_dict['mechanism_init_times'] 405 | self.mechanism_update_times = param_dict['mechanism_update_times'] 406 | self.topological_orders = param_dict['topological_orders'] 407 | for key, d in param_dict['mechanism_param_dict'].items(): 408 | self.mechanisms[key] = self.create_mechanism(d['in_size']) 409 | self.mechanisms[key].load_param_dict(d) 410 | -------------------------------------------------------------------------------- /src/models/graph_models.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | from typing import List, Tuple, Dict, Any 4 | 5 | import networkx as nx 6 | import torch 7 | import torch.distributions as dist 8 | from torch.autograd import Function 9 | from torch.nn.functional import logsigmoid 10 | 11 | 12 | def generate_all_dags(num_nodes: int = 3, node_labels: List[str] = None) -> List[nx.DiGraph]: 13 | """Generates all directed acyclic graphs with a given number of nodes or node labels. 14 | 15 | Parameters 16 | ---------- 17 | num_nodes : int 18 | The number of nodes. If `node_labels` is given this has no effect. 19 | node_labels : List[str] 20 | List of node labels. If not None, the number of nodes is inferred automatically. 21 | 22 | Returns 23 | ------ 24 | List[nx.DiGraph] 25 | A list of DAGs as Networkx DiGraph objects.. 26 | """ 27 | # check if node labels given and create enumerative mapping 28 | if node_labels is not None: 29 | node_labels = sorted(list(set(node_labels))) 30 | num_nodes = len(node_labels) 31 | node_map = dict(zip(list(range(num_nodes)), node_labels)) 32 | else: 33 | node_map = dict(zip(list(range(num_nodes)), list(range(num_nodes)))) 34 | 35 | # check dag size feasibility 36 | assert num_nodes > 0, f'There is no such thing as a graph with {num_nodes} nodes.' 37 | assert num_nodes < 5, f'There are a lot of DAGs with {num_nodes} nodes...' 38 | 39 | # generate adjecency lists of all possible, simple graphs with num_nodes nodes 40 | adj_lists = [[]] 41 | for src in range(num_nodes): 42 | for dest in range(num_nodes): 43 | if src != dest: 44 | adj_lists = adj_lists + [[[node_map[src], node_map[dest]]] + adj_list for adj_list in adj_lists] 45 | 46 | # create graphs and keep only DAGs 47 | graphs = [] 48 | for adj_list in adj_lists: 49 | graph = nx.DiGraph() 50 | graph.add_nodes_from(list(node_map.values())) 51 | graph.add_edges_from(adj_list) 52 | if nx.is_directed_acyclic_graph(graph): 53 | graphs.append(graph) 54 | 55 | return graphs 56 | 57 | 58 | def get_graph_key(graph: nx.DiGraph) -> str: 59 | """Generates a unique string representation of a directed graph. Can be used as a dictionary key. 60 | 61 | Parameters 62 | ---------- 63 | graph : nx.DiGraph 64 | The graph for which to generate the string representation. 65 | 66 | Returns 67 | ------ 68 | str 69 | A unique string representation of `graph`. 70 | """ 71 | graph_str = '' 72 | for i, node in enumerate(sorted(graph)): 73 | if i > 0: 74 | graph_str += '|' 75 | graph_str += str(node) + '<-' + ','.join([str(parent) for parent in get_parents(node, graph)]) 76 | 77 | return graph_str 78 | 79 | 80 | def resolve_graph_key(key: str) -> nx.DiGraph: 81 | """Return a NetworkX DiGraph object according to the given graph key. 82 | 83 | Parameters 84 | ---------- 85 | key : str 86 | The string representation of the graph to be generated. 87 | 88 | Returns 89 | ------ 90 | nx.DiGraph 91 | A graph object corresponding to the given graph key. 92 | """ 93 | graph = nx.DiGraph() 94 | mech_strings = key.split('|') 95 | for mstr in mech_strings: 96 | idx = mstr.find('<-') 97 | node = mstr[:idx] 98 | parents = mstr[idx + 2:].split(',') if len(mstr) > idx + 2 else [] 99 | 100 | graph.add_node(node) 101 | for parent in parents: 102 | graph.add_edge(parent, node) 103 | 104 | return graph 105 | 106 | 107 | def get_parents(node: str, graph: nx.DiGraph) -> List[str]: 108 | """Returns a list of parents for a given node in a given graph. 109 | 110 | Parameters 111 | ---------- 112 | node : str 113 | The child node. 114 | graph : nx.DiGraph 115 | The graph inducing the parent set. 116 | 117 | Returns 118 | ------ 119 | List[str] 120 | The list of parents. 121 | """ 122 | return sorted(list(graph.predecessors(node))) 123 | 124 | 125 | def graph_to_adj_mat(graph: nx.DiGraph, node_labels: List[str]) -> torch.Tensor: 126 | """Returns the adjecency matrix of the given graph as tensor. 127 | 128 | Parameters 129 | ---------- 130 | graph : nx.DiGraph 131 | The graph. 132 | node_labels : List[str] 133 | The list of node labels determining the order of the adjacency matrix. 134 | 135 | Returns 136 | ------ 137 | torch.Tensor 138 | The adjacency matrix of the graph. 139 | """ 140 | return torch.tensor(nx.to_numpy_array(graph, nodelist=node_labels)).float() 141 | 142 | 143 | class CategoricalModel: 144 | """ 145 | Class that represents a categorical distribution over graphs. 146 | 147 | Attributes 148 | ---------- 149 | graphs : List[nx.DiGraph] 150 | List of possible DAGs for this model. 151 | node_labels : List[str] 152 | List of node labels. 153 | num_nodes : int 154 | The number of nodes in this model. 155 | num_graphs : int 156 | The number of possible DAGs for this model. 157 | log_probs : Dict[str, torch.Tensor] 158 | Dictionary of graph identifiers and corresponding log probabilities. 159 | """ 160 | 161 | def __init__(self, node_labels: List[str]): 162 | """ 163 | Parameters 164 | ---------- 165 | node_labels : List[str] 166 | List of node labels. 167 | """ 168 | self.node_labels = sorted(list(set(node_labels))) 169 | self.num_nodes = len(node_labels) 170 | self.graphs = generate_all_dags(node_labels=node_labels) 171 | self.num_graphs = len(self.graphs) 172 | graph_keys = [get_graph_key(graphs) for graphs in self.graphs] 173 | self.log_probs = dict(zip(graph_keys, -torch.log(self.num_graphs * torch.ones(self.num_graphs)))) 174 | 175 | def log_prob(self, graph: nx.DiGraph) -> torch.Tensor: 176 | """ 177 | Returns the log probability for a given graph. 178 | 179 | Parameters 180 | ---------- 181 | graph : nx.DiGraph 182 | The query graph. 183 | 184 | Returns 185 | ---------- 186 | torch.Tensor 187 | The log probability. 188 | """ 189 | assert get_graph_key(graph) in self.log_probs 190 | return self.log_probs[get_graph_key(graph)] 191 | 192 | def prob(self, graph: nx.DiGraph) -> torch.Tensor: 193 | """ 194 | Returns the probability for a given graph. 195 | 196 | Parameters 197 | ---------- 198 | graph : nx.DiGraph 199 | The query graph. 200 | 201 | Returns 202 | ---------- 203 | torch.Tensor 204 | The probability. 205 | """ 206 | assert get_graph_key(graph) in self.log_probs 207 | return self.log_probs[get_graph_key(graph)].exp() 208 | 209 | def set_log_prob(self, log_prob: torch.tensor, graph: nx.DiGraph): 210 | """ 211 | Sets the log probability of a given graph. 212 | 213 | Parameters 214 | ---------- 215 | log_prob: torch.tensor 216 | The new log prabability. 217 | graph : nx.DiGraph 218 | The target graph. 219 | """ 220 | assert log_prob.numel() == 1 221 | assert get_graph_key(graph) in self.log_probs 222 | 223 | self.log_probs[get_graph_key(graph)] = log_prob.squeeze() 224 | 225 | def normalize(self): 226 | """ 227 | Normalizes the distribution over graphs such that the sum over all graphs probabilities equals 1. 228 | """ 229 | logits = torch.stack([p for p in self.log_probs.values()]) 230 | log_evidence = logits.logsumexp(dim=0) 231 | log_probs = logits - log_evidence 232 | self.log_probs = dict(zip(self.log_probs.keys(), log_probs)) 233 | 234 | def entropy(self): 235 | """ 236 | Returns the entropy of the categorical distribution over graphs. 237 | 238 | Returns 239 | ---------- 240 | torch.Tensor 241 | The entropy. 242 | """ 243 | tmp = torch.stack(list(self.log_probs.values())) 244 | return -(tmp.exp() * tmp).sum() 245 | 246 | def sample(self, num_graphs: int) -> List[nx.DiGraph]: 247 | """ 248 | Samples a graph. 249 | 250 | Parameters 251 | ---------- 252 | num_graphs: int 253 | Number of graphs to sample. 254 | 255 | Returns 256 | ---------- 257 | List[nx.DiGraph] 258 | List of sampled graph objects. 259 | """ 260 | probs = torch.stack([self.prob(graph) for graph in self.graphs]) 261 | graph_idc = torch.multinomial(probs, num_graphs, replacement=True) 262 | return [self.graphs[idx] for idx in graph_idc] 263 | 264 | def sort_by_prob(self, descending: bool = True) -> List[nx.DiGraph]: 265 | """ 266 | Returns a list of all graphs sorted by their probabilities. 267 | 268 | Parameters 269 | ---------- 270 | descending: bool 271 | If true, sort in descending order, ascending otherwise. 272 | 273 | Returns 274 | ---------- 275 | List[nx.DiGraph] 276 | List of sorted graph objects. 277 | """ 278 | 279 | def compare(graph1, graph2): 280 | return (self.log_prob(graph1) - self.log_prob(graph2)).item() 281 | 282 | return sorted(self.graphs, key=functools.cmp_to_key(compare), reverse=descending) 283 | 284 | def edge_probs(self) -> torch.Tensor: 285 | """ 286 | Returns the matrix of edge probabilities. 287 | 288 | Returns 289 | ---------- 290 | torch.Tensor 291 | Matrix of edge probabilities. 292 | """ 293 | edge_probs = torch.zeros(self.num_nodes, self.num_nodes) 294 | for graph in self.graphs: 295 | adj_mat = graph_to_adj_mat(graph, self.node_labels) 296 | edge_probs += self.prob(graph) * adj_mat 297 | 298 | return edge_probs 299 | 300 | def get_mc_graphs(self, mode: str, num_mc_graphs: int = 20) -> Tuple[List[nx.DiGraph], torch.Tensor]: 301 | """ 302 | Returns a set of graphs and corresponding log-weights for Monte Carlo estimation. 303 | 304 | Parameters 305 | ---------- 306 | mode: str 307 | There are three strategies for generating the MC graph set: 308 | 'full': Returns all graphs and their log-probabilities as log-weights. 309 | 'sampling': Returns `num_mc_graphs` samples from the distribution over graphs with uniform weights. 310 | 'n-best': Returns the `num_mc_graphs` graphs with the highest probabilities weighted according to their 311 | re-normalized probabilities. 312 | num_mc_graphs: int 313 | The size of the returned set of graphs (see description above). Has no effect when mode is 'full'. 314 | 315 | Returns 316 | ---------- 317 | List[nx.DiGraph], torch.Tensor 318 | The set of MC graphs and their corresponding log-weights. 319 | """ 320 | if mode not in {'full', 'sampling', 'n-best'}: 321 | print('Invalid sampling mode >' + mode + '<. Doing instead.') 322 | mode = 'full' 323 | 324 | if mode == 'sampling': 325 | graphs = self.sample(num_mc_graphs) 326 | log_weights = -torch.ones(num_mc_graphs) * math.log(num_mc_graphs) 327 | elif mode == 'n-best': 328 | graphs = self.sort_by_prob()[:num_mc_graphs] 329 | log_weights = [self.log_prob(graph) for graph in graphs] 330 | log_weights = torch.log_softmax(torch.stack(log_weights), dim=0) 331 | else: 332 | graphs = self.graphs 333 | log_weights = [self.log_prob(graph) for graph in graphs] 334 | log_weights = torch.stack(log_weights) 335 | 336 | return graphs, log_weights 337 | 338 | def param_dict(self) -> Dict[str, Any]: 339 | """ 340 | Returns the current parameters of the an instance of this class as a dictionary. 341 | 342 | Returns 343 | ---------- 344 | Dict[str, Any] 345 | Parameter dictionary. 346 | """ 347 | params = {'node_labels': self.node_labels, 348 | 'num_nodes': self.num_nodes, 349 | 'graphs': self.graphs, 350 | 'num_graphs': self.num_graphs, 351 | 'log_probs': self.log_probs} 352 | return params 353 | 354 | def load_param_dict(self, param_dict: Dict[str, Any]): 355 | """ 356 | Sets the parameters of this class instance with the parameter values given in `param_dict`. 357 | 358 | Parameters 359 | ---------- 360 | param_dict : Dict[str, Any] 361 | Parameter dictionary. 362 | """ 363 | self.node_labels = param_dict['node_labels'] 364 | self.num_nodes = param_dict['num_nodes'] 365 | self.graphs = param_dict['graphs'] 366 | self.num_graphs = param_dict['num_graphs'] 367 | self.log_probs = param_dict['log_probs'] 368 | 369 | 370 | class DiBSModel: 371 | def __init__(self, node_labels: List[str], embedding_size: int, num_particles: int, std: float = 1.): 372 | self.node_labels = sorted(list(set(node_labels))) 373 | self.num_nodes = len(self.node_labels) 374 | self.node_id_to_label_dict = dict(zip(list(range(self.num_nodes)), node_labels)) 375 | self.node_label_to_id_dict = dict(zip(node_labels, list(range(self.num_nodes)))) 376 | self.embedding_size = embedding_size 377 | self.num_particles = num_particles 378 | self.normal_prior_std = std 379 | self.particles = self.sample_initial_particles(num_particles) 380 | 381 | def _check_particle_shape(self, z: torch.Tensor): 382 | assert z.dim() == 4 and z.shape[1:] == (self.embedding_size, self.num_nodes, 2), print(z.shape) 383 | 384 | def sample_initial_particles(self, num_particles) -> torch.Tensor: 385 | # sample particles from normal prior 386 | normal = torch.distributions.Normal(0., self.normal_prior_std) 387 | particles = normal.sample((num_particles, self.embedding_size, self.num_nodes, 2)) 388 | particles.requires_grad_(True) 389 | self._check_particle_shape(particles) 390 | return particles 391 | 392 | def edge_logits(self, alpha: float = 1.): 393 | return alpha * torch.einsum('ikj,ikl->ijl', self.particles[..., 0], self.particles[..., 1]) 394 | 395 | def edge_probs(self, alpha: float = 1.): 396 | # compute edge probs 397 | edge_probs = torch.sigmoid(self.edge_logits(alpha)) 398 | 399 | # set probs of self loops to 0 400 | mask = torch.eye(self.num_nodes).repeat(self.num_particles, 1, 1).bool() 401 | edge_probs[mask] = 0. 402 | return edge_probs 403 | 404 | def edge_log_probs(self, alpha: float = 1.): 405 | return logsigmoid(self.edge_logits(alpha)) 406 | 407 | def log_generative_prob(self, adj_mats: torch.Tensor, alpha: float = 1., batch_mode=True): 408 | assert adj_mats.dim() == 4 and adj_mats.shape[2:] == (self.num_nodes, self.num_nodes) 409 | assert adj_mats.shape[0] == self.particles.shape[0] or not batch_mode 410 | logits = self.edge_logits(alpha) 411 | log_edge_probs = logsigmoid(logits) 412 | log_not_edge_probs = log_edge_probs - logits # =logsigmoid(-logits) = log(1-sigmoid(logits)) 413 | 414 | # set probs of self loops to 0 415 | mask = torch.eye(self.num_nodes).repeat(self.num_particles, 1, 1).bool() 416 | log_edge_probs = torch.where(mask, torch.tensor(0.), log_edge_probs) 417 | log_not_edge_probs = torch.where(mask, torch.tensor(0.), log_not_edge_probs) 418 | 419 | if batch_mode: 420 | graph_log_probs = torch.einsum('hijk,hjk->hi', adj_mats, log_edge_probs) + \ 421 | torch.einsum('hijk,hjk->hi', (1. - adj_mats), log_not_edge_probs) 422 | else: 423 | graph_log_probs = torch.einsum('hijk,ljk->lhi', adj_mats, log_edge_probs) + \ 424 | torch.einsum('hijk,ljk->lhi', (1. - adj_mats), log_not_edge_probs) 425 | 426 | return graph_log_probs 427 | 428 | def unnormalized_log_prior(self, alpha: float = 1., beta: float = 1.) -> torch.Tensor: 429 | normal = torch.distributions.Normal(0., self.normal_prior_std) 430 | 431 | ec = self.expected_cyclicity(alpha) 432 | log_prior = normal.log_prob(self.particles).sum(dim=(1, 2, 3)) / self.particles[0].numel() - beta * ec 433 | return log_prior.float() 434 | 435 | def expected_cyclicity(self, alpha: float = 1., num_samples: int = 100) -> torch.Tensor: 436 | adj_mats = self.sample_soft_graphs(num_samples, alpha) 437 | scores = AcyclicityScore.apply(adj_mats) 438 | return scores.sum(dim=-1) / num_samples 439 | 440 | def sample_soft_graphs(self, num_samples: int, alpha: float = 1.): 441 | edge_logits = self.edge_logits(alpha) 442 | 443 | transforms = [dist.SigmoidTransform().inv, dist.AffineTransform(loc=0., scale=1.)] 444 | logistic = torch.distributions.TransformedDistribution(dist.Uniform(0, 1), transforms) 445 | reparam_logits = logistic.rsample((num_samples, *edge_logits.shape)) + \ 446 | edge_logits.unsqueeze(0).expand(num_samples, -1, -1, -1) 447 | soft_adj_mats = torch.sigmoid(reparam_logits).permute(1, 0, 2, 3) 448 | 449 | # eliminate self loops 450 | mask = torch.eye(self.num_nodes).repeat(self.num_particles, num_samples, 1, 1).bool() 451 | soft_adj_mats = torch.where(mask, torch.tensor(0.), soft_adj_mats) 452 | return soft_adj_mats 453 | 454 | def sample_graphs(self, num_samples: int, alpha: float = 1., fixed_edges: List[Tuple[int, int]] = None) \ 455 | -> Tuple[List[List[nx.DiGraph]], torch.Tensor]: 456 | # compute bernoulli probs from latent particles 457 | edge_probs = self.edge_probs(alpha) 458 | 459 | # modify probs 460 | if fixed_edges is not None: 461 | for i, j in fixed_edges: 462 | edge_probs[:, i, j] = 1. 463 | 464 | # sample adjacency matrices and generate graph objects 465 | adj_mats = torch.bernoulli(edge_probs.unsqueeze(1).expand(-1, num_samples, -1, -1)) 466 | graphs = [[self.adj_mat_to_graph(adj_mats[pidx, sidx]) for sidx in range(num_samples)] for pidx in range( 467 | self.num_particles)] 468 | return graphs, adj_mats 469 | 470 | def dagify_graphs(self, graphs: List[List[nx.DiGraph]], adj_mats: torch.Tensor): 471 | """Uses a simple heuristic to 'dagify' cyclic graphs in-place. Note: this can be handy for testing and 472 | debugging during developement, but should not be necessary when the DiBS model is trained properly (as it 473 | should then almost always return DAGs when sampling). 474 | 475 | Parameters 476 | ---------- 477 | graphs : List[List[nx.DiGraph]] 478 | Nested lists of graph objects. 479 | adj_mats : torch.Tensor 480 | Tensor of adjacency matrices corresponding to the graph objects in `graphs`. 481 | """ 482 | edge_probs = self.edge_probs() 483 | for particle_idx in range(self.num_particles): 484 | num_dagified = 0 485 | for graph_idx, graph in enumerate(graphs[particle_idx]): 486 | # check if the graph is cyclic 487 | if not nx.is_directed_acyclic_graph(graphs[particle_idx][graph_idx]): 488 | edges, _ = self.sort_edges(adj_mats[particle_idx, graph_idx], edge_probs[particle_idx]) 489 | 490 | graph = nx.DiGraph() 491 | graph.add_nodes_from(self.node_labels) 492 | adj_mats[particle_idx, graph_idx] = torch.zeros(self.num_nodes, self.num_nodes) 493 | for edge_idx, edge in enumerate(edges): 494 | source_node = self.node_id_to_label_dict[edge[0]] 495 | sink_node = self.node_id_to_label_dict[edge[1]] 496 | if not nx.has_path(graph, sink_node, source_node): 497 | # if there is no path from the target to the source node, we can safely add the edge to 498 | # the graph without creating a cycle 499 | graph.add_edge(source_node, sink_node) 500 | adj_mats[particle_idx, graph_idx, edge[0], edge[1]] = 1 501 | 502 | graphs[particle_idx][graph_idx] = graph 503 | num_dagified += 1 504 | 505 | if num_dagified > 0: 506 | print(f'Dagified {num_dagified} graphs of the {particle_idx + 1}-th particle!') 507 | 508 | def sort_edges(self, adj_mat: torch.Tensor, edge_weights: torch.Tensor, descending=True): 509 | edges = [(i, j) for i in range(self.num_nodes) for j in range(self.num_nodes) if adj_mat.bool()[i, j]] 510 | weights = edge_weights[adj_mat.bool()] 511 | weights, idc = torch.sort(weights, descending=descending) 512 | edges = [edges[idx] for idx in idc] 513 | return edges, weights 514 | 515 | def adj_mat_to_graph_key(self, adj_mat: torch.Tensor): 516 | assert adj_mat.shape == (self.num_nodes, self.num_nodes) 517 | key = '|'.join( 518 | [self.node_labels[i] + '<-' + ','.join( 519 | [self.node_labels[j] for j in range(self.num_nodes) if adj_mat[j, i] > 0.5]) for i in 520 | range(self.num_nodes)]) 521 | return key 522 | 523 | def adj_mat_to_graph(self, adj_mat: torch.Tensor) -> nx.DiGraph: 524 | assert adj_mat.shape == (self.num_nodes, self.num_nodes) 525 | graph = nx.from_numpy_array(adj_mat.int().numpy(), create_using=nx.DiGraph) 526 | graph = nx.relabel_nodes(graph, self.node_id_to_label_dict) 527 | return graph 528 | 529 | def graph_to_adj_mat(self, graph: nx.DiGraph) -> torch.Tensor: 530 | return graph_to_adj_mat(graph, self.node_labels) 531 | 532 | def get_limit_graphs(self): 533 | # round edge probs to get limit graphs 534 | edge_probs = self.edge_probs() 535 | adj_mats = edge_probs.round().unsqueeze(1) 536 | graphs = [[self.adj_mat_to_graph(adj_mats[i, 0])] for i in range(self.num_particles)] 537 | return graphs, adj_mats 538 | 539 | def particle_similarities(self, bandwidth=1.): 540 | distances = [(self.particles - self.particles[i:i + 1].detach()) ** 2 for i in range(self.num_particles)] 541 | similarities = [(-d.sum(dim=(1, 2, 3)) / bandwidth).exp() for d in distances] 542 | kernel_mat = torch.stack(similarities, dim=1) 543 | return kernel_mat 544 | 545 | def param_dict(self): 546 | params = {'node_id_to_label_dict': self.node_id_to_label_dict, 547 | 'node_label_to_id_dict': self.node_label_to_id_dict, 548 | 'embedding_size': self.num_particles, 549 | 'normal_prior_std': self.normal_prior_std, 550 | 'particles': self.particles} 551 | return params 552 | 553 | def load_param_dict(self, param_dict): 554 | self.node_id_to_label_dict = param_dict['node_id_to_label_dict'] 555 | self.node_label_to_id_dict = param_dict['node_label_to_id_dict'] 556 | self.node_labels = sorted(list(self.node_label_to_id_dict.keys())) 557 | self.num_nodes = len(self.node_labels) 558 | self.particles = param_dict['particles'] 559 | self.num_particles = self.particles.shape[0] 560 | self.embedding_size = self.particles.shape[1] 561 | self.normal_prior_std = param_dict['normal_prior_std'] 562 | 563 | 564 | class AcyclicityScore(Function): 565 | @staticmethod 566 | def forward(ctx, adj_mat: torch.Tensor, round_edge_weights=False): 567 | assert adj_mat.dim() >= 3 and adj_mat.shape[-1] == adj_mat.shape[-2], print( 568 | f'Ill-shaped input: {adj_mat.shape}') 569 | num_nodes = adj_mat.shape[-1] 570 | eyes = torch.eye(num_nodes).double().expand_as(adj_mat) 571 | tmp = eyes + (adj_mat.round().double() if round_edge_weights else adj_mat) / num_nodes 572 | 573 | tmp_pow = tmp.matrix_power(num_nodes - 1) 574 | ctx.grad = tmp_pow.transpose(-1, -2) 575 | score = (tmp_pow @ tmp).diagonal(dim1=-2, dim2=-1).sum(dim=-1) - num_nodes 576 | return score 577 | 578 | @staticmethod 579 | @torch.autograd.function.once_differentiable 580 | def backward(ctx, grad_output): 581 | return ctx.grad * grad_output[..., None, None], None 582 | -------------------------------------------------------------------------------- /src/abci_dibs_gp.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | 3 | import torch.optim 4 | from torch.nn.functional import log_softmax 5 | 6 | from src.abci_base import ABCIBase 7 | from src.environments.environment import * 8 | from src.experimental_design.exp_designer_abci_dibs_gp import ExpDesignerABCIDiBSGP 9 | from src.models.gp_model import GaussianProcessModel 10 | from src.models.graph_models import DiBSModel, get_graph_key 11 | from src.utils.metrics import shd, auroc, auprc 12 | 13 | 14 | class ABCIDiBSGP(ABCIBase): 15 | policies = {'observational', 'random', 'random-fixed-value', 'graph-info-gain', 'scm-info-gain', 16 | 'intervention-info-gain'} 17 | 18 | def __init__(self, env: Environment, policy: str, num_particles: int = 5, 19 | num_mc_graphs: int = 40, embedding_size: int = None, num_workers: int = 1, dibs_plus=True, 20 | linear: bool = False): 21 | assert policy in self.policies, print(f'Invalid policy {policy}!') 22 | super().__init__(env, policy, num_workers) 23 | 24 | # store params 25 | embedding_size = env.num_nodes if embedding_size is None else embedding_size 26 | self.num_particles = num_particles 27 | self.num_mc_graphs = num_mc_graphs 28 | self.embedding_size = embedding_size 29 | self.dibs_plus = dibs_plus 30 | self.linear = linear 31 | 32 | # init models 33 | self.graph_model = DiBSModel(env.node_labels, embedding_size, num_particles) 34 | self.mechanism_model = GaussianProcessModel(env.node_labels, linear=linear) 35 | 36 | # init mc graphs 37 | self.mc_graphs = self.mc_adj_mats = None 38 | 39 | def experiment_designer_factory(self): 40 | distributed = self.num_workers > 1 41 | return ExpDesignerABCIDiBSGP(self.env.intervention_bounds, opt_strategy='gp-ucb', distributed=distributed) 42 | 43 | def run(self, num_experiments=10, batch_size=3, update_interval=5, log_interval=1, num_initial_obs_samples=1, 44 | checkpoint_interval: int = 10, outdir: str = None, job_id: str = ''): 45 | 46 | # pre-compute env test data stats 47 | with torch.no_grad(): 48 | num_test_samples = self.env.num_test_samples_per_intervention 49 | env_obs_test_ll = self.env.log_likelihood(self.env.observational_test_data) / num_test_samples 50 | env_intr_test_lls = {} 51 | for node, experiments in self.env.interventional_test_data.items(): 52 | env_intr_test_lls[node] = self.env.log_likelihood(experiments) / num_test_samples 53 | 54 | # run experiments 55 | for epoch in range(num_experiments): 56 | print(f'Starting experiment cycle {epoch + 1}/{num_experiments}...') 57 | 58 | # pick intervention according to policy 59 | print(f'Design and perform experiment...', flush=True) 60 | info_gain = None 61 | if self.policy == 'observational' or len(self.experiments) == 0 and num_initial_obs_samples > 0: 62 | interventions = {} 63 | elif self.policy == 'random': 64 | interventions = self.get_random_intervention() 65 | elif self.policy == 'random-fixed-value': 66 | interventions = self.get_random_intervention(0.) 67 | else: 68 | if self.policy in {'graph-info-gain', 'scm-info-gain'}: 69 | # sample mc graphs 70 | outer_mc_graphs, _ = self.sample_mc_graphs(set_data=True, num_graphs=5, only_dags=False) 71 | inner_mc_graphs, _ = self.sample_mc_graphs(set_data=True, num_graphs=30, only_dags=False) 72 | with torch.no_grad(): 73 | log_inner_graph_weights, log_inner_particle_weights = self.compute_importance_weights( 74 | inner_mc_graphs, use_cache=True, log_weights=True) 75 | outer_graph_weights, outer_particle_weights = self.compute_importance_weights(outer_mc_graphs, 76 | use_cache=True) 77 | 78 | graphs = [g for glist in inner_mc_graphs for g in glist] 79 | graphs += [g for glist in outer_mc_graphs for g in glist] 80 | args = {'mechanism_model': self.mechanism_model.submodel(graphs), 81 | 'inner_mc_graphs': inner_mc_graphs, 82 | 'log_inner_graph_weights': log_inner_graph_weights, 83 | 'log_inner_particle_weights': log_inner_particle_weights, 84 | 'outer_mc_graphs': outer_mc_graphs, 85 | 'outer_graph_weights': outer_graph_weights, 86 | 'outer_particle_weights': outer_particle_weights, 87 | 'batch_size': batch_size, 88 | 'num_exp_per_graph': 100, 89 | 'policy': self.policy} 90 | 91 | elif self.policy in {'intervention-info-gain'}: 92 | self.mechanism_model.discard_mechanisms(len(self.experiments), -1) 93 | 94 | # sample mc graphs 95 | outer_mc_graphs, _ = self.sample_mc_graphs(set_data=True, num_graphs=3, only_dags=False) 96 | inner_mc_graphs, _ = self.sample_mc_graphs(set_data=True, num_graphs=9, only_dags=False) 97 | with torch.no_grad(): 98 | # compute importance weights 99 | log_inner_graph_weights, log_inner_particle_weights = self.compute_importance_weights( 100 | inner_mc_graphs, use_cache=True, log_weights=True) 101 | outer_graph_weights, outer_particle_weights = self.compute_importance_weights(outer_mc_graphs, 102 | use_cache=True) 103 | 104 | graphs = [g for glist in inner_mc_graphs for g in glist] 105 | graphs += [g for glist in outer_mc_graphs for g in glist] 106 | args = {'mechanism_model': self.mechanism_model.submodel(graphs), 107 | 'experiments': self.experiments, 108 | 'interventional_queries': self.env.interventional_queries, 109 | 'inner_mc_graphs': inner_mc_graphs, 110 | 'log_inner_graph_weights': log_inner_graph_weights, 111 | 'log_inner_particle_weights': log_inner_particle_weights, 112 | 'outer_mc_graphs': outer_mc_graphs, 113 | 'outer_graph_weights': outer_graph_weights, 114 | 'outer_particle_weights': outer_particle_weights, 115 | 'num_mc_queries': 5, 116 | 'num_batches_per_query': 3, 117 | 'batch_size': batch_size, 118 | 'num_exp_per_graph': 50, 119 | 'policy': self.policy} 120 | else: 121 | assert False, print(f'Invalid policy {self.policy}!') 122 | 123 | if self.num_workers > 1: 124 | interventions, info_gain = self.design_experiment_distributed(args) 125 | else: 126 | designer = self.experiment_designer_factory() 127 | designer.init_design_process(args) 128 | interventions, info_gain = designer.get_best_experiment(self.env.intervenable_nodes) 129 | 130 | # record expected information gain of chosen intervention 131 | if info_gain is None: 132 | info_gain = torch.tensor(-1.) 133 | self.info_gain_list.append(info_gain) 134 | 135 | # resample particles 136 | num_experiments_conducted = len(self.experiments) 137 | resampling_interval = 1 if num_experiments_conducted < 6 else (3 if num_experiments_conducted < 10 else 5) 138 | if num_experiments_conducted > 0 and num_experiments_conducted % resampling_interval == 0: 139 | self.resample_particles(use_cache=True) 140 | # clear mechs & caches: many different mechs after resampling 141 | self.mechanism_model.discard_mechanisms(num_experiments_conducted, -1) 142 | 143 | # perform experiment 144 | num_samples = batch_size 145 | if num_experiments_conducted == 0 and num_initial_obs_samples > 0: 146 | num_samples = num_initial_obs_samples 147 | self.experiments.append(self.env.sample(interventions, num_samples)) 148 | 149 | # clear caches 150 | self.mechanism_model.clear_prior_mll_cache() 151 | self.mechanism_model.clear_posterior_mll_cache() 152 | self.mechanism_model.entropy_cache.clear() 153 | 154 | # update latent particles 155 | # num_steps = 100 if len(self.experiments) < 10 else 200 156 | print(f'Updating latent particles via SVGD...', flush=True) 157 | num_steps = 500 158 | losses = self.update_latent_particles(num_steps) 159 | self.loss_list.extend(losses) 160 | 161 | # discard old mechanisms 162 | self.mechanism_model.discard_mechanisms(len(self.experiments), max_age=0) 163 | print(f'There are currently {self.mechanism_model.get_num_mechanisms()} unique mechanisms in our model...') 164 | 165 | print(f'Logging evaluation stats...', flush=True) 166 | self.mc_graphs, self.mc_adj_mats = self.sample_mc_graphs(set_data=True) 167 | 168 | # record graph posterior entropy 169 | self.graph_entropy_list.append(torch.tensor(-1.)) 170 | 171 | # record expected SHD 172 | with torch.no_grad(): 173 | eshd = self.graph_posterior_expectation(lambda g: shd(self.env.graph, g)) 174 | self.eshd_list.append(eshd) 175 | 176 | # record env graph LL 177 | graph_ll = self.compute_graph_log_posterior(self.env.graph) 178 | self.graph_ll_list.append(graph_ll) 179 | 180 | # record observational test LLs 181 | def test_ll(graph): 182 | return self.mechanism_model.mll(self.env.observational_test_data, graph, prior_mode=False, 183 | use_cache=True, mode='joint') 184 | 185 | self.mechanism_model.clear_posterior_mll_cache() 186 | with torch.no_grad(): 187 | ll = self.graph_posterior_expectation(test_ll) 188 | self.observational_test_ll_list.append(ll) 189 | 190 | # record interventional test LLs 191 | for node, experiments in self.env.interventional_test_data.items(): 192 | def test_ll(graph): 193 | return self.mechanism_model.mll(experiments, graph, prior_mode=False, use_cache=True, mode='joint') 194 | 195 | self.mechanism_model.clear_posterior_mll_cache() 196 | with torch.no_grad(): 197 | ll = self.graph_posterior_expectation(test_ll) 198 | self.interventional_test_ll_lists[node].append(ll) 199 | 200 | # record observational KLD 201 | def test_ll(graph): 202 | return self.mechanism_model.mll(self.env.observational_test_data, graph, prior_mode=False, 203 | use_cache=True, mode='independent_samples', reduce=False) 204 | 205 | self.mechanism_model.clear_posterior_mll_cache() 206 | with torch.no_grad(): 207 | ll = self.graph_posterior_expectation(test_ll, logspace=True).mean() 208 | self.observational_kld_list.append(env_obs_test_ll - ll) 209 | 210 | # record interventional test KLDs 211 | for node, experiments in self.env.interventional_test_data.items(): 212 | def test_ll(graph): 213 | return self.mechanism_model.mll(experiments, graph, prior_mode=False, use_cache=True, 214 | mode='independent_samples', reduce=False) 215 | 216 | self.mechanism_model.clear_posterior_mll_cache() 217 | with torch.no_grad(): 218 | ll = self.graph_posterior_expectation(test_ll, logspace=True).mean() 219 | self.interventional_kld_lists[node].append(env_intr_test_lls[node] - ll) 220 | 221 | # record AUROC/AUPRC scores 222 | with torch.no_grad(): 223 | posterior_edge_probs = self.compute_posterior_edge_probs(use_cache=True) 224 | true_adj_mat = self.graph_model.graph_to_adj_mat(self.env.graph) 225 | 226 | self.auroc_list.append(auroc(posterior_edge_probs, true_adj_mat)) 227 | self.auprc_list.append(auprc(posterior_edge_probs, true_adj_mat)) 228 | 229 | # record query KLD 230 | if self.env.interventional_queries is not None: 231 | self.mc_graphs, self.mc_adj_mats = self.sample_mc_graphs(set_data=True, num_graphs=5, only_dags=False) 232 | 233 | def test_query_ll(graph): 234 | # check if graph is acyclic 235 | if get_graph_key(graph) not in self.mechanism_model.topological_orders: 236 | print(f'Cannot evaluate test query ll for cyclic graph.') 237 | num_mc_queries = len(self.env.interventional_queries[0].sample_queries) 238 | num_batches_per_query = self.env.interventional_queries[0].sample_queries[0].num_batches 239 | return torch.ones(num_mc_queries, num_batches_per_query) * 1e-8 240 | 241 | query_lls = self.mechanism_model.query_log_probs(self.env.interventional_queries, graph, 200) 242 | return query_lls 243 | 244 | self.mechanism_model.clear_posterior_mll_cache() 245 | with torch.no_grad(): 246 | ll = self.graph_posterior_expectation(test_query_ll, logspace=True).mean() 247 | self.query_kld_list.append(self.env.query_ll - ll) 248 | 249 | if outdir is not None and 0 < epoch < num_experiments - 1 and (epoch + 1) % checkpoint_interval == 0: 250 | model = 'abci-dibs-gp-linear' if self.linear else 'abci-dibs-gp' 251 | outpath = outdir + model + '-' + self.policy + '-' + self.env.name + f'-{job_id}-exp-{epoch + 1}.pth' 252 | self.save(outpath) 253 | 254 | if log_interval > 0 and epoch % log_interval == 0: 255 | print(f'Experiment {epoch + 1}/{num_experiments}, ESHD is {eshd.item()}', flush=True) 256 | 257 | def resample_particles(self, threshold: float = 1e-2, use_cache: bool = False): 258 | mc_graphs, _ = self.sample_mc_graphs() 259 | num_particles = len(mc_graphs) 260 | 261 | max_particles_to_keep = math.ceil(num_particles / 4.) 262 | with torch.no_grad(): 263 | _, particle_weights = self.compute_importance_weights(mc_graphs, use_cache=use_cache, dibs_plus=True) 264 | 265 | particle_idc = particle_weights.argsort(descending=True).numpy() 266 | num_kept = 0 267 | resampled_particles = [] 268 | for i in particle_idc: 269 | if num_kept >= max_particles_to_keep or particle_weights[i] < threshold: 270 | self.graph_model.particles[i] = self.graph_model.sample_initial_particles(1).squeeze(0) 271 | resampled_particles.append(i) 272 | else: 273 | num_kept += 1 274 | 275 | print(f'Resampling particles {resampled_particles} according to weights {particle_weights.squeeze()} (kept ' 276 | f'{num_kept}/{max_particles_to_keep}') 277 | 278 | def update_latent_particles(self, num_steps: int = 100, log_interval: int = 25): 279 | optimizer = torch.optim.Adam([self.graph_model.particles], lr=1e-1) 280 | # alphas = 1. + 1e-2 * torch.arange(1, num_steps + 1).numpy() 281 | alphas = torch.ones(num_steps).numpy() 282 | betas = 1. + 25e-2 * torch.arange(1, num_steps + 1).numpy() 283 | # betas = 50. * torch.ones(num_steps).numpy() 284 | 285 | losses = [] 286 | for i in range(num_steps): 287 | self.mc_graphs, self.mc_adj_mats = self.sample_mc_graphs(alphas[i]) 288 | 289 | optimizer.zero_grad() 290 | log_posterior_grads, unnormalized_log_posterior = self.estimate_score_function(alphas[i], betas[i]) 291 | 292 | bandwidth = self.graph_model.embedding_size * self.graph_model.num_nodes * 2. 293 | # bandwidth = 10. 294 | particle_similarities = self.graph_model.particle_similarities(bandwidth=bandwidth) 295 | sim_grads = torch.autograd.grad(particle_similarities.sum(), self.graph_model.particles)[0] 296 | particle_grads = torch.einsum('ab,bdef->adef', particle_similarities.detach(), 297 | log_posterior_grads) - sim_grads 298 | 299 | self.graph_model.particles.grad = -particle_grads / self.graph_model.num_particles 300 | 301 | optimizer.step() 302 | losses.append(-unnormalized_log_posterior.item()) 303 | 304 | if i > 50: 305 | change = (torch.tensor(losses[-4:-1]).mean() - torch.tensor(losses[-5:-2]).mean()).abs() 306 | if change < 1e-3: 307 | print(f'Stopping particle updates early in iteration {i} after improvement stagnates...') 308 | break 309 | 310 | if log_interval > 0 and i % log_interval == 0: 311 | print(f'Step {i + 1} of {num_steps}, negative log posterior is {-unnormalized_log_posterior.item()}...', 312 | flush=True) 313 | 314 | return losses 315 | 316 | def estimate_score_function(self, alpha: float = 1., beta: float = 1.) -> Tuple[torch.Tensor, torch.Tensor]: 317 | num_particles, num_mc_graphs = self.mc_adj_mats.shape[0:2] 318 | 319 | # compute log prior p(Z) 320 | log_prior = self.graph_model.unnormalized_log_prior(alpha, beta).sum() 321 | 322 | # compute graph weights with baseline for variance reduction (p(D|G) - b) / p(D|Z) 323 | with torch.no_grad(): 324 | graph_mlls = self.compute_graph_mlls(self.mc_graphs, use_cache=True) 325 | log_normalization = graph_mlls.logsumexp(dim=1) 326 | particle_mlls = log_normalization - math.log(num_mc_graphs) 327 | graph_weights = (graph_mlls - log_normalization.unsqueeze(1)).exp() 328 | 329 | baseline = torch.ones(num_particles, 1) / num_mc_graphs 330 | 331 | # compute log generative probabilities p(G|Z) 332 | log_generative_probs = self.graph_model.log_generative_prob(self.mc_adj_mats, alpha) 333 | tmp = log_prior + ((graph_weights - baseline) * log_generative_probs).sum() 334 | score_func = torch.autograd.grad(tmp, self.graph_model.particles)[0] 335 | 336 | unnormalized_log_posterior = (log_prior + particle_mlls.sum()) / num_particles 337 | return score_func, unnormalized_log_posterior 338 | 339 | def sample_mc_graphs(self, alpha: float = 1., set_data=False, num_graphs: int = None, only_dags=False): 340 | num_graphs = self.num_mc_graphs if num_graphs is None else num_graphs 341 | 342 | with torch.no_grad(): 343 | mc_graphs, mc_adj_mats = self.graph_model.sample_graphs(num_graphs, alpha) 344 | if only_dags: 345 | self.graph_model.dagify_graphs(mc_graphs, mc_adj_mats) 346 | 347 | self.init_graph_mechanisms(mc_graphs, set_data=set_data) 348 | 349 | return mc_graphs, mc_adj_mats 350 | 351 | def init_graph_mechanisms(self, graphs: List[List[nx.DiGraph]], set_data=False): 352 | # initialize mechanisms 353 | num_particles = len(graphs) 354 | time = len(self.experiments) 355 | initialized_mechanisms = set() 356 | for i in range(num_particles): 357 | for graph in graphs[i]: 358 | keys = self.mechanism_model.init_mechanisms(graph, init_time=time) 359 | initialized_mechanisms.update(keys) 360 | initialized_mechanisms = list(initialized_mechanisms) 361 | 362 | # update mechanism hyperparameters 363 | hyperparam_update_interval = 1 if time < 6 else (3 if time < 10 else 5) 364 | if hyperparam_update_interval > 0: 365 | update_time = 1 if time <= 1 else (time // hyperparam_update_interval) * hyperparam_update_interval 366 | 367 | self.mechanism_model.update_gp_hyperparameters(update_time, self.experiments, set_data, 368 | initialized_mechanisms) 369 | 370 | def compute_graph_mlls(self, graphs: List[List[nx.DiGraph]], experiments: List[Experiment] = None, prior_mode=True, 371 | use_cache=True): 372 | num_particles = len(graphs) 373 | num_mc_graphs = len(graphs[0]) 374 | experiments = self.experiments if experiments is None else experiments 375 | graph_mlls = [self.mechanism_model.mll(experiments, graph, prior_mode=prior_mode, use_cache=use_cache) for i in 376 | range(num_particles) for graph in graphs[i]] 377 | graph_mlls = torch.stack(graph_mlls).view(num_particles, num_mc_graphs) 378 | return graph_mlls 379 | 380 | def graph_posterior_expectation(self, func: Callable[[nx.DiGraph], torch.Tensor], use_cache=True, logspace=False): 381 | num_particles, num_mc_graphs = self.mc_adj_mats.shape[0:2] 382 | 383 | # compute function values 384 | func_values = [func(graph) for i in range(num_particles) for graph in self.mc_graphs[i]] 385 | func_output_shape = func_values[0].shape 386 | func_output_dim = len(func_output_shape) 387 | func_values = torch.stack(func_values).view(num_particles, num_mc_graphs, *func_output_shape) 388 | 389 | # compute expectation 390 | if logspace: 391 | log_graph_weights, log_particle_weights = self.compute_importance_weights(self.mc_graphs, 392 | use_cache=use_cache, 393 | log_weights=True) 394 | log_graph_weights = log_graph_weights.view(num_particles, num_mc_graphs, *([1] * func_output_dim)) 395 | log_particle_weights = log_particle_weights.view(num_particles, *([1] * func_output_dim)) 396 | expected_value = (log_graph_weights + func_values).logsumexp(dim=1) 397 | expected_value = (log_particle_weights + expected_value).logsumexp(dim=0) 398 | return expected_value 399 | 400 | graph_weights, particle_weights = self.compute_importance_weights(self.mc_graphs, use_cache=use_cache) 401 | graph_weights = graph_weights.view(num_particles, num_mc_graphs, *([1] * func_output_dim)) 402 | particle_weights = particle_weights.view(num_particles, *([1] * func_output_dim)) 403 | 404 | expected_value = particle_weights @ (graph_weights * func_values).sum(dim=1) 405 | return expected_value 406 | 407 | def compute_posterior_edge_probs(self, use_cache=True): 408 | num_particles, num_mc_graphs = self.mc_adj_mats.shape[0:2] 409 | 410 | # compute expectation 411 | graph_weights, particle_weights = self.compute_importance_weights(self.mc_graphs, use_cache=use_cache) 412 | 413 | posterior_edge_probs = torch.zeros(self.env.num_nodes, self.env.num_nodes) 414 | for i in range(self.env.num_nodes): 415 | for j in range(self.env.num_nodes): 416 | particle_edge_probs = torch.zeros(num_particles) 417 | for particle_idx in range(num_particles): 418 | particle_edge_probs[particle_idx] = sum([graph_weights[particle_idx, graph_idx] for graph_idx 419 | in range(num_mc_graphs) if 420 | self.mc_adj_mats[particle_idx, graph_idx, i, j].bool()]) 421 | 422 | posterior_edge_probs[i, j] = particle_weights @ particle_edge_probs 423 | 424 | return posterior_edge_probs 425 | 426 | def compute_graph_log_posterior(self, graph: nx.DiGraph, alpha: float = 1.): 427 | num_particles = len(self.mc_graphs) 428 | num_mc_graphs = len(self.mc_graphs[0]) 429 | 430 | self.init_graph_mechanisms([[graph]]) 431 | with torch.no_grad(): 432 | graph_mlls = self.compute_graph_mlls(self.mc_graphs) 433 | log_prior = self.graph_model.unnormalized_log_prior(beta=50.) 434 | log_normalization = graph_mlls.logsumexp(dim=1) 435 | particle_mlls = log_normalization - math.log(num_mc_graphs) 436 | log_particle_weights = log_softmax(log_prior + particle_mlls, dim=0) if self.dibs_plus else \ 437 | -torch.tensor(num_particles).log() 438 | 439 | adj_mat = self.graph_model.graph_to_adj_mat(graph).expand(num_particles, 1, -1, -1) 440 | log_generative_probs = self.graph_model.log_generative_prob(adj_mat, alpha).squeeze() 441 | tmp = (log_generative_probs - particle_mlls + log_particle_weights).logsumexp(dim=0) 442 | 443 | log_graph_posterior = tmp + self.mechanism_model.mll(self.experiments, graph, prior_mode=True) 444 | 445 | return log_graph_posterior 446 | 447 | def compute_importance_weights(self, mc_graphs, use_cache=False, beta=50., log_weights: bool = False, 448 | dibs_plus: bool = None): 449 | if dibs_plus is None: 450 | dibs_plus = self.dibs_plus 451 | 452 | num_particles = len(mc_graphs) 453 | num_mc_graphs = len(mc_graphs[0]) 454 | graph_mlls = self.compute_graph_mlls(mc_graphs, use_cache=use_cache) 455 | log_normalization = graph_mlls.logsumexp(dim=1) 456 | log_graph_weights = (graph_mlls - log_normalization.unsqueeze(1)) 457 | if dibs_plus: 458 | log_particle_prior = self.graph_model.unnormalized_log_prior(beta=beta) 459 | particle_mlls = log_normalization - math.log(num_mc_graphs) 460 | log_particle_weights = log_softmax(log_particle_prior + particle_mlls, dim=0) 461 | else: 462 | log_particle_weights = -torch.tensor(num_particles).log() * torch.ones(num_particles) 463 | 464 | if log_weights: 465 | return log_graph_weights, log_particle_weights 466 | return log_graph_weights.exp(), log_particle_weights.exp() 467 | 468 | def param_dict(self): 469 | params = super().param_dict() 470 | params.update({'num_particles': self.num_particles, 471 | 'num_mc_graphs': self.num_mc_graphs, 472 | 'embedding_size': self.embedding_size, 473 | 'dibs_plus': self.dibs_plus, 474 | 'linear': self.linear, 475 | 'mechanism_model_params': self.mechanism_model.param_dict(), 476 | 'graph_model_params': self.graph_model.param_dict()}) 477 | return params 478 | 479 | def load_param_dict(self, param_dict): 480 | super().load_param_dict(param_dict) 481 | self.num_particles = param_dict['num_particles'] 482 | self.num_mc_graphs = param_dict['num_mc_graphs'] 483 | self.embedding_size = param_dict['embedding_size'] 484 | self.dibs_plus = param_dict['dibs_plus'] 485 | self.linear = param_dict['linear'] 486 | self.mechanism_model.load_param_dict(param_dict['mechanism_model_params']) 487 | self.graph_model.load_param_dict(param_dict['graph_model_params']) 488 | 489 | def save(self, path): 490 | torch.save(self.param_dict(), path) 491 | 492 | @classmethod 493 | def load(cls, path, num_workers: int = 1): 494 | param_dict = torch.load(path) 495 | 496 | env_param_dict = param_dict['env_param_dict'] 497 | env = Environment(env_param_dict['num_nodes'], mechanism_model=None, num_test_samples_per_intervention=0, 498 | num_test_queries=0, graph=env_param_dict['graph']) 499 | env.load_param_dict(env_param_dict) 500 | 501 | abci = ABCIDiBSGP(env, param_dict['policy'], param_dict['num_particles'], 502 | param_dict['num_mc_graphs'], param_dict['embedding_size'], num_workers, 503 | param_dict['dibs_plus'], param_dict['linear']) 504 | abci.load_param_dict(param_dict) 505 | return abci 506 | --------------------------------------------------------------------------------