├── .gitignore ├── llm_cache.pickle ├── text-davinci-002_llm_cache.pickle ├── text-davinci-003_llm_cache.pickle ├── requirements.txt ├── codebooks ├── cancer.csv ├── asia.csv ├── child.csv ├── insurance.csv └── alarm.csv ├── utils ├── metrics.py ├── plotting.py ├── data_generation.py ├── download_datasets.py ├── dag_utils.py └── language_models.py ├── launch.sh ├── algo ├── global_scoring.py └── greedy_search.py ├── sweep.sh ├── models ├── noisy_expert.py ├── oracles.py └── priors.py ├── README.md └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | figures/ 3 | _raw_bayesian_nets/ 4 | /wandb/ 5 | .DS_Store -------------------------------------------------------------------------------- /llm_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephLong614/Causal-disco-LLM-imperfect-experts/HEAD/llm_cache.pickle -------------------------------------------------------------------------------- /text-davinci-002_llm_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephLong614/Causal-disco-LLM-imperfect-experts/HEAD/text-davinci-002_llm_cache.pickle -------------------------------------------------------------------------------- /text-davinci-003_llm_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephLong614/Causal-disco-LLM-imperfect-experts/HEAD/text-davinci-003_llm_cache.pickle -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bnlearn==0.7.14 2 | causaldag==0.1a163 3 | cdt==0.6.0 4 | networkx==3.0 5 | openai==0.27.2 6 | pandas==1.5.3 7 | pymed==0.8.9 8 | seaborn==0.12.2 9 | wandb==0.15.2 -------------------------------------------------------------------------------- /codebooks/cancer.csv: -------------------------------------------------------------------------------- 1 | node,var_name,var_description 2 | 1,pollution,level of pollution 3 | 2,smoker,smoking status 4 | 3,cancer,presence of cancer 5 | 4,xray,xray 6 | 5,dysponea,laboured breathing 7 | -------------------------------------------------------------------------------- /codebooks/asia.csv: -------------------------------------------------------------------------------- 1 | node,var_name,var_description 2 | 1,asia,visited Asia 3 | 2,tub,tuberculosis 4 | 3,smoke,smoking cigarettes 5 | 4,lung,lung cancer 6 | 5,bronc,bronchitis 7 | 6,either,individual has either tuberculosis or lung cancer 8 | 7,xray,positive xray 9 | 8,dysp,"dyspnoae, laboured breathing " 10 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from cdt.metrics import SHD 2 | import networkx as nx 3 | import numpy as np 4 | from utils.dag_utils import list_of_tuples_to_digraph 5 | 6 | def get_mec_shd(true_G, mec): 7 | """ 8 | the graphs need to be ordered to be comparable 9 | """ 10 | target = nx.to_numpy_array(true_G) 11 | pred = np.stack([nx.to_numpy_array(list_of_tuples_to_digraph(dag)) for dag in mec], -1).sum(-1) 12 | pred = np.clip(pred, 0, 1) 13 | shd = SHD(target, pred, double_for_anticausal=False) 14 | return shd, pred -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import seaborn as sns 6 | 7 | 8 | def plot_heatmap(g, lbls, dataset='', method=None, name='', base_dir='figures/'): 9 | dir_ = base_dir + dataset + '/' 10 | if method: 11 | dir_ = dir_ + method + '/' 12 | 13 | if not os.path.exists(dir_): 14 | os.makedirs(dir_) 15 | 16 | ax = sns.heatmap(g, square=True) 17 | # ax.set_xticks(range(len(lbls))) 18 | ax.set_xticklabels(lbls, rotation=90) 19 | # ax.set_yticks(range(len(lbls))) 20 | ax.set_yticklabels(lbls, rotation=0) 21 | plt.tight_layout() 22 | plt.savefig(dir_ + name) 23 | plt.close() 24 | 25 | -------------------------------------------------------------------------------- /launch.sh: -------------------------------------------------------------------------------- 1 | for data in child alarm asia cancer insurance; do 2 | for prior in 0 1; do 3 | for epsilon in 0.01 0.05 0.1; do 4 | python main.py --dataset $data --tabular 1 --epsilon $epsilon --algo global_scoring --uniform-prior $prior 5 | done 6 | python main.py --dataset $data --tabular 0 --algo global_scoring --uniform-prior $prior 7 | for tolerance in 0.1 0.25 0.5; do 8 | python main.py --dataset $data --tolerance $tolerance --tabular 0 --algo greedy --uniform-prior $prior 9 | for epsilon in 0.01 0.05 0.1; do 10 | python main.py --dataset $data --tolerance $tolerance --tabular 1 --epsilon $epsilon --algo greedy --uniform-prior $prior 11 | done 12 | done 13 | done 14 | done -------------------------------------------------------------------------------- /algo/global_scoring.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.dag_utils import get_mec 4 | 5 | def global_scoring(observed_arcs, model, cpdag, undirected_edges, **kwargs): 6 | 7 | all_scores = [] 8 | 9 | mec = get_mec(cpdag) 10 | for dag in mec: 11 | score, denom = 0, 0 12 | 13 | # only score edges that are not yet determined 14 | for edge in (set(dag) - cpdag.arcs): 15 | # score += int(edge in observed_arcs) 16 | score += int(model([edge], [edge]) > 0.5) 17 | denom += 1 18 | 19 | all_scores.append(score/denom) 20 | top_indices = np.argwhere(all_scores == np.amax(all_scores)).flatten() 21 | 22 | mec = [mec[i] for i in top_indices] 23 | scores = [all_scores[i] for i in top_indices] 24 | 25 | return mec, dict(), np.mean(scores) -------------------------------------------------------------------------------- /sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | d=$1 # dataset 3 | w=$2 # wandb project 4 | for s in 965079 79707 239916 537973 953100 5 | do 6 | for a in "greedy_conf" 7 | do 8 | for t in 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. 9 | do 10 | for e in 0.1 0.3 11 | do 12 | python3 main.py --seed=$s --tol $t --prior mec --epsilon $e --algo $a --tabular --dataset $d --wandb --wandb-project $w 13 | done 14 | 15 | for e in "text-davinci-002" "text-davinci-003" 16 | do 17 | python3 main.py --seed=$s --tol $t --prior mec --algo $a --dataset $d --wandb --llm-engine $e --wandb-project $w 18 | 19 | done 20 | 21 | done 22 | done 23 | 24 | for e in "text-davinci-002" "text-davinci-003" 25 | do 26 | python3 main.py --seed=$s --prior mec --algo "naive" --dataset $d --wandb --llm-engine $e --wandb-project $w 27 | done 28 | done 29 | 30 | python3 main.py --algo "PC" --dataset $d --wandb --wandb-project $w 31 | -------------------------------------------------------------------------------- /codebooks/child.csv: -------------------------------------------------------------------------------- 1 | node,var_name,var_description 2 | 1,BirthAsphyxia,lack of oxygen to the blood during the infant's birth 3 | 2,Disease,infant methemoglobinemia 4 | 3,Age,age of infant at disease presentation 5 | 4,LVH,thickening of the left ventricle 6 | 5,DuctFlow,blood flow across the ductus arteriosus 7 | 6,CardiacMixing,mixing of oxygenated and deoxygenated blood 8 | 7,LungParench,the state of the blood vessels in the lungs 9 | 8,LungFlow,low blood flow in the lungs 10 | 9,Sick,presence of an illness 11 | 10,HypoxiaDistribution,low oxygen areas equally distributed around the body 12 | 11,O2Hypoxia,hypoxia when breathing oxygen 13 | 12,CO2,level of CO2 in the body 14 | 13,ChestXray,having a chest x-ray 15 | 14,Grunting ,grunting in infants 16 | 15,LVHreport,report of having LVH 17 | 16,LowerBodO2,level of oxygen in the lower body 18 | 17,RightQuadO2,level of oxygen in the right up quadricep muscule 19 | 18,CO2Report,a document reporting high level of CO2 levels in blood 20 | 19,XrayReport,lung excessively filled with blood 21 | 20,GruntingReport,report of infant grunting 22 | -------------------------------------------------------------------------------- /codebooks/insurance.csv: -------------------------------------------------------------------------------- 1 | node,var_name,var_description 2 | 1,Age,age 3 | 2,SocioEcon,socioeconomic status 4 | 3,RiskAversion,being risk averse 5 | 4,GoodStudent,being a good student driver 6 | 5,SeniorTrain,received additional driving training 7 | 6,DrivingSkill,driving skill 8 | 7,MedCost,cost of medical treatment 9 | 8,OtherCar,being involved with other cars in the accident 10 | 9,MakeModel,owning a sport car 11 | 10,VehicleYear,year of vehicle 12 | 11,HomeBase,neighbourhood type 13 | 12,AntiTheft,car has anti-theft 14 | 13,DrivHist,driving history 15 | 14,DrivQuality,driving quality 16 | 15,Airbag,airbag 17 | 16,Antilock,anti-lock 18 | 17,RuggedAuto,ruggedness of the car 19 | 18,CarValue,value of the car 20 | 19,Mileage,how much mileage is on the car 21 | 20,Accident,severity of the accident 22 | 21,Cushioning,quality of cushioning in car 23 | 22,Theft,theft occurred on the car 24 | 23,ILiCost,inspection cost 25 | 24,OtherCarCost,cost of the other cars 26 | 25,ThisCarDam,damage to the car 27 | 26,ThisCarCost,costs for the insured car 28 | 27,PropCost,ratio of the cost for the two cars 29 | -------------------------------------------------------------------------------- /utils/data_generation.py: -------------------------------------------------------------------------------- 1 | import bnlearn as bn 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import pprint 6 | import sys 7 | 8 | from utils.dag_utils import order_graph 9 | 10 | pp = pprint.PrettyPrinter(width=82, compact=True) 11 | 12 | # Utility functions to mute printing done in BNLearn 13 | class HiddenPrints: 14 | def __enter__(self): 15 | self._original_stdout = sys.stdout 16 | sys.stdout = open(os.devnull, 'w') 17 | 18 | def __exit__(self, exc_type, exc_val, exc_tb): 19 | sys.stdout.close() 20 | sys.stdout = self._original_stdout 21 | 22 | def generate_dataset(bn_path, n=1000): 23 | 24 | # Load DAG, probability tables, etc. 25 | with HiddenPrints(): 26 | model = bn.import_DAG(bn_path, verbose=1) 27 | 28 | G = nx.from_pandas_adjacency(model["adjmat"].astype(int), create_using=nx.DiGraph) 29 | 30 | # Sample data 31 | data = bn.sampling(model, n=n, verbose=1) 32 | 33 | # Label nodes in causal graph 34 | nx.relabel_nodes(G, dict(zip(range(len(G.nodes())), data.columns))) 35 | G = order_graph(G) 36 | 37 | return G, data -------------------------------------------------------------------------------- /codebooks/alarm.csv: -------------------------------------------------------------------------------- 1 | node,var_name,var_description 2 | 1,MINVOLSET,the amount of time using a breathing machine 3 | 2,VENTMACH,the intensity level of a breathing machine 4 | 3,disconnect,disconnection 5 | 4,venttube,breathing tube 6 | 5,kinkedtube,kinked chest tube 7 | 6,intubation,intubation 8 | 7,PULMEMBOLUS,sudden blockage in the pulmonary arteries 9 | 8,ventlung,lung ventilation 10 | 9,press,breathing pressure 11 | 10,shunt,shunt - normal and high 12 | 11,PAP,pulmonary artery pressure 13 | 12,FIO2,high concentration of oxygen in the gas mixture 14 | 13,minvol,minute volume 15 | 14,ventalv,alveolar ventilation 16 | 15,ANAPHYLAXIS,anaphylaxis 17 | 16,pvsat,pulmonary artery oxygen saturation 18 | 17,artCO2,arterial CO2 19 | 18,TPR,total peripheral resistance 20 | 19,insuffanesth,insufficient anesthesia 21 | 20,SaO2,oxygen saturation 22 | 21,expCO2,expelled CO2 23 | 22,LVFAILURE,left ventricular failure 24 | 23,hypovolemia,hypovolemia 25 | 24,catechol,catecholamine 26 | 25,HISTORY,previous medical history 27 | 26,lvedvolume,left ventricular end-diastolic volume 28 | 27,strokevolume,stroke volume 29 | 28,errlowoutput,error low output 30 | 29,HR,heart rate 31 | 30,errcauter,error cauter 32 | 31,PCWP,pulmonary capillary wedge pressure 33 | 32,CVP,central venous pressure 34 | 33,CO,cardiac output 35 | 34,HRBP,heart rate blood pressure 36 | 35,HRsat,oxygen saturation 37 | 36,HREKG,Heart rate displayed on EKG monitor 38 | 37,BP,blood pressure 39 | -------------------------------------------------------------------------------- /utils/download_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_datasets(): 4 | DATA_URL = \ 5 | [ 6 | # Small 7 | "https://www.bnlearn.com/bnrepository/sachs/sachs.bif.gz", 8 | "https://www.bnlearn.com/bnrepository/asia/asia.bif.gz", 9 | "https://www.bnlearn.com/bnrepository/cancer/cancer.bif.gz", 10 | 11 | # Medium 12 | "https://www.bnlearn.com/bnrepository/alarm/alarm.bif.gz", 13 | "https://www.bnlearn.com/bnrepository/barley/barley.bif.gz", 14 | "https://www.bnlearn.com/bnrepository/child/child.bif.gz", 15 | "https://www.bnlearn.com/bnrepository/insurance/insurance.bif.gz", 16 | "https://www.bnlearn.com/bnrepository/mildew/mildew.bif.gz", 17 | "https://www.bnlearn.com/bnrepository/water/water.bif.gz", 18 | 19 | # Large 20 | "https://www.bnlearn.com/bnrepository/hailfinder/hailfinder.bif.gz", 21 | "https://www.bnlearn.com/bnrepository/hepar2/hepar2.bif.gz", 22 | "https://www.bnlearn.com/bnrepository/win95pts/win95pts.bif.gz" 23 | ] 24 | DATA_URL = {os.path.basename(u).replace(".bif.gz", ""): u for u in DATA_URL} 25 | 26 | os.makedirs("_raw_bayesian_nets", exist_ok=True) 27 | for name, u in DATA_URL.items(): 28 | os.system(f"wget {u} -q -O _raw_bayesian_nets/{os.path.basename(u)} && gunzip -f _raw_bayesian_nets/{os.path.basename(u)} && echo '{name}: success' || echo '{name}': FAILED!") 29 | 30 | SUPPORTED_DATASETS = {name: f"_raw_bayesian_nets/{name}.bif" for name in DATA_URL} 31 | -------------------------------------------------------------------------------- /utils/dag_utils.py: -------------------------------------------------------------------------------- 1 | from causaldag import DAG 2 | import networkx as nx 3 | import numpy as np 4 | 5 | def get_mec(G_cpdag): 6 | 7 | mec = [] 8 | 9 | for dag in G_cpdag.all_dags(): 10 | g = [edge for edge in dag] 11 | mec.append(g) 12 | 13 | return mec 14 | 15 | def get_undirected_edges(true_G, verbose=False): 16 | 17 | dag = DAG.from_nx(true_G) 18 | edges = dag.arcs - dag.cpdag().arcs 19 | 20 | if verbose: 21 | print("Unoriented edges: ", edges) 22 | 23 | return edges 24 | 25 | def get_decisions_from_mec(mec, undirected_edges): 26 | decisions = [] 27 | 28 | for edge in undirected_edges: 29 | node_i = edge[0] 30 | node_j = edge[1] 31 | i_j = np.sum([((node_i, node_j) in dag) for dag in mec]) 32 | j_i = np.sum([((node_j, node_i) in dag) for dag in mec]) 33 | # if i_j and j_i we don't have to make a decision 34 | if not (i_j and j_i): 35 | if i_j: 36 | decisions.append((node_i, node_j)) 37 | else: 38 | decisions.append((node_j, node_i)) 39 | 40 | return decisions 41 | 42 | def order_graph(graph): 43 | H = nx.DiGraph() 44 | #print(graph.nodes) 45 | H.add_nodes_from(sorted(graph.nodes(data=True))) 46 | H.add_edges_from(graph.edges(data=True)) 47 | return H 48 | 49 | 50 | def list_of_tuples_to_digraph(list_of_tuples): 51 | G = nx.DiGraph() 52 | # Add nodes best_graph 53 | for edge in list_of_tuples: 54 | node_i = edge[0] 55 | node_j = edge[1] 56 | G.add_edge(node_i, node_j) 57 | G = order_graph(G) 58 | return G 59 | 60 | def is_dag_in_mec(G, mec): 61 | 62 | for dag in mec: 63 | ans = True 64 | for edge in dag: 65 | if edge not in G.edges: 66 | ans = False 67 | break 68 | if ans: 69 | return 1. 70 | 71 | return 0. -------------------------------------------------------------------------------- /models/noisy_expert.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import numpy as np 3 | 4 | class NoisyExpert(object): 5 | def __init__(self, prior, likelihoods): 6 | """ 7 | A class for calculating the posterior probability of edge orientations (arcs) in a Markov 8 | Equivalence Class, based on the assumption that when queried for an edge the expert returns 9 | an orientation that depends only on the true orientation of that edge. 10 | Only edges that are undirected in the CPDAG are considered. 11 | 12 | Parameters 13 | ---------- 14 | 15 | prior : models.priors.Prior 16 | Provides the prior probability of any subset of arcs 17 | 18 | likelihoods : dict {arc: arc_likelihood} 19 | Likelihood of each observed arc 20 | """ 21 | 22 | self._prior = prior 23 | self._likelihoods = likelihoods 24 | 25 | def _partition_function(self, obs_arcs): 26 | 27 | ans = 0. 28 | 29 | for prob, true_arcs in self._prior.enumerate(): 30 | 31 | # p(obs|true) 32 | ans += self.likelihood(obs_arcs, true_arcs) * prob 33 | 34 | return ans 35 | 36 | # likelihood of a set of orientations is factorizable P(O|E) 37 | def likelihood(self, obs_arcs, true_arcs): 38 | # TODO: check lists contain exactly one orientation of each edge 39 | 40 | ans = 1. 41 | for (x1, x2) in obs_arcs: 42 | 43 | if (x1, x2) in true_arcs: 44 | ans *= self._likelihoods[(x1, x2)] 45 | else: 46 | ans *= 1 - self._likelihoods[(x1, x2)] 47 | 48 | return ans 49 | 50 | # posterior probability of subset of possible arcs 51 | def posterior(self, obs_arcs, true_arcs): 52 | 53 | ans = 0. 54 | 55 | # if an edge is not oriented in the decision, need to marginalize over it 56 | margin_edges = list() 57 | for (x1, x2) in obs_arcs: 58 | 59 | if ((x1, x2) not in true_arcs) and ((x2, x1) not in true_arcs): 60 | margin_edges.append(((x1, x2), (x2, x1))) 61 | 62 | if len(margin_edges) > 0: 63 | 64 | # Cartesian product over possible orientations of edges 65 | for arcs in product(*margin_edges): 66 | ans += self.likelihood(obs_arcs, list(arcs) + true_arcs) * self._prior(list(arcs) + true_arcs) 67 | 68 | else: 69 | ans = self.likelihood(obs_arcs, true_arcs) * self._prior(true_arcs) 70 | 71 | return ans / self._partition_function(obs_arcs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causal Discovery using Language Models 2 | Stephanie Long, Alexandre Piché, Valentina Zantedeschi, Tibor Schuster, Alexandre Drouin (2023). [Structured Probabilistic Inference & Generative Modeling (SPIGM) workshop](https://spigmworkshop.github.io/). *International Conference in Machine Learning* (ICML). 3 | 4 | > Understanding the causal relationships that underlie a system is a fundamental prerequisite to accurate decision-making. In this work, we explore how expert knowledge can be used to improve the data-driven identification of causal graphs, beyond Markov equivalence classes. In doing so, we consider a setting where we can query an expert about the orientation of causal relationships between variables, but where the expert may provide erroneous information. We propose strategies for amending such expert knowledge based on consistency properties, e.g., acyclicity and conditional independencies in the equivalence class. We then report a case study, on real data, where a large language model is used as an imperfect expert. 5 | 6 | [[Paper]](https://arxiv.org/abs/2307.02390) 7 | 8 | 9 | ![greedy_conf_main](https://github.com/StephLong614/Causal-disco-LLM-imperfect-experts/assets/17014892/3f13bfb8-e125-4c4b-887b-3c576b1a4e01) 10 | 11 | # Running experiments 12 | 13 | To run our greedy algorithm with the S_risk strategy (selecting at each iteration the edge that leads to the lowest risk of excluding the true graph) 14 | ```python3 15 | python3 main.py --llm-engine text-davinci-002 --algo greedy_conf --dataset child --tol 0.1 --seed=965079 16 | ``` 17 | 18 | To run our greedy algorithm with the S_size strategy (selecting at each iteration the edge that leads to the smallest equivalence class) 19 | ```python3 20 | python3 main.py --llm-engine text-davinci-002 --algo greedy_mec --dataset child --tol 0.1 --seed=965079 21 | ``` 22 | 23 | You can use text-davinci-003 instead, by specifying "--llm-engine text-davinci-002", or for calling the epsilon-expert, specify: "--tabular --epsilon " 24 | Use "--tol" to select a different tolerance level. 25 | 26 | [OpenAI text-davinci models will be deprecated on Jan 4th 2024](https://platform.openai.com/docs/models/gpt-3-5). 27 | To reproduce our experiments, we cached the calls to OpenAI API in [text-davinci-002_llm_cache.pickle](https://github.com/StephLong614/Causal-disco-LLM-imperfect-experts/blob/main/text-davinci-002_llm_cache.pickle) and [text-davinci-003_llm_cache.pickle](https://github.com/StephLong614/Causal-disco-LLM-imperfect-experts/blob/main/text-davinci-003_llm_cache.pickle) for the seeds: 965079 79707 239916 537973 953100. 28 | Our code automatically loads these pickles. 29 | 30 | # Citing this work 31 | Please use the following Bibtex entry to cite this work: 32 | ``` 33 | @misc{long2023causal, 34 | title={Causal Discovery with Language Models as Imperfect Experts}, 35 | author={Stephanie Long and Alexandre Piché and Valentina Zantedeschi and Tibor Schuster and Alexandre Drouin}, 36 | year={2023}, 37 | eprint={2307.02390}, 38 | archivePrefix={arXiv}, 39 | primaryClass={cs.AI} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /models/oracles.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class BaseOracle(object): 4 | """ 5 | Base class for an expert that can be queried to obtain edge orientations 6 | 7 | """ 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.likelihoods = dict() 13 | 14 | def decide(self, x1, x2): 15 | """ 16 | Get edge orientation from expert 17 | 18 | Parameters: 19 | ----------- 20 | x1: object 21 | A node in the graph 22 | x2: object 23 | Another node in the graph 24 | 25 | Returns: 26 | -------- 27 | orientation: tuple 28 | The orientation of the edge (source, target) according to the expert. 29 | 30 | """ 31 | raise NotImplementedError() 32 | 33 | class EpsilonOracle(BaseOracle): 34 | """ 35 | An expert that randomly lies about the true edge orientations with some probability. 36 | 37 | """ 38 | 39 | def __init__(self, arcs, epsilon=0.05, random_state=np.random): 40 | """ 41 | Constructs the expert 42 | 43 | Parameters: 44 | ----------- 45 | arcs: list of tuples 46 | A list of the ground truth edge orientations 47 | epsilon: float, default: 0.05 48 | The probability with which the expert returns an incorrect orientation 49 | random_state: np.random.RandomState, default: np.random 50 | The random state to use for randomness 51 | 52 | """ 53 | self.arcs = arcs 54 | self.epsilon = epsilon 55 | self.random_state = random_state 56 | super().__init__() 57 | 58 | def decide(self, x1, x2): 59 | """ 60 | Get edge orientation from expert 61 | 62 | Parameters: 63 | ----------- 64 | x1: object 65 | A node in the graph 66 | x2: object 67 | Another node in the graph 68 | 69 | Returns: 70 | -------- 71 | orientation: tuple 72 | The orientation of the edge (source, target) according to the expert. 73 | Note that the expert lies with epsilon probability. 74 | 75 | """ 76 | if (x1, x2) not in self.arcs and (x2, x1) not in self.arcs: 77 | raise ValueError(f"Edge {x1}--{x2} not in graph.") 78 | 79 | if (x1, x2) in self.arcs: 80 | true_edge, false_edge = (x1, x2), (x2, x1) 81 | else: 82 | true_edge, false_edge = (x2, x1), (x1, x2) 83 | 84 | self.likelihoods[false_edge] = self.epsilon 85 | self.likelihoods[true_edge] = 1 - self.epsilon 86 | 87 | if self.random_state.rand() < self.epsilon: 88 | return false_edge 89 | 90 | else: 91 | return true_edge 92 | 93 | def decide_all(self): 94 | 95 | observations = [] 96 | for arc in self.arcs: 97 | 98 | # Get decision by expert 99 | observations.append(self.decide(*arc)) 100 | 101 | return observations -------------------------------------------------------------------------------- /algo/greedy_search.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | from utils.dag_utils import get_decisions_from_mec, get_mec 5 | 6 | def greedy_search_mec_size(observed_arcs, model, cpdag, undirected_edges, tol=0.501): 7 | decisions = [] 8 | p_correct = 1. 9 | mec = get_mec(cpdag) 10 | 11 | while (p_correct > 1 - tol) and (len(mec) > 1): 12 | decision_scores_ = {} 13 | 14 | for decision_potential in observed_arcs: 15 | 16 | potential_new_mec = [dag for dag in mec if decision_potential in dag] 17 | resulting_decisions = get_decisions_from_mec(potential_new_mec, undirected_edges) 18 | dec = { 19 | 'resulting_decisions': resulting_decisions, 20 | 'probability': model(observed_arcs, resulting_decisions), 21 | 'mec_size': len(potential_new_mec) 22 | } 23 | 24 | if (p_correct * dec['probability'] > 1 - tol) and (len(potential_new_mec) > 0): 25 | decision_scores_[decision_potential] = dec 26 | 27 | if len(decision_scores_) > 0: 28 | decision_scores_ = sorted(decision_scores_.items(), key=lambda item: item[1]['mec_size'], reverse=False) 29 | decision_taken = decision_scores_[0] 30 | else: 31 | break 32 | 33 | decision_taken = decision_scores_[0] 34 | decisions = decision_taken[1]['resulting_decisions'] 35 | mec = [dag for dag in mec if decision_taken[0] in dag] 36 | p_correct *= decision_taken[1]['probability'] 37 | 38 | return mec, decisions, p_correct 39 | 40 | def greedy_search_confidence(observed_arcs, model, cpdag, undirected_edges, tol=0.501): 41 | decisions = [] 42 | p_correct = 1. 43 | mec = get_mec(cpdag) 44 | possible_decisions = copy.copy(observed_arcs) 45 | 46 | while (p_correct > 1 - tol) and (len(mec) > 1): 47 | decision_scores_ = {} 48 | 49 | for decision_potential in possible_decisions: 50 | 51 | potential_new_mec = [dag for dag in mec if decision_potential in dag] 52 | resulting_decisions = get_decisions_from_mec(potential_new_mec, undirected_edges) 53 | dec = { 54 | 'resulting_decisions': resulting_decisions, 55 | 'probability': model(observed_arcs, resulting_decisions), 56 | 'mec_size': len(potential_new_mec) 57 | } 58 | 59 | if (p_correct * dec['probability'] > 1 - tol) and (len(potential_new_mec) > 0): 60 | decision_scores_[decision_potential] = dec 61 | 62 | if len(decision_scores_) > 0: 63 | decision_scores_ = sorted(decision_scores_.items(), key=lambda item: item[1]['probability'], reverse=True) 64 | decision_taken = decision_scores_[0] 65 | else: 66 | break 67 | 68 | decision_taken = decision_scores_[0] 69 | possible_decisions.remove(decision_taken[0]) 70 | 71 | decisions = decision_taken[1]['resulting_decisions'] 72 | mec = [dag for dag in mec if decision_taken[0] in dag] 73 | p_correct *= decision_taken[1]['probability'] 74 | 75 | return mec, decisions, p_correct 76 | 77 | get_cost = lambda p, size: np.log(p) - 0.5 * size 78 | 79 | def greedy_search_bic(observed_arcs, model, cpdag, undirected_edges, **kwargs): 80 | decisions = [] 81 | 82 | p_correct = 1. 83 | past_decision_score = -10000 84 | improvement = 1e-3 85 | mec = get_mec(cpdag) 86 | 87 | while improvement > 0: 88 | decision_scores = {} 89 | 90 | for decision_potential in observed_arcs: 91 | 92 | potential_new_mec = [dag for dag in mec if bool(decision_potential in dag)] 93 | resulting_decisions = get_decisions_from_mec(potential_new_mec, undirected_edges) 94 | dec = { 95 | 'resulting_decisions': resulting_decisions, 96 | 'probability': model(observed_arcs, resulting_decisions), 97 | 'mec_size': len(potential_new_mec) 98 | } 99 | 100 | dec['score'] = get_cost(p=dec['probability'], size=len(potential_new_mec)) 101 | 102 | if (dec['score'] - past_decision_score) > 0: 103 | decision_scores[decision_potential] = dec 104 | 105 | decision_scores_ = sorted(decision_scores.items(), key=lambda item: item[1]['score'], reverse=True) 106 | 107 | if len(decision_scores_) > 0: 108 | decision_taken = decision_scores_[0] 109 | else: 110 | break 111 | 112 | decisions = decision_taken[1]['resulting_decisions'] 113 | improvement = decision_taken[1]['score'] - past_decision_score 114 | past_decision_score = decision_taken[1]['score'] 115 | mec = [dag for dag in mec if bool(decision_taken[0] in dag)] 116 | p_correct *= decision_taken[1]['probability'] 117 | 118 | return mec, decisions, p_correct -------------------------------------------------------------------------------- /utils/language_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import openai 3 | import pickle 4 | import random 5 | 6 | from scipy.special import softmax 7 | from scipy.optimize import fsolve 8 | 9 | PROMPT_TEMPLATE = """ 10 | Among these two options which one is the most likely true: 11 | (A) {0} {2} {1} 12 | (B) {1} {2} {0} 13 | The answer is: 14 | """ 15 | 16 | OPTIONS = ['(A)', '(B)'] 17 | 18 | LOCK_TOKEN = ' (' 19 | 20 | VERBS = ["provokes", " triggers","causes", "leads to", "induces", "results in", "brings about", "yields", "generates", "initiates", "produces", "stimulates", "instigates", "fosters", "engenders", "promotes", "catalyzes", "gives rise to", "spurs", "sparks"] 21 | 22 | def get_prompt(edge, codebook, verb=None): 23 | 24 | node_i, node_j = edge 25 | long_name_node_i = codebook.loc[codebook['var_name']==node_i, 'var_description'].to_string(index=False) 26 | long_name_node_j = codebook.loc[codebook['var_name']==node_j, 'var_description'].to_string(index=False) 27 | 28 | if 'Series' in long_name_node_i: 29 | print(f"{node_i} is not defined") 30 | if 'Series' in long_name_node_j: 31 | print(f"{node_j} is not defined") 32 | 33 | if verb is None: 34 | verb = random.choice(VERBS) 35 | 36 | options = PROMPT_TEMPLATE.format(long_name_node_i, long_name_node_j, verb) 37 | 38 | return options 39 | 40 | def get_lms_probs(undirected_edges, codebook, tmp_scaling=1, engine='davinci-002'): 41 | """ 42 | return: dictionary of tuple and their likelihood of being wrong by the LM 43 | example {('Age', 'Disease'): 0.05, ...} 44 | """ 45 | 46 | gpt3_decision_probs = {} 47 | decisions = [] 48 | 49 | for edge in undirected_edges: 50 | 51 | log_scores = gpt3_scoring(edge, codebook, options=OPTIONS, lock_token=LOCK_TOKEN, engine=engine) 52 | scores = softmax(log_scores / tmp_scaling) 53 | 54 | 55 | gpt3_decision_probs[(edge[0], edge[1])] = scores[0] 56 | gpt3_decision_probs[(edge[1], edge[0])] = scores[1] 57 | 58 | if scores[0] > scores[1]: 59 | decisions.append((edge[0], edge[1])) 60 | else: 61 | decisions.append((edge[1], edge[0])) 62 | 63 | return gpt3_decision_probs, decisions 64 | 65 | def temperature_scaling(directed_edges, codebook, engine): 66 | err_scores = [] 67 | num_errs = 0 68 | 69 | for edge in directed_edges: 70 | # node_i -> node_j 71 | options = get_prompt(edge, codebook) 72 | 73 | log_scores = gpt3_scoring(options, options=OPTIONS, lock_token=LOCK_TOKEN, engine=engine) 74 | 75 | if log_scores[0] < log_scores[1]: 76 | num_errs += 1 77 | err_scores.append(log_scores[1]) 78 | print(edge) 79 | else: 80 | err_scores.append(log_scores[0]) 81 | 82 | estimated_error = num_errs / len(directed_edges) 83 | err_scores = np.array(err_scores) 84 | 85 | equation = lambda t: np.average(np.exp(err_scores / t) / (np.exp(err_scores / t) + np.exp((1 - err_scores) / t))) - estimated_error 86 | 87 | temperature = fsolve(equation, 1.) 88 | print(np.average(np.exp(err_scores / temperature) / (np.exp(err_scores / temperature) + np.exp((1 - err_scores) / temperature)))) 89 | 90 | return float(temperature), estimated_error 91 | 92 | 93 | def gpt3_call(engine, edge, codebook, options, max_tokens=128, temperature=0, 94 | logprobs=1, echo=False, cache_file='llm_cache.pickle'): 95 | cache_file = engine + '_llm_cache.pickle' 96 | LLM_CACHE = {} 97 | try: 98 | with open(cache_file, 'rb') as f: 99 | LLM_CACHE = pickle.load(f) 100 | except: 101 | pass 102 | 103 | verbs = random.sample(VERBS, len(VERBS)) 104 | for verb in verbs: 105 | prompt = get_prompt(edge, codebook, verb) 106 | gpt3_prompt_options = [f"{prompt}{o}" for o in options] 107 | 108 | full_query = "" 109 | for p in gpt3_prompt_options: 110 | full_query += p 111 | 112 | id = tuple((engine, full_query, max_tokens, temperature, logprobs, echo)) 113 | if id in LLM_CACHE.keys(): 114 | response = LLM_CACHE[id] 115 | break 116 | 117 | # if ID is not in pickle (with any verb option) 118 | else: 119 | print('no cache hit, api call') 120 | response = openai.Completion.create(engine=engine, 121 | prompt=gpt3_prompt_options, 122 | max_tokens=max_tokens, 123 | temperature=temperature, 124 | logprobs=logprobs, 125 | echo=echo) 126 | LLM_CACHE[id] = response 127 | with open(cache_file, 'wb') as f: 128 | pickle.dump(LLM_CACHE, f) 129 | return response 130 | 131 | 132 | def gpt3_scoring(edge, codebook, options, engine="text-davinci-002", verbose=False, n_tokens_score=9999999999, lock_token=None, ): 133 | verbose and print("Scoring", len(options), "options") 134 | 135 | response = gpt3_call(engine, edge, codebook, options, max_tokens=0, logprobs=1, temperature=0, echo=True, ) 136 | scores = [] 137 | for option, choice in zip(options, response["choices"]): 138 | if lock_token is not None: 139 | n_tokens_score = choice["logprobs"]["tokens"][::-1].index(lock_token) 140 | tokens = choice["logprobs"]["tokens"][-n_tokens_score:] 141 | verbose and print("Tokens:", tokens) 142 | token_logprobs = choice["logprobs"]["token_logprobs"][-n_tokens_score:] 143 | total_logprob = 0 144 | denom = 0 145 | for token, token_logprob in zip(reversed(tokens), reversed(token_logprobs)): 146 | if token_logprob is not None: 147 | denom += 1 148 | total_logprob += token_logprob 149 | scores.append(total_logprob) 150 | return np.array(scores) 151 | 152 | 153 | if __name__ == '__main__': 154 | options = """ 155 | Options: 156 | (A) Cancer causes smoking 157 | (B) Smoking causes cancer 158 | The answer is: 159 | """ 160 | log_scores = gpt3_scoring(options, options=['(A)', '(B)'], lock_token=' (') 161 | scores = softmax(log_scores) 162 | print(scores) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from causaldag import DAG 4 | 5 | import networkx as nx 6 | import os 7 | import pandas as pd 8 | import wandb 9 | import numpy as np 10 | import random 11 | 12 | from algo.greedy_search import greedy_search_mec_size, greedy_search_confidence, greedy_search_bic 13 | 14 | from models.noisy_expert import NoisyExpert 15 | from models.oracles import EpsilonOracle 16 | 17 | from utils.data_generation import generate_dataset 18 | from utils.dag_utils import get_undirected_edges, is_dag_in_mec, get_mec 19 | from utils.metrics import get_mec_shd 20 | from utils.language_models import get_lms_probs, temperature_scaling 21 | 22 | parser = argparse.ArgumentParser(description='Description of your program.') 23 | 24 | # Add arguments 25 | parser.add_argument('--algo', default="greedy_conf", choices=["greedy_mec", "greedy_conf", "greedy_bic", "global_scoring", "PC", "naive"], help='What algorithm to use') 26 | parser.add_argument('--dataset', default="child", type=str, help='What dataset to use') 27 | parser.add_argument('--tabular', default=False, action="store_true", help='Use tabular expert, else use gpt3') 28 | parser.add_argument('--prior', default="mec", choices=["mec", "independent"]) 29 | parser.add_argument('--probability', default="posterior", choices=["posterior", "prior", "likelihood"]) 30 | 31 | parser.add_argument('--wandb-project', default='noisy expert', type=str, help='Name of your wandb project') 32 | parser.add_argument('--llm-engine', default='text-davinci-002') 33 | parser.add_argument('--calibrate', default=False, action="store_true", help='Calibrate gpt3') 34 | 35 | 36 | parser.add_argument('--epsilon', default=0.05, type=float, help='expert error rate') 37 | parser.add_argument('-tol', '--tolerance', default=0.1, type=float, help='algorithm error tolerance') 38 | 39 | parser.add_argument('--seed', type=int, default=20230515, help='random seed') 40 | parser.add_argument('--verbose', default=False, action="store_true", help='For debugging purposes') 41 | parser.add_argument('--wandb', default=False, action="store_true", help='to log on wandb') 42 | 43 | def blindly_follow_expert(observed_arcs, model, cpdag, *args, **kwargs): 44 | return [list(observed_arcs) + list(cpdag.arcs)], observed_arcs, model(observed_arcs, observed_arcs) 45 | 46 | if __name__ == '__main__': 47 | 48 | args = parser.parse_args() 49 | 50 | wandb.init(config=args, 51 | project=args.wandb_project, 52 | mode=None if args.wandb else 'disabled' 53 | ) 54 | 55 | random.seed(args.seed) 56 | np.random.seed(args.seed) 57 | 58 | match args.algo: 59 | case "greedy_mec": 60 | algo = greedy_search_mec_size 61 | case "greedy_conf": 62 | algo = greedy_search_confidence 63 | case "greedy_bic": 64 | algo = greedy_search_bic 65 | args.tolerance = 1. 66 | 67 | case "global_scoring": 68 | from algo.global_scoring import global_scoring 69 | algo = global_scoring 70 | args.tolerance = 1. 71 | 72 | case "PC": 73 | algo = lambda a, b, cpdag, c, tol: (get_mec(cpdag), dict(), 1.) 74 | args.tolerance = 0. 75 | case "naive": 76 | algo = blindly_follow_expert 77 | args.tolerance = 1. 78 | 79 | match args.prior: 80 | case "mec": 81 | from models.priors import MECPrior 82 | prior_type = MECPrior 83 | 84 | case "independent": 85 | from models.priors import IndependentPrior 86 | prior_type = IndependentPrior 87 | 88 | if not os.path.exists("_raw_bayesian_nets"): 89 | from utils.download_datasets import download_datasets 90 | download_datasets() 91 | 92 | print(args) 93 | 94 | true_G, _ = generate_dataset('_raw_bayesian_nets/' + args.dataset + '.bif') 95 | cpdag = DAG.from_nx(true_G).cpdag() 96 | 97 | undirected_edges = get_undirected_edges(true_G, verbose=args.verbose) 98 | 99 | if args.tabular: 100 | oracle = EpsilonOracle(undirected_edges, epsilon=args.epsilon) 101 | observations = oracle.decide_all() 102 | likelihoods = oracle.likelihoods 103 | 104 | else: 105 | try: 106 | codebook = pd.read_csv('codebooks/' + args.dataset + '.csv') 107 | except: 108 | print('cannot load the codebook') 109 | codebook = None 110 | 111 | if args.calibrate: 112 | tmp_scale, eps = temperature_scaling(cpdag.arcs, codebook, engine=args.llm_engine) 113 | print("LLM has %.3f error rate" % eps) 114 | else: 115 | tmp_scale = 1. 116 | 117 | likelihoods, observations = get_lms_probs(undirected_edges, codebook, tmp_scale, engine=args.llm_engine) 118 | 119 | print("\nTrue Orientations:", undirected_edges) 120 | print("\nOrientations given by the expert:", observations) 121 | print(likelihoods) 122 | prior = prior_type(cpdag) 123 | model = NoisyExpert(prior, likelihoods) 124 | 125 | match args.probability: 126 | case "posterior": 127 | prob_method = model.posterior 128 | 129 | case "likelihood": 130 | prob_method = model.likelihood 131 | 132 | case "prior": 133 | prob_method = lambda _, edges: prior(edges) 134 | 135 | new_mec, decisions, p_correct = algo(observations, prob_method, cpdag, likelihoods, tol=args.tolerance) 136 | 137 | if args.verbose: 138 | print("\nFinal MEC", new_mec) 139 | 140 | shd, learned_adj = get_mec_shd(true_G, new_mec) 141 | 142 | learned_G = nx.from_numpy_array(learned_adj, create_using=nx.DiGraph) 143 | learned_G = nx.relabel_nodes(learned_G, {i: n for i, n in zip(learned_G.nodes, true_G.nodes)}) 144 | 145 | diff = nx.difference(learned_G, true_G) 146 | print("\nFinal wrong orientations:", diff.edges) 147 | 148 | print('\nConfidence true DAG is in final MEC: %.3f' % p_correct) 149 | print("Final MEC's SHD: ", shd) 150 | print('MEC size: ', len(new_mec)) 151 | print('true-still-in-MEC: ', is_dag_in_mec(true_G, new_mec)) 152 | wandb.log({'mec size': len(new_mec), 153 | 'shd': shd, 154 | 'prob-correct': p_correct, 155 | 'true-still-in-MEC': is_dag_in_mec(true_G, new_mec)}) 156 | wandb.finish() -------------------------------------------------------------------------------- /models/priors.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | 5 | from itertools import chain, product 6 | 7 | class Prior(object): 8 | 9 | def __init__(self, cpdag): 10 | """ 11 | A class for calculating the prior probability of arcs in a Markov Equivalence Class. 12 | The probability corresponds to a uniform distribution over DAGs in the MEC. 13 | Hence, only edges that are undirected in the CPDAG are considered/useful. 14 | 15 | Parameters 16 | ---------- 17 | cpdag : causaldag.PDAG 18 | A CPDAG object representing the Markov equivalence class of DAGs. 19 | 20 | """ 21 | 22 | self._cpdag = cpdag 23 | 24 | # Get all orientations for undirected CPDAG edges 25 | self.arcs = list( 26 | chain(*[[(x1, x2), (x2, x1)] for x1, x2 in cpdag.edges]) 27 | ) 28 | 29 | # Assign an ID to each orientation 30 | self._arc_idx = dict( 31 | zip(self.arcs, range(len(self.arcs))) 32 | ) 33 | 34 | # Prior table (n_orientations x n_dags_in_mec) 35 | self._prior = np.zeros( 36 | (len(self.arcs), self.support_size()) 37 | ) 38 | 39 | # Get occurence of each edge in all DAGs of the MEC 40 | for i, dag in enumerate(self._cpdag.all_dags()): 41 | for edge, idx in self._arc_idx.items(): 42 | if edge in dag: 43 | self._prior[idx, i] = 1 44 | 45 | @abstractmethod 46 | def enumerate(self): 47 | """ 48 | All possible complete orientations of edges and their probabilities. 49 | 50 | Returns 51 | ------- 52 | generator 53 | each item is a tuple: (probability, list of arcs) 54 | """ 55 | pass 56 | 57 | @abstractmethod 58 | def support_size(self): 59 | """ 60 | Return the number of possible combinations of arcs (with non-zero probability) 61 | """ 62 | pass 63 | 64 | @abstractmethod 65 | def __call__(self, arcs): 66 | """ 67 | Compute the probability of the given set of arcs in the equivalence class. 68 | 69 | Parameters 70 | ---------- 71 | arcs : list of tuples 72 | A list of arcs, each represented as a tuple (source, target), where source and target 73 | are the nodes in the graph. 74 | 75 | Returns 76 | ------- 77 | float 78 | The probability of the given set of arcs occurring in the equivalence class of DAGs, 79 | computed as the product of the occurrence of each edge in all possible DAGs, divided 80 | by the total number of possible DAGs in the equivalence class. 81 | 82 | """ 83 | pass 84 | 85 | class IndependentPrior(Prior): 86 | 87 | def __init__(self, cpdag, weights="uniform"): 88 | 89 | super(IndependentPrior, self).__init__(cpdag) 90 | 91 | self._arc_weights = np.ones(len(self.arcs)) 92 | 93 | if weights == "uniform": 94 | self._arc_weights /= 2 # exactly 50% chances for an arc of being oriented either way 95 | 96 | elif weights == "occurences": 97 | for arc, idx in self._arc_idx.items(): 98 | self._arc_weights[idx] = self._prior(arc) 99 | 100 | else: 101 | raise NotImplementedError 102 | 103 | def __call__(self, arcs): 104 | """ 105 | Compute the probability of the given set of arcs in the equivalence class, supposing independence. 106 | 107 | Parameters 108 | ---------- 109 | arcs : list of tuples 110 | A list of arcs, each represented as a tuple (source, target), where source and target 111 | are the nodes in the graph. 112 | 113 | Returns 114 | ------- 115 | float 116 | The probability of the given set of arcs as the product of the probability of each arc. 117 | 118 | """ 119 | return np.prod([self._arc_weights[self._arc_idx[e]] for e in arcs]) 120 | 121 | def enumerate(self): 122 | """ 123 | All possible complete orientations of CPDAG's edges. 124 | 125 | Returns 126 | ------- 127 | generator 128 | each item is a list of arcs. 129 | """ 130 | for complete_orientation in product(*[(self.arcs[i], self.arcs[i+1]) for i in range(0, len(self.arcs), 2)]): 131 | 132 | yield self(complete_orientation), complete_orientation 133 | 134 | def support_size(self): 135 | return 2 ** (len(self.arcs) // 2) 136 | 137 | class MECPrior(Prior): 138 | 139 | def __call__(self, arcs): 140 | """ 141 | Compute the probability of the given set of arcs in the equivalence class. 142 | 143 | Parameters 144 | ---------- 145 | arcs : list of tuples 146 | A list of arcs, each represented as a tuple (source, target), where source and target 147 | are the nodes in the graph. 148 | 149 | Returns 150 | ------- 151 | float 152 | The probability of the given set of arcs occurring in the equivalence class of DAGs, 153 | computed as the product of the occurrence of each edge in all possible DAGs, divided 154 | by the total number of possible DAGs in the equivalence class. 155 | 156 | """ 157 | return ( 158 | np.vstack([self._prior[self._arc_idx[e]] for e in arcs]) 159 | .prod(axis=0) 160 | .sum() 161 | / self._prior.shape[1] 162 | ) 163 | 164 | def enumerate(self): 165 | """ 166 | All possible complete orientations of CPDAG's edges. 167 | 168 | Returns 169 | ------- 170 | generator 171 | each item is a list of arcs corresponding to one DAG in the MEC. 172 | """ 173 | for dag in self._cpdag.all_dags(): 174 | complete_orientation = list() 175 | 176 | for x1, x2 in dag: 177 | if self._cpdag.has_edge(x1, x2): 178 | complete_orientation.append((x1, x2)) 179 | 180 | # uniform distribution 181 | yield 1. / self.support_size(), complete_orientation 182 | 183 | def support_size(self): 184 | return len(self._cpdag.all_dags()) --------------------------------------------------------------------------------