├── aux ├── __init__.py └── helper.py ├── models ├── __init__.py ├── ncm.py ├── parameterizedSCM.py ├── scm.py └── tncm.py ├── .gitignore ├── media ├── thumbnail.png ├── Figure-App-NCM-Distr.pdf ├── Fig-Schematic-Diff-SPNs.pdf └── Figure-App-LTNCM-Distr.pdf ├── scripts └── animate-plt.py ├── README.md ├── scm-dgp.py ├── exp2.py ├── ncm-est.py ├── exp1.py └── tncm-est.py /aux/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *DS_Store* 2 | datasets 3 | experiments 4 | *cache* 5 | -------------------------------------------------------------------------------- /media/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zecevic-matej/Not-All-Causal-Inference-is-the-Same/HEAD/media/thumbnail.png -------------------------------------------------------------------------------- /media/Figure-App-NCM-Distr.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zecevic-matej/Not-All-Causal-Inference-is-the-Same/HEAD/media/Figure-App-NCM-Distr.pdf -------------------------------------------------------------------------------- /media/Fig-Schematic-Diff-SPNs.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zecevic-matej/Not-All-Causal-Inference-is-the-Same/HEAD/media/Fig-Schematic-Diff-SPNs.pdf -------------------------------------------------------------------------------- /media/Figure-App-LTNCM-Distr.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zecevic-matej/Not-All-Causal-Inference-is-the-Same/HEAD/media/Figure-App-LTNCM-Distr.pdf -------------------------------------------------------------------------------- /scripts/animate-plt.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | plt.clf() 5 | plt.close() 6 | for i in range(10): 7 | fig, axs = plt.subplots(2, len(ncm.V), sharex=True, sharey=True, figsize=(12, 5)) 8 | for a in axs.flatten(): 9 | a.bar([0,1],np.random.rand(2)) 10 | a.set_xlim(0,1) 11 | a.set_ylim(0,1) 12 | #plt.draw() 13 | plt.pause(0.001) 14 | #input("Press [enter] to continue.") 15 | plt.show() 16 | for i in range(100): 17 | plt.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Not All Causal Inference is the Same [Accepted to TMLR 2023] 2 | 3 | Official code repository for reproducing the empirical section of the paper: "7.1 ‘Bonus:’ An Easy Solution to Speeding Up Mechanism Inference in SCM". 4 | 5 | ![Thumbnail of Figure 3 from Paper](media/thumbnail.png) 6 | 7 | --- 8 | 9 | **Code Structure:** 10 | 11 | * `aux` contains helper functions 12 | * `models` contains base functions like neural nets, sum-product networks but also the actual (T)NCM 13 | * `expX.py` reproduces an experiment as found in the paper (e.g. `exp1.py` reproduces the results from Figure 6 of the main paper) -------------------------------------------------------------------------------- /aux/helper.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | class Logger(object): 4 | def __init__(self, file, live_writing=False): 5 | self.terminal = sys.stdout 6 | if not os.path.isfile(file): 7 | with open(file, "w") as f: 8 | f.write('Start of File.\n\n') 9 | self.file = file 10 | self.live_writing = live_writing 11 | self.log = open(file, "a") 12 | self.no_writing = False 13 | 14 | def write(self, message): 15 | self.terminal.write(message) 16 | if not self.no_writing: 17 | if self.live_writing: 18 | self.log = open(self.file, 'a') 19 | self.log.write(message) 20 | self.log.close() 21 | else: 22 | if "End of File." in message: 23 | self.log.write(message) 24 | self.log.close() 25 | self.no_writing = True 26 | else: 27 | self.log.write(message) 28 | def flush(self): 29 | pass -------------------------------------------------------------------------------- /scm-dgp.py: -------------------------------------------------------------------------------- 1 | from models.scm import * 2 | import os 3 | 4 | np.random.seed(0) 5 | params = [np.round(np.random.uniform(0.1, 0.9, 4),decimals=1) for _ in range(5)] 6 | print(f'Random Parameterizations: {params}') 7 | scms = [BackdoorSCM] 8 | base_dir = './datasets/SCMs/' 9 | if not os.path.exists(base_dir): 10 | os.makedirs(base_dir) 11 | for scm in scms: 12 | for selected in range(len(params)): 13 | scm1 = scm(U_params=params[selected]) 14 | if selected == 0: 15 | print(f'>>> Starting with {scm1.name}\n ' 16 | f'- Adjacency:\n{scm1.adj}') 17 | n_samples = 10000 18 | scm1.sample(n_samples) 19 | ate, l2x1, l2x0 = scm1.ate(n_samples) 20 | base_name = f'{scm1.name}_{params[selected]}p_N{int(n_samples)}' 21 | scm1.l1.to_csv(os.path.join(base_dir, base_name + '_L1' + '.csv')) 22 | l2x1.to_csv(os.path.join(base_dir, base_name + '_doX1' + '.csv')) 23 | l2x0.to_csv(os.path.join(base_dir, base_name + '_doX0' + '.csv')) 24 | pd.DataFrame(np.round(np.array([ate])[:,np.newaxis],decimals=2)).to_csv(os.path.join(base_dir, base_name + '_ATE' + '.csv')) 25 | pd.DataFrame(scm1.adj).to_csv(os.path.join(base_dir, base_name + '_adj' + '.csv')) -------------------------------------------------------------------------------- /models/ncm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .parameterizedSCM import ParameterizedSCM 3 | 4 | 5 | class MLP(torch.nn.Module): 6 | def __init__(self, input_size, output_size, hidden_sizes): 7 | super().__init__() 8 | 9 | # define a simple MLP neural net 10 | self.net = [] 11 | hs = [input_size] + hidden_sizes + [output_size] 12 | for h0, h1 in zip(hs, hs[1:]): 13 | self.net.extend([ 14 | torch.nn.Linear(h0, h1), 15 | torch.nn.ReLU(), 16 | ]) 17 | self.net.pop() # pop the last ReLU for the output layer 18 | #print(f'Layer Sizes: {str([l.weight.shape if isinstance(l, torch.nn.Linear) else None for l in self.net])}') 19 | self.net = torch.nn.Sequential(*self.net) 20 | 21 | def forward(self, input): 22 | return self.net(input) 23 | 24 | 25 | class NCM(ParameterizedSCM): 26 | """ 27 | Classical NCM, own implementation following the Xia et al. 2021 paper. 28 | """ 29 | 30 | def __init__(self, adj, scale=False): 31 | super(NCM, self).__init__(adj) 32 | for V in self.graph: 33 | V_name = self.i2n(V) 34 | pa_V = self.graph[V] 35 | #print(f'Variable {V_name} := ') 36 | hs = [10,10,10]#[2*scale for _ in range(3)] #[10*int(len(pa_V)+1) for _ in range(3)] if scale else [10,10,10] 37 | self.S.update({V_name: 38 | MLP(len(pa_V)+1, 1, hs) 39 | }) 40 | 41 | def params(self): 42 | return [item for sublist in [list(f.parameters()) for f in self.S.values()] for item in sublist] 43 | 44 | def forward(self, v, doX, Xi, samples, debug=False): 45 | 46 | # eine Perle der Codegeschichte - consistency check pure torch 47 | Consistency = torch.ones((samples, 1)) 48 | Consistency = torch.where(doX >= 0, 49 | torch.where(torch.tile(v[:,0].unsqueeze(1),(samples,1)) == doX, 1., 0.), Consistency) 50 | 51 | pVs = [] 52 | for V in self.topology: 53 | 54 | pa_V = self.graph[V] 55 | 56 | V_arg = torch.cat((*[torch.tile(v[:,pa].unsqueeze(1),(samples,1)) for pa in pa_V],self.U[self.i2n(V)](samples)),axis=1) 57 | 58 | pV = torch.sigmoid(self.S[self.i2n(V)](V_arg)) 59 | pV = pV * v[:, V].unsqueeze(1) + (torch.ones_like(v[:, V]) - v[:, V]).unsqueeze(1) * (1 - pV) 60 | 61 | # the intervention checking might do an extra run for the intervened node, not rly important tho 62 | # furthermore, it is not validated yet - and rn we mainly check for interventions on X, where we change 63 | # X,Y relationships 64 | pV = torch.where(torch.tensor(Xi) == torch.tensor(V), torch.where(doX >= 0, torch.ones((samples, 1)), pV), pV) 65 | 66 | pVs.append(pV) 67 | 68 | pV = torch.cat((Consistency, *pVs),axis=1) 69 | 70 | agg = lambda t: torch.mean(torch.prod(t,axis=1)) 71 | 72 | if debug:#all(doX != -1*torch.ones_like(doX)):#debug: 73 | import pdb; pdb.set_trace() 74 | 75 | ret = agg(pV) 76 | 77 | if torch.isnan(ret): 78 | import pdb; pdb.set_trace() 79 | 80 | return ret -------------------------------------------------------------------------------- /models/parameterizedSCM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import itertools 4 | from abc import ABCMeta, abstractmethod 5 | 6 | 7 | class ParameterizedSCM(): 8 | __metaclass__ = ABCMeta 9 | def __init__(self, adj): 10 | alpha = 'XYZWABCDEFGHIJKLMNOPQRSTUV' 11 | self.i2n = lambda x: alpha[x] 12 | self.V = [] 13 | self.S = {} 14 | self.U = {} 15 | self.graph = {} 16 | for V in range(len(adj)): 17 | pa_V = list(np.where(adj[:,V])[0]) # assumes binary adjacency with row causes column 18 | self.graph.update({V: pa_V}) 19 | V_name = self.i2n(V) 20 | self.V.append(V_name) 21 | U_V = lambda bs: torch.rand(bs,1) # uniform [0,1) 22 | # torch.bernoulli(.25 * torch.ones(bs,1)) # bernoulli {0,1} 23 | self.U.update({V_name: U_V}) 24 | self.topologicalSort() 25 | 26 | def print_graph(self): 27 | print('The NCM models the following graph:') 28 | for k in self.graph: 29 | print(f'{[self.i2n(x) for x in self.graph[k]]} --> {self.i2n(k)}') 30 | 31 | def indices_to_names(self, indices): 32 | return [self.i2n(x) for x in indices] 33 | 34 | # A recursive function used by topologicalSort 35 | def topologicalSortUtil(self, v, visited, stack): 36 | 37 | # Mark the current node as visited. 38 | visited[v] = True 39 | 40 | # Recur for all the vertices adjacent to this vertex 41 | for i in self.graph[v]: 42 | if visited[i] == False: 43 | self.topologicalSortUtil(i, visited, stack) 44 | 45 | # Push current vertex to stack which stores result 46 | stack.insert(0, v) 47 | 48 | # The function to do Topological Sort. It uses recursive 49 | # topologicalSortUtil() 50 | def topologicalSort(self): 51 | # Mark all the vertices as not visited 52 | visited = [False] * len(self.V) 53 | stack = [] 54 | 55 | # Call the recursive helper function to store Topological 56 | # Sort starting from all vertices one by one 57 | for i in range(len(self.V)): 58 | if visited[i] == False: 59 | self.topologicalSortUtil(i, visited, stack) 60 | 61 | self.topology = list(reversed(stack)) 62 | 63 | def compute_marginals(self, samples, doX=-1, Xi=-1, debug=False): 64 | pred_marginals = {} 65 | N = len(self.V) 66 | for ind_d in range(N): 67 | vals = [] 68 | for val in [0, 1]: 69 | domains = [[0, 1]] * (N - 1) 70 | domains.insert(ind_d, [val]) 71 | combinations = np.stack([x for x in itertools.product(*domains)]) 72 | p_comb = [] 73 | for ind, c in enumerate(combinations): 74 | # print(f'{ind}:\t{c}') 75 | c = torch.tensor(c,dtype=torch.float).unsqueeze(0) 76 | pC = self.forward(c, torch.tensor([doX]*samples).unsqueeze(1), Xi, samples, debug) 77 | # print(f'mean(p(c)) = {pC}') 78 | p_comb.append(pC) 79 | # print(f'Sum = {sum(p_comb)}\t [{p_comb}]') 80 | vals.append(sum(p_comb).item()) 81 | pred_marginals.update({ind_d: vals}) 82 | # print(f'Marginals =\n\t{pred_marginals}') 83 | if debug: 84 | import pdb; pdb.set_trace() 85 | return pred_marginals 86 | 87 | @abstractmethod 88 | def params(self): 89 | pass 90 | 91 | @abstractmethod 92 | def forward(self, v, doX, Xi, samples, debug=False): 93 | pass -------------------------------------------------------------------------------- /exp2.py: -------------------------------------------------------------------------------- 1 | # ATE Estimation and JSD 2 | 3 | from glob import glob 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | from scipy.spatial import distance 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | sns.set_theme(style="whitegrid") 11 | 12 | class D(dict): 13 | def __missing__(self, key): 14 | self[key] = D() 15 | return self[key] 16 | 17 | p_ncm = "./experiments/NCM/20211012-210344/*/*/Marginals.pkl" 18 | p_tncm = "./experiments/TNCM/20211012-210317_TNCM/*/*/Marginals.pkl" 19 | p_gt_ate = "./datasets/SCMs/*_ATE.csv" 20 | p_gt_l1 = "./datasets/SCMs/*_L1.csv" 21 | p_gt_l2_dX0 = "./datasets/SCMs/*_doX0.csv" 22 | p_gt_l2_dX1 = "./datasets/SCMs/*_doX1.csv" 23 | 24 | d = D() 25 | for ind, p in enumerate(glob(p_ncm) + glob(p_tncm)): 26 | with open(p, "rb") as f: 27 | M = pickle.load(f) 28 | ncm_type = "TNCM" if "TNCM" in p else "NCM" 29 | scm_type = p.split("/")[-3].split("SCM")[0] 30 | params = p.split("/")[-3].split("_")[1].split("p")[0] 31 | seed = p.split("/")[-2].split("-")[1] 32 | d[ncm_type][scm_type][params][seed] = M 33 | print(f' {ind+1}/{len(p_ncm)+len(p_tncm)} ', end="\r", flush=True) 34 | d_gt_ate = D() 35 | for p in glob(p_gt_ate): 36 | scm_type = p.split("SCM")[1].split("s/")[1] 37 | params = p.split("p")[0].split("_")[1] 38 | d_gt_ate[scm_type][params] = pd.read_csv(p).iloc[0,1] 39 | 40 | def get_gt_distr(p_gt): 41 | d_gt = D() 42 | for p in glob(p_gt): 43 | scm_type = p.split("SCM")[1].split("s/")[1] 44 | params = p.split("p")[0].split("_")[1] 45 | d_gt[scm_type][params] = np.array(pd.read_csv(p))[:,1:] 46 | return d_gt 47 | d_gt_l1 = get_gt_distr(p_gt_l1) 48 | d_gt_l2_dX0 = get_gt_distr(p_gt_l2_dX0) 49 | d_gt_l2_dX1 = get_gt_distr(p_gt_l2_dX1) 50 | 51 | 52 | 53 | # violin plots for ATE 54 | #tips = sns.load_dataset("tips") 55 | #ax = sns.violinplot(x=tips["total_bill"]) 56 | fig, axs = plt.subplots(1,2,figsize=(10,5)) 57 | d_ate = D() 58 | for ind, ncm_type in enumerate(d.keys()): 59 | df = pd.DataFrame(columns=["SCM", "ATE Error"]) 60 | err_means = [] 61 | for scm_type in d[ncm_type].keys(): 62 | errs = [] 63 | for params in d[ncm_type][scm_type].keys(): 64 | tncm_pY1dX0 = np.mean([d[ncm_type][scm_type][params][seed]['L2_doX0'][ncm_type][1][1] for seed in d[ncm_type][scm_type][params].keys()]) 65 | tncm_pY1dX1 = np.mean([d[ncm_type][scm_type][params][seed]['L2_doX1'][ncm_type][1][1] for seed in d[ncm_type][scm_type][params].keys()]) 66 | tncm_ate = tncm_pY1dX1 - tncm_pY1dX0 67 | gt_ate = d_gt_ate[scm_type][params] 68 | err_ate = abs(gt_ate - tncm_ate) 69 | errs.append([scm_type, err_ate]) 70 | err_means.append(np.mean([t[1] for t in errs])) 71 | df = df.append(pd.DataFrame(errs, columns=["SCM", "ATE Error"])) 72 | d_ate[ncm_type] = df 73 | 74 | sns.violinplot(x="SCM", y="ATE Error", data=df, cut=0, ax=axs[ind]) 75 | maxes = np.array([np.max(g) for g in err_means]) 76 | minofmax = np.min(maxes) 77 | secondmaxofmax = np.partition(maxes.flatten(), -2)[-2] 78 | axs[0].set_ylim(0., .1) 79 | axs[1].set_ylim(0., .1) 80 | plt.show() 81 | 82 | 83 | # JSD table 84 | #distance.jensenshannon([1.0, 0.0, 0.0], [0.0, 1.0, 0.0], 2.0) 85 | d_distr = D() 86 | distr_cols = ["SCM", "JSD L1", "JSD L2 do(X=0)", "JSD L2 do(X=1)"] 87 | for ind, ncm_type in enumerate(d.keys()): 88 | df = pd.DataFrame(columns=distr_cols) 89 | for scm_type in d[ncm_type].keys(): 90 | errs = [] 91 | for params in d[ncm_type][scm_type].keys(): 92 | l1 = [d[ncm_type][scm_type][params][seed]['L1'][ncm_type] for seed in d[ncm_type][scm_type][params].keys()] 93 | l1 = np.hstack((np.mean(np.array(pd.DataFrame(l1).applymap(lambda x: x[0])),axis=0),np.mean(np.array(pd.DataFrame(l1).applymap(lambda x: x[1])),axis=0))) # prob vector with x1=0, x2=0...x1=1,...xn=1 94 | l2_dX0 = [d[ncm_type][scm_type][params][seed]['L2_doX0'][ncm_type] for seed in d[ncm_type][scm_type][params].keys()] 95 | l2_dX0 = np.hstack((np.mean(np.array(pd.DataFrame(l2_dX0).applymap(lambda x: x[0])),axis=0),np.mean(np.array(pd.DataFrame(l2_dX0).applymap(lambda x: x[1])),axis=0))) 96 | l2_dX1 = [d[ncm_type][scm_type][params][seed]['L2_doX1'][ncm_type] for seed in d[ncm_type][scm_type][params].keys()] 97 | l2_dX1 = np.hstack((np.mean(np.array(pd.DataFrame(l2_dX1).applymap(lambda x: x[0])),axis=0),np.mean(np.array(pd.DataFrame(l2_dX1).applymap(lambda x: x[1])),axis=0))) 98 | 99 | gt_l1 = np.mean(d_gt_l1[scm_type][params],axis=0) 100 | gt_l1 = np.hstack((1-gt_l1, gt_l1)) 101 | gt_l2_dX0 = np.mean(d_gt_l2_dX0[scm_type][params],axis=0) 102 | gt_l2_dX0 = np.hstack((1-gt_l2_dX0, gt_l2_dX0)) 103 | gt_l2_dX1 = np.mean(d_gt_l2_dX1[scm_type][params],axis=0) 104 | gt_l2_dX1 = np.hstack((1-gt_l2_dX1, gt_l2_dX1)) 105 | 106 | err_distr_l1 = distance.jensenshannon(l1, gt_l1, 2.0) 107 | err_distr_l2_dX0 = distance.jensenshannon(l2_dX0, gt_l2_dX0, 2.0) 108 | err_distr_l2_dX1 = distance.jensenshannon(l2_dX1, gt_l2_dX1, 2.0) 109 | errs.append([scm_type, err_distr_l1, err_distr_l2_dX0, err_distr_l2_dX1]) 110 | df = df.append(pd.DataFrame(errs, columns=distr_cols)) 111 | d_distr[ncm_type] = df 112 | for ncm_type in d.keys(): 113 | for scm_type in d[ncm_type].keys(): 114 | print(f"NCM: {ncm_type}\t SCM: {scm_type}\t\nMeans:\n{np.mean(d_distr[ncm_type].loc[d_distr[ncm_type]['SCM'] == scm_type])}") -------------------------------------------------------------------------------- /models/scm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from abc import ABCMeta, abstractmethod 4 | 5 | class SCM(): 6 | __metaclass__ = ABCMeta 7 | def __init__(self, U_params=None): 8 | if U_params is None: 9 | U_params = np.round(np.random.uniform(0.1, 0.9, 4),decimals=1) 10 | print(f'U params, [U_X, U_Y, U_Z, U_W]={U_params}') 11 | self.U_X = lambda s: np.random.binomial(1,U_params[0],s) 12 | self.U_Y = lambda s: np.random.binomial(1,U_params[1],s) 13 | self.U_Z = lambda s: np.random.binomial(1,U_params[2],s) 14 | self.U_W = lambda s: np.random.binomial(1,U_params[3],s) 15 | self.X = self.Y = self.Z = self.W = None 16 | self.x = self.y = self.z = self.w = None 17 | self.l1 = self.l2 = None 18 | self.U_params = U_params 19 | 20 | @abstractmethod 21 | def topological_computation(self, n_samples, doX=None): 22 | ''' To override, return df with x,y,z,w''' 23 | pass 24 | 25 | def sample(self, n_samples, doX=None): 26 | df = self.topological_computation(n_samples, doX) 27 | if doX is not None: 28 | self.l2 = df 29 | print(f'Generated Interventional Data (L2) with {n_samples} samples where do(X={doX}).') 30 | else: 31 | self.l1 = df 32 | print(f'Generated Observational Data (L1) with {n_samples} samples.') 33 | 34 | def ate(self, n_samples, debug=False): 35 | # p(Y=1|do(X=1)) - p(Y=1|do(X=0)) 36 | self.sample(n_samples, doX=np.ones_like(n_samples)) 37 | l2x1 = self.l2 38 | self.sample(n_samples, doX=np.zeros_like(n_samples)) 39 | l2x0 = self.l2 40 | w = 1/n_samples 41 | ate = l2x1['y'].sum()*w - l2x0['y'].sum()*w 42 | print(f'Computed empirical ATE p(Y=1|do(X=1)) - p(Y=1|do(X=0)) = {ate:.2f} for {n_samples} samples.') 43 | self.l2 = None 44 | if debug: 45 | import pdb; pdb.set_trace() 46 | return ate, l2x1, l2x0 47 | 48 | class ChainSCM(SCM): 49 | # X -> Y -> Z -> W 50 | def __init__(self, U_params=None): 51 | super(ChainSCM, self).__init__(U_params) 52 | self.name = "ChainSCM" 53 | self.X = lambda u: u 54 | self.Y = lambda u, x: np.logical_and(u, x).astype(np.int64) 55 | self.Z = lambda u, y: np.logical_and(u, y).astype(np.int64) 56 | self.W = lambda u, z: np.logical_and(u, z).astype(np.int64) 57 | self.adj = np.array([[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1],[0, 0, 0, 0]],dtype=float) 58 | 59 | def topological_computation(self, n_samples, doX=None): 60 | if doX is not None: 61 | self.x = doX 62 | else: 63 | self.x = self.X(self.U_X(n_samples)) 64 | self.y = self.Y(self.U_Y(n_samples), self.x) 65 | self.z = self.Z(self.U_Z(n_samples), self.y) 66 | self.w = self.W(self.U_W(n_samples), self.z) 67 | df = pd.DataFrame.from_dict({'x': self.x, 'y': self.y, 'z': self.z, 'w': self.w}) 68 | return df 69 | 70 | class ColliderSCM(SCM): 71 | # Y -> Z <- X <- W 72 | def __init__(self, U_params=None): 73 | super(ColliderSCM, self).__init__(U_params) 74 | self.name = "ColliderSCM" 75 | self.X = lambda u, w: np.logical_and(u, w).astype(np.int64) 76 | self.Y = lambda u: u 77 | self.Z = lambda u, y, x: np.logical_or(np.logical_and(u, y), x).astype(np.int64) 78 | self.W = lambda u: u 79 | self.adj = np.array([[0, 0, 1, 0],[0, 0, 1, 0],[0, 0, 0, 0],[1, 0, 0, 0]],dtype=float) 80 | 81 | def topological_computation(self, n_samples, doX=None): 82 | self.w = self.W(self.U_W(n_samples)) 83 | if doX is not None: 84 | self.x = doX 85 | else: 86 | self.x = self.X(self.U_X(n_samples), self.w) 87 | self.y = self.Y(self.U_Y(n_samples)) 88 | self.z = self.Z(self.U_Z(n_samples), self.y, self.x) 89 | df = pd.DataFrame.from_dict({'x': self.x, 'y': self.y, 'z': self.z, 'w': self.w}) 90 | return df 91 | 92 | class ConfounderSCM(SCM): 93 | # Y <- Z -> X -> {W,Y} 94 | def __init__(self, U_params=None): 95 | super(ConfounderSCM, self).__init__(U_params) 96 | self.name = "ConfounderSCM" 97 | self.X = lambda u, z: np.logical_xor(u, z).astype(np.int64) # this right one violates positivity assumption, since Z=1 implies X=1 ALWAYS! np.logical_or(u, z).astype(np.int64) 98 | self.Y = lambda u, z, x: np.logical_xor(np.logical_and(u, x), np.logical_and(u, z)).astype(np.int64) 99 | self.Z = lambda u: u 100 | self.W = lambda u, x: np.logical_and(u, x).astype(np.int64) 101 | self.adj = np.array([[0, 1, 0, 1],[0, 0, 0, 0],[1, 1, 0, 0],[0, 0, 0, 0]],dtype=float) 102 | 103 | def topological_computation(self, n_samples, doX=None): 104 | self.z = self.Z(self.U_Z(n_samples)) 105 | if doX is not None: 106 | self.x = doX 107 | else: 108 | self.x = self.X(self.U_X(n_samples), self.z) 109 | self.y = self.Y(self.U_Y(n_samples), self.z, self.x) 110 | self.w = self.W(self.U_W(n_samples), self.x) 111 | df = pd.DataFrame.from_dict({'x': self.x, 'y': self.y, 'z': self.z, 'w': self.w}) 112 | return df 113 | 114 | class BackdoorSCM(SCM): 115 | # Y <- X <- Z -> W -> Y 116 | def __init__(self, U_params=None): 117 | super(BackdoorSCM, self).__init__(U_params) 118 | self.name = "BackdoorSCM" 119 | self.Z = lambda u: u 120 | self.X = lambda u, z: np.logical_xor(u, z).astype(np.int64) 121 | self.W = lambda u, z: np.logical_and(u, z).astype(np.int64) 122 | self.Y = lambda u, w, x: np.logical_and(np.logical_and(u, w), x).astype(np.int64) 123 | self.adj = np.array([[0, 1, 0, 0],[0, 0, 0, 0],[1, 0, 0, 1],[0, 1, 0, 0]],dtype=float) 124 | 125 | def topological_computation(self, n_samples, doX=None): 126 | self.z = self.Z(self.U_Z(n_samples)) 127 | if doX is not None: 128 | self.x = doX 129 | else: 130 | self.x = self.X(self.U_X(n_samples), self.z) 131 | self.w = self.W(self.U_W(n_samples), self.z) 132 | self.y = self.Y(self.U_Y(n_samples), self.w, self.x) 133 | df = pd.DataFrame.from_dict({'x': self.x, 'y': self.y, 'z': self.z, 'w': self.w}) 134 | return df -------------------------------------------------------------------------------- /models/tncm.py: -------------------------------------------------------------------------------- 1 | import torch, itertools 2 | import numpy as np 3 | from spn.algorithms.layerwise.layers import Product, Sum 4 | from spn.experiments.RandomSPNs_layerwise.distributions import RatNormal 5 | from spn.algorithms.layerwise.utils import SamplingContext 6 | from .parameterizedSCM import ParameterizedSCM 7 | #from models.EinsumNetwork import Graph, EinsumNetwork 8 | 9 | class SPN(torch.nn.Module): 10 | # layered sum-product network 11 | def __init__(self, D, C, K, C2=None): 12 | super().__init__() 13 | 14 | assert D%K == 0 15 | self.C2 = C2 16 | 17 | # Normal leaf layer, output shape: [N=?, D=2, C=5, R=1] 18 | self.leaf1 = RatNormal(in_features=D, out_channels=C) 19 | 20 | # Product layer, output shape: [N=?, D=1, C=5, R=1] 21 | self.p1 = Product(in_features=D, cardinality=K) 22 | 23 | # Sum layer, root node, output shape: [N=?, D=1, C=1, R=1] 24 | self.s1 = Sum(in_channels=C, in_features=int(D/K), out_channels=1) 25 | 26 | if C2 is not None: 27 | self.s1 = Sum(in_channels=C, in_features=int(D/K), out_channels=C2) 28 | self.leaf2 = RatNormal(in_features=D, out_channels=C) 29 | self.p2 = Product(in_features=D, cardinality=K) 30 | self.s2 = Sum(in_channels=C, in_features=int(D/K), out_channels=C2) 31 | self.pc = Product(in_features=2, cardinality=2) 32 | self.sc = Sum(in_channels=C2, in_features=1, out_channels=1) 33 | 34 | #print(f'Layer Sizes: {str([self.leaf, self.p, self.s])}') 35 | 36 | def forward(self, x): 37 | # Forward bottom up 38 | if self.C2 is None: 39 | x = self.leaf1(x) 40 | x = self.p1(x) 41 | xc = self.s1(x) 42 | else: 43 | x1 = self.leaf1(x) 44 | x2 = self.leaf2(x) 45 | x1 = self.p1(x1) 46 | x2 = self.p2(x2) 47 | x1 = self.s1(x1) 48 | x2 = self.s2(x2) 49 | xc = torch.cat((x1,x2), axis=1) 50 | xc = self.pc(xc) 51 | xc = self.sc(xc) 52 | 53 | return xc 54 | 55 | def sample(self, n=100): 56 | # Sample top down 57 | ctx = self.s.sample(n=n, context=SamplingContext(n=n)) 58 | ctx = self.p.sample(context=ctx) 59 | samples = self.leaf.sample(context=ctx) 60 | return samples 61 | 62 | class TNCM(ParameterizedSCM): 63 | """ 64 | Tractable NCM, based on SPN for tractable (fast/efficient) inference. 65 | """ 66 | 67 | def __init__(self, adj, spn_type="EinSum", scale=False): 68 | super(TNCM, self).__init__(adj) 69 | self.spn_type = spn_type 70 | for V in self.graph: 71 | V_name = self.i2n(V) 72 | pa_V = self.graph[V] 73 | #print(f'Variable {V_name} := ') 74 | 75 | # if self.spn_type == "EinSum": 76 | # einet = EinsumNetwork.EinsumNetwork( 77 | # Graph.random_binary_trees(num_var=len(pa_V) + 2, depth=1, 78 | # num_repetitions=3), 79 | # EinsumNetwork.Args( 80 | # num_classes=1, 81 | # num_input_distributions=5, 82 | # exponential_family=EinsumNetwork.CategoricalArray, 83 | # exponential_family_args={'K': 2}, 84 | # num_sums=2, 85 | # num_var=len(pa_V) + 2, 86 | # online_em_frequency=1, 87 | # online_em_stepsize=0.05) 88 | # ) 89 | # einet.initialize() 90 | # model = einet#lambda x: EinsumNetwork.log_likelihoods(einet.forward(x)) 91 | # else: 92 | C = 30#len(pa_V)+1+scale #int((len(pa_V)+1)+10) if scale else 30 93 | model = SPN(D=len(pa_V)+1, C=C, K=len(pa_V)+1) 94 | 95 | self.S.update({V_name: model}) 96 | 97 | def params(self): 98 | return [item for sublist in [list(f.parameters()) for f in self.S.values()] for item in sublist] 99 | 100 | def forward(self, v, doX, Xi, samples, debug=False): 101 | 102 | # eine Perle der Codegeschichte - consistency check pure torch 103 | Consistency = torch.ones((samples, 1)) 104 | Consistency = torch.where(doX >= 0, 105 | torch.where(v[:,0] == doX, 1., 0.), Consistency) 106 | 107 | pVs = [] 108 | for V in self.topology: 109 | 110 | pa_V = self.graph[V] 111 | 112 | V_arg = torch.cat((*[torch.tile(v[:,pa].unsqueeze(1),(samples,1)) for pa in pa_V],self.U[self.i2n(V)](samples)),axis=1) 113 | 114 | pV = torch.minimum(torch.exp(self.S[self.i2n(V)](V_arg)).reshape(samples, 1), torch.ones(samples, 1)) 115 | pV = pV * v[:, V].unsqueeze(1) + (torch.ones_like(v[:, V]) - v[:, V]).unsqueeze(1) * (1 - pV) 116 | 117 | # debug, if values are not bounded 118 | # if any(pV > 1): 119 | # import pdb; pdb.set_trace() 120 | 121 | # the intervention checking might do an extra run for the intervened node, not rly important tho 122 | # furthermore, it is not validated yet - and rn we mainly check for interventions on X, where we change 123 | # X,Y relationships 124 | pV = torch.where(torch.tensor(Xi) == torch.tensor(V), torch.where(doX >= 0, torch.ones((samples, 1)), pV), pV) 125 | 126 | pVs.append(pV) 127 | 128 | pV = torch.cat((Consistency, *pVs),axis=1) 129 | 130 | agg = lambda t: torch.mean(torch.prod(t,axis=1)) 131 | 132 | if debug:#all(doX != -1*torch.ones_like(doX)):#debug: 133 | import pdb; pdb.set_trace() 134 | 135 | ret = agg(pV) 136 | 137 | if torch.isnan(ret): 138 | import pdb; pdb.set_trace() 139 | 140 | return ret 141 | 142 | def compute_marginals(self, samples, doX=-1, Xi=-1, debug=False): 143 | pred_marginals = {} 144 | N = len(self.V) 145 | for ind_d in range(N): 146 | vals = [] 147 | for val in [0, 1]: 148 | domains = [[0, 1]] * (N - 1) 149 | domains.insert(ind_d, [val]) 150 | combinations = np.stack([x for x in itertools.product(*domains)]) 151 | p_comb = [] 152 | for ind, c in enumerate(combinations): 153 | # print(f'{ind}:\t{c}') 154 | c = torch.tensor(c,dtype=torch.float).unsqueeze(0) 155 | #c = torch.tile(c, (samples, 1)) 156 | pC = self.forward(c, torch.tensor([doX]*samples).unsqueeze(1), Xi, samples, debug) 157 | # print(f'mean(p(c)) = {pC}') 158 | p_comb.append(pC) 159 | # print(f'Sum = {sum(p_comb)}\t [{p_comb}]') 160 | vals.append(sum(p_comb).item()) 161 | pred_marginals.update({ind_d: vals}) 162 | # print(f'Marginals =\n\t{pred_marginals}') 163 | if debug: 164 | import pdb; pdb.set_trace() 165 | return pred_marginals -------------------------------------------------------------------------------- /ncm-est.py: -------------------------------------------------------------------------------- 1 | from models.ncm import NCM 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import time, os, sys 6 | import itertools 7 | import matplotlib.pyplot as plt 8 | import random 9 | from aux.helper import Logger 10 | from datetime import datetime 11 | import pickle 12 | 13 | class Config(): 14 | def __init__(self): 15 | self.max_epochs = 1#3000#0 16 | self.batch_size = 1 17 | self.loss_running_window = 1000 18 | self.seeds = [0, 4, 304, 606] 19 | self.samples = 1000 20 | self.loss_int_weight = 0. 21 | self.loss_int_weight_decay = False 22 | 23 | self.load_model = False 24 | self.save_viz = True 25 | self.animate = False 26 | self.dir_exp = "./experiments/NCM/" 27 | 28 | 29 | cfg = Config() 30 | 31 | def cycle(iterable): 32 | while True: 33 | for x in iterable: 34 | yield x 35 | 36 | def compute_gt_marginals(meta, distr='L1'): 37 | gt_marginals = {} 38 | for i in range(meta[distr].shape[1]): 39 | p1 = (meta[distr][:,i].sum() / meta[distr].shape[0]).item() 40 | gt_marginals.update({i: (1-p1, p1)}) 41 | return gt_marginals 42 | 43 | def load_dataset(seed, p, data_dir="./datasets/SCMs/"): 44 | desc_m, U_params_str, N_str = p.split("_") 45 | U_params = [float(x) for x in U_params_str[1:-2].split(" ")] 46 | N = int(N_str[1:]) 47 | p = os.path.join(data_dir, p) 48 | 49 | adj = np.array(pd.read_csv(p+'_adj.csv'))[:,1:] 50 | ATE = np.array(pd.read_csv(p+'_ATE.csv'))[:,1:][0][0] 51 | L1 = torch.tensor(np.array(pd.read_csv(p+'_L1.csv'))[:,1:],dtype=torch.float) 52 | doX0 = torch.tensor(np.array(pd.read_csv(p+'_doX0.csv'))[:,1:],dtype=torch.float) 53 | doX1 = torch.tensor(np.array(pd.read_csv(p+'_doX1.csv'))[:,1:],dtype=torch.float) 54 | 55 | n_train = int(len(L1) - 0.2 * len(L1)) 56 | n_other = int(len(L1) - n_train) 57 | 58 | def seed_worker(worker_id): 59 | worker_seed = torch.initial_seed() % 2 ** 32 60 | np.random.seed(worker_seed) 61 | random.seed(worker_seed) 62 | 63 | g = torch.Generator() 64 | g.manual_seed(seed) 65 | train_data = torch.utils.data.DataLoader(L1, batch_size=cfg.batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g) 66 | 67 | meta = {'Model': desc_m, 'N': N, 'U_Params': U_params, 'adj': adj, 'ATE': ATE, 'L1': L1, 'L2_doX0': doX0, 'L2_doX1': doX1} 68 | return meta, train_data 69 | 70 | def plot_marginals(pred, gt, seed, figsize=(13,9), running_losses=None, animate=False, save=None): 71 | plt.clf() 72 | N = len(gt) 73 | s = 3 if running_losses is not None else 2 74 | fig, axs = plt.subplots(s,N, figsize=figsize) 75 | for ind, a in enumerate(axs.flatten()): 76 | if running_losses is not None: 77 | if ind >= 2*N: 78 | for t in running_losses: 79 | a.plot(range(len(t[0])), t[0], label=t[1]) 80 | a.legend() 81 | break 82 | if ind >= N: 83 | marginals = gt 84 | color = 'black' 85 | else: 86 | marginals = pred 87 | color = 'blue' 88 | a.bar([0,1], marginals[ind % N], facecolor=color) 89 | a.set_title(f'Var:{ind} (p0={marginals[ind % N][0]:.2f},p1={marginals[ind % N][1]:.2f})') 90 | a.set_xlim(0,1) 91 | a.set_xticks([0,1]) 92 | a.set_ylim(0,1) 93 | plt.suptitle(f'NCM (Seed {seed})') 94 | plt.tight_layout() 95 | if animate: 96 | plt.pause(0.001) 97 | #input("Press [enter] to continue.") 98 | elif save is not None: 99 | plt.savefig(os.path.join(save[0], f'Viz-seed-{seed}{save[1]}.png')) 100 | else: 101 | plt.show() 102 | 103 | 104 | 105 | ps = [ 106 | 107 | "ChainSCM_[0.1 0.8 0.7 0.8]p_N10000", 108 | "ChainSCM_[0.4 0.6 0.5 0.8]p_N10000", 109 | "ChainSCM_[0.5 0.7 0.6 0.5]p_N10000", 110 | 111 | "BackdoorSCM_[0.1 0.8 0.7 0.8]p_N10000", 112 | "BackdoorSCM_[0.4 0.6 0.5 0.8]p_N10000", 113 | "BackdoorSCM_[0.5 0.7 0.6 0.5]p_N10000", 114 | 115 | "ColliderSCM_[0.1 0.8 0.7 0.8]p_N10000", 116 | "ColliderSCM_[0.4 0.6 0.5 0.8]p_N10000", 117 | "ColliderSCM_[0.5 0.7 0.6 0.5]p_N10000", 118 | 119 | "ConfounderSCM_[0.1 0.8 0.7 0.8]p_N10000", 120 | "ConfounderSCM_[0.4 0.6 0.5 0.8]p_N10000", 121 | "ConfounderSCM_[0.5 0.7 0.6 0.5]p_N10000", 122 | 123 | ] 124 | cfg.dir_exp = os.path.join(cfg.dir_exp, datetime.now().strftime("%Y%m%d-%H%M%S")) 125 | for p in ps: 126 | 127 | scm_model = p.split("SCM_")[0] 128 | dir_exp = os.path.join(cfg.dir_exp, p) 129 | 130 | for seed in cfg.seeds: 131 | np.random.seed(seed) 132 | torch.manual_seed(seed) 133 | random.seed(seed) 134 | 135 | meta, train_data = load_dataset(seed, p) 136 | 137 | ncm = NCM(adj=meta['adj']) 138 | 139 | ncm.print_graph() 140 | print(f'Topological Order: {ncm.indices_to_names(ncm.topology)}') 141 | 142 | gt_marginals = compute_gt_marginals(meta, 'L1') 143 | 144 | if not cfg.load_model: 145 | 146 | dir_seed = os.path.join(dir_exp, f"seed-{seed}") 147 | dir_seed_models = os.path.join(dir_seed, 'models') 148 | txt_meta = os.path.join(dir_seed, f'Experiment-Log-seed-{seed}.txt') 149 | if not os.path.exists(dir_seed_models): 150 | os.makedirs(dir_seed_models) 151 | sys.stdout = Logger(txt_meta) 152 | 153 | print(f'Config = {cfg.__dict__}\n') 154 | print(f'Data Set Size = {len(train_data)}') 155 | print(f'Meta Data; Model {meta["Model"]}\t U_params {meta["U_Params"]}\t GT ATE {meta["ATE"]}') 156 | 157 | optimizer = torch.optim.Adam(ncm.params()) 158 | 159 | train_ds = cycle(train_data) 160 | best_loss_tr = 100000 161 | loss_int_weight = cfg.loss_int_weight 162 | 163 | plt.clf() 164 | plt.close() 165 | animate = cfg.animate 166 | 167 | loss_per_step = [] 168 | loss_components_per_step = [] 169 | running_losses_tr = [] 170 | running_losses_va = [] 171 | running_losses_components_tr = [] 172 | running_losses_components_va = [] 173 | for step in range(cfg.max_epochs): # epochs now 174 | t0 = time.time() 175 | 176 | for i in range(len(train_data)): 177 | 178 | datapoint = next(train_ds) 179 | for V in ncm.V: 180 | ncm.S[V].zero_grad() 181 | 182 | pV_est = ncm.forward(datapoint, samples=cfg.samples, doX=torch.tensor([-1.]*cfg.samples).unsqueeze(1), Xi=-1) # -1 means no intervention, clean up!! 183 | # this is the special case where y:=v and thus we only need to count how often we have the v in the data 184 | reps = len(torch.where((train_data.dataset == datapoint).all(axis=1))[0]) 185 | pV_int_est = ncm.forward(datapoint, samples=cfg.samples, doX=torch.tensor([0.] * cfg.samples).unsqueeze(1), Xi=0) # TODO: clean the doX interface 186 | pV_int_mass = reps * pV_int_est 187 | pV_int_mass = torch.where(pV_int_mass == 0., torch.tensor(0.)+1e-6, pV_int_mass) 188 | 189 | nll_l1 = -1 * torch.log(pV_est) 190 | nll_l2 = -1 * torch.log(pV_int_mass) 191 | if torch.isnan(nll_l1) or torch.isnan(nll_l2): 192 | import pdb; pdb.set_trace() 193 | loss = nll_l1 + loss_int_weight * nll_l2 194 | #print(f"Reps: {reps}\t Weight: {loss_int_weight}") 195 | loss.backward() 196 | optimizer.step() 197 | 198 | if cfg.loss_int_weight_decay: 199 | loss_int_weight = 1/np.log(step + 3) # 1/log(3) = .9 200 | # not introducing stopping, since it probably won't matter 201 | if loss_int_weight < .001: 202 | loss_int_weight = .001 203 | 204 | loss_per_step.append(loss.item()) 205 | loss_components_per_step.append((nll_l1.item(), nll_l2.item())) 206 | 207 | if i % cfg.loss_running_window == 0: 208 | t1 = time.time() 209 | 210 | running_loss_tr = np.mean(loss_per_step[-cfg.loss_running_window:]) 211 | running_loss_components_tr = np.mean(np.array(loss_components_per_step)[-cfg.loss_running_window:,:],axis=0) 212 | running_losses_components_tr.append((running_loss_components_tr[0],running_loss_components_tr[1])) 213 | 214 | if running_loss_tr < best_loss_tr: 215 | # uncomment this for non-valid # best_valid_elbo = train_elbo 216 | best_loss_tr = running_loss_tr 217 | #suffix = '' 218 | suffix = f' [Saved Model.]' 219 | states = { 220 | "MLPs": [s.state_dict() for s in ncm.S.values()], 221 | } 222 | best_model_name = f'NCM-seed-{seed}-ep-{step}-i-{i}' 223 | torch.save(states, os.path.join(dir_seed_models, best_model_name)) 224 | else: 225 | suffix = "" 226 | 227 | print( 228 | f"Epoch {step:<10d}/{cfg.max_epochs} I {i}/{len(train_data)}\t" 229 | f"Train NLL: {running_loss_tr:<5.3f} (L1 {running_loss_components_tr[0]:<5.3f}, L2 {running_loss_components_tr[1]:<5.3f})\t" 230 | f"" + suffix 231 | ) 232 | t0 = t1 233 | 234 | checkpoint = torch.load(os.path.join(dir_seed_models,best_model_name)) 235 | for ind, s in enumerate(ncm.S.values()): 236 | s.load_state_dict(checkpoint["MLPs"][ind]) 237 | print(f'Loaded best model.') 238 | 239 | marginals = {'Model': os.path.join(dir_seed_models,best_model_name)} 240 | 241 | pred_marginals = ncm.compute_marginals(samples=cfg.samples) 242 | plot_marginals(pred_marginals, gt_marginals, seed, save=(dir_seed, '-best') if cfg.save_viz or not cfg.load_model else None) 243 | marginals.update({'L1': {'GT': gt_marginals, 'NCM': pred_marginals}}) 244 | print(f'Computed Marginals for L1 distribution.') 245 | 246 | # intervention 247 | pred_marginals = ncm.compute_marginals(samples=cfg.samples, doX=1., Xi=0) 248 | plot_marginals(pred_marginals, compute_gt_marginals(meta, 'L2_doX1'), seed, save=(dir_seed, '-doX1') if cfg.save_viz or not cfg.load_model else None) 249 | marginals.update({'L2_doX1': {'GT': gt_marginals, 'NCM': pred_marginals}}) 250 | print(f'Computed Marginals for do(X=1).') 251 | 252 | # intervention 253 | pred_marginals = ncm.compute_marginals(samples=cfg.samples, doX=0., Xi=0) 254 | plot_marginals(pred_marginals, compute_gt_marginals(meta, 'L2_doX0'), seed, save=(dir_seed, '-doX0') if cfg.save_viz or not cfg.load_model else None) 255 | marginals.update({'L2_doX0': {'GT': gt_marginals, 'NCM': pred_marginals}}) 256 | print(f'Computed Marginals for do(X=0).') 257 | 258 | with open(os.path.join(dir_seed, "Marginals.pkl"), "wb") as f: 259 | pickle.dump(marginals, f) 260 | print('Saved Marginals.') -------------------------------------------------------------------------------- /exp1.py: -------------------------------------------------------------------------------- 1 | # Experiment on Tractability 2 | 3 | from models.ncm import NCM 4 | from models.tncm import TNCM 5 | from models.ncm import MLP 6 | from models.tncm import SPN 7 | import numpy as np 8 | import torch 9 | import time, itertools 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | 13 | 14 | def create_chain_adj(N): 15 | adj = np.zeros((N,N)) 16 | for i in range(N-1): 17 | adj[i,i+1] = 1. 18 | return adj 19 | 20 | def create_chain_adj_with_increasing_complexity(N): 21 | adj = np.zeros((N,N)) 22 | for i in range(N-1): 23 | adj[:i+1,i+1] = 1. 24 | return adj 25 | 26 | def normalize(ys): 27 | ymin = np.min(ys) 28 | ymax = np.max(ys) 29 | return [(y-ymin)/(ymax - ymin) for y in ys] 30 | 31 | def timed_marginal_inference(ncm, samples=1, doX=-1, Xi=-1, debug=False, restrict=False): 32 | t0 = time.time() 33 | N = len(ncm.V) 34 | ind_d = N-1 35 | vals = [] 36 | for val in [0, 1]: 37 | domains = [[0, 1]] * (N - 1) 38 | domains.insert(ind_d, [val]) 39 | combinations = np.stack([x for x in itertools.product(*domains)]) 40 | p_comb = [] 41 | for ind, c in enumerate(combinations): 42 | c = torch.tensor(c, dtype=torch.float).unsqueeze(0) 43 | pC = ncm.forward(c, torch.tensor([doX] * samples).unsqueeze(1), Xi, samples, debug) 44 | p_comb.append(pC) 45 | if restrict: 46 | break 47 | vals.append(sum(p_comb).item()) 48 | t = time.time() - t0 49 | return t, vals 50 | 51 | 52 | seed_list = [0]#, 4, 304]#, 606] 53 | 54 | exp1_data = {'NCM': {}, 'TNCM': {}} 55 | N_start = 3 56 | # N_end = 12 57 | # for N in range(N_start, N_end): 58 | # exp1_data['NCM'].update({N: {}}) 59 | # exp1_data['TNCM'].update({N: {}}) 60 | # for seed in seed_list: 61 | # adj = create_chain_adj(N) 62 | # ncm = NCM(adj=adj) 63 | # t_ncm, _ = timed_marginal_inference(ncm) 64 | # exp1_data['NCM'][N].update({seed: t_ncm}) 65 | # tncm = TNCM(adj=adj, spn_type="Regular") 66 | # t_tncm, _ = timed_marginal_inference(tncm) 67 | # exp1_data['TNCM'][N].update({seed: t_tncm}) 68 | # print(f'Exp1 {N+1}/{N_end} ', end="\r", flush=True) 69 | # 70 | # # visualize intractability of parameterized SCM 71 | # sns.reset_orig() 72 | # fig = plt.figure(figsize=(10,7)) 73 | # for ncm_type in ['NCM', 'TNCM']: 74 | # y = [np.mean(list(exp1_data[ncm_type][i].values())) for i in range(N_start, N_end)] 75 | # yerr = [np.std(list(exp1_data[ncm_type][i].values())) for i in range(N_start, N_end)] 76 | # x = range(N_start, N_end) 77 | # #plt.plot(x, y, label=ncm_type) 78 | # plt.errorbar(x, y, yerr=yerr, label=ncm_type) 79 | # plt.gca().set_yscale('log') 80 | # plt.legend() 81 | # plt.grid(True, which="both", ls="-") 82 | # plt.show() 83 | 84 | 85 | 86 | exp2_data = {'NCM': {}, 'TNCM': {}} 87 | # adj = create_chain_adj(N_start) 88 | # N_end = 300#1000 89 | # for N in range(N_start, N_end): 90 | # exp2_data['NCM'].update({N: {}}) 91 | # exp2_data['TNCM'].update({N: {}}) 92 | # for seed in seed_list: 93 | # #adj = create_chain_adj_with_increasing_complexity(N) 94 | # ncm = NCM(adj=adj, scale=N)#True) 95 | # t_ncm, _ = timed_marginal_inference(ncm, restrict=True) 96 | # exp2_data['NCM'][N].update({seed: t_ncm}) 97 | # tncm = TNCM(adj=adj, spn_type="Regular", scale=N)#True) 98 | # t_tncm, _ = timed_marginal_inference(tncm, restrict=True) 99 | # exp2_data['TNCM'][N].update({seed: t_tncm}) 100 | # print(f'Exp2 {N+1}/{N_end} ', end="\r", flush=True) 101 | 102 | # visualize the relativitity of the intractability 103 | # sns.set_theme() 104 | # fig = plt.figure(figsize=(10,7)) 105 | # ys = [] 106 | # for ncm_type in ['NCM', 'TNCM']: 107 | # y = np.array([np.mean(list(exp2_data[ncm_type][i].values())) for i in range(N_start, N_end)]) 108 | # ys.append(y) 109 | # yerr = np.array([np.std(list(exp2_data[ncm_type][i].values())) for i in range(N_start, N_end)]) 110 | # x = range(N_start, N_end) 111 | # p = plt.plot(x, y, label=ncm_type) 112 | # #plt.errorbar(x, y, yerr=yerr, label=ncm_type) 113 | # pc = p[0].get_color() 114 | # plt.fill_between(x, y - yerr, y + yerr, alpha=0.5, facecolor=pc) 115 | # lin = lambda x: 0.00001 * np.array(x) 116 | # y = lin(x) 117 | # ys.append(y) 118 | # plt.plot(x, lin(x), color="black") 119 | # maxes = np.array([np.max(g) for g in ys]) 120 | # minofmax = np.min(maxes) 121 | # secondmaxofmax = np.partition(maxes.flatten(), -2)[-2] 122 | # cand = secondmaxofmax 123 | # plt.gca().set_ylim(0 - 0.05 * cand, cand + 0.05 * cand) 124 | # plt.gca().set_xlim(0 - 0.05 * len(x), len(x) + 0.05 * len(x)) 125 | # #plt.gca().set_aspect('equal', adjustable='datalim') 126 | # plt.gca().set_aspect(1.0/plt.gca().get_data_ratio(), adjustable='box') 127 | # plt.legend() 128 | # plt.grid(True, which="both", ls="-") 129 | # plt.show() 130 | 131 | 132 | 133 | 134 | 135 | 136 | # visualize the quadratic nature of MLP 137 | N_start = 3 138 | N_end = 500 139 | window = 30 140 | seed_list = [0, 4, 304]#, 606] 141 | exp2_data.update({'MLP': {}}) 142 | ts = [] 143 | for N in range(N_start, N_end): 144 | exp2_data['MLP'].update({N: {}}) 145 | for seed in seed_list: 146 | adj = create_chain_adj_with_increasing_complexity(N) 147 | #ncm = NCM(adj=adj, scale=True) 148 | mlp = MLP(len(adj), 1, [10*len(adj) for _ in range(3)]) 149 | t0 = time.time() 150 | #ncm.forward(torch.ones(1,len(adj)), torch.tensor([-1] * 1).unsqueeze(1), -1, 1, False) 151 | mlp(torch.ones(1,len(adj))) 152 | t1 = time.time() - t0 153 | exp2_data['MLP'][N].update({seed: t1}) 154 | ts.append(t1) 155 | if len(ts) > window: 156 | avg_time = np.mean(ts[-window:]) 157 | else: 158 | avg_time = -1 159 | print(f'Exp MLP {N+1}/{N_end} {avg_time:.4f}', end="\r", flush=True) 160 | 161 | if False: 162 | sns.set_theme() 163 | ncm_type = "MLP" 164 | fig = plt.figure() 165 | ax = fig.add_subplot(111) 166 | y = [np.mean(list(exp2_data[ncm_type][i].values())) for i in range(N_start, N_end)] 167 | yerr = [np.std(list(exp2_data[ncm_type][i].values())) for i in range(N_start, N_end)] 168 | x = range(N_start, N_end) 169 | plt.plot(x, y, label=ncm_type) 170 | #lin = lambda x: np.min(y)*x 171 | #plt.plot(x, lin(x)) 172 | plt.legend() 173 | ax.set_aspect(1.0/ax.get_data_ratio(), adjustable='box') 174 | plt.show() 175 | 176 | # visualize the linear nature of SPN 177 | from models.EinsumNetwork import Graph, EinsumNetwork 178 | exp2_data.update({'SPN': {}}) 179 | ts = [] 180 | for N in range(N_start, N_end): 181 | exp2_data['SPN'].update({N: {}}) 182 | for seed in seed_list: 183 | adj = create_chain_adj_with_increasing_complexity(N) 184 | spn = SPN(D=len(adj), C=10+len(adj), K=len(adj)) # note the plus here opposed to times since the C internally squares! so with plus we have linear scaling of network size 185 | # spn = EinsumNetwork.EinsumNetwork( 186 | # Graph.random_binary_trees(num_var=len(adj), depth=1, 187 | # num_repetitions=3), 188 | # EinsumNetwork.Args( 189 | # num_classes=1, 190 | # num_input_distributions=5*len(adj), 191 | # exponential_family=EinsumNetwork.CategoricalArray, 192 | # exponential_family_args={'K': 2}, 193 | # num_sums=2*len(adj), 194 | # num_var=len(adj), 195 | # online_em_frequency=1, 196 | # online_em_stepsize=0.05) 197 | # ) 198 | # spn.initialize() 199 | t0 = time.time() 200 | spn(torch.ones(1,len(adj))) 201 | t1 = time.time() - t0 202 | exp2_data['SPN'][N].update({seed: t1}) 203 | ts.append(t1) 204 | if len(ts) > window: 205 | avg_time = np.mean(ts[-window:]) 206 | else: 207 | avg_time = -1 208 | print(f'Exp SPN {N+1}/{N_end} {avg_time:.4f}', end="\r", flush=True) 209 | 210 | if False: 211 | sns.set_theme() 212 | ncm_type = "SPN" 213 | fig = plt.figure() 214 | ax = fig.add_subplot(111) 215 | y = [np.mean(list(exp2_data[ncm_type][i].values())) for i in range(N_start, N_end)] 216 | yerr = [np.std(list(exp2_data[ncm_type][i].values())) for i in range(N_start, N_end)] 217 | x = range(N_start, N_end) 218 | plt.plot(x, y, label=ncm_type) 219 | #plt.errorbar(x, y, yerr=yerr, label=ncm_type) 220 | #lin = lambda x: np.min(y)*x 221 | #plt.plot(x, lin(x)) 222 | plt.legend() 223 | ax.set_aspect(1.0/ax.get_data_ratio(), adjustable='box') 224 | plt.show() 225 | 226 | 227 | types = ['MLP', 'SPN', 'Lin'] 228 | colors = ['blue', 'orange', 'black'] 229 | for type in types[:-1]: 230 | y = [np.mean(list(exp2_data[type][i].values())) for i in range(N_start, N_end)] 231 | yerr = [np.std(list(exp2_data[type][i].values())) for i in range(N_start, N_end)] 232 | exp2_data[type].update({'Mean': y, 'Std': yerr}) 233 | sns.set_theme() 234 | fig,axs = plt.subplots(1,4, figsize=(18,6)) 235 | lin = lambda x: 0.00001 * np.array(x) #np.mean([exp2_data['MLP']['Mean'][0]/N_start,exp2_data['SPN']['Mean'][0]/N_start]) * x 236 | for ind, a in enumerate(axs): 237 | x = range(N_start, N_end) 238 | if ind == len(axs)-1: 239 | for i in range(3): 240 | #x = normalize(x) 241 | y = exp2_data[types[i]]['Mean'] if i != 2 else lin(x) 242 | std = np.array(exp2_data[types[i]]['Std']) if i != 2 else None 243 | # = normalize(y) 244 | p = a.plot(x, y, label=types[i], color=colors[i]) 245 | if i != 2: 246 | pc = p[0].get_color() 247 | a.fill_between(x, y-std, y+std, alpha=0.5, facecolor=pc) 248 | a.set_xlim(0, len(x)+.05*len(x)) 249 | maxes = np.array([np.max(g) for g in [exp2_data[t]['Mean'] for t in types[:-1]] + lin(x)]) 250 | minofmax = np.min(maxes) 251 | secondmaxofmax = np.partition(maxes.flatten(), -2)[-2] 252 | cand = secondmaxofmax 253 | #a.set_ylim(0 - .05*cand, cand+.05*cand) 254 | #a.set_xlim(0 - .05*len(x), len(x)+.05*len(x)) 255 | #a.set_aspect(1.0/a.get_data_ratio(), adjustable='box') 256 | a.set_aspect(1.0/a.get_data_ratio(), adjustable='box') 257 | elif ind == len(axs)-2: 258 | #x = normalize(x) 259 | a.plot(x, lin(x), label=types[ind], color=colors[ind]) 260 | #a.set_ylim(0 - .05*len(x), len(x)+.05*len(x)) 261 | #a.set_xlim(0 - .05*len(x), len(x)+.05*len(x)) 262 | #a.set_aspect(1.0/a.get_data_ratio(), adjustable='box') 263 | a.set_aspect(1.0/a.get_data_ratio(), adjustable='box') 264 | else: 265 | y = exp2_data[types[ind]]['Mean'] 266 | std = np.array(exp2_data[types[ind]]['Std']) 267 | #y = normalize(y) 268 | #x = normalize(x) 269 | p = a.plot(x, y, label=types[ind], color=colors[ind]) 270 | pc = p[0].get_color() 271 | a.fill_between(x, y-std, y+std, alpha=0.5, facecolor=pc) 272 | cand = np.max(y) 273 | #a.set_aspect('equal', adjustable='datalim') 274 | a.set_aspect(1.0/a.get_data_ratio(), adjustable='box') 275 | a.legend() 276 | #plt.tight_layout() 277 | plt.show() 278 | 279 | ind = 0 280 | a = plt.gca() 281 | y = exp2_data[types[ind]]['Mean'] 282 | std = np.array(exp2_data[types[ind]]['Std']) 283 | p = a.plot(x, y, label=types[ind], color=colors[ind]) 284 | pc = p[0].get_color() 285 | a.fill_between(x, y - std, y + std, alpha=0.5, facecolor=pc) 286 | cand = np.max(y) 287 | a.set_aspect(1.0 / a.get_data_ratio(), adjustable='box') 288 | #a.set_aspect('equal', adjustable='box') 289 | plt.show() 290 | 291 | 292 | # visualize base time complexity comparison 293 | # base_speed = 0.1 294 | # lin = lambda x: base_speed * x 295 | # qua = lambda x: base_speed * x**2 296 | # cub = lambda x: base_speed * x**3 297 | # exp = lambda x: base_speed * 2 ** x 298 | # n = 10 299 | # x = np.arange(1,n) 300 | # names = ['Linear', 'Quadratic', 'Cubic', 'Exponential'] 301 | # for ind, f in enumerate([lin, qua, cub, exp]): 302 | # plt.plot(x, f(x), label=names[ind]) 303 | # #plt.gca().set_aspect('equal', adjustable='datalim') 304 | # plt.xlim(0,n) 305 | # plt.legend() 306 | # plt.show() -------------------------------------------------------------------------------- /tncm-est.py: -------------------------------------------------------------------------------- 1 | from models.tncm import TNCM 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import time, os, sys 6 | import itertools 7 | import matplotlib.pyplot as plt 8 | import random 9 | from aux.helper import Logger 10 | from datetime import datetime 11 | import pickle 12 | 13 | class Config(): 14 | def __init__(self): 15 | self.max_epochs = 3 16 | self.batch_size = 1 17 | self.loss_running_window = 1000 18 | self.seeds = [0, 4, 304, 606] 19 | self.samples = 200 20 | self.loss_int_weight = 0. 21 | self.loss_int_weight_decay = False 22 | 23 | self.load_model = None 24 | self.save_viz = True 25 | self.animate = False 26 | self.dir_exp = "./experiments/TNCM/" 27 | 28 | 29 | cfg = Config() 30 | 31 | def cycle(iterable): 32 | while True: 33 | for x in iterable: 34 | yield x 35 | 36 | def compute_gt_marginals(meta, distr='L1'): 37 | gt_marginals = {} 38 | for i in range(meta[distr].shape[1]): 39 | p1 = (meta[distr][:,i].sum() / meta[distr].shape[0]).item() 40 | gt_marginals.update({i: (1-p1, p1)}) 41 | return gt_marginals 42 | 43 | def load_dataset(seed, p, data_dir="./datasets/SCMs/"): 44 | desc_m, U_params_str, N_str = p.split("_") 45 | U_params = [float(x) for x in U_params_str[1:-2].split(" ")] 46 | N = int(N_str[1:]) 47 | p = os.path.join(data_dir, p) 48 | 49 | adj = np.array(pd.read_csv(p+'_adj.csv'))[:,1:] 50 | ATE = np.array(pd.read_csv(p+'_ATE.csv'))[:,1:][0][0] 51 | L1 = torch.tensor(np.array(pd.read_csv(p+'_L1.csv'))[:,1:],dtype=torch.float) 52 | doX0 = torch.tensor(np.array(pd.read_csv(p+'_doX0.csv'))[:,1:],dtype=torch.float) 53 | doX1 = torch.tensor(np.array(pd.read_csv(p+'_doX1.csv'))[:,1:],dtype=torch.float) 54 | 55 | n_train = int(len(L1) - 0.2 * len(L1)) 56 | n_other = int(len(L1) - n_train) 57 | 58 | def seed_worker(worker_id): 59 | worker_seed = torch.initial_seed() % 2 ** 32 60 | np.random.seed(worker_seed) 61 | random.seed(worker_seed) 62 | 63 | g = torch.Generator() 64 | g.manual_seed(seed) 65 | train_data = torch.utils.data.DataLoader(L1, batch_size=cfg.batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g) 66 | 67 | meta = {'Model': desc_m, 'N': N, 'U_Params': U_params, 'adj': adj, 'ATE': ATE, 'L1': L1, 'L2_doX0': doX0, 'L2_doX1': doX1} 68 | return meta, train_data 69 | 70 | def plot_marginals(pred, gt, seed, figsize=(13,9), running_losses=None, animate=False, save=None): 71 | plt.clf() 72 | N = len(gt) 73 | s = 3 if running_losses is not None else 2 74 | fig, axs = plt.subplots(s,N, figsize=figsize) 75 | for ind, a in enumerate(axs.flatten()): 76 | if running_losses is not None: 77 | if ind >= 2*N: 78 | for t in running_losses: 79 | a.plot(range(len(t[0])), t[0], label=t[1]) 80 | a.legend() 81 | break 82 | if ind >= N: 83 | marginals = gt 84 | color = 'black' 85 | else: 86 | marginals = pred 87 | color = 'blue' 88 | a.bar([0,1], marginals[ind % N], facecolor=color) 89 | a.set_title(f'Var:{ind} (p0={marginals[ind % N][0]:.2f},p1={marginals[ind % N][1]:.2f})') 90 | a.set_xlim(0,1) 91 | a.set_xticks([0,1]) 92 | a.set_ylim(0,1) 93 | plt.suptitle(f'NCM (Seed {seed})') 94 | plt.tight_layout() 95 | if animate: 96 | plt.pause(0.001) 97 | #input("Press [enter] to continue.") 98 | elif save is not None: 99 | plt.savefig(os.path.join(save[0], f'Viz-seed-{seed}{save[1]}.png')) 100 | else: 101 | plt.show() 102 | 103 | spn_type = "Regular" #"EinSum" 104 | 105 | ps = [ 106 | 107 | "ChainSCM_[0.1 0.8 0.7 0.8]p_N10000", 108 | "ChainSCM_[0.4 0.6 0.5 0.8]p_N10000", 109 | "ChainSCM_[0.5 0.7 0.6 0.5]p_N10000", 110 | 111 | "BackdoorSCM_[0.1 0.8 0.7 0.8]p_N10000", 112 | "BackdoorSCM_[0.4 0.6 0.5 0.8]p_N10000", 113 | "BackdoorSCM_[0.5 0.7 0.6 0.5]p_N10000", 114 | 115 | "ColliderSCM_[0.1 0.8 0.7 0.8]p_N10000", 116 | "ColliderSCM_[0.4 0.6 0.5 0.8]p_N10000", 117 | "ColliderSCM_[0.5 0.7 0.6 0.5]p_N10000", 118 | 119 | "ConfounderSCM_[0.1 0.8 0.7 0.8]p_N10000", 120 | "ConfounderSCM_[0.4 0.6 0.5 0.8]p_N10000", 121 | "ConfounderSCM_[0.5 0.7 0.6 0.5]p_N10000", 122 | 123 | ] 124 | cfg.dir_exp = os.path.join(cfg.dir_exp, datetime.now().strftime("%Y%m%d-%H%M%S") + "_TNCM") 125 | for p in ps: 126 | 127 | scm_model = p.split("SCM_")[0] 128 | dir_exp = os.path.join(cfg.dir_exp, p) 129 | 130 | for seed in cfg.seeds: 131 | np.random.seed(seed) 132 | torch.manual_seed(seed) 133 | random.seed(seed) 134 | 135 | meta, train_data = load_dataset(seed, p) 136 | 137 | ncm = TNCM(adj=meta['adj'], spn_type=spn_type) 138 | 139 | ncm.print_graph() 140 | print(f'Topological Order: {ncm.indices_to_names(ncm.topology)}') 141 | 142 | gt_marginals = compute_gt_marginals(meta, 'L1') 143 | 144 | if cfg.load_model is None: 145 | 146 | dir_seed = os.path.join(dir_exp, f"seed-{seed}") 147 | dir_seed_models = os.path.join(dir_seed, 'models') 148 | txt_meta = os.path.join(dir_seed, f'Experiment-Log-seed-{seed}.txt') 149 | if not os.path.exists(dir_seed_models): 150 | os.makedirs(dir_seed_models) 151 | sys.stdout = Logger(txt_meta) 152 | 153 | print(f'Config = {cfg.__dict__}\n') 154 | print(f'Data Set Size = {len(train_data)}') 155 | print(f'Meta Data; Model {meta["Model"]}\t U_params {meta["U_Params"]}\t GT ATE {meta["ATE"]}') 156 | 157 | #optimizer = torch.optim.SGD(ncm.params(), lr=0.2, weight_decay=0.0) 158 | if spn_type != "EinSum": 159 | optimizer = torch.optim.Adam(ncm.params()) 160 | 161 | train_ds = cycle(train_data) 162 | best_loss_tr = -100000. if spn_type == "EinSum" else 100000 # TODO: make EiNet also minimize #100000 163 | loss_int_weight = cfg.loss_int_weight 164 | 165 | plt.clf() 166 | plt.close() 167 | animate = cfg.animate 168 | 169 | loss_per_step = [] 170 | loss_components_per_step = [] 171 | running_losses_tr = [] 172 | running_losses_va = [] 173 | running_losses_components_tr = [] 174 | running_losses_components_va = [] 175 | for step in range(cfg.max_epochs): # epochs now 176 | t0 = time.time() 177 | 178 | for i in range(len(train_data)): 179 | 180 | batch = next(train_ds) 181 | 182 | if spn_type != "EinSum": 183 | for V in ncm.V: 184 | ncm.S[V].zero_grad() 185 | 186 | pV_est = ncm.forward(batch, samples=cfg.samples, doX=torch.tensor([-1.]*cfg.samples).unsqueeze(1), Xi=-1) # -1 means no intervention, clean up!! 187 | # this is the special case where y:=v and thus we only need to count how often we have the v in the data 188 | # reps = len(torch.where((train_data.dataset == datapoint).all(axis=1))[0]) 189 | # pV_int_est = ncm.forward(datapoint, samples=cfg.samples, doX=torch.tensor([0.] * cfg.samples).unsqueeze(1), Xi=0) # TODO: clean the doX interface 190 | # pV_int_mass = reps * pV_int_est 191 | # pV_int_mass = torch.where(pV_int_mass == 0., torch.tensor(0.)+1e-6, pV_int_mass) 192 | 193 | if spn_type == "EinSum": 194 | nll_l1 = torch.log(pV_est) # TODO: correct this, it is not nll but ll, but the other code forces it to be named like this for the moment 195 | nll_l2 = -1 * torch.zeros(1) 196 | loss = nll_l1 197 | else: 198 | nll_l1 = -1 * torch.log(pV_est) 199 | nll_l2 = -1 * torch.zeros(1)#* torch.log(pV_int_mass) 200 | if torch.isnan(nll_l1) or torch.isnan(nll_l2): 201 | import pdb; pdb.set_trace() 202 | loss = nll_l1 + loss_int_weight * nll_l2 203 | #print(f"Reps: {reps}\t Weight: {loss_int_weight}") 204 | loss.backward() 205 | 206 | if spn_type == "EinSum": 207 | for V in ncm.V: 208 | ncm.S[V].em_process_batch() 209 | else: 210 | optimizer.step() 211 | 212 | if cfg.loss_int_weight_decay: 213 | loss_int_weight = 1/np.log(step + 3) # 1/log(3) = .9 214 | # not introducing stopping, since it probably won't matter 215 | if loss_int_weight < .001: 216 | loss_int_weight = .001 217 | 218 | loss_per_step.append(loss.item()) 219 | loss_components_per_step.append((nll_l1.item(), nll_l2.item())) 220 | 221 | if i!= 0 and int(i * cfg.batch_size) % cfg.loss_running_window == 0: 222 | t1 = time.time() 223 | 224 | running_loss_tr = np.mean(loss_per_step[-cfg.loss_running_window:]) 225 | running_loss_components_tr = np.mean(np.array(loss_components_per_step)[-cfg.loss_running_window:,:],axis=0) 226 | running_losses_components_tr.append((running_loss_components_tr[0],running_loss_components_tr[1])) 227 | 228 | if spn_type == "EinSum": 229 | condition = lambda x,y: x > y 230 | else: 231 | condition = lambda x,y: x < y 232 | 233 | if condition(running_loss_tr,best_loss_tr): # TODO: make EiNet also minimize 234 | # uncomment this for non-valid # best_valid_elbo = train_elbo 235 | best_loss_tr = running_loss_tr 236 | #suffix = '' 237 | suffix = f' [Saved Model.]' 238 | states = { 239 | "SPNs": [s.state_dict() for s in ncm.S.values()], 240 | } 241 | best_model_name = f'NCM-seed-{seed}-ep-{step}-i-{i}' 242 | torch.save(states, os.path.join(dir_seed_models, best_model_name)) 243 | else: 244 | suffix = "" 245 | 246 | print( 247 | f"Epoch {step:<10d}/{cfg.max_epochs} I {i}/{len(train_data)}\t" 248 | f"Train NLL: {running_loss_tr:<5.3f} (L1 {running_loss_components_tr[0]:<5.3f}, L2 {running_loss_components_tr[1]:<5.3f})\t" 249 | f"" + suffix 250 | ) 251 | t0 = t1 252 | 253 | if spn_type == "EinSum": 254 | for V in ncm.V: 255 | ncm.S[V].em_update() 256 | 257 | checkpoint = torch.load(os.path.join(dir_seed_models,best_model_name) if cfg.load_model is None else cfg.load_model) 258 | for ind, s in enumerate(ncm.S.values()): 259 | s.load_state_dict(checkpoint["SPNs"][ind]) 260 | print(f'Loaded best model.') 261 | 262 | marginals = {'Model': os.path.join(dir_seed_models,best_model_name)} 263 | 264 | pred_marginals = ncm.compute_marginals(samples=cfg.samples) 265 | plot_marginals(pred_marginals, gt_marginals, seed, save=(dir_seed, '-best') if cfg.save_viz or cfg.load_model is None else None) 266 | marginals.update({'L1': {'GT': gt_marginals, 'TNCM': pred_marginals}}) 267 | print(f'Computed Marginals for L1 distribution.') 268 | 269 | # intervention 270 | pred_marginals = ncm.compute_marginals(samples=cfg.samples, doX=1., Xi=0) 271 | plot_marginals(pred_marginals, compute_gt_marginals(meta, 'L2_doX1'), seed, save=(dir_seed, '-doX1') if cfg.save_viz or cfg.load_model is None else None) 272 | marginals.update({'L2_doX1': {'GT': gt_marginals, 'TNCM': pred_marginals}}) 273 | print(f'Computed Marginals for do(X=1).') 274 | 275 | # intervention 276 | pred_marginals = ncm.compute_marginals(samples=cfg.samples, doX=0., Xi=0) 277 | plot_marginals(pred_marginals, compute_gt_marginals(meta, 'L2_doX0'), seed, save=(dir_seed, '-doX0') if cfg.save_viz or cfg.load_model is None else None) 278 | marginals.update({'L2_doX0': {'GT': gt_marginals, 'TNCM': pred_marginals}}) 279 | print(f'Computed Marginals for do(X=0).') 280 | 281 | with open(os.path.join(dir_seed, "Marginals.pkl"), "wb") as f: 282 | pickle.dump(marginals, f) 283 | print('Saved Marginals.') 284 | 285 | if cfg.load_model is not None: 286 | break --------------------------------------------------------------------------------