├── __init__.py ├── categorical ├── __init__.py ├── utils.py ├── experiment_distance.py ├── script_experiments.py ├── experiment_loops.py ├── plot_old.py ├── plot_sweep.py └── models.py ├── normal_pkg ├── __init__.py ├── plot_adaptation.py ├── plot_distances.py ├── distances.py ├── test_normal.py ├── proximal_optimizer.py ├── adaptation.py ├── normal.py └── normal_distance.py ├── environment.yaml ├── README.md ├── LICENSE ├── .gitignore ├── main.py ├── averaging_manager.py ├── simplex.py ├── 3_failure_space.ipynb ├── 4_dirichlet.ipynb └── 8_IRM.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /categorical/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /normal_pkg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: causaladaptation 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - numpy=1.18.1 7 | - matplotlib=3.1.3 8 | - scipy=1.4.1 9 | - pytorch=1.5.0 10 | - tqdm -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Analysis of the Adaptation Speed of Causal Models 2 | Code to reproduce results of the paper. 3 | Dependencies are indicated in `environment.yaml`. 4 | 5 | ### Categorical 6 | To get categorical adaptation plots run 7 | ``` 8 | python main.py categorical adaptation 9 | python main.py categorical plot 10 | ``` 11 | 12 | ### Normal 13 | To get scatter plots of distance: anti vs causal run 14 | ``` 15 | python main.py normal distance 16 | ``` 17 | 18 | To get adaptation results run 19 | ``` 20 | python main.py normal adaptation 21 | python main.py normal plot 22 | ``` 23 | -------------------------------------------------------------------------------- /categorical/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def kullback_leibler(p1, p2): 5 | return np.sum(p1 * np.log(p1 / p2), axis=-1) 6 | 7 | 8 | def entropy(p): 9 | return -np.sum(p * np.log(p), axis=-1) 10 | 11 | 12 | def logsumexp(s): 13 | smax = np.amax(s, axis=-1) 14 | return smax + np.log( 15 | np.sum(np.exp(s - np.expand_dims(smax, axis=-1)), axis=-1)) 16 | 17 | 18 | def logit2proba(s): 19 | return np.exp(s - np.expand_dims(logsumexp(s), axis=-1)) 20 | 21 | 22 | def proba2logit(p): 23 | s = np.log(p) 24 | s -= np.mean(s, axis=-1, keepdims=True) 25 | return s 26 | 27 | 28 | def test_proba2logit(): 29 | p = np.random.dirichlet(np.ones(50), size=300) 30 | s = proba2logit(p) 31 | assert np.allclose(0, np.sum(s, axis=-1)) 32 | 33 | q = logit2proba(s) 34 | assert np.allclose(1, np.sum(q, axis=-1)), q 35 | assert np.allclose(p, q), p - q 36 | 37 | 38 | if __name__ == "__main__": 39 | test_proba2logit() 40 | -------------------------------------------------------------------------------- /normal_pkg/plot_adaptation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import matplotlib 4 | from categorical.plot_sweep import two_plots 5 | from adaptation import CholeskyModule 6 | 7 | CholeskyModule 8 | 9 | def learning_curves(results_dir='normal_results'): 10 | for k in [10]: 11 | # Optimize hyperparameters for nsteps 12 | nsteps = 100 13 | init = 'natural' 14 | for intervention in ['cause', 'effect', 'mechanism']: 15 | plotname = f'{intervention}_{init}_k={k}' 16 | filepath = os.path.join(results_dir, plotname + '.pkl') 17 | if os.path.isfile(filepath): 18 | with open(filepath, 'rb') as fin: 19 | results = pickle.load(fin) 20 | two_plots(results, nsteps, plotname=plotname, dirname='normal_adaptation', 21 | figsize=(3, 3)) 22 | 23 | 24 | if __name__ == "__main__": 25 | matplotlib.use('pgf') 26 | matplotlib.rcParams['mathtext.fontset'] = 'cm' 27 | matplotlib.rcParams['pdf.fonttype'] = 42 28 | learning_curves() 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Remi Le Priol 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /normal_pkg/plot_distances.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def abline(ax, slope, intercept): 9 | """Plot a line from slope and intercept""" 10 | x_vals = np.array(ax.get_xlim()) 11 | y_vals = intercept + slope * x_vals 12 | ax.plot(x_vals, y_vals, '--', color='grey') 13 | 14 | 15 | def all_distances(results_dir='normal_results', plotdir='plots/normal_distances/'): 16 | with open(os.path.join(results_dir, 'distances_100.pkl'), 'rb') as fin: 17 | results = pickle.load(fin) 18 | 19 | os.makedirs(plotdir, exist_ok=True) 20 | 21 | for exp in results: 22 | for unit in ['nat', 'cho']: 23 | name = "dist{unit}_{intervention}_{init}_k={dim}.pdf".format(unit=unit, **exp) 24 | savefile = os.path.join(plotdir, name) 25 | print("Saving in plots ", savefile) 26 | scatter_distances(unit, exp) 27 | plt.tight_layout() 28 | plt.savefig(savefile) 29 | plt.close() 30 | 31 | 32 | def scatter_distances(unit, exp, alpha=.5): 33 | """Draw anticausal vs causal scatter plot""" 34 | dist = exp['distances'] 35 | fig, ax = plt.subplots(figsize=(3, 3)) 36 | ax.scatter(dist['causal_' + unit], dist['anti_' + unit], linewidth=0, alpha=alpha) 37 | ax.set_xlabel(r'$|| \theta^{(0)}_{\rightarrow} - \theta^*_{\rightarrow} ||$') 38 | ax.set_ylabel(r'$|| \theta^{(0)}_{\leftarrow} - \theta^*_{\leftarrow} ||$') 39 | abline(ax, 1, 0) 40 | ax.grid() 41 | ax.axis('equal') 42 | 43 | return fig 44 | 45 | 46 | if __name__ == "__main__": 47 | matplotlib.use('pgf') 48 | matplotlib.rcParams['mathtext.fontset'] = 'cm' 49 | matplotlib.rcParams['pdf.fonttype'] = 42 50 | all_distances() 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # personal 2 | hessian/ 3 | 4 | # IDE stuff 5 | .idea/ 6 | 7 | # google 8 | *.gslides 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | # Data and temporary folders 116 | data/ 117 | tmp/ 118 | 119 | # Data and figure folders 120 | **/data/ 121 | **/figures/ -------------------------------------------------------------------------------- /normal_pkg/distances.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | from normal_pkg import normal 8 | 9 | 10 | def intervention_distances(k, n, intervention='cause', init='natural', interpolation=1): 11 | """Sample n conditional Gaussians between cause and effect of dimension k 12 | and evaluate the distance after intervention between causal and anticausal models.""" 13 | 14 | ans = defaultdict(list) 15 | for i in range(n): 16 | # sample mechanisms 17 | reference = normal.sample(k, init) 18 | 19 | try: 20 | interp = interpolation[i] 21 | except TypeError: 22 | interp = interpolation 23 | 24 | transfer = reference.intervention(intervention, interp) 25 | 26 | ans['causal_nat'] += [reference.distance(transfer)] 27 | revref = reference.reverse() 28 | revtrans = transfer.reverse() 29 | ans['anti_nat'] += [revref.distance(revtrans)] 30 | ans['joint_nat'] += [reference.to_joint().distance(transfer.to_joint())] 31 | ans['causal_cho'] += [reference.to_cholesky().distance(transfer.to_cholesky())] 32 | ans['anti_cho'] += [revref.to_cholesky().distance(revtrans.to_cholesky())] 33 | 34 | return ans 35 | 36 | 37 | def record_distances(savedir = 'normal_results'): 38 | n = 100 39 | # kk = [1, 2, 3, 10] 40 | # kk = [20, 30, 40] 41 | kk = [10] 42 | 43 | results = [] 44 | for intervention in ['cause', 'effect', 'mechanism']: 45 | for init in ['natural']: # , 'cholesky']: 46 | for k in kk: 47 | np.random.seed(1) 48 | exp = { 49 | 'intervention': intervention, 50 | 'init': init, 51 | 'dim': k, 52 | } 53 | print('Recording distances for ', exp) 54 | exp = {**exp, 'distances': intervention_distances(k, n, intervention, init)} 55 | results.append(exp) 56 | 57 | os.makedirs(savedir, exist_ok=True) 58 | savefile = os.path.join(savedir, f'distances_{n}.pkl') 59 | print("Saving results in ", savefile) 60 | with open(savefile, 'wb') as fout: 61 | pickle.dump(results, fout) 62 | 63 | 64 | if __name__ == '__main__': 65 | a = intervention_distances(3, 4) 66 | record_distances() 67 | -------------------------------------------------------------------------------- /normal_pkg/test_normal.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import scipy 5 | 6 | from normal_pkg import normal 7 | 8 | np.random.seed(1) 9 | 10 | 11 | class TestNormals(unittest.TestCase): 12 | 13 | def setUp(self): 14 | self.nat = normal.sample_natural(dim=3, mode='conjugate') 15 | self.cho = normal.sample_cholesky(dim=3) 16 | 17 | # change of representation 18 | def test_nat2mean2nat(self): 19 | self.assertAlmostEqual(0, self.nat.to_mean().to_natural().distance(self.nat)) 20 | 21 | def test_nat2joint2nat(self): 22 | self.assertAlmostEqual(0, self.nat.to_joint().to_conditional().distance(self.nat)) 23 | 24 | def test_nat2joint2mean2cond2nat(self): 25 | self.assertAlmostEqual(0, self.nat.to_joint().to_mean().to_conditional().to_natural() 26 | .distance(self.nat)) 27 | 28 | def test_nat2mean2joint2nat2cond(self): 29 | self.assertAlmostEqual(0, self.nat.to_mean().to_joint().to_natural().to_conditional() 30 | .distance(self.nat)) 31 | 32 | def test_nat2cho2nat(self): 33 | self.assertAlmostEqual(0, self.nat.to_cholesky().to_natural().distance(self.nat)) 34 | 35 | def test_cho2nat2cho(self): 36 | self.assertAlmostEqual(0, self.cho.to_natural().to_cholesky().distance(self.cho)) 37 | 38 | def test_nat2joint2cho2nat2cond(self): 39 | self.assertAlmostEqual(0, self.nat.to_joint().to_cholesky().to_natural() 40 | .to_conditional().distance(self.nat)) 41 | 42 | # change of direction 43 | def test_reversereverse(self): 44 | self.assertAlmostEqual(0, self.nat.reverse().reverse().distance(self.nat)) 45 | 46 | # misc 47 | def test_interventions(self): 48 | self.nat.intervention(on='cause') 49 | self.nat.intervention(on='effect') 50 | 51 | def test_meanjoint(self): 52 | meanjoint = self.nat.to_joint().to_mean() 53 | meanjoint.sample(5) 54 | encoder = scipy.stats.ortho_group.rvs(meanjoint.mean.shape[0]) 55 | meanjoint.encode(encoder) 56 | 57 | def test_meancond(self): 58 | self.nat.to_mean().sample(5) 59 | 60 | def test_natjoint(self): 61 | natjoint = self.nat.to_joint() 62 | natjoint.logpartition 63 | natjoint.negativeloglikelihood(np.random.randn(10, natjoint.eta.shape[0])) 64 | 65 | 66 | if __name__ == '__main__': 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /categorical/experiment_distance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from categorical.models import sample_joint 7 | 8 | 9 | def intervention_distances(k, n, concentration, intervention, dense_init=True): 10 | """Sample n mechanisms of order k and for each of them sample an intervention on the desired 11 | mechanism. Return the distance between the original distribution and the intervened 12 | distribution in the causal parameter space and in the anticausal parameter space. 13 | """ 14 | # causal parameters 15 | causal = sample_joint(k, n, concentration, dense_init) 16 | transfer = causal.intervention(on=intervention, concentration=concentration, dense=dense_init) 17 | cpd, csd = causal.sqdistance(transfer) 18 | 19 | # anticausal parameters 20 | anticausal = causal.reverse() 21 | antitransfer = transfer.reverse() 22 | apd, asd = anticausal.sqdistance(antitransfer) 23 | 24 | return np.array([[cpd, csd], [apd, asd]]) 25 | 26 | 27 | def test_intervention_distances(): 28 | print('test experiment') 29 | for intervention in ['cause', 'effect', 'mechanism', 'gmechanism', 'independent', 'geometric', 30 | 'weightedgeo']: 31 | for dense_init in [True, False]: 32 | intervention_distances(2, 3, 1, intervention, dense_init) 33 | 34 | 35 | def all_distances(savedir='categorical_results'): 36 | n = 300 37 | kk = np.arange(2, 100, 8) 38 | 39 | results = [] 40 | for intervention in ['cause', 'effect', 'mechanism', 'gmechanism']: 41 | for dense_init in [True, False]: 42 | for concentration in [1]: 43 | distances = [] 44 | for k in kk: 45 | distances.append(intervention_distances( 46 | k, n, 1, intervention, dense_init 47 | )) 48 | exp = { 49 | 'intervention': intervention, 50 | 'dense_init': dense_init, 51 | 'concentration': concentration, 52 | 'dimensions': kk, 53 | 'distances': np.array(distances) 54 | } 55 | results.append(exp) 56 | 57 | os.makedirs(savedir, exist_ok=True) 58 | 59 | with open(os.path.join( 60 | savedir, f'categorical_distances_{n}.pkl'), 'wb') as fout: 61 | pickle.dump(results, fout) 62 | 63 | 64 | if __name__ == "__main__": 65 | test_intervention_distances() 66 | all_distances() 67 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import categorical 4 | from categorical import plot_sweep, script_experiments 5 | from normal_pkg import adaptation, distances, plot_adaptation, plot_distances 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser( 9 | description='Evaluate the speed of adpatation of cause-effect models.') 10 | parser.add_argument('distribution', type=str, choices=['categorical', 'normal']) 11 | parser.add_argument('action', type=str, choices=['distance', 'adaptation', 'plot']) 12 | 13 | args = parser.parse_args() 14 | 15 | if args.distribution == 'categorical': 16 | results_dir = 'categorical_results' 17 | if args.action == 'distance': 18 | # categorical.script_experiments.all_distances(savedir=results_dir) 19 | raise DeprecationWarning("Categorical distances does not work.") 20 | 21 | elif args.action == 'adaptation': 22 | for init_dense in [True, False]: 23 | for k in [20]: 24 | for intervention in ['cause', 'effect', 'singlecond']: 25 | categorical.script_experiments.parameter_sweep( 26 | intervention, k, init_dense, savedir=results_dir) 27 | 28 | elif args.action == 'plot': 29 | for dense in [True, False]: 30 | categorical.plot_sweep.all_plot(dense=dense, input_dir=results_dir) 31 | 32 | elif args.distribution == 'normal': 33 | results_dir = 'normal_results' 34 | if args.action == 'distance': 35 | print("Measuring distances") 36 | distances.record_distances(savedir=results_dir) 37 | plot_distances.all_distances() 38 | 39 | elif args.action == 'adaptation': 40 | print("\n\n Simulating adaptation to interventions") 41 | base_hparams = {'n': 100, 'T': 400, 'batch_size': 1, 'use_prox': True, 42 | 'log_interval': 10, 'intervention_scale': 1, 43 | 'init': 'natural', 'preccond_scale': 10} 44 | print("Hyperparameters ", base_hparams) 45 | lrlr = [.001, .003, .01, .03, .1] 46 | for k in [10]: 47 | base_hparams['k'] = k 48 | for intervention in ['cause', 'effect']: 49 | hparams = {**base_hparams, 'k': k, 'intervention': intervention} 50 | adaptation.sweep_lr(lrlr, hparams, savedir=results_dir) 51 | 52 | elif args.action == 'plot': 53 | plot_adaptation.learning_curves(results_dir=results_dir) 54 | -------------------------------------------------------------------------------- /averaging_manager.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from torch import nn, optim 5 | 6 | 7 | class AveragedModel: 8 | 9 | def __init__(self, model, optimizer): 10 | self.model = model 11 | self.optimizer = optimizer 12 | self.current_parameters = {} 13 | 14 | def __enter__(self): 15 | """Assign average value to self.model parameters""" 16 | for p in self.model.parameters(): 17 | self.current_parameters[p] = p.data 18 | if p in self.optimizer.state: 19 | p.data = self.optimizer.state[p]['ax'] 20 | return self.model 21 | 22 | def __exit__(self, exc_type, exc_val, exc_tb): 23 | for p in self.model.parameters(): 24 | if p in self.current_parameters: 25 | p.data = self.current_parameters[p] 26 | 27 | 28 | def test_AveragedModel(d=10): 29 | torch.manual_seed(1) 30 | model = nn.Linear(d, 1, bias=False) 31 | optimizer = optim.ASGD(model.parameters(), lr=.05, 32 | lambd=0, alpha=0, t0=0, weight_decay=0) 33 | 34 | print(next(model.parameters())) 35 | xeval = torch.randn(100, d) 36 | yeval = xeval.mean(dim=1, keepdim=True) 37 | 38 | trajectory = defaultdict(list) 39 | for t in range(100): 40 | # evaluate 41 | with torch.no_grad(): 42 | trajectory['error'] += [float(torch.mean((model(xeval) - yeval) ** 2))] 43 | trajectory['pdist'] += [float(torch.norm((next(model.parameters()).data 44 | - 1 / d * torch.ones(d)) ** 2))] 45 | with AveragedModel(model, optimizer): 46 | trajectory['aerror'] += [float(torch.mean((model(xeval) - yeval) ** 2))] 47 | trajectory['apdist'] += [float(torch.norm((next(model.parameters()).data 48 | - 1 / d * torch.ones(d)) ** 2))] 49 | 50 | # train 51 | optimizer.zero_grad() 52 | x = torch.randn(1, d) 53 | y = torch.mean(x, dim=1, keepdim=True) 54 | loss = torch.mean((y - model(x)) ** 2) 55 | loss.backward() 56 | optimizer.step() 57 | 58 | print(next(model.parameters()).data) 59 | 60 | import matplotlib.pyplot as plt 61 | # plt.scatter(error, aerror, c=np.arange(len(error))) 62 | plt.plot(trajectory['error'], alpha=.5, label='MSE') 63 | plt.plot(trajectory['aerror'], alpha=.5, label='average MSE') 64 | plt.plot(trajectory['pdist'], alpha=.5, label='pdist') 65 | plt.plot(trajectory['apdist'], alpha=.5, label='average pdist') 66 | plt.yscale('log') 67 | plt.legend() 68 | plt.grid() 69 | plt.show() 70 | 71 | 72 | if __name__ == "__main__": 73 | test_AveragedModel() 74 | -------------------------------------------------------------------------------- /normal_pkg/proximal_optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import optim 5 | 6 | 7 | class PerturbedProximalGradient(optim.ASGD): 8 | 9 | def __init__(self, params, use_prox, **kwargs): 10 | super(PerturbedProximalGradient, self).__init__(params, **kwargs) 11 | self.use_prox = use_prox 12 | # note that even if I do not use prox, 13 | # I still need to project on lower triangular matrix 14 | 15 | def step(self): 16 | """Mostly copied from ASGD, but there was no other way to do both 17 | projection, proximal step and averaging. 18 | """ 19 | for group in self.param_groups: 20 | for p in group['params']: 21 | if p.grad is None: 22 | continue 23 | grad = p.grad.data 24 | if grad.is_sparse: 25 | raise RuntimeError('ASGD does not support sparse gradients') 26 | state = self.state[p] 27 | 28 | # State initialization 29 | if len(state) == 0: 30 | state['step'] = 0 31 | state['eta'] = group['lr'] 32 | state['mu'] = 1 33 | state['ax'] = torch.zeros_like(p.data) 34 | 35 | state['step'] += 1 36 | 37 | if group['weight_decay'] != 0: 38 | grad = grad.add(group['weight_decay'], p.data) 39 | 40 | # decay term 41 | p.data.mul_(1 - group['lambd'] * state['eta']) 42 | 43 | # update parameter 44 | p.data.add_(-state['eta'], grad) 45 | 46 | # NEW 47 | istriangular = getattr(p, 'triangular', False) 48 | if istriangular: 49 | # project back onto lower triangular matrices 50 | p.data = torch.tril(p.data) 51 | 52 | # proximal update on diagonal parameters with - log loss 53 | di = torch.diag(p.data) 54 | if self.use_prox: 55 | diff = .5 * (torch.sqrt(di ** 2 + 4 * state['eta']) - di) 56 | else: 57 | diff = torch.clamp(di, min=1e-3) - di 58 | p.data.add_(torch.diag(diff)) 59 | # END NEW 60 | 61 | # averaging 62 | if state['mu'] != 1: 63 | state['ax'].add_(p.data.sub(state['ax']).mul(state['mu'])) 64 | else: 65 | state['ax'].copy_(p.data) 66 | 67 | # update eta and mu 68 | state['eta'] = (group['lr'] / 69 | math.pow((1 + group['lambd'] * group['lr'] * state['step']), 70 | group['alpha'])) 71 | state['mu'] = 1 / max(1, state['step'] - group['t0']) 72 | -------------------------------------------------------------------------------- /categorical/script_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from categorical.experiment_loops import experiment_guess, experiment_optimize 7 | 8 | 9 | def optimize_distances(k=10): 10 | results = [] 11 | base_experiment = { 12 | 'n': 100, 'k': k, 'T': 1500, 'n0': 10, 13 | 'batch_size': 1, 'scheduler_exponent': 0, 14 | 'concentration': 1, 'intervention': 'cause' 15 | } 16 | # for lr in [0.01, 0.05, 0.1]: 17 | for lr in [0.01, 0.1, .5]: 18 | trajectory = experiment_optimize( 19 | lr=lr, **base_experiment) 20 | experiment = {**base_experiment, 'lr': lr, **trajectory} 21 | results.append(experiment) 22 | 23 | savedir = 'results' 24 | os.makedirs(savedir, exist_ok=True) 25 | savefile = os.path.join(savedir, f'categorical_optimize_k={k}.pkl') 26 | if os.path.exists(savefile): 27 | with open(savefile, 'rb') as fin: 28 | previous_results = pickle.load(fin) 29 | else: 30 | previous_results = [] 31 | 32 | with open(savefile, 'wb') as fout: 33 | pickle.dump(previous_results + results, fout) 34 | 35 | 36 | def parameter_sweep(intervention, k, init, seed=17, guess=False, savedir='categorical_results'): 37 | print(f'intervention on {intervention} with k={k}') 38 | results = [] 39 | base_experiment = { 40 | 'n': 100, 'k': k, 'T': 1500, 41 | 'batch_size': 1, 42 | 'intervention': intervention, 43 | 'is_init_dense': init, 44 | 'concentration': 1, 45 | 'use_map': True 46 | } 47 | for exponent in [0]: 48 | for lr, n0 in zip([.03, .1, .3, 1, 3, 9, 30], 49 | [0.3, 1, 3, 10, 30, 90, 200]): 50 | np.random.seed(seed) 51 | parameters = {'n0': n0, 'lr': lr, 'scheduler_exponent': exponent, **base_experiment} 52 | if guess: 53 | trajectory = experiment_guess(**parameters) 54 | else: 55 | trajectory = experiment_optimize(**parameters) 56 | results.append({ 57 | 'hyperparameters': parameters, 58 | 'trajectory': trajectory, 59 | 'guess': guess 60 | }) 61 | 62 | os.makedirs(savedir, exist_ok=True) 63 | 64 | savefile = f'{intervention}_k={k}.pkl' 65 | if base_experiment['is_init_dense']: 66 | savefile = 'denseinit_' + savefile 67 | else: 68 | savefile = 'sparseinit_' + savefile 69 | if guess: 70 | savefile = 'guess_' + savefile 71 | else: 72 | savefile = 'sweep2_' + savefile 73 | 74 | savepath = os.path.join(savedir, savefile) 75 | with open(savepath, 'wb') as fout: 76 | pickle.dump(results, fout) 77 | 78 | 79 | if __name__ == "__main__": 80 | guess = False 81 | for init_dense in [True, False]: 82 | for k in [20]: 83 | # parameter_sweep('cause', k, init_dense, guess=guess) 84 | # parameter_sweep('effect', k, init_dense, guess=guess) 85 | # parameter_sweep('singlecond', k, init_dense) 86 | parameter_sweep('gmechanism', k, init_dense) 87 | -------------------------------------------------------------------------------- /simplex.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.tri as tri 3 | import numpy as np 4 | 5 | 6 | diracs = np.eye(3) 7 | uniform = np.ones(3)/3 8 | centered_diracs = diracs - uniform 9 | corners2d = np.array([[0, 2/np.sqrt(6)], 10 | [-np.sqrt(2)/2, -1/np.sqrt(6)], 11 | [np.sqrt(2)/2, -1/np.sqrt(6)],]) 12 | corners2d /= np.sqrt(2) 13 | 14 | 15 | class Simplex: 16 | """Draw any function on a 2d simplex""" 17 | 18 | def __init__(self,corners=corners2d, resolution=6): 19 | self.corners = corners 20 | # change of basis matrices 21 | self.d3tod2 = np.linalg.solve(centered_diracs, corners) 22 | # 3*3 and 3*2 ouputs 3*2 23 | self.d2tod3 = np.linalg.solve(corners[:2], centered_diracs[:2]) 24 | # 2*2 and 2*3 outputs 2*3 25 | 26 | # triangular mesh 27 | self.triangle = tri.Triangulation(corners[:, 0], corners[:, 1]) 28 | self.refiner = tri.UniformTriRefiner(self.triangle) 29 | self.trimesh = self.refiner.refine_triangulation(subdiv=resolution) 30 | 31 | # coordinates 32 | self.cartesians = np.array([self.trimesh.x, self.trimesh.y]).T 33 | self.barycentrics = self.cart2bary(self.cartesians) 34 | print(f"Simplex with resolution {resolution}: {self.cartesians.shape[0]} points.") 35 | 36 | 37 | def bary2cart(self,barys): 38 | """Take an N*3 array of barycenter coordinates 39 | and return an N*2 array of 2D cartesian coordinates.""" 40 | return np.dot(barys - uniform, self.d3tod2) 41 | 42 | def cart2bary(self,cart): 43 | """Take an N*2 array of 2D cartesian coordinates 44 | and return an N*3 array of barycenter coordinates.""" 45 | return np.dot(cart, self.d2tod3) + uniform 46 | 47 | def plotoptions(self): 48 | plt.axis('equal') 49 | plt.axis('off') 50 | plt.tight_layout() 51 | 52 | 53 | def show_borders(self, color='blue'): 54 | # plot contours 55 | self.plotoptions() 56 | plt.triplot(self.triangle,color=color) 57 | 58 | # plot lines between center of triangle and midline of each edge 59 | center = np.mean(self.corners,axis=0) 60 | middles = 0.5*(self.corners + np.roll(self.corners,1,axis=0)) 61 | for m in middles: 62 | plt.plot([m[0], center[0]],[m[1],center[1]], color=color) 63 | 64 | 65 | def show_mesh(self): 66 | self.plotoptions() 67 | plt.triplot(self.trimesh,linewidth=0.5) 68 | 69 | def show_func(self, baryfunc): 70 | values = baryfunc(self.barycentrics) 71 | self.show_borders() 72 | clevels = plt.tricontour(self.trimesh, values) 73 | plt.clabel(clevels) 74 | 75 | def constraint_line(self,eps,color='r'): 76 | """Plot the line of equation $q_1-q_2=eps$""" 77 | extremes = np.array([[(1+eps)/2, (1-eps)/2, 0], 78 | [eps,0,1-eps]]) 79 | cart = self.bary2cart(extremes) 80 | plt.plot(cart[:,0],cart[:,1],color=color) 81 | 82 | def scatter(self, barypoints, color='red'): 83 | carts = self.bary2cart(barypoints) 84 | if len(carts.shape)==1: 85 | carts = carts[np.newaxis,:] 86 | plt.scatter(carts[:,0],carts[:,1], 87 | s=50, color=color,marker='*',zorder=3) 88 | 89 | def show_func3d(self, baryfunc): 90 | values = baryfunc(self.barycentrics) 91 | 92 | from mpl_toolkits.mplot3d import Axes3D 93 | 94 | fig = plt.figure(figsize=(5,5)) 95 | ax = fig.gca(projection='3d') 96 | ax.triplot(self.triangle) 97 | # Plot the surface. 98 | surf = ax.plot_trisurf(self.cartesians[:,0], self.cartesians[:,1], values, 99 | cmap='viridis', linewidth=3, antialiased=False) 100 | # Add a color bar which maps values to colors. 101 | fig.colorbar(surf, shrink=0.5, aspect=5) 102 | plt.show() -------------------------------------------------------------------------------- /categorical/experiment_loops.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | import tqdm 6 | from torch import optim 7 | 8 | from averaging_manager import AveragedModel 9 | from categorical.models import CategoricalModule, Counter, JointMAP, JointModule, sample_joint 10 | 11 | 12 | def experiment_optimize(k, n, T, lr, intervention, 13 | concentration=1, 14 | is_init_dense=False, 15 | batch_size=10, scheduler_exponent=0, n0=10, 16 | log_interval=10, use_map=False): 17 | """Measure optimization speed and parameters distance. 18 | 19 | Hypothesis: initial distance to optimum is correlated to optimization speed with SGD. 20 | 21 | Sample n mechanisms of order k and for each of them sample an 22 | intervention on the desired mechanism. Use SGD to update a causal 23 | and an anticausal model for T steps. At each step, measure KL 24 | and distance in scores for causal and anticausal directions. 25 | """ 26 | causalstatic = sample_joint(k, n, concentration, is_init_dense) 27 | transferstatic = causalstatic.intervention( 28 | on=intervention, 29 | concentration=concentration, 30 | dense=is_init_dense 31 | ) 32 | # MODULES 33 | causal = causalstatic.to_module() 34 | transfer = transferstatic.to_module() 35 | 36 | anticausal = causalstatic.reverse().to_module() 37 | antitransfer = transferstatic.reverse().to_module() 38 | 39 | joint = JointModule(causal.to_joint().detach().view(n, -1)) 40 | jointtransfer = JointModule(transfer.to_joint().detach().view(n, -1)) 41 | 42 | # Optimizers 43 | optkwargs = {'lr': lr, 'lambd': 0, 'alpha': 0, 't0': 0, 44 | 'weight_decay': 0} 45 | causaloptimizer = optim.ASGD(causal.parameters(), **optkwargs) 46 | antioptimizer = optim.ASGD(anticausal.parameters(), **optkwargs) 47 | jointoptimizer = optim.ASGD(joint.parameters(), **optkwargs) 48 | optimizers = [causaloptimizer, antioptimizer, jointoptimizer] 49 | 50 | # MAP 51 | counter = Counter(np.zeros([n, k, k])) 52 | smooth_MLE = JointMAP(np.ones([n, k, k]), counter) 53 | joint_MAP = JointMAP(n0 * causalstatic.to_joint(return_probas=True), counter) 54 | 55 | steps = [] 56 | ans = defaultdict(list) 57 | for t in tqdm.tqdm(range(T)): 58 | 59 | # EVALUATION 60 | if t % log_interval == 0: 61 | steps.append(t) 62 | 63 | with torch.no_grad(): 64 | 65 | for model, optimizer, target, name in zip( 66 | [causal, anticausal, joint], 67 | optimizers, 68 | [transfer, antitransfer, jointtransfer], 69 | ['causal', 'anti', 'joint'] 70 | ): 71 | # SGD 72 | ans[f'kl_{name}'].append(target.kullback_leibler(model)) 73 | ans[f'scoredist_{name}'].append(target.scoredist(model)) 74 | 75 | # ASGD 76 | with AveragedModel(model, optimizer) as m: 77 | ans[f'kl_{name}_average'].append( 78 | target.kullback_leibler(m)) 79 | ans[f'scoredist_{name}_average'].append( 80 | target.scoredist(m)) 81 | 82 | # MAP 83 | if use_map: 84 | ans['kl_MAP_uniform'].append( 85 | transfer.kullback_leibler(smooth_MLE)) 86 | ans['kl_MAP_source'].append( 87 | transfer.kullback_leibler(joint_MAP)) 88 | 89 | # UPDATE 90 | for opt in optimizers: 91 | opt.lr = lr / t ** scheduler_exponent 92 | opt.zero_grad() 93 | 94 | if batch_size == 'full': 95 | causalloss = transfer.kullback_leibler(causal).sum() 96 | antiloss = antitransfer.kullback_leibler(anticausal).sum() 97 | jointloss = jointtransfer.kullback_leibler(joint).sum() 98 | else: 99 | aa, bb = transferstatic.sample(m=batch_size) 100 | if use_map: 101 | counter.update(aa, bb) 102 | taa, tbb = torch.from_numpy(aa), torch.from_numpy(bb) 103 | causalloss = - causal(taa, tbb).sum() / batch_size 104 | antiloss = - anticausal(taa, tbb).sum() / batch_size 105 | jointloss = - joint(taa, tbb).sum() / batch_size 106 | 107 | for loss, opt in zip([causalloss, antiloss, jointloss], optimizers): 108 | loss.backward() 109 | opt.step() 110 | 111 | for key, item in ans.items(): 112 | ans[key] = torch.stack(item).numpy() 113 | 114 | return {'steps': np.array(steps), **ans} 115 | 116 | 117 | def test_experiment_optimize(): 118 | for intervention in ['cause', 'effect', 'gmechanism']: 119 | experiment_optimize( 120 | k=2, n=3, T=6, lr=.1, batch_size=4, log_interval=1, 121 | intervention=intervention, use_map=True 122 | ) 123 | 124 | 125 | def experiment_guess( 126 | k, n, T, lr, intervention, 127 | concentration=1, is_init_dense=False, 128 | batch_size=10, scheduler_exponent=0, log_interval=10 129 | ): 130 | """Measure optimization speed after guessing intervention. 131 | 132 | Sample n mechanisms of order k and for each of them sample an 133 | intervention on the desired mechanism. Initialize a causal and 134 | an anticausal model to the reference distribution. Duplicate them 135 | and initialize one module of each duplicata to the uniform. 136 | Run the optimization and record KL and distance. Also record the 137 | accuracy of guessing the intervention based on the lowest KL. 138 | """ 139 | causalstatic = sample_joint(k, n, concentration, is_init_dense) 140 | transferstatic = causalstatic.intervention( 141 | on=intervention, 142 | concentration=concentration, 143 | dense=is_init_dense 144 | ) 145 | causal = causalstatic.to_module() 146 | transfer = transferstatic.to_module() 147 | 148 | anticausal = causalstatic.reverse().to_module() 149 | antitransfer = transferstatic.reverse().to_module() 150 | 151 | joint = JointModule(causal.to_joint().detach().view(n, -1)) 152 | jointtransfer = JointModule(transfer.to_joint().detach().view(n, -1)) 153 | 154 | # TODO put all models and their transfer into a dict 155 | models = [causal, anticausal, joint] 156 | names = ['causal', 'anti', 'joint'] 157 | targets = [transfer, antitransfer, jointtransfer] 158 | 159 | # step 1 : duplicate models with intervention guessing 160 | causalguessA = CategoricalModule(torch.zeros([n, k]), causal.sba, is_btoa=False) 161 | causalguessB = CategoricalModule(causal.sa, torch.zeros([n, k, k]), is_btoa=False) 162 | 163 | antiguessA = CategoricalModule(anticausal.sa, torch.zeros([n, k, k]), is_btoa=True) 164 | antiguessB = CategoricalModule(torch.zeros([n, k]), anticausal.sba, is_btoa=True) 165 | 166 | models += [causalguessA, causalguessB, antiguessA, antiguessB] 167 | names += ['CausalGuessA', 'CausalGuessB', 'AntiGuessA', 'AntiGuessB'] 168 | targets += [transfer, transfer, antitransfer, antitransfer] 169 | 170 | optkwargs = {'lr': lr, 'lambd': 0, 'alpha': 0, 't0': 0, 'weight_decay': 0} 171 | optimizer = optim.ASGD([p for m in models for p in m.parameters()], **optkwargs) 172 | 173 | # intervention guess 174 | marginalA = JointModule(causal.sa) 175 | marginalB = JointModule(anticausal.sa) 176 | 177 | steps = [] 178 | ans = defaultdict(list) 179 | for step in tqdm.tqdm(range(T)): 180 | 181 | # EVALUATION 182 | if step % log_interval == 0: 183 | steps.append(step) 184 | with torch.no_grad(): 185 | 186 | for model, target, name in zip(models, targets, names): 187 | # SGD 188 | ans[f'kl_{name}'].append(target.kullback_leibler(model)) 189 | ans[f'scoredist_{name}'].append(target.scoredist(model)) 190 | 191 | # ASGD 192 | with AveragedModel(model, optimizer): 193 | ans[f'kl_{name}_average'].append(target.kullback_leibler(model)) 194 | ans[f'scoredist_{name}_average'].append(target.scoredist(model)) 195 | 196 | # UPDATE 197 | optimizer.lr = lr / step ** scheduler_exponent 198 | optimizer.zero_grad() 199 | 200 | if batch_size == 'full': 201 | loss = sum([t.kullback_leibler(m).sum() for m, t in zip(models, targets)]) 202 | else: 203 | aa, bb = transferstatic.sample(m=batch_size, return_tensor=True) 204 | loss = sum([- m(aa, bb).sum() for m in models]) / batch_size 205 | 206 | # step 2, estimate likelihood of samples aa and bb for each marginals 207 | # of the reference model and take the lowest likelihood as a guess 208 | # for the intervention. Take the average over all examples seen until now 209 | with torch.no_grad(): 210 | ans['loglikelihoodA'].append(marginalA(torch.zeros_like(aa), aa).mean(dim=1)) 211 | ans['loglikelihoodB'].append(marginalB(torch.zeros_like(bb), bb).mean(dim=1)) 212 | 213 | loss.backward() 214 | optimizer.step() 215 | 216 | for key, item in ans.items(): 217 | ans[key] = torch.stack(item).numpy() 218 | 219 | return {'steps': np.array(steps), **ans} 220 | 221 | 222 | def test_experiment_guess(): 223 | for bs in ['full', 4]: 224 | for intervention in ['cause', 'effect']: 225 | experiment_guess( 226 | k=5, n=10, T=6, lr=.1, batch_size=bs, log_interval=1, 227 | intervention=intervention 228 | ) 229 | 230 | 231 | if __name__ == "__main__": 232 | test_experiment_optimize() 233 | test_experiment_guess() 234 | -------------------------------------------------------------------------------- /categorical/plot_old.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def dist_plot(): 9 | with open('categorical_results/categorical_distances_300.pkl', 'rb') as fin: 10 | results = pickle.load(fin) 11 | 12 | plotdir = 'plots/distances' 13 | os.makedirs(plotdir, exist_ok=True) 14 | 15 | for exp in results: 16 | name = bigplotname(exp) 17 | print(name) 18 | bigplot(exp) 19 | plt.savefig(os.path.join(plotdir, name)) 20 | plt.close() 21 | 22 | 23 | def bigplot(exp, confidence=(5, 95)): 24 | """Draw 4 scatter plots and 2 line plots""" 25 | proba, score = tuple(np.swapaxes(exp['distances'], 0, 2)) 26 | causal_proba, anti_proba = tuple(proba) 27 | causal_score, anti_score = tuple(score) 28 | 29 | dimensions = exp['dimensions'] 30 | 31 | fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(16, 20)) 32 | 33 | csname = r'$\|\theta_\rightarrow^0 - theta_\rightarrow^*\| $' 34 | cpname = r'$\|p_\rightarrow^0 - p_\rightarrow^*\| $' 35 | asname = r'$\|\theta_\leftarrow^0 - theta_\leftarrow^*\| $' 36 | apname = r'$\|p_\leftarrow^0 - p_\leftarrow^*\| $' 37 | 38 | axs[0, 0].set_title('causal score vs proba') 39 | axs[0, 1].set_title('anticausal score vs proba') 40 | axs[1, 0].set_title('proba anticausal vs causal') 41 | axs[1, 1].set_title('score anticausal vs causal') 42 | 43 | for k, cpd, apd, csd, asd in zip( 44 | dimensions, 45 | causal_proba, anti_proba, 46 | causal_score, anti_score): 47 | axs[0, 0].scatter(cpd, csd, label=k, alpha=.2) 48 | axs[0, 0].set_xlabel(cpname) 49 | axs[0, 0].set_ylabel(csname) 50 | axs[0, 1].scatter(apd, asd, label=k, alpha=.2) 51 | axs[0, 1].set_xlabel(apname) 52 | axs[0, 1].set_ylabel(asname) 53 | axs[1, 0].scatter(cpd, apd, label=k, alpha=.2) 54 | axs[1, 0].set_xlabel(cpname) 55 | axs[1, 0].set_ylabel(apname) 56 | axs[1, 1].scatter(csd, asd, label=k, alpha=.2) 57 | axs[1, 1].set_xlabel(csname) 58 | axs[1, 1].set_ylabel(asname) 59 | 60 | abline(axs[1, 0], 1, 0) 61 | abline(axs[1, 1], 1, 0) 62 | 63 | axs[2, 0].set_title('proba ratio anticausal / causal ') 64 | axs[2, 1].set_title('score ratio anticausal / causal ') 65 | ratio_proba = anti_proba / causal_proba 66 | ratio_score = anti_score / causal_score 67 | 68 | for ax, ratio in zip([axs[2, 0], axs[2, 1]], [ratio_proba, ratio_score]): 69 | ax.plot(dimensions, np.mean(ratio, axis=-1), label='mean') 70 | ax.fill_between(dimensions, 71 | np.percentile(ratio, confidence[0], axis=-1), 72 | np.percentile(ratio, confidence[1], axis=-1), 73 | alpha=.4, label='confidence {} %'.format( 74 | confidence[1] - confidence[0])) 75 | abline(ax, 0, 1) 76 | ax.set_xlabel('k') 77 | 78 | for ax in axs.flatten(): 79 | ax.legend() 80 | 81 | return fig, axs 82 | 83 | 84 | def bigplotname(exp): 85 | return "{}_{}.pdf".format( 86 | exp['intervention'], 87 | 'dense' if exp['dense_init'] else 'sparse', 88 | ) 89 | 90 | 91 | def abline(ax, slope, intercept): 92 | """Plot a line from slope and intercept""" 93 | x_vals = np.array(ax.get_xlim()) 94 | y_vals = intercept + slope * x_vals 95 | ax.plot(x_vals, y_vals, '--', color='grey') 96 | 97 | 98 | COLORS = { 99 | 'causal': 'blue', 100 | 'anti': 'red', 101 | 'joint': 'palegreen', 102 | 'causal_average': 'darkblue', 103 | 'anti_average': 'darkred', 104 | 'joint_average': 'darkgreen', 105 | 'MAP_uniform': 'yellow', 106 | 'MAP_source': 'gold' 107 | } 108 | 109 | 110 | def optim_plot(): 111 | with open('results/categorical_optimize_k=10.pkl', 'rb') as fin: 112 | results = pickle.load(fin) 113 | 114 | plotdir = 'plots/categorical_optim' 115 | os.makedirs(plotdir, exist_ok=True) 116 | 117 | for exp in results: 118 | print(longplotname(exp)) 119 | longplot(exp, statistics=True, plot_bound=False) 120 | plt.savefig(os.path.join(plotdir, 'average_' + longplotname(exp))) 121 | # curve_plot(exp, statistics=False) 122 | # plt.savefig(os.path.join(plotdir, 'curves_' + longplotname(exp))) 123 | optim_scatter(exp) 124 | plt.savefig(os.path.join(plotdir, 'scatter_' + longplotname(exp))) 125 | plt.close() 126 | 127 | 128 | def longplot(exp, confidence=(5, 95), statistics=True, 129 | plot_bound=False): 130 | """Draw mean trajectory plot with percentiles""" 131 | fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 8)) 132 | for ax, metric in zip(axs, ['kl', 'scoredist']): 133 | ax.grid(True) 134 | ax.set_yscale('log') 135 | for model_family in COLORS.keys(): 136 | key = metric + '_' + model_family 137 | if key in exp: 138 | values = exp[key] 139 | 140 | if statistics: # plot mean and percentile statistics 141 | ax.plot( 142 | exp['steps'], 143 | values.mean(axis=1), 144 | label=model_family, 145 | color=COLORS[model_family] 146 | ) 147 | ax.fill_between( 148 | exp['steps'], 149 | np.percentile(values, confidence[0], axis=1), 150 | np.percentile(values, confidence[1], axis=1), 151 | alpha=.4, 152 | color=COLORS[model_family] 153 | ) 154 | else: 155 | ax.plot( 156 | exp['steps'], 157 | values, 158 | alpha=.1, 159 | color=COLORS[model_family], 160 | label=model_family 161 | ) 162 | 163 | axs[0].set_ylabel('KL(transfer, model)') 164 | if statistics: 165 | axs[0].legend() 166 | axs[1].set_ylabel('||transfer - model ||^2') 167 | 168 | if plot_bound: 169 | make_bound(exp, axs[0]) 170 | 171 | 172 | def make_bound(exp, ax): 173 | for model_family in ['causal', 'anti']: 174 | initial_distance = exp['scoredist_' + model_family][0].mean() 175 | steps = np.array(exp['steps'])[1:] 176 | if exp['batch_size'] == 'full' and exp['scheduler_exponent'] == 0: 177 | constant = rate_constant_fullbatch( 178 | smoothness=1 / 2, 179 | initial_lr=exp['lr'], 180 | initial_distance=initial_distance 181 | ) 182 | bound = constant / (steps - 1) 183 | elif np.isclose(exp['scheduler_exponent'], 2 / 3): 184 | constant = rate_constant_twothird( 185 | smoothness=1 / 2, 186 | initial_lr=exp['lr'], 187 | initial_distance=initial_distance, 188 | variance=1 / exp['batch_size'] 189 | ) 190 | bound = constant / (steps ** (1 / 3) - 1) 191 | else: 192 | print('Convergence bound only available for ' 193 | 'gradient descent with constant learning rate ' 194 | 'or stochastic gradient descent with learning rate scheduling exponent 2/3.') 195 | return 196 | 197 | initial_kl = exp['kl_' + model_family][0].mean() 198 | print(f"{model_family} initial distance = {initial_distance:.2f} \t" 199 | f" bound constant = {constant:.2f} \t" 200 | f"ratio = {constant / initial_distance:.2f} \t" 201 | f"ratio constant / initial kl = {constant / initial_kl:.2f}") 202 | 203 | ax.plot(steps, bound, 204 | color=COLORS[model_family], 205 | linestyle='--') 206 | 207 | 208 | def longplotname(exp): 209 | return ( 210 | '{intervention}_k={k}_bs={batch_size}_rate={scheduler_exponent:.1f}' 211 | '_lr={lr}_concentration={concentration}_T={T}.pdf').format(**exp) 212 | 213 | 214 | def rate_constant_twothird(smoothness, initial_lr, initial_distance, variance): 215 | return (np.exp(3 * (2 * smoothness * initial_lr) ** 2) 216 | * (1 + 4 * (smoothness * initial_lr) ** (3 / 2)) 217 | / (3 * initial_lr) 218 | * (initial_distance + variance / smoothness ** 2)) 219 | 220 | 221 | def rate_constant_fullbatch(smoothness, initial_lr, initial_distance): 222 | return - initial_distance / ( 223 | initial_lr * (smoothness / 2 * initial_lr - 1)) 224 | 225 | 226 | def optim_scatter(exp, end_step=1000): 227 | fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 8)) 228 | 229 | index = np.searchsorted(exp['steps'], end_step) 230 | initial_distances = {} 231 | ends = {} 232 | integrals = {} 233 | for model_family in COLORS.keys(): 234 | key = 'scoredist_' + model_family 235 | if key in exp: 236 | initial_distances[model_family] = exp[key][0] 237 | values = exp['kl_' + model_family] 238 | ends[model_family] = values[index] 239 | integrals[model_family] = values[:index].mean(axis=0) 240 | 241 | axs[0].scatter( 242 | initial_distances[model_family], 243 | ends[model_family], 244 | alpha=.3, 245 | color=COLORS[model_family], 246 | label=model_family 247 | ) 248 | axs[1].scatter( 249 | initial_distances[model_family], 250 | integrals[model_family], 251 | alpha=.3, 252 | color=COLORS[model_family], 253 | label=model_family 254 | ) 255 | 256 | for ax, metrics in zip(axs, [ends, integrals]): 257 | ax.grid(True) 258 | ax.set_yscale('log') 259 | ax.set_xscale('log') 260 | # plot edge between identical problems 261 | # ax.plot( 262 | # np.stack(( 263 | # initial_distances['causal'], 264 | # initial_distances['anti'] 265 | # ), axis=1).T, 266 | # np.stack(( 267 | # metrics['causal'], 268 | # metrics['anti'] 269 | # ), axis=1).T, 270 | # alpha=.1, color='black') 271 | 272 | axs[0].set_ylabel(f'KL(transfer, model) at step {end_step}') 273 | axs[0].legend() 274 | axs[1].set_ylabel(f'Average KL(transfer, model) from step 0 to {end_step}') 275 | axs[1].set_xlabel( 276 | 'Parameter squared distance from initialization to optimum.') 277 | 278 | 279 | if __name__ == "__main__": 280 | dist_plot() 281 | -------------------------------------------------------------------------------- /normal_pkg/adaptation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | from torch import nn 9 | 10 | from averaging_manager import AveragedModel 11 | from normal_pkg import normal 12 | from normal_pkg.proximal_optimizer import PerturbedProximalGradient 13 | 14 | 15 | def pamper(a): 16 | if torch.is_tensor(a): 17 | b = a.clone().detach() 18 | elif isinstance(a, np.ndarray): 19 | b = torch.from_numpy(a) 20 | else: 21 | b = torch.tensor(a) 22 | return b.to(torch.float32) 23 | 24 | 25 | def cholesky_numpy2module(cho, BtoA): 26 | return CholeskyModule(cho.za, cho.la, cho.linear, cho.bias, cho.lcond, BtoA) 27 | 28 | 29 | class CholeskyModule(nn.Module): 30 | 31 | def __init__(self, za, la, linear, bias, lcond, BtoA=False): 32 | super().__init__() 33 | dim = za.shape[0] 34 | self.za = nn.Parameter(pamper(za)) 35 | self.la = nn.Parameter(pamper(la)) 36 | self.linear = nn.Linear(dim, dim, bias=True) 37 | self.linear.weight.data = pamper(linear) 38 | self.linear.bias.data = pamper(bias) 39 | self.lcond = nn.Parameter(pamper(lcond)) 40 | self.BtoA = BtoA 41 | 42 | self.la.triangular = True 43 | self.lcond.triangular = True 44 | 45 | def forward(self, a, b, test_via_joint=False, nograd_logdet=False): 46 | """Compute only the quadratic part of the loss to get its gradient. 47 | I will use a proximal operator to optimize the log-partition logdet-barrier. 48 | """ 49 | batch_size = a.shape[0] 50 | if self.BtoA: 51 | a, b = b, a 52 | 53 | # use conditional parametrization 54 | marginal = .5 * torch.sum((a @ self.la - self.za) ** 2) / batch_size 55 | zcond = self.linear(a) 56 | conditional = .5 * torch.sum((b @ self.lcond - zcond) ** 2) / batch_size 57 | 58 | logdet = - torch.sum(torch.log(torch.diag(self.la))) 59 | logdet += - torch.sum(torch.log(torch.diag(self.lcond))) 60 | if nograd_logdet: 61 | logdet.detach_() 62 | 63 | loss1 = marginal + conditional + logdet 64 | 65 | # use joint parametrization 66 | # CAREFUL the joint parametrization inverts the roles 67 | # of cause and effect for simplicity 68 | if test_via_joint: 69 | x = torch.cat([b, a], 1) 70 | z, L = self.joint_parameters() 71 | quadratic = .5 * torch.sum((x @ L - z) ** 2) / batch_size 72 | logdet = - torch.sum(torch.log(torch.diag(L))) 73 | loss2 = quadratic + logdet 74 | 75 | assert torch.isclose(loss1, loss2), (loss1, loss2) 76 | 77 | return loss1 78 | 79 | def joint_parameters(self): 80 | """Return joint cholesky, with order of X and Y inverted.""" 81 | zeta = torch.cat([self.linear.bias, self.za]) 82 | L = torch.cat([ 83 | torch.cat([self.lcond, torch.zeros_like(self.lcond)], 1), 84 | torch.cat([- self.linear.weight.t(), self.la], 1) 85 | ], 0) 86 | return zeta, L 87 | 88 | def dist(self, other): 89 | return ( 90 | torch.sum((self.za - other.za) ** 2) 91 | + torch.sum((self.la - other.la) ** 2) 92 | + torch.sum((self.linear.weight - other.linear.weight) ** 2) 93 | + torch.sum((self.linear.bias - other.linear.bias) ** 2) 94 | + torch.sum((self.lcond - other.lcond) ** 2) 95 | ) 96 | 97 | def __repr__(self): 98 | return ( 99 | f'CholeskyModule(' 100 | f'\n \t za={self.za.data},' 101 | f'\n \t la={self.la.data},' 102 | f'\n \t linear={self.linear.weight.data},' 103 | f'\n \t bias={self.linear.bias.data},' 104 | f'\n \t lcond={self.lcond.data})' 105 | ) 106 | 107 | 108 | def cholesky_kl(p0: CholeskyModule, p1: CholeskyModule, decompose=False, nograd_logdet=False): 109 | z0, L0 = p0.joint_parameters() 110 | z1, L1 = p1.joint_parameters() 111 | V, _ = torch.triangular_solve(L1, L0, upper=False) 112 | diff = V @ z0 - z1 113 | vecnorm = .5 * torch.sum(diff ** 2) 114 | matnorm = .5 * (torch.sum(V ** 2) - z0.shape[0]) 115 | logdet = - torch.sum(torch.log(torch.diag(V))) 116 | if nograd_logdet: 117 | logdet.detach_() 118 | 119 | matdivergence = matnorm + logdet 120 | total = vecnorm + matdivergence 121 | if decompose: 122 | return {'vector': vecnorm.item(), 'matrix': matdivergence.item(), 123 | 'total': total.item(), 'v/t': vecnorm.item() / total.item(), 124 | 'm/t': matdivergence.item() / total.item()} 125 | else: 126 | return total 127 | 128 | 129 | class AdaptationExperiment: 130 | """Sample one distribution, adapt and record adaptation speed.""" 131 | 132 | def __init__( 133 | self, T, log_interval, # recording 134 | k, intervention, init, preccond_scale, intervention_scale, # distributions 135 | lr=.1, batch_size=10, scheduler_exponent=0, use_prox=False, # optimizer 136 | ): 137 | self.k = k 138 | self.intervention = intervention 139 | self.init = init 140 | self.lr = lr 141 | self.scheduler_exponent = scheduler_exponent 142 | self.log_interval = log_interval 143 | self.use_prox = use_prox 144 | 145 | reference = normal.sample(k, init, scale=preccond_scale) 146 | transfer = reference.intervention(intervention, intervention_scale) 147 | 148 | self.deterministic = True if batch_size == 0 else False 149 | if not self.deterministic: 150 | meanjoint = transfer.to_joint().to_mean() 151 | sampler = torch.distributions.MultivariateNormal( 152 | pamper(meanjoint.mean), covariance_matrix=pamper(meanjoint.cov) 153 | ) 154 | data_size = torch.tensor([T, batch_size]) 155 | self.dataset = sampler.sample(data_size) 156 | 157 | self.models = { 158 | 'causal': cholesky_numpy2module(reference.to_cholesky(), BtoA=False), 159 | 'anti': cholesky_numpy2module(reference.reverse().to_cholesky(), BtoA=True) 160 | } 161 | 162 | self.targets = { 163 | 'causal': cholesky_numpy2module(transfer.to_cholesky(), BtoA=False), 164 | 'anti': cholesky_numpy2module(transfer.reverse().to_cholesky(), BtoA=True) 165 | } 166 | 167 | optkwargs = {'lr': lr, 'lambd': 0, 'alpha': 0, 't0': 0, 'weight_decay': 0} 168 | self.optimizer = PerturbedProximalGradient( 169 | [p for m in self.models.values() for p in m.parameters()], self.use_prox, **optkwargs) 170 | 171 | self.trajectory = defaultdict(list) 172 | self.step = 0 173 | 174 | def evaluate(self): 175 | self.trajectory['steps'].append(self.step) 176 | with torch.no_grad(): 177 | for name, model in self.models.items(): 178 | target = self.targets[name] 179 | # SGD 180 | kl = cholesky_kl(target, model).item() 181 | dist = target.dist(model).item() 182 | self.trajectory[f'kl_{name}'].append(kl) 183 | self.trajectory[f'scoredist_{name}'].append(dist) 184 | # ASGD 185 | with AveragedModel(model, self.optimizer): 186 | kl = cholesky_kl(target, model).item() 187 | dist = target.dist(model).item() 188 | self.trajectory[f'kl_{name}_average'].append(kl) 189 | self.trajectory[f'scoredist_{name}_average'].append(dist) 190 | 191 | def iterate(self): 192 | self.step += 1 193 | self.optimizer.lr = self.lr / self.step ** self.scheduler_exponent 194 | self.optimizer.zero_grad() 195 | 196 | if self.deterministic: 197 | loss = sum([cholesky_kl(self.targets[name], model, nograd_logdet=self.use_prox) 198 | for name, model in self.models.items()]) 199 | else: 200 | samples = self.dataset[self.step - 1] 201 | aa, bb = samples[:, :self.k], samples[:, self.k:] 202 | loss = sum( 203 | [model(aa, bb, nograd_logdet=self.use_prox) for model in self.models.values()]) 204 | loss.backward() 205 | self.optimizer.step() 206 | self.trajectory['loss'].append(loss.item()) 207 | 208 | def run(self, T): 209 | for t in range(T): 210 | if t % self.log_interval == 0: 211 | self.evaluate() 212 | self.iterate() 213 | # print(self.__repr__()) 214 | 215 | def __repr__(self): 216 | return f'AdaptationExperiment step={self.step} \n' \ 217 | + '\n'.join([f'{name} \t {model}' for name, model in self.models.items()]) 218 | 219 | 220 | def batch_adaptation(n, T, **parameters): 221 | trajectories = defaultdict(list) 222 | models = [] 223 | for _ in tqdm.tqdm(range(n)): 224 | exp = AdaptationExperiment(T=T, **parameters) 225 | exp.run(T) 226 | for key, item in exp.trajectory.items(): 227 | trajectories[key].append(item) 228 | models += [(exp.models, exp.targets)] 229 | 230 | for key, item in trajectories.items(): 231 | trajectories[key] = np.array(item).T 232 | trajectories['steps'] = trajectories['steps'][:, 0] 233 | 234 | return trajectories, models 235 | 236 | 237 | def sweep_lr(lrlr, base_experiment, seed=1, savedir='normal_results'): 238 | results = [] 239 | print(base_experiment) 240 | for lr in lrlr: 241 | np.random.seed(seed) 242 | torch.manual_seed(seed) 243 | parameters = {'lr': lr, 'scheduler_exponent': 0, **base_experiment} 244 | trajectory, models = batch_adaptation(**parameters) 245 | results.append({ 246 | 'hyperparameters': parameters, 247 | 'trajectory': trajectory, 248 | 'models': models 249 | }) 250 | 251 | os.makedirs(savedir, exist_ok=True) 252 | savefile = '{intervention}_{init}_k={k}.pkl'.format(**base_experiment) 253 | savepath = os.path.join(savedir, savefile) 254 | with open(savepath, 'wb') as fout: 255 | print("Saving results in ", savepath) 256 | pickle.dump(results, fout) 257 | 258 | 259 | def test_AdaptationExperiment(): 260 | batch_adaptation(T=100, k=3, n=1, lr=.1, batch_size=1, log_interval=10, 261 | intervention='cause', init='natural') 262 | 263 | 264 | if __name__ == "__main__": 265 | # test_AdaptationExperiment() 266 | 267 | base = {'n': 100, 'T': 400, 'batch_size': 1, 'use_prox': True, 268 | 'log_interval': 10, 'intervention_scale': 1, 269 | 'init': 'natural', 'preccond_scale': 10} 270 | lrlr = [.03] 271 | for k in [10]: 272 | base['k'] = k 273 | sweep_lr(lrlr, {**base, 'intervention': 'cause'}) 274 | sweep_lr(lrlr, {**base, 'intervention': 'effect'}) 275 | sweep_lr(lrlr, {**base, 'intervention': 'mechanism'}) 276 | -------------------------------------------------------------------------------- /3_failure_space.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "%matplotlib inline\n", 12 | "%load_ext autoreload\n", 13 | "%autoreload 2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "def ratio_func(x):\n", 23 | " return np.abs(1/(1-x))" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 15, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "n = 10000\n", 33 | "samples_up = np.random.randn(n)\n", 34 | "samples_down = 1*np.random.randn(n)\n", 35 | "samples = samples_up / samples_down" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 16, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "0.2394\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "print(np.mean(ratio_func(samples)>np.sqrt(2)))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 17, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAGHFJREFUeJzt3XuQnXV9x/H393nO2UuymxCShXAJ\nQpCb44XAFnQU2iq1iAhexg7OqMxUm6GtVVu1yjhj1fYPbUfrpVaNSrXWWzvAFJ2xipdILYJuEAMY\nJAEJhsBmCSG3ze7Zc55v/3ies3t2s5eT7Hn27P6ez2ty5rmf881zzn72t7/ncszdERGRpS9qdwEi\nItIaCnQRkUAo0EVEAqFAFxEJhAJdRCQQCnQRkUCUmlnJzB4FDgI1oOru/XkWJSIix66pQM/8obs/\nlVslIiIyL+pyEREJhDVzpaiZ/RbYBzjweXffNM06G4GNAMuXL7/4/PPPb3GpIvN3aKTKb/ceBuA5\np6wgjqzNFYlM2LJly1Pu3ne82zcb6Ke6+24zOwm4Hfgrd79jpvX7+/t9YGDgeGsSyc0dDw3x5pt+\nDsCvPvByVi4rt7kikQlmtmU+xyib6nJx993ZcA9wK3DJ8b6giIjkY85AN7PlZtZbHwdeDtyfd2Ei\nefBJ47oxnYSlmbNcTgZuNbP6+l939//JtSoRETlmcwa6uz8CvGABahERkXnQaYtSKI0nAeirACQ0\nCnQRkUAo0EVEAqFAFxEJhAJdCsVnGBcJgQJdRCQQCnQRkUAo0KVYGvpZmrmPkchSokAXEQmEAl1E\nJBAKdCksdbhIaBToUii6w6KETIEuIhIIBbqISCAU6FIo7tOPi4RAgS4iEggFuohIIBToUlg640VC\no0CXQlG/uYRMgS4iEggFuhSXWusSGAW6FIoyXEKmQBcRCYQCXUQkEAp0KSx1v0hoFOhSKPqWIgmZ\nAl1EJBAKdCksNdYlNAp0KRRluIRMgS4iEggFuhSWbs4loVGgS6Go31xCpkAXEQmEAl1EJBAKdCks\ndb9IaJoOdDOLzeyXZvadPAsSyZdSXMJ1LC30dwDb8ipERETmp6lAN7PTgVcCX8y3HJGFo7a6hKbZ\nFvongL8FkplWMLONZjZgZgNDQ0MtKU6k1dRvLiGbM9DN7Gpgj7tvmW09d9/k7v3u3t/X19eyAkVE\npDnNtNBfDFxjZo8C3wReamb/kWtVIgtAt9KV0MwZ6O5+o7uf7u5nAtcBP3L3N+ZemYiIHBOdhy6F\noja5hKx0LCu7+2Zgcy6ViIjIvKiFLoWlLnQJjQJdCkUhLiFToIuIBEKBLiISCAW6iEggFOhSKPra\nOQmZAl0KSwdIJTQKdBGRQCjQpVDUKpeQKdBFRAKhQJfC0gFSCY0CXUQkEAp0KRS1ySVkCnQpLB0g\nldAo0EVEAqFAl0LR185JyBToUliKdgmNAl1EJBAKdBGRQCjQRUQCoUCXwtIBUgmNAl1EJBAKdCkU\nNcolZAp0KZTGG3Ip2yU0CnQplCSZGFdrXUKjQJdCacxwHRSV0CjQpVASV5eLhEuBLsXSkOKJWugS\nGAW6FMqkFrryXAKjQJdCacxwtdAlNAp0KRS10CVkCnQplMYQV6BLaBToUig+6SwXJbqERYEuhTK5\nD71tZYjkYs5AN7MuM/u5mf3KzB4wsw8tRGEieZjc5aJEl7CUmlhnFHipux8yszLwUzP7rrvflXNt\nIi3XeFBULXQJzZyB7mkz5lA2Wc4e+lGQJWlyo1wfYwlLU33oZhab2b3AHuB2d797mnU2mtmAmQ0M\nDQ21uk6RllALXULWVKC7e83dLwROBy4xs+dOs84md+939/6+vr5W1ynScupCl9Ac01ku7v4MsBm4\nMpdqRHI2uYWuRJewNHOWS5+ZnZCNdwNXAA/mXZhIHnRhkYSsmbNcTgG+YmYx6S+A/3T37+Rblkg+\nEp22KAFr5iyXrcCGBahFJHf6CjoJma4UlUJpbJSrD11Co0CXQnHdbVECpkCXQknUQpeAKdClUCad\n5dK+MkRyoUCXQpn8BReKdAmLAl0KpTHClecSGgW6FIrrXi4SMAW6FIruhy4hU6BLoehuixIyBboU\nis8yJbLUKdClUNRCl5Ap0KVYdLdFCZgCXQpF90OXkCnQpVB0paiETIEuhTL5wiJFuoRFgS6Fkuhu\nixIwBboUiu6HLiFToEuh6H7oEjIFuhRKY4arhS6hUaBLoUzqQ29jHSJ5UKBLoejmXBIyBboUSqIr\nRSVgCnQpGN3LRcKlQJdCSRIwS8ddvegSGAW6FIrjxFmiq8tFQqNAl0JJHKKoHuhKdAmLAl0KxZ2J\nFnqbaxFpNQW6FIq7E2ct9ERHRSUwCnQpFAcimxgXCYkCXQolaWyhK9ElMAp0KRR3xgNdB0UlNAp0\nKZTEnUinLUqgFOhSKGkfev0sFyW6hEWBLoXi6kOXgCnQpVDcIYomxkVCMmegm9k6M/uxmW0zswfM\n7B0LUZhIHhovLNIXXEhoSk2sUwXe5e73mFkvsMXMbnf3X+dcm0jLJe7jl/6LhGbOFrq7P+Hu92Tj\nB4FtwGl5FyaSB6ehha5OdAnMMfWhm9mZwAbg7mmWbTSzATMbGBoaak11Ii2WJE45Tj/2NXW5SGCa\nDnQz6wFuBt7p7gemLnf3Te7e7+79fX19raxRpGWqiVOO1UKXMDUV6GZWJg3zr7n7LfmWJJKf+qX/\npcjUQpfgNHOWiwFfAra5+8fzL0kkP9VaGuhRZFTVQpfANNNCfzHwJuClZnZv9rgq57pEclFraKGr\ny0VCM+dpi+7+U0DneUkQaolTLkfEpha6hEdXikqh1BInjiLiWC10CY8CXQqlljixoRa6BEmBLoWS\nttDTg6K69F9Co0CXQqkHeikyqjUFuoRFgS6FUj/LJTKdhy7hUaBLoSTZQdGSDopKgBToUihVHRSV\ngCnQpVDGT1vUQVEJkAJdCiUNdIh1UFQCpECXQkkPiqqFLmFSoEuhTGqhqw9dAqNAl0KpJU4pa6HX\nFOgSGAW6FEotcSIzYlOgS3gU6FIojV0uCnQJjQJdCqXxtEUFuoRGgS6Fkp7lkrXQdZaLBEaBLoXh\n7mqhS9AU6FIY9fyOLfuSaAW6BMY8hz87e3t7/eKLL27584rMh1vEzkvfxQmP/S+VnrWMdZ7Aafd9\nud1liYz7yU9+ssXd+493e7XQpTDcYgDMa+A1iPTxl7DM+SXRx+O8885j8+bNeTy1yHF7ZrjChR++\nnbe/7S/Yums/W3buY/MnN7e7LJFxZjav7dVEkcKo1BIAynFEOTbGsmmRUCjQpTDGsrsrdsQR5Tii\nUlWgS1gU6FIY9QAvlywNdLXQJTAKdCmMehdLRxzTWYrU5SLBUaBLYYy30OO0hT6mL7iQwCjQpTDq\nLfJyKe1DryWui4skKAp0KYxJB0VLls1Tt4uEQ4EuhTHeh16K6IjTj74OjEpIFOhSGI3noXeUskDX\nqYsSEAW6FMbUg6KgLhcJiwJdCmPitMVoItCrOigq4VCgS2FMtNAnulxGq7V2liTSUgp0KYyRsTTQ\nuztilpXTOy8eGVOgSziaCnQzu8nM9pjZ/XkXJJKXenh3lWOWdaSBPlxRoEs4mm2hfxm4Msc6RHI3\nMh7oEcs60ztHD1eq7SxJpKWaCnR3vwN4OudaRHI1MlYjsvSgqFroEqKWfcGFmW0ENgKcccYZrXpa\nkZY5UqnRXY4xs4lAH1WgL3ofXNnuCpaMlgW6u28CNgH09/frXDBZdI6M1ejKDoYu72hBl4uCRhaZ\nXL6CTmQxagz07qyFfrixy0UBLUucTluUwhgdS8aDvLMUEUfGEfWhS0CaaqGb2TeAPwDWmNku4O/c\n/Ut5FiYyL9O0tocr76HbV8IHr8GAZckXOXzHp+FnX134+kRy0FSgu/sb8i5EJG8HfRk9dmR8ehkj\nDNPVxopEWktdLlIYh+iml+Hx6WU2yrB3trEikdZSoEthHPRuelALXcKlQJfCOEQ3vQ1dLssZYRi1\n0CUcCnQpBPc00Btb6N02yrCrhS7hUKBLIQzTSY2YXpvoQ+9hhIN0t7EqkdZSoEsh7PX0NMbVHBif\nd6Id4GnvbVdJIi2nQJdC2Esa3GtsItBX2wGeoZeq68dAwqBPshTCeAvd9o/Pq7fWn0atdAmDAl0K\nYa+vAOBEOzg+b00W7vVlIkudAl0K4SnS0J7ch56Ge731LrLUKdClEPb6CpZzhG6rjM+rh/te1EKX\nMCjQpRD2+kpWNxwQhYkul6fU5SKBUKBLITzpq+jjmUnzVjBMBxUG/cQ2VSXSWgp0KYTH/GSeZXsm\nzYvMWWdD7PST2lSVSGsp0CV4I17mSVZxRjR41LJn2SA7/eQ2VCXSegp0Cd4u78OJeJZNH+iP+cm4\nvgVXAqBAl+DVW+AzBfowXQyhUxdl6VOgS/B+46cDcLY9cdSys203ANuT0xe0JpE8KNAlePcl6znD\nBllph49a9tzoUQC2+voFrkqk9RToErz7/CyeZ49Mu2yVHWKd7eH+5KwFrkqk9RToErRBP4FdfhIv\niKYPdIDn28NsSc7RgVFZ8hToErQ7as8H4CXRfTOuc1l0H0+ymu1+2kKVJZILBboEbXPyAvrYxwX2\n2IzrXB5vBeBHyYaFKkskFwp0CdZB7+aHyUW8PN6C2czrnWpPc6Ft55baZep2kSVNgS7Buq32Ikbo\n5HXxHXOue138Yx7ydfzSn70AlYnkQ4EuQRrzmE21q3mBPcwG2zHn+lfHd9HDMJ+rvmoBqhPJhwJd\ngvRvtSvZ6Wt5e+mWWbtb6npshBtK3+b7ye9xd3J+/gWK5ECBLsF5aPAg/1x9HVdEW3hZ/Mumt3tL\n/F1OY4h3j93Afl+WY4Ui+VCgS1Ce3D/CW78yQA9H+IfyTce0bbdV+FTHv7DbV3PD2F9zxDtyqlIk\nHwp0CcYDu/fz+s/fydOHK2zq+Dhrbd8xP8fF0XY+Vv4cdycXcH3lvezR943KEqJAlyXvSKXGp3+4\nndf8651Uqglfe+ulbIgePu7ne3X8f3yi/Bm2+nquGv0IN9cuI/EmOuJF2qzU7gJEjtfQwVG+9YvH\n+OpdOxk8MMpVz1vL31/7XFb3dM77ua+Jf8b59hjvGbuBd439OV+wV3J9/D2uje9kmY22oHqR1lOg\ny5JRS5wdew5xx0ND/GDbIAM791FLnJc8ew2fum4Dl65f3dLXOzd6nFs7PsC3kxfx2eqruLH6Z3y4\n+iYuj7ZyRXQPl0YPss72NHUWjchCUKDLouPuDB0a5bdDh3nkqcM8MnSIB3YfYOuu/RwarQJw/tpe\nbvj99bz2otM5u68nt1oic66N7+Sa6E4G/Dxuq72I22sX873kEgDWsJ8N0XbOs99xdrSbZ9tuzrIn\n6LGR3GoSmYl5Dtc69/f3+8DAQMufV5Ymd2e0mrD/yBgHjoxxYGQsG6/y1KFR9hwcZfDACHsOjDJ4\ncITB/SMcrtTGt++gwrn2OBdGO7gw2sGlto110VNt/P/Ag76OLcm53JOcw73+bHb6ydSIx9fp5TCn\n2NOstac51fZyEvtYZYdYZYdYSTpcxUFOsMP0MkxkuueAgH3owBZ37z/e7ZtqoZvZlcAngRj4ort/\n5HhfUBaWu1NNnLFawljNqWbDdDo5almlllCtOdUkoVJNhxPLnZGxGiPVGiOVGkfGskclmWZejdFq\nwsGRNLgrtWTGGjtLESet6OTk3i4uWLuCy8/p48zVy1jf18NZa5Zz6idPJV5EgWcGF9jvuCD6HW/k\nhwBUPGanr2WHn8qjvpZBX8VuX82TfiLbkjN4ipX4LOcgdDPCckZZZiMsY4TljLDMRtMhI3RbhQ7G\n6GQsHdoYnVQn5ll9WZVOKnRYlTJVytSIqVGiRkySTtuU6YblJWrqQlrC5gx0M4uBzwB/BOwCfmFm\nt7n7r2faZt/hCt/4+WO4g+PZEHDH0wHeOE46DZC4N8yb2J5snemW1Z87mTKvvg4NrzHTc0+87uR1\nqE/P8tzJNP+XbFMSd2qJk2T1TYzX50OSODWfGJ+6zcRyJ0lmeM76dlO2yetmU2bQVYrp7ojpLsd0\nlSO6yul4T2eJNT2ddGXjK7pLrOwus6KrnA67y6zoKrGiu8ya5Z2s6C5hs6XIIgrzmXRYjXPscc7h\n8WmX19w4wHL2eQ/76GW/L2cfvezzHg7RzbB3cZguDmfDYbo46N0MsorD3sUIHVQoM0qZCuVc/y/x\neMinAV/KpmMSIhzDiS0dj0gwnAgnJsGoz0+XNQ6n227iOSe2i8efc/K6ZqRDHAOi9KdtfNqmGWeG\n+fVxstecaZkBZtNvP/W1o1mWNVsX3Dav966ZFvolwA53fwTAzL4JXAvMGOi7njnCjbfMfP/pvJhB\nZJa9CWAY2b/xaRufTterL48imzTfsoU2y/YAUTT788aREVn6SMfT14qzeVEE5XJ09Drj40b0wM3j\nH/TIEiIS4vEPfJL9UCQT65AQWfrDYzhlS1tqpazFVqZKiRplaxinOsv8bFur0UWFLip0Mpbuo1r2\nUJfxrGJzVmVdLfDkvJ4rcaNCKXtkIe/1sC8xSgcVT5fXiKkSjQ+rlKh5RJWYGumw/qhP1zxmrHGa\ndDqNXaPmEzGbYNSYPJ2Mj08dptuOHbXu9Nt69twJhvtE/DXGeDIlfqcun4jOuZZb/dfGeMy3R/6B\nfhrwu4bpXcClU1cys43AxmxydOdHr75/XpUtjDVA+zpjm7cU6lwKNYLqbDXV2VrnzWfjZgJ9ur+F\nj/ob2N03AZsAzGxgPh37C0V1ts5SqBFUZ6upztYys3mdTdLM3xW7gHUN06cDu+fzoiIi0nrNBPov\ngHPM7Cwz6wCuY74dPSIi0nJzdrm4e9XM3gZ8j/S0xZvc/YE5NtvUiuIWgOpsnaVQI6jOVlOdrTWv\nOnO5sEhERBae7rYoIhIIBbqISCDmFehm9noze8DMEjPrn7LsRjPbYWa/MbM/bph/ZTZvh5m9bz6v\nf5w1f8vM7s0ej5rZvdn8M83sSMOyzy10bVPq/KCZPd5Qz1UNy6bdt22q85/M7EEz22pmt5rZCdn8\nRbU/s5ra+tmbjpmtM7Mfm9m27GfpHdn8Gd//Ntb6qJndl9UzkM070cxuN7Pt2XBVm2s8r2Gf3Wtm\nB8zsnYthf5rZTWa2x8zub5g37f6z1Keyz+pWM7uoqRdJL4c/vgdwAemJ8JuB/ob5zwF+BXQCZwEP\nkx5QjbPx9UBHts5z5lPDPOv/GPCBbPxM4P521TJNbR8E3j3N/Gn3bRvrfDlQysY/Cnx0ke7PRfXZ\na6jrFOCibLwXeCh7j6d9/9tc66PAminz/hF4Xzb+vvr7vxge2Xv+JPCsxbA/gcuBixp/Lmbaf8BV\nwHdJrwN6IXB3M68xrxa6u29z999Ms+ha4JvuPuruvwV2kN5CYPw2Au5eAeq3EVhwll67/yfAN9rx\n+vMw075tC3f/vrtXs8m7SK9TWIwWzWevkbs/4e73ZOMHgW2kV2cvFdcCX8nGvwK8uo21TPUy4GF3\n39nuQgDc/Q7g6SmzZ9p/1wL/7qm7gBPM7JS5XiOvPvTpbhdw2izz2+EyYNDdtzfMO8vMfmlmPzGz\ny9pUV6O3ZX9u3dTwp+xi2odT/Slpq6JuMe3PxbzfgLSbCtgA3J3Nmu79bycHvm9mW7JbfQCc7O5P\nQPrLCTipbdUd7TomN9gW2/6EmfffcX1e5wx0M/uBmd0/zWO21s1Mtwto6jYC89VkzW9g8pv9BHCG\nu28A/gb4upmtaHVtx1DnZ4GzgQuz2j5W32yap8r13NNm9qeZvR+oAl/LZi34/pzDgu+3Y2FmPcDN\nwDvd/QAzv//t9GJ3vwh4BfCXZnZ5uwuaiaUXQV4D/Fc2azHuz9kc1+e1mQuLrjiOYma7XUDutxGY\nq2YzKwGvBS5u2GYUGM3Gt5jZw8C5QG7f1NHsvjWzLwDfySYX/FYMTezP64GrgZd51gHYjv05h0V7\nCwszK5OG+dfc/RYAdx9sWN74/reNu+/OhnvM7FbSbqxBMzvF3Z/IugT2tLXICa8A7qnvx8W4PzMz\n7b/j+rzm1eVyG3CdmXWa2VnAOcDPWTy3EbgCeNDdd9VnmFmfpfd+x8zWZzU/0oba6vU09pe9Bqgf\nGZ9p37aFpV9+8l7gGncfbpi/qPYni+ezN0l2LOdLwDZ3/3jD/Jne/7Yws+Vm1lsfJz0Yfj/pPrw+\nW+164L/bU+FRJv0Fvtj2Z4OZ9t9twJuzs11eCOyvd83Map5HbV9D+ptkFBgEvtew7P2kZxX8BnhF\nw/yrSI/kPwy8v01Hm78M3DBl3uuAB0jPfrgHeFU7amuo56vAfcDW7M09Za5926Y6d5D29d2bPT63\nGPdnVlPbP3vT1PQS0j+ltzbsw6tme//bVOf67L38Vfa+vj+bvxr4IbA9G564CPbpMmAvsLJhXtv3\nJ+kvmCeAsSw33zLT/iPtcvlM9lm9j4azCGd76NJ/EZFA6EpREZFAKNBFRAKhQBcRCYQCXUQkEAp0\nEZFAKNBFRAKhQBcRCcT/A76y1SRCdylpAAAAAElFTkSuQmCC\n", 63 | "text/plain": [ 64 | "" 65 | ] 66 | }, 67 | "metadata": {}, 68 | "output_type": "display_data" 69 | } 70 | ], 71 | "source": [ 72 | "bounds = (-100, 100)\n", 73 | "xx = np.linspace(*bounds, 1000)\n", 74 | "plt.plot(xx, ratio_func(xx))\n", 75 | "plt.hlines(np.sqrt(2),*bounds)\n", 76 | "plt.ylim(0, 5)\n", 77 | "plt.xlim(*bounds)\n", 78 | "#plt.hist(samples,bins=1000, density=True, cumulative=True,)\n", 79 | "pass" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.7.4" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 2 111 | } 112 | -------------------------------------------------------------------------------- /4_dirichlet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 54, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2019-06-27T22:25:22.891701Z", 9 | "start_time": "2019-06-27T22:25:22.889254Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stdout", 15 | "output_type": "stream", 16 | "text": [ 17 | "The autoreload extension is already loaded. To reload it, use:\n", 18 | " %reload_ext autoreload\n" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "import numpy as np\n", 24 | "np.set_printoptions(precision=1)\n", 25 | "%load_ext autoreload\n", 26 | "%autoreload 2" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 55, 32 | "metadata": { 33 | "ExecuteTime": { 34 | "end_time": "2019-06-27T22:34:51.558322Z", 35 | "start_time": "2019-06-27T22:34:51.554350Z" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "k =10\n", 41 | "ma = np.random.dirichlet(np.ones(k), size=1)\n", 42 | "cba = np.random.dirichlet(np.ones(k), size=k)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 56, 48 | "metadata": { 49 | "ExecuteTime": { 50 | "end_time": "2019-06-27T22:34:51.693177Z", 51 | "start_time": "2019-06-27T22:34:51.689697Z" 52 | } 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "mb = np.sum(cba * ma[:, np.newaxis], axis=0)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 57, 62 | "metadata": { 63 | "ExecuteTime": { 64 | "end_time": "2019-06-27T22:34:51.804391Z", 65 | "start_time": "2019-06-27T22:34:51.801243Z" 66 | } 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "def entropy(x):\n", 71 | " return -np.sum(x * np.log(x)) - np.log(x.shape[0])" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 58, 77 | "metadata": { 78 | "ExecuteTime": { 79 | "end_time": "2019-06-27T22:34:52.180176Z", 80 | "start_time": "2019-06-27T22:34:52.176851Z" 81 | } 82 | }, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "1.7467355755 1.717802664 -4.4408920985e-16\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "print(entropy(ma), entropy(mb), entropy(np.ones(k)/k))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "# Distances " 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 63, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import categorical_distance as categ" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 204, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "k =100\n", 119 | "causal = categ.sample_joint(k=k, n=5, concentration=1, symmetric=True)\n", 120 | "anti = causal.reverse()\n", 121 | "intcausal = causal.intervention(on='effect',concentration=1,fromjoint=False)\n", 122 | "intanti = intcausal.reverse()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 205, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "[-17231.9 -11897. -14001.7 -13912.9 -13771.9]\n", 135 | "0\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "dicausal = causal.scoredist(intcausal)\n", 141 | "dianti = anti.scoredist(intanti)\n", 142 | "delta = dianti - dicausal \n", 143 | "print(delta) # the anticausal model is advantaged\n", 144 | "print((delta>0).sum())" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | " Now Let's look at the contribution of each term" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 206, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "[ 169.7 119.7 140.5 144.2 139.6]\n", 164 | "[ 174.6 120.9 141.9 141.1 139.7]\n", 165 | "[ 0.6 0.7 0.5 0.5 0.6] alpha\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# marginal difference\n", 171 | "dimargi = (np.sum((anti.sa - intanti.sa)**2, axis=1))\n", 172 | "# causal mean deviation\n", 173 | "dicaumean = (np.sum((intanti.sa - np.mean(causal.sba, axis=1))**2,axis=1))\n", 174 | "# anticausal mean deviation\n", 175 | "dianmean = (np.sum((causal.sa - np.mean(anti.sba, axis=1))**2, axis=1))\n", 176 | "print(dimargi)\n", 177 | "print(dicaumean)\n", 178 | "print(dianmean, 'alpha')" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 207, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "difference -14163.1047877\n", 191 | "estimate -14163.1047877\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "estimate = dimargi - k*(dicaumean - dianmean) # oooh sign error\n", 197 | "print('difference ', delta.mean())\n", 198 | "print('estimate ', estimate.mean())" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "# variance of logsumexp" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 196, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "ss = np.random.randn(99,100)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 197, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "lge = categ.logsumexp(ss)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 198, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "0.017512454027028036" 235 | ] 236 | }, 237 | "execution_count": 198, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "np.var(lge)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 199, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "data": { 253 | "text/plain": [ 254 | "0.24195816381631272" 255 | ] 256 | }, 257 | "execution_count": 199, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "np.var(np.amax(ss,axis=0))" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 200, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "0.010927555382498953" 275 | ] 276 | }, 277 | "execution_count": 200, 278 | "metadata": {}, 279 | "output_type": "execute_result" 280 | } 281 | ], 282 | "source": [ 283 | "np.var(np.mean(ss, axis=0))" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 201, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "0.8400310407532311" 295 | ] 296 | }, 297 | "execution_count": 201, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "np.var(ss[0])" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 202, 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "data": { 313 | "text/plain": [ 314 | "0.8466488010108999" 315 | ] 316 | }, 317 | "execution_count": 202, 318 | "metadata": {}, 319 | "output_type": "execute_result" 320 | } 321 | ], 322 | "source": [ 323 | "np.var(ss[0] - np.mean(ss[1:],axis=0))" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 211, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "[-0. -0.1 0. -0.1 0. -0.1 -0.1 -0. -0.2 -0.1 0.1 -0.2 -0.3 -0. 0.\n", 336 | " -0.1 0. -0.1 -0.1 -0.1 0. 0. 0.1 -0.1 0. -0. 0. -0.1 0.1 -0.\n", 337 | " -0.1 -0. 0.2 -0. 0.1 0.2 -0.2 -0. 0. 0.1 0.1 0. 0.1 -0. 0.3\n", 338 | " -0.1 -0. 0.1 0.2 0. 0.1 0.1 0.2 0. -0.1 0.1 -0.1 0.1 -0.3 0.\n", 339 | " 0.2 -0.1 -0.1 0.2 0.1 -0. 0.1 0.1 0.1 0.1 0. -0. -0.2 -0.1 0.1\n", 340 | " 0. -0.1 -0. -0. -0.1 -0.1 -0.1 -0.1 -0.2 -0.1 0.1 -0. -0.1 0. -0.1\n", 341 | " 0. -0.1 0. 0.1 0.1 -0.1 0. 0.1 0. 0.1]\n", 342 | "[-0. -0.1 -0.1 -0.1 -0.1 -0.1 -0.1 0.2 -0. 0.1 -0. 0. -0.1 0.2 0.\n", 343 | " -0.1 -0.1 -0. -0.1 -0.1 -0.1 0. 0. -0. -0. -0.1 -0.1 0.2 -0.1 -0.1\n", 344 | " -0. -0. 0. -0.1 0. 0. 0.1 0.1 0. -0.1 0.1 0.2 -0. 0. 0.\n", 345 | " -0.2 0.1 -0.1 0.1 0. 0.1 -0. -0.1 0.1 0. -0. 0.1 0.2 0. 0.1\n", 346 | " 0. 0. 0.1 0. 0. 0.1 -0. -0.3 -0. -0.1 0.1 0.3 -0. 0.2 0.\n", 347 | " 0.2 -0.1 0. -0.1 0. -0. -0.2 -0.1 -0.1 -0.2 0.1 0.2 0.1 -0.1 0.1\n", 348 | " -0.1 -0.3 0. -0.3 0.2 0.1 -0. -0.1 0. 0. ]\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "print(np.mean(ss, axis=0))\n", 354 | "print(np.round(causal.sba.mean(1)[0],decimals=1))" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 214, 360 | "metadata": {}, 361 | "outputs": [ 362 | { 363 | "data": { 364 | "text/plain": [ 365 | "0.015800798719678929" 366 | ] 367 | }, 368 | "execution_count": 214, 369 | "metadata": {}, 370 | "output_type": "execute_result" 371 | } 372 | ], 373 | "source": [ 374 | "np.var(causal.sba.mean(1))" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [] 383 | } 384 | ], 385 | "metadata": { 386 | "kernelspec": { 387 | "display_name": "deep (3.6)", 388 | "language": "python", 389 | "name": "deep" 390 | }, 391 | "language_info": { 392 | "codemirror_mode": { 393 | "name": "ipython", 394 | "version": 3 395 | }, 396 | "file_extension": ".py", 397 | "mimetype": "text/x-python", 398 | "name": "python", 399 | "nbconvert_exporter": "python", 400 | "pygments_lexer": "ipython3", 401 | "version": "3.6.9" 402 | }, 403 | "varInspector": { 404 | "cols": { 405 | "lenName": 16, 406 | "lenType": 16, 407 | "lenVar": 40 408 | }, 409 | "kernels_config": { 410 | "python": { 411 | "delete_cmd_postfix": "", 412 | "delete_cmd_prefix": "del ", 413 | "library": "var_list.py", 414 | "varRefreshCmd": "print(var_dic_list())" 415 | }, 416 | "r": { 417 | "delete_cmd_postfix": ") ", 418 | "delete_cmd_prefix": "rm(", 419 | "library": "var_list.r", 420 | "varRefreshCmd": "cat(var_dic_list()) " 421 | } 422 | }, 423 | "types_to_exclude": [ 424 | "module", 425 | "function", 426 | "builtin_function_or_method", 427 | "instance", 428 | "_Feature" 429 | ], 430 | "window_display": false 431 | } 432 | }, 433 | "nbformat": 4, 434 | "nbformat_minor": 2 435 | } 436 | -------------------------------------------------------------------------------- /normal_pkg/normal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.linalg 4 | import scipy.stats 5 | 6 | 7 | class MeanConditionalNormal: 8 | 9 | def __init__(self, mua, cova, linear, bias, covcond): 10 | self.mua = mua 11 | self.cova = cova 12 | self.linear = linear 13 | self.bias = bias 14 | self.covcond = covcond 15 | 16 | def to_natural(self): 17 | # exactly the same formulas as natural to mean 18 | # those parametrizations are symmetric 19 | preca = np.linalg.inv(self.cova) 20 | preccond = np.linalg.inv(self.covcond) 21 | etaa = preca @ self.mua 22 | linear = preccond @ self.linear 23 | bias = preccond @ self.bias 24 | return NaturalConditionalNormal(etaa, preca, linear, bias, preccond) 25 | 26 | def to_joint(self): 27 | mub = np.dot(self.linear, self.mua) + self.bias 28 | mean = np.concatenate([self.mua, mub], axis=0) 29 | 30 | d = self.mua.shape[0] 31 | crosscov = np.dot(self.cova, self.linear.T) 32 | cov = np.zeros([2 * d, 2 * d]) 33 | cov[:d, :d] = self.cova 34 | cov[:d, d:] = crosscov 35 | cov[d:, :d] = crosscov.T 36 | cov[d:, d:] = self.covcond + np.linalg.multi_dot([ 37 | self.linear, self.cova, self.linear.T]) 38 | 39 | return MeanJointNormal(mean, cov) 40 | 41 | def sample(self, n): 42 | aa = np.random.multivariate_normal(self.mua, self.cova, size=n) 43 | bb = np.dot(aa, self.linear.T) \ 44 | + np.random.multivariate_normal(self.bias, self.covcond, size=n) 45 | return aa, bb 46 | 47 | 48 | class MeanJointNormal: 49 | 50 | def __init__(self, mean, cov): 51 | self.mean = mean 52 | self.cov = cov 53 | 54 | def to_natural(self): 55 | precision = np.linalg.inv(self.cov) 56 | return NaturalJointNormal(precision @ self.mean, precision) 57 | 58 | def to_conditional(self): 59 | d = self.mean.shape[0] // 2 60 | 61 | # parameters of marginal on A 62 | mua = self.mean[:d] 63 | cova = self.cov[:d, :d] 64 | preca = np.linalg.inv(cova) 65 | 66 | # intermediate values required for calculus 67 | mub = self.mean[d:] 68 | covb = self.cov[d:, d:] 69 | crosscov = self.cov[:d, d:] 70 | 71 | # parameters of conditional 72 | linear = np.dot(crosscov.T, preca) 73 | bias = mub - np.dot(linear, mua) 74 | covcond = covb - np.linalg.multi_dot([linear, cova, linear.T]) 75 | 76 | return MeanConditionalNormal(mua, cova, linear, bias, covcond) 77 | 78 | def sample(self, n): 79 | return np.random.multivariate_normal(self.mean, self.cov, size=n) 80 | 81 | def encode(self, encoder): 82 | mu = np.dot(encoder, self.mean) 83 | cov = np.linalg.multi_dot([encoder, self.cov, encoder.T]) 84 | return MeanJointNormal(mu, cov) 85 | 86 | 87 | class NaturalJointNormal: 88 | 89 | def __init__(self, eta, precision): 90 | self.eta = eta 91 | self.precision = precision 92 | 93 | def to_mean(self): 94 | cov = np.linalg.inv(self.precision) 95 | return MeanJointNormal(cov @ self.eta, cov) 96 | 97 | def to_cholesky(self): 98 | L = np.linalg.cholesky(self.precision) 99 | zeta = scipy.linalg.solve_triangular(L, self.eta, lower=True) 100 | return CholeskyJointNormal(zeta, L) 101 | 102 | def to_conditional(self): 103 | d = self.eta.shape[0] // 2 104 | # conditional parameters 105 | preccond = self.precision[d:, d:] 106 | linear = - self.precision[d:, :d] 107 | bias = self.eta[d:] 108 | 109 | # marginal parameters 110 | tmp = linear.T @ np.linalg.inv(preccond) 111 | preca = self.precision[:d, :d] - tmp @ linear 112 | etaa = self.eta[:d] + tmp @ bias 113 | return NaturalConditionalNormal(etaa, preca, linear, bias, preccond) 114 | 115 | def reverse(self): 116 | d = self.eta.shape[0] // 2 117 | eta = np.roll(self.eta, d) 118 | precision = np.roll(self.precision, shift=[d, d], axis=[0, 1]) 119 | return NaturalJointNormal(eta, precision) 120 | 121 | @property 122 | def logpartition(self): 123 | s, logdet = np.linalg.slogdet(self.precision) 124 | assert s == 1 125 | return self.eta.T @ np.linalg.solve(self.precision, self.eta) - logdet 126 | 127 | def negativeloglikelihood(self, x): 128 | """Return the NLL of each point in x. 129 | x is a n*2dim array where each row is a datapoint. 130 | """ 131 | linearterm = -x @ self.eta - np.sum((x @ self.precision) * x, axis=1) 132 | return linearterm + self.logpartition 133 | 134 | def distance(self, other): 135 | return np.sqrt( 136 | np.sum((self.eta - other.eta) ** 2) 137 | + np.sum((self.precision - other.precision) ** 2) 138 | ) 139 | 140 | 141 | class CholeskyJointNormal: 142 | 143 | def __init__(self, zeta, L): 144 | self.zeta = zeta 145 | self.L = L 146 | 147 | def to_natural(self): 148 | return NaturalJointNormal( 149 | eta=self.L @ self.zeta, 150 | precision=self.L @ self.L.T 151 | ) 152 | 153 | def kullback_leibler(self, other): 154 | V = scipy.linalg.solve_triangular(self.L, other.L).T 155 | return (.5 * np.sum((V @ self.zeta - other.zeta) ** 2) 156 | + .5 * np.sum(V ** 2) - np.sum(np.log(np.diag(V)))) 157 | 158 | 159 | class NaturalConditionalNormal: 160 | """Joint Gaussian distribution between a cause variable A and an effect variable B. 161 | 162 | B is a linear encoder of A plus gaussian noise. 163 | The relevant parameters to describe the joint distribution are the parameters of A, 164 | and the parameters of B given A. 165 | """ 166 | 167 | def __init__(self, etaa, preca, linear, bias, preccond): 168 | # marginal 169 | self.etaa = etaa 170 | self.preca = preca 171 | # conditional 172 | self.linear = linear 173 | self.bias = bias 174 | self.preccond = preccond 175 | 176 | def to_joint(self): 177 | tmp = np.linalg.solve(self.preccond, self.linear).T 178 | eta = np.concatenate([self.etaa - tmp @ self.bias, self.bias], axis=0) 179 | 180 | d = self.etaa.shape[0] 181 | precision = np.zeros([2 * d, 2 * d]) 182 | precision[:d, :d] = self.preca + tmp @ self.linear 183 | precision[:d, d:] = - self.linear.T 184 | precision[d:, :d] = - self.linear 185 | precision[d:, d:] = self.preccond 186 | return NaturalJointNormal(eta, precision) 187 | 188 | def to_mean(self): 189 | cova = np.linalg.inv(self.preca) 190 | covcond = np.linalg.inv(self.preccond) 191 | mua = cova @ self.etaa 192 | linear = covcond @ self.linear 193 | bias = covcond @ self.bias 194 | return MeanConditionalNormal(mua, cova, linear, bias, covcond) 195 | 196 | def to_cholesky(self): 197 | la = np.linalg.cholesky(self.preca) 198 | lcond = np.linalg.cholesky(self.preccond) 199 | return CholeskyConditionalNormal( 200 | za=scipy.linalg.solve_triangular(la, self.etaa, lower=True), 201 | la=la, 202 | linear=scipy.linalg.solve_triangular(lcond, self.linear, lower=True), 203 | bias=scipy.linalg.solve_triangular(lcond, self.bias, lower=True), 204 | lcond=lcond 205 | ) 206 | 207 | def intervention(self, on, interpolation): 208 | """Sample natural parameters of a marginal distribution 209 | Substitute them in the cause or effect marginals. 210 | """ 211 | dim = self.etaa.shape[0] 212 | prec = wishart(dim) 213 | eta = np.random.multivariate_normal(np.zeros(dim), prec / 2 / dim) 214 | if on == 'cause': 215 | eta = (1 - interpolation) * self.etaa + interpolation * eta 216 | prec = (1 - interpolation) * self.preca + interpolation * prec 217 | return NaturalConditionalNormal(eta, prec, self.linear, self.bias, self.preccond) 218 | elif on == 'effect': 219 | # linear = (1 - interpolation) * self.linear 220 | linear = 0 * self.linear 221 | rev = self.reverse() 222 | bias = (1 - interpolation) * rev.etaa + interpolation * eta 223 | prec = (1 - interpolation) * rev.preca + interpolation * prec 224 | return NaturalConditionalNormal(self.etaa, self.preca, linear, bias, prec) 225 | elif on == 'mechanism': 226 | linear = (self.preccond @ np.random.randn(dim, dim) / np.sqrt(dim) * .95) 227 | linear = (1 - interpolation) * self.linear + interpolation * linear 228 | bias = (1 - interpolation) * self.bias + interpolation * eta 229 | return NaturalConditionalNormal(self.etaa, self.preca, linear, bias, self.preccond) 230 | 231 | def reverse(self): 232 | """Return the ConditionalGaussian from B to A.""" 233 | return self.to_joint().reverse().to_conditional() 234 | 235 | def distance(self, other): 236 | """Return Euclidean distance between self and other in natural parameter space.""" 237 | return np.sqrt( 238 | np.sum((self.etaa - other.etaa) ** 2) 239 | + np.sum((self.preca - other.preca) ** 2) 240 | + np.sum((self.linear - other.linear) ** 2) 241 | + np.sum((self.bias - other.bias) ** 2) 242 | + np.sum((self.preccond - other.preccond) ** 2) 243 | ) 244 | 245 | @property 246 | def logpartition(self): 247 | return self.to_joint().logpartition 248 | 249 | 250 | class CholeskyConditionalNormal: 251 | 252 | def __init__(self, za, la, linear, bias, lcond): 253 | self.za = za 254 | self.la = la 255 | self.linear = linear 256 | self.bias = bias 257 | self.lcond = lcond 258 | 259 | def to_natural(self): 260 | return NaturalConditionalNormal( 261 | etaa=np.dot(self.la, self.za), 262 | preca=np.dot(self.la, self.la.T), 263 | linear=np.dot(self.lcond, self.linear), 264 | bias=np.dot(self.lcond, self.bias), 265 | preccond=np.dot(self.lcond, self.lcond.T) 266 | ) 267 | 268 | def distance(self, other): 269 | return np.sqrt( 270 | np.sum((self.za - other.za) ** 2) 271 | + np.sum((self.la - other.la) ** 2) 272 | + np.sum((self.linear - other.linear) ** 2) 273 | + np.sum((self.bias - other.bias) ** 2) 274 | + np.sum((self.lcond - other.lcond) ** 2) 275 | ) 276 | 277 | 278 | # _____ _ 279 | # | __ \ (_) 280 | # | |__) | __ _ ___ _ __ ___ 281 | # | ___/ '__| |/ _ \| '__/ __| 282 | # | | | | | | (_) | | \__ \ 283 | # |_| |_| |_|\___/|_| |___/ 284 | def wishart(dim, scale=1): 285 | ans = scipy.stats.wishart(df=2 * dim + 2, scale=np.eye(dim) / dim * scale).rvs() 286 | if dim == 1: 287 | ans = np.array([[ans]]) 288 | return ans 289 | 290 | 291 | def sample_natural(dim, mode='conjugate', scale=10): 292 | """Sample natural parameters of a ConditionalGaussian of dimension dim.""" 293 | 294 | if mode == 'naive': 295 | # parameters of marginal on A 296 | etaa = np.random.randn(dim) 297 | preca = wishart(dim) 298 | 299 | # parameters of conditional 300 | linear = np.random.randn(dim, dim)/ np.sqrt(dim) * .95 301 | bias = np.random.randn(dim) 302 | preccond = wishart(dim, scale) 303 | 304 | elif mode == 'conjugate': 305 | n0 = 2 * dim + 2 306 | preca = wishart(dim) 307 | preccond = wishart(dim, scale) 308 | 309 | etaa = np.random.multivariate_normal(np.zeros(dim), preca / n0) 310 | bias = np.random.multivariate_normal(np.zeros(dim), preccond / n0) 311 | 312 | linear = preccond @ np.random.randn(dim, dim) / np.sqrt(dim) * .95 313 | 314 | return NaturalConditionalNormal(etaa, preca, linear, bias, preccond) 315 | 316 | 317 | def sample_triangular(dim): 318 | t = np.tril(np.random.randn(dim, dim), -1) 319 | diag = np.sqrt(np.random.gamma(shape=2, scale=2, size=dim)) 320 | return t + np.diag(diag) 321 | 322 | 323 | def sample_cholesky(dim): 324 | """Sample cholesky parameters of a ConditionalGaussian of dimension dim.""" 325 | # parameters of marginal on A 326 | zetaa = np.random.randn(dim) 327 | lowera = sample_triangular(dim) 328 | 329 | # parameters of conditional 330 | linear = np.random.randn(dim, dim) 331 | bias = np.random.randn(dim) 332 | lowercond = sample_triangular(dim) 333 | 334 | return CholeskyConditionalNormal(zetaa, lowera, linear, bias, lowercond) 335 | 336 | 337 | def sample(dim, mode, **kwargs): 338 | if mode == 'natural': 339 | return sample_natural(dim, mode='conjugate', **kwargs) 340 | elif mode == 'naive': 341 | return sample_natural(dim, mode=mode, **kwargs) 342 | elif mode == 'cholesky': 343 | return sample_cholesky(dim).to_natural() 344 | -------------------------------------------------------------------------------- /normal_pkg/normal_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.linalg 4 | import scipy.stats 5 | 6 | 7 | def check_symmetric(a, tol=1e-8): 8 | if np.allclose(a, a.T, atol=tol): 9 | return True 10 | else: 11 | print(a - a.T) 12 | return False 13 | 14 | 15 | class ConditionalGaussian(): 16 | """Joint Gaussian distribution between a cause variable A and an effect variable B. 17 | 18 | B is a linear transformation of A plus gaussian noise. 19 | The relevant parameters to describe the joint distribution are the parameters of A, 20 | and the parameters of B given A. 21 | For both distributions we store both meand natural parameters because we wish to compute 22 | distances in both of these parametrizations and save compute time. 23 | """ 24 | addfreedom = 100 25 | 26 | def __init__(self, dim, mua, cova, linear, bias, condcov, 27 | etaa=None, preca=None, natlinear=None, natbias=None, condprec=None): 28 | self.dim = dim 29 | self.eye = np.eye(dim) 30 | 31 | # mean parameters 32 | self.mua = mua 33 | self.cova = cova 34 | 35 | self.linear = linear 36 | self.bias = bias 37 | self.condcov = condcov 38 | 39 | # natural parameters 40 | self.preca = np.linalg.inv(self.cova) if preca is None else preca 41 | self.etaa = np.dot(self.preca, self.mua) if etaa is None else etaa 42 | 43 | self.condprec = np.linalg.inv(self.condcov) if condprec is None else condprec 44 | self.natlinear = np.dot(self.condprec, self.linear) if natlinear is None else natlinear 45 | self.natbias = np.dot(self.condprec, self.bias) if natbias is None else natbias 46 | 47 | def is_consistent(self): 48 | assert check_symmetric(self.cova), self.cova 49 | assert check_symmetric(self.preca), self.preca 50 | assert check_symmetric(self.condcov), self.condcov 51 | assert check_symmetric(self.condprec), self.condprec 52 | 53 | assert np.allclose(self.preca, np.linalg.inv(self.cova)) 54 | assert np.allclose(self.etaa, np.dot(self.preca, self.mua)) 55 | assert np.allclose(self.condprec, np.linalg.inv(self.condcov)) 56 | assert np.allclose(self.natlinear, np.dot(self.condprec, self.linear)) 57 | assert np.allclose(self.natbias, np.dot(self.condprec, self.bias)) 58 | 59 | def joint_parameters(self): 60 | mub = np.dot(self.linear, self.mua) + self.bias 61 | mean = np.concatenate([self.mua, mub], axis=0) 62 | 63 | crosscov = np.dot(self.cova, self.linear.T) 64 | covariance = np.zeros([2 * self.dim, 2 * self.dim]) 65 | covariance[:self.dim, :self.dim] = self.cova 66 | covariance[:self.dim, self.dim:] = crosscov 67 | covariance[self.dim:, :self.dim] = crosscov.T 68 | covariance[self.dim:, self.dim:] = self.condcov + np.linalg.multi_dot([ 69 | self.linear, self.cova, self.linear.T]) 70 | 71 | return mean, covariance 72 | 73 | @classmethod 74 | def from_joint(cls, mean, covariance): 75 | dim = int(mean.shape[0] / 2) 76 | 77 | # parameters of marginal on A 78 | mua = mean[:dim] 79 | cova = covariance[:dim, :dim] 80 | preca = np.linalg.inv(cova) 81 | 82 | # intermediate values required for calculus 83 | mub = mean[dim:] 84 | covb = covariance[dim:, dim:] 85 | crosscov = covariance[:dim, dim:] 86 | 87 | # parameters of conditional 88 | linear = np.dot(crosscov.T, preca) 89 | bias = mub - np.dot(linear, mua) 90 | covcond = covb - np.linalg.multi_dot([linear, cova, linear.T]) 91 | 92 | return cls(dim, mua, cova, linear, bias, covcond, preca=preca) 93 | 94 | @classmethod 95 | def random(cls, dim, symmetric=False): 96 | """Return a random ConditionalGaussian where each variable has dimension dim. 97 | 98 | If symmetric is False, then sample each parameters independently. 99 | Else ensure that cause and effect distributions are sampled similarly 100 | by sampling the joint and then computing the conditional parameters. 101 | 102 | """ 103 | if not symmetric: 104 | covariance_distribution = scipy.stats.invwishart( 105 | df=dim + cls.addfreedom, scale=np.eye(dim)) 106 | 107 | # parameters of marginal on A 108 | mua = np.random.randn(dim) 109 | cova = covariance_distribution.rvs() 110 | 111 | # parameters of conditional 112 | linear = np.random.randn(dim, dim) 113 | bias = np.random.randn(dim) 114 | covcond = covariance_distribution.rvs() 115 | 116 | return cls(dim, mua, cova, linear, bias, covcond) 117 | 118 | else: 119 | mean = np.random.randn(2 * dim) 120 | covariance = scipy.stats.invwishart( 121 | df=2 * (dim + cls.addfreedom), scale=np.eye(2 * dim)).rvs() 122 | 123 | return cls.from_joint(mean, covariance) 124 | 125 | def intervene_on_cause(self): 126 | """Random intervention on the mean parameters of the cause A""" 127 | mua = np.random.randn(self.dim) 128 | cova = scipy.stats.invwishart( 129 | df=self.dim + ConditionalGaussian.addfreedom, scale=self.eye).rvs() 130 | 131 | return ConditionalGaussian( 132 | self.dim, mua, cova, self.linear, self.bias, self.condcov, 133 | natlinear=self.natlinear, natbias=self.natbias, condprec=self.condprec 134 | ) 135 | 136 | def intervene_on_effect(self): 137 | """Random intervention on the mean parameters of the effect B""" 138 | mub = np.random.randn(self.dim) 139 | covb = scipy.stats.invwishart( 140 | df=self.dim + ConditionalGaussian.addfreedom, scale=self.eye).rvs() 141 | 142 | return ConditionalGaussian( 143 | self.dim, self.mua, self.cova, etaa=self.etaa, preca=self.preca, 144 | linear=np.zeros_like(self.linear), natlinear=np.zeros_like(self.natlinear), 145 | bias=mub, condcov=covb 146 | ) 147 | 148 | def reverse(self, viajoint=False): 149 | """Return the ConditionalGaussian from B to A.""" 150 | if not viajoint: 151 | mub = np.dot(self.linear, self.mua) + self.bias 152 | covb = self.condcov + np.linalg.multi_dot([self.linear, self.cova, self.linear.T]) 153 | 154 | natlinear = self.natlinear.T 155 | natbias = self.etaa - np.dot(natlinear, self.bias) 156 | preccond = self.preca + np.linalg.multi_dot( 157 | [self.linear.T, self.condprec, self.linear]) 158 | 159 | covcond = np.linalg.inv(preccond) 160 | linear = np.dot(covcond, natlinear) 161 | bias = np.dot(covcond, natbias) 162 | 163 | return ConditionalGaussian(self.dim, mua=mub, cova=covb, 164 | linear=linear, bias=bias, condcov=covcond, 165 | natlinear=natlinear, natbias=natbias, condprec=preccond) 166 | else: 167 | m, cov = self.joint_parameters() 168 | rm = np.roll(m, self.dim) 169 | rcov = np.roll(cov, shift=[self.dim, self.dim], axis=[0, 1]) 170 | return ConditionalGaussian.from_joint(rm, rcov) 171 | 172 | def squared_distances(self, other): 173 | """Return squared distance both in mean and natural parameter space.""" 174 | meandist = ( 175 | np.sum((self.mua - other.mua) ** 2) 176 | + np.sum((self.cova - other.cova) ** 2) 177 | + np.sum((self.linear - other.linear) ** 2) 178 | + np.sum((self.bias - other.bias) ** 2) 179 | + np.sum((self.condcov - other.condcov) ** 2) 180 | ) 181 | 182 | natdist = ( 183 | np.sum((self.etaa - other.etaa) ** 2) 184 | + np.sum((self.preca - other.preca) ** 2) 185 | + np.sum((self.natlinear - other.natlinear) ** 2) 186 | + np.sum((self.natbias - other.natbias) ** 2) 187 | + np.sum((self.condprec - other.condprec) ** 2) 188 | ) 189 | 190 | return meandist, natdist 191 | 192 | def sample(self, n): 193 | aa = np.random.multivariate_normal(self.mua, self.cova, size=n) 194 | bb = np.dot(aa, self.linear.T) \ 195 | + np.random.multivariate_normal(self.bias, self.condcov, size=n) 196 | return aa, bb 197 | 198 | def logdensity(self, aa, bb, param='mean'): 199 | if param == 'mean': 200 | adiff = aa - self.mua 201 | marginal = -.5 * np.dot(adiff, np.dot(adiff, self.preca)) 202 | bmeans = np.dot(aa, self.linear.T) + self.bias 203 | 204 | def encode(self, matrix): 205 | mean, covariance = self.joint_parameters() 206 | newmean = np.dot(matrix, mean) 207 | newcovariance = np.linalg.multi_dot([ 208 | matrix, covariance, matrix.T 209 | ]) 210 | 211 | return ConditionalGaussian.from_joint(newmean, newcovariance) 212 | 213 | 214 | def test_ConditionalGaussian(dim=2): 215 | ConditionalGaussian.random(dim, symmetric=False) 216 | a = ConditionalGaussian.random(dim, symmetric=True) 217 | b = a.intervene_on_cause() 218 | a.intervene_on_effect() 219 | a.reverse() 220 | a.squared_distances(b) 221 | a.sample(10) 222 | transform = scipy.stats.ortho_group.rvs(2 * dim) 223 | c = a.encode(transform) 224 | d = b.encode(transform) 225 | 226 | a.is_consistent() 227 | b.is_consistent() 228 | c.is_consistent() 229 | d.is_consistent() 230 | 231 | #print("causal distance after intervention", a.squared_distances(b)) 232 | #print("transformed after intervention", c.squared_distances(d)) 233 | 234 | dist = a.reverse().reverse().squared_distances(a) 235 | assert np.allclose(dist, 0), print("reverse reverse", dist) 236 | dist = c.reverse().reverse().squared_distances(c) 237 | assert np.allclose(dist, 0), print("reverse reverse", dist) 238 | dist = a.reverse(viajoint=True).squared_distances(a.reverse(viajoint=False)) 239 | assert np.allclose(dist, 0), print("reverse vs joint reverse", dist) 240 | 241 | 242 | test_ConditionalGaussian() 243 | 244 | 245 | def gaussian_distances(k, n, intervention='cause', symmetric=False): 246 | """Sample n conditional Gaussians between cause and effect of dimension k 247 | and evaluate the distance after intervention between causal and anticausal models.""" 248 | 249 | ans = np.zeros([n, 6]) 250 | for i in range(n): 251 | # sample mechanisms 252 | original = ConditionalGaussian.random(k, symmetric) 253 | revorig = original.reverse() 254 | 255 | if intervention == 'cause': 256 | transfer = original.intervene_on_cause() 257 | else: # intervention on effect 258 | transfer = original.intervene_on_effect() 259 | revtrans = transfer.reverse() 260 | 261 | meandist, natdist = original.squared_distances(transfer) 262 | revmeandist, revnatdist = revorig.squared_distances(revtrans) 263 | 264 | ans[i] = np.array([meandist, revmeandist, revmeandist / meandist, 265 | natdist, revnatdist, revnatdist / natdist]) 266 | 267 | return ans 268 | 269 | 270 | gaussian_distances(3, 4) 271 | 272 | 273 | def transform_distances(k, n, m, intervention='cause', transformation='orthonormal', 274 | noiserange=None): 275 | """ Evaluate distance induced by interventions and orthonormal transformations. 276 | 277 | Sample n conditional Gaussians between cause and effect of dimension k. 278 | For each of these reference distribution, sample an intervention. 279 | Sample m orthonormal transformation of dimension 2k. 280 | Get the n*(m+2) transformed distribution for reference and transfer distributions 281 | +2 because we want to see the values of causal and anticausal models. 282 | Evaluate the distance after intervention in transformed and non-trnasformed spaces. 283 | """ 284 | 285 | if transformation == 'orthonormal': 286 | transformers = scipy.stats.ortho_group.rvs(dim=2 * k, size=m) 287 | else: # tranformation=='small' 288 | # small orthonormal deviations around the identity 289 | noise = np.random.randn(2 * k, 2 * k) 290 | antisymmetric = noise - noise.T 291 | if noiserange is None: 292 | noiserange = np.linspace(0.1, 1, m) 293 | else: 294 | m = len(noiserange) 295 | transformers = [scipy.linalg.expm(eps * antisymmetric) for eps in noiserange] 296 | 297 | ans = np.zeros([n, m + 2, 2]) 298 | for i in range(n): 299 | # sample mechanisms 300 | original = ConditionalGaussian.random(k, symmetric=True) 301 | alloriginals = [original, original.reverse()] + [original.encode(t) for t in transformers] 302 | 303 | if intervention == 'cause': 304 | transfer = original.intervene_on_cause() 305 | else: # intervention on effect 306 | transfer = original.intervene_on_effect() 307 | alltransfers = [transfer, transfer.reverse()] + [transfer.encode(t) for t in transformers] 308 | 309 | distances = [d.squared_distances(dt) 310 | for d, dt in zip(alloriginals, alltransfers)] 311 | 312 | ans[i] = np.array(distances) 313 | 314 | return ans 315 | 316 | 317 | transform_distances(3, 4, 5) 318 | transform_distances(3, 4, 5, transformation='small') 319 | transform_distances(3, 4, 5, intervention='effect', transformation='small') 320 | -------------------------------------------------------------------------------- /categorical/plot_sweep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import scipy.stats 9 | 10 | 11 | def add_capitals(dico): 12 | return {**dico, **{key[0].capitalize() + key[1:]: item for key, item in dico.items()}} 13 | 14 | 15 | COLORS = { 16 | 'causal': 'blue', 17 | 'anti': 'red', 18 | 'joint': 'green', 19 | 'causal_average': 'darkblue', 20 | 'anti_average': 'darkred', 21 | 'joint_average': 'darkgreen', 22 | 'MAP_uniform': 'yellow', 23 | 'MAP_source': 'gold', 24 | # guess 25 | 'CausalGuessX': 'skyblue', 26 | 'CausalGuessY': 'darkcyan', 27 | 'AntiGuessX': 'salmon', 28 | 'AntiGuessY': 'chocolate', 29 | } 30 | MARKERS = {key: 'o' for key in COLORS} 31 | MARKERS['causal'] = '^' 32 | MARKERS['anti'] = 'v' 33 | 34 | COLORS = add_capitals(COLORS) 35 | MARKERS = add_capitals(MARKERS) 36 | 37 | 38 | def value_at_step(trajectory, nsteps=1000): 39 | """Return the KL and the integral KL up to nsteps.""" 40 | steps = trajectory['steps'] 41 | index = np.searchsorted(steps, nsteps) - 1 42 | 43 | ans = {} 44 | # ans['end_step'] = steps[index] 45 | for key, item in trajectory.items(): 46 | if key.startswith('kl_'): 47 | ans[key[3:]] = item[index].mean() 48 | # ans['endkl_' + key[3:]] = item[index].mean() 49 | # ans['intkl_' + key[3:]] = item[:index].mean() 50 | 51 | return ans 52 | 53 | 54 | def get_best(results, nsteps): 55 | """Store per model each parameter and kl values 56 | then for each model return the argmax parameters and curves 57 | for kl and integral kl 58 | """ 59 | 60 | by_model = {} 61 | # dictionary where each key is a model, 62 | # and each value is a list of this model's hyperparameter 63 | # and outcome at step nsteps 64 | for exp in results: 65 | trajectory = exp['trajectory'] 66 | for model, metric in value_at_step(trajectory, nsteps).items(): 67 | if model not in by_model: 68 | by_model[model] = [] 69 | toadd = { 70 | 'hyperparameters': exp['hyperparameters'], 71 | **exp['hyperparameters'], 72 | 'value': metric, 73 | 'kl': trajectory['kl_' + model], 74 | 'steps': trajectory['steps'] 75 | } 76 | if 'scoredist_' + model in trajectory: 77 | toadd['scoredist'] = trajectory['scoredist_' + model] 78 | by_model[model] += [toadd] 79 | 80 | # select only the best hyperparameters for this model. 81 | for model, metrics in by_model.items(): 82 | dalist = sorted(metrics, key=lambda x: x['value']) 83 | # Ensure that the optimal configuration does not diverge as optimization goes on. 84 | for duh in dalist: 85 | if duh['kl'][0].mean() * 2 > duh['kl'][-1].mean(): 86 | break 87 | by_model[model] = duh 88 | 89 | # print the outcome 90 | for model, item in by_model.items(): 91 | if 'MAP' in model: 92 | print(model, ('\t n0={n0:.0f},' 93 | '\t kl={value:.3f}').format(**item)) 94 | else: 95 | print(model, ('\t alpha={scheduler_exponent},' 96 | '\t lr={lr:.1e},' 97 | '\t kl={value:.3f}').format(**item)) 98 | 99 | return by_model 100 | 101 | 102 | def curve_plot(bestof, nsteps, figsize, logscale=False, endstep=400, confidence=(5, 95)): 103 | """Draw mean trajectory plot with percentiles""" 104 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) 105 | for model, item in sorted(bestof.items()): 106 | xx = item['steps'] 107 | values = item['kl'] 108 | 109 | # truncate plot for k-invariance 110 | end_id = np.searchsorted(xx, endstep) + 1 111 | xx = xx[:end_id] 112 | values = values[:end_id] 113 | 114 | # plot mean and percentile statistics 115 | ax.plot(xx, values.mean(axis=1), label=model, 116 | marker=MARKERS[model], markevery=len(xx) // 6, markeredgewidth=0, 117 | color=COLORS[model], alpha=.9) 118 | ax.fill_between( 119 | xx, 120 | np.percentile(values, confidence[0], axis=1), 121 | np.percentile(values, confidence[1], axis=1), 122 | alpha=.4, 123 | color=COLORS[model] 124 | ) 125 | 126 | ax.axvline(nsteps, linestyle='--', color='black') 127 | ax.grid(True) 128 | if logscale: 129 | ax.set_yscale('log') 130 | ax.set_ylabel(r'$\mathrm{KL}(\mathbf{p}^*, \mathbf{p}^{(t)})$') 131 | ax.set_xlabel('number of samples t') 132 | ax.legend() 133 | 134 | return fig, ax 135 | 136 | 137 | def scatter_plot(bestof, nsteps, figsize, logscale=False): 138 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) 139 | alldist = [] 140 | allkl = [] 141 | for model, item in sorted(bestof.items()): 142 | if 'scoredist' not in item: 143 | continue 144 | index = min(np.searchsorted(item['steps'], nsteps), len(item['steps']) - 1) 145 | 146 | initial_distances = item['scoredist'][0] 147 | end_kl = item['kl'][index] 148 | ax.scatter( 149 | initial_distances, 150 | end_kl, 151 | alpha=.3, 152 | color=COLORS[model], 153 | marker=MARKERS[model], 154 | linewidth=0, 155 | label=model if False else None 156 | ) 157 | alldist += list(initial_distances) 158 | allkl += list(end_kl) 159 | 160 | # linear regression 161 | slope, intercept, rval, pval, _ = scipy.stats.linregress(alldist, allkl) 162 | x_vals = np.array(ax.get_xlim()) 163 | y_vals = intercept + slope * x_vals 164 | ax.plot( 165 | x_vals, y_vals, '--', color='black', alpha=.8, 166 | label=f'y=ax+b, r2={rval ** 2:.2f}' 167 | f',\na={slope:.1e}, b={intercept:.2f}' 168 | ) 169 | 170 | # look 171 | ax.legend() 172 | ax.grid(True) 173 | if logscale: 174 | ax.set_yscale('log') 175 | ax.set_xscale('log') 176 | ax.set_xlim(min(alldist), max(alldist)) 177 | else: 178 | ax.ticklabel_format(axis='both', style='sci', scilimits=(0, 0), useMathText=True) 179 | 180 | ax.set_ylabel(r'$\mathrm{KL}(\mathbf{p}^*, \mathbf{p}^{(t)}); T=$' + str(nsteps)) 181 | ax.set_xlabel(r'$||\theta^{(0)} - \theta^* ||^2$') 182 | return fig, ax 183 | 184 | 185 | def two_plots(results, nsteps, plotname, dirname, verbose=False, figsize=(6, 3)): 186 | print(dirname, plotname) 187 | bestof = get_best(results, nsteps) 188 | # remove the models I don't want to compare 189 | # eg remove SGD, MAP. Keep ASGD and rename them to remove average. 190 | selected = { 191 | key[0].capitalize() + key[1:-len('_average')].replace('A', 'X').replace('B', 'Y'): item 192 | for key, item in bestof.items() 193 | if key.endswith('_average')} 194 | for key in ['MAP_uniform', 'MAP_source']: 195 | # selected[key] = bestof[key] 196 | pass 197 | if dirname.startswith('guess'): 198 | selected.pop('Joint', None) 199 | 200 | curves, ax1 = curve_plot(selected, nsteps, figsize, logscale=False) 201 | # initstring = 'denseinit' if results[0]["is_init_dense"] else 'sparseinit' 202 | # curves.suptitle(f'Average KL tuned for {nsteps} samples with {confidence} percentiles, ' 203 | # f'{initstring}, k={results[0]["k"]}') 204 | scatter, ax2 = scatter_plot(selected, nsteps, figsize, 205 | logscale=(dirname == 'guess_sparseinit')) 206 | 207 | if verbose: 208 | for ax in [ax1, ax2]: 209 | info = str(next(iter(selected.values()))['hyperparameters']) 210 | txt = ax.text(0.5, 1, info, ha='center', va='top', 211 | wrap=True, transform=ax.transAxes, 212 | # bbox=dict(boxstyle='square') 213 | ) 214 | txt._get_wrap_line_width = lambda: 400. # wrap to 600 screen pixels 215 | 216 | # small adjustments for intervention guessing 217 | if dirname.startswith('guess'): 218 | curves.axes[0].set_ylim(0, 1.5) 219 | for fig in [curves, scatter]: 220 | fig.axes[0].set_xlabel('') 221 | fig.axes[0].set_ylabel('') 222 | 223 | for style, fig in {'curves': curves, 'scatter': scatter}.items(): 224 | for figpath in [ 225 | os.path.join('plots', dirname, f'{style}_{plotname}.pdf')]: 226 | print("Saving ", figpath) 227 | os.makedirs(os.path.dirname(figpath), exist_ok=True) 228 | # os.path.join('plots/sweep/png', f'{style}_{plotname}.png')]: 229 | fig.savefig(figpath, bbox_inches='tight') 230 | plt.close(curves) 231 | plt.close(scatter) 232 | print() 233 | 234 | 235 | def plot_marginal_likelihoods(results, intervention, k, dirname): 236 | exp = results[0] 237 | values = {} 238 | for whom in ['A', 'B']: 239 | values[whom] = exp['loglikelihood' + whom][:100].cumsum(0) 240 | xx = np.arange(1, values[whom].shape[0] + 1) 241 | values[whom] /= xx[:, np.newaxis] 242 | 243 | if intervention == 'cause': 244 | right, wrong = 'A', 'B' 245 | else: 246 | right, wrong = 'B', 'A' 247 | 248 | plt.plot(values[wrong] - values[right], alpha=.2) 249 | plt.hlines(0, 0, values['B'].shape[0]) 250 | plt.grid() 251 | plt.ylim(-1, 1) 252 | figpath = os.path.join('plots', dirname, 'guessing', f'guess_{intervention}_k={k}.pdf') 253 | os.makedirs(os.path.dirname(figpath), exist_ok=True) 254 | plt.savefig(figpath, bbox_inches='tight') 255 | plt.close() 256 | 257 | 258 | def merge_results(results1, results2, bs=5): 259 | """Combine results from intervention on cause and effect. 260 | Also report statistics about pooled results. 261 | 262 | Pooled records the average over 10 cause and 10 effect interventions 263 | the goal is to have tighter percentile curves 264 | which are representative of the algorithm's performance 265 | """ 266 | combined = [] 267 | pooled = [] 268 | for e1, e2 in zip(results1, results2): 269 | h1, h2 = e1['hyperparameters'], e2['hyperparameters'] 270 | assert h1['lr'] == h2['lr'] 271 | t1, t2 = e1['trajectory'], e2['trajectory'] 272 | combined_trajs = {'steps': t1['steps']} 273 | pooled_trajs = combined_trajs.copy() 274 | for key in t1.keys(): 275 | if key.startswith(('scoredist', 'kl')): 276 | combined_trajs[key] = np.concatenate((t1[key], t2[key]), axis=1) 277 | meantraj = (t1[key] + t2[key]) / 2 278 | pooled_trajs[key] = np.array([ 279 | meantraj[:, bs * i:bs * (i + 1)].mean(axis=1) 280 | for i in range(meantraj.shape[1] // bs) 281 | ]).T 282 | combined += [{'hyperparameters': h1, 'trajectory': combined_trajs}] 283 | pooled += [{'hyperparameters': h2, 'trajectory': pooled_trajs}] 284 | return combined, pooled 285 | 286 | 287 | def all_plot(guess=False, dense=True, 288 | input_dir='categorical_results', output_dir='camera_ready', 289 | figsize=(3.6, 2.2)): 290 | basefile = '_'.join(['guess' if guess else 'sweep2', 291 | 'denseinit' if dense else 'sparseinit']) 292 | print(basefile, '\n---------------------') 293 | 294 | prior_string = 'dense' if dense else 'sparse' 295 | 296 | for k in [20]: # [10, 20, 50]: 297 | # Optimize hyperparameters for nsteps such that curves are k-invariant 298 | nsteps = k ** 2 // 4 299 | allresults = defaultdict(list) 300 | for intervention in ['cause', 'effect']: 301 | # 'singlecond', 'gmechanism', 'independent', 'geometric', 'weightedgeo']: 302 | plotname = f'{prior_string}_{intervention}_k={k}' 303 | file = f'{basefile}_{intervention}_k={k}.pkl' 304 | filepath = os.path.join(input_dir, file) 305 | print(os.path.abspath(filepath)) 306 | if os.path.isfile(filepath): 307 | with open(filepath, 'rb') as fin: 308 | results = pickle.load(fin) 309 | print(1) 310 | two_plots(results, nsteps, 311 | plotname=plotname, 312 | dirname=output_dir, 313 | figsize=figsize) 314 | allresults[intervention] = results 315 | # if guess: 316 | # plot_marginal_likelihoods(results, intervention, k, basefile) 317 | 318 | # if not guess and 'cause' in allresults and 'effect' in allresults: 319 | # combined, pooled = merge_results(allresults['cause'], allresults['effect']) 320 | # if len(combined) > 0: 321 | # for key, item in {'combined': combined, 'pooled': pooled}.items(): 322 | # two_plots(item, nsteps, 323 | # plotname=f'{prior_string}_{key}_k={k}', 324 | # dirname=output_dir, 325 | # figsize=figsize) 326 | 327 | 328 | if __name__ == '__main__': 329 | np.set_printoptions(precision=2) 330 | matplotlib.use('pgf') 331 | matplotlib.rcParams['mathtext.fontset'] = 'cm' 332 | matplotlib.rcParams['pdf.fonttype'] = 42 333 | 334 | # all_plot(guess=True, dense=True) 335 | # all_plot(guess=True, dense=False) 336 | all_plot(guess=False, dense=True) 337 | all_plot(guess=False, dense=False) 338 | -------------------------------------------------------------------------------- /categorical/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy import stats 5 | from torch import nn, optim 6 | 7 | from categorical.utils import kullback_leibler, logit2proba, logsumexp, proba2logit 8 | 9 | 10 | def joint2conditional(joint): 11 | marginal = np.sum(joint, axis=-1) 12 | conditional = joint / np.expand_dims(marginal, axis=-1) 13 | 14 | return CategoricalStatic(marginal, conditional) 15 | 16 | 17 | def jointlogit2conditional(joint, is_btoa): 18 | sa = logsumexp(joint) 19 | sa -= sa.mean(axis=1, keepdims=True) 20 | sba = joint - sa[:, :, np.newaxis] 21 | sba -= sba.mean(axis=2, keepdims=True) 22 | 23 | return CategoricalStatic(sa, sba, from_probas=False, is_btoa=is_btoa) 24 | 25 | 26 | def sample_joint(k, n, concentration=1, dense=False, logits=True): 27 | """Sample n causal mechanisms of categorical variables of dimension K. 28 | 29 | The concentration argument specifies the concentration of the resulting cause marginal. 30 | """ 31 | if logits: 32 | sa = stats.loggamma.rvs(concentration, size=(n, k)) 33 | sa -= sa.mean(axis=1, keepdims=True) 34 | 35 | conditional_concentration = concentration if dense else concentration / k 36 | if conditional_concentration > 0.1: 37 | sba = stats.loggamma.rvs(conditional_concentration, size=(n, k, k)) 38 | else: 39 | # A loggamma with small shape parameter is well approximated 40 | # by a negative exponential with parameter scale = 1/ shape 41 | sba = - stats.expon.rvs(scale=1 / conditional_concentration, size=(n, k, k)) 42 | sba -= sba.mean(axis=2, keepdims=True) 43 | return CategoricalStatic(sa, sba, from_probas=False) 44 | else: 45 | pa = np.random.dirichlet(concentration * np.ones(k), size=n) 46 | condconcentration = concentration if dense else concentration / k 47 | pba = np.random.dirichlet(condconcentration * np.ones(k), size=[n, k]) 48 | return CategoricalStatic(pa, pba, from_probas=True) 49 | 50 | 51 | class CategoricalStatic: 52 | """Represent n categorical distributions of variables (a,b) of dimension k each.""" 53 | 54 | def __init__(self, marginal, conditional, from_probas=True, is_btoa=False): 55 | """The distribution is represented by a marginal p(a) and a conditional p(b|a) 56 | 57 | marginal is n*k array. 58 | conditional is n*k*k array. Each element conditional[i,j,k] is p_i(b=k |a=j) 59 | """ 60 | self.n, self.k = marginal.shape 61 | self.BtoA = is_btoa 62 | 63 | if not conditional.shape == (self.n, self.k, self.k): 64 | raise ValueError( 65 | f'Marginal shape {marginal.shape} and conditional ' 66 | f'shape {conditional.shape} do not match.') 67 | 68 | if from_probas: 69 | self.marginal = marginal 70 | self.conditional = conditional 71 | self.sa = proba2logit(marginal) 72 | self.sba = proba2logit(conditional) 73 | else: 74 | self.marginal = logit2proba(marginal) 75 | self.conditional = logit2proba(conditional) 76 | self.sa = marginal 77 | self.sba = conditional 78 | 79 | def to_joint(self, return_probas=True): 80 | if return_probas: 81 | return self.conditional * self.marginal[:, :, np.newaxis] 82 | else: # return logits 83 | joint = self.sba \ 84 | + (self.sa - logsumexp(self.sba))[:, :, np.newaxis] 85 | return joint - np.mean(joint, axis=(1, 2), keepdims=True) 86 | 87 | def reverse(self): 88 | """Return conditional from b to a. 89 | Compute marginal pb and conditional pab such that pab*pb = pba*pa. 90 | """ 91 | joint = self.to_joint(return_probas=False) 92 | joint = np.swapaxes(joint, 1, 2) # invert variables 93 | return jointlogit2conditional(joint, not self.BtoA) 94 | 95 | def probadist(self, other): 96 | pd = np.sum((self.marginal - other.marginal) ** 2, axis=1) 97 | pd += np.sum((self.conditional - other.conditional) ** 2, axis=(1, 2)) 98 | return pd 99 | 100 | def scoredist(self, other): 101 | sd = np.sum((self.sa - other.sa) ** 2, axis=1) 102 | sd += np.sum((self.sba - other.sba) ** 2, axis=(1, 2)) 103 | return sd 104 | 105 | def sqdistance(self, other): 106 | """Return the squared euclidean distance between self and other""" 107 | return self.probadist(other), self.scoredist(other) 108 | 109 | def kullback_leibler(self, other): 110 | p0 = self.to_joint().reshape(self.n, self.k ** 2) 111 | p1 = other.to_joint().reshape(self.n, self.k ** 2) 112 | return kullback_leibler(p0, p1) 113 | 114 | def intervention(self, on, concentration=1, dense=True): 115 | # sample new marginal 116 | if on == 'independent': 117 | # make cause and effect independent, 118 | # but without changing the effect marginal. 119 | newmarginal = self.reverse().marginal 120 | elif on == 'geometric': 121 | newmarginal = logit2proba(self.sba.mean(axis=1)) 122 | elif on == 'weightedgeo': 123 | newmarginal = logit2proba(np.sum(self.sba * self.marginal[:, :, None], axis=1)) 124 | else: 125 | newmarginal = np.random.dirichlet(concentration * np.ones(self.k), size=self.n) 126 | 127 | # TODO use logits of the marginal for stability certainty 128 | # replace the cause or the effect by this marginal 129 | if on == 'cause': 130 | return CategoricalStatic(newmarginal, self.conditional) 131 | elif on in ['effect', 'independent', 'geometric', 'weightedgeo']: 132 | # intervention on effect 133 | newconditional = np.repeat(newmarginal[:, None, :], self.k, axis=1) 134 | return CategoricalStatic(self.marginal, newconditional) 135 | elif on == 'mechanism': 136 | # sample a new mechanism from the same prior 137 | sba = sample_joint(self.k, self.n, concentration, dense, logits=True).sba 138 | return CategoricalStatic(self.sa, sba, from_probas=False) 139 | elif on == 'gmechanism': 140 | # sample from a gaussian centered on each conditional 141 | sba = np.random.normal(self.sba, self.sba.std()) 142 | sba -= sba.mean(axis=2, keepdims=True) 143 | return CategoricalStatic(self.sa, sba, from_probas=False) 144 | elif on == 'singlecond': 145 | newscores = stats.loggamma.rvs(concentration, size=(self.n, self.k)) 146 | newscores -= newscores.mean(1, keepdims=True) 147 | # if 'simple': 148 | # a0 = 0 149 | # elif 'max': 150 | a0 = np.argmax(self.sa, axis=1) 151 | sba = self.sba.copy() 152 | sba[np.arange(self.n), a0] = newscores 153 | return CategoricalStatic(self.sa, sba, from_probas=False) 154 | else: 155 | raise ValueError(f'Intervention on {on} is not supported.') 156 | 157 | def sample(self, m, return_tensor=False): 158 | """For each of the n distributions, return m samples. (n*m*2 array) """ 159 | flatjoints = self.to_joint().reshape((self.n, self.k ** 2)) 160 | samples = np.array( 161 | [np.random.choice(self.k ** 2, size=m, p=p) for p in flatjoints]) 162 | a = samples // self.k 163 | b = samples % self.k 164 | if not return_tensor: 165 | return a, b 166 | else: 167 | return torch.from_numpy(a), torch.from_numpy(b) 168 | 169 | def to_module(self): 170 | return CategoricalModule(self.sa, self.sba, is_btoa=self.BtoA) 171 | 172 | def __repr__(self): 173 | return (f"n={self.n} categorical of dimension k={self.k}\n" 174 | f"{self.marginal}\n" 175 | f"{self.conditional}") 176 | 177 | 178 | def test_ConditionalStatic(): 179 | print('test categorical static') 180 | 181 | # test the reversion formula on a known example 182 | pa = np.array([[.5, .5]]) 183 | pba = np.array([[[.5, .5], [1 / 3, 2 / 3]]]) 184 | anspb = np.array([[5 / 12, 7 / 12]]) 185 | anspab = np.array([[[3 / 5, 2 / 5], [3 / 7, 4 / 7]]]) 186 | 187 | test = CategoricalStatic(pa, pba).reverse() 188 | answer = CategoricalStatic(anspb, anspab) 189 | 190 | probadist, scoredist = test.sqdistance(answer) 191 | assert probadist < 1e-4, probadist 192 | assert scoredist < 1e-4, scoredist 193 | 194 | # ensure that reverse is reversible 195 | distrib = sample_joint(3, 17, 1, True) 196 | assert np.allclose(0, distrib.reverse().reverse().sqdistance(distrib)) 197 | 198 | distrib.kullback_leibler(distrib.reverse()) 199 | n = 10000 200 | a, b = distrib.sample(n) 201 | c = a * distrib.k + b 202 | val, approx = np.unique(c[0], return_counts=True) 203 | approx = approx.astype(float) / n 204 | joint = distrib.to_joint()[0].flatten() 205 | assert np.allclose(joint, approx, atol=1e-2, rtol=1e-1), print(joint, approx) 206 | 207 | 208 | class CategoricalModule(nn.Module): 209 | """Represent n categorical conditionals as a pytorch module""" 210 | 211 | def __init__(self, sa, sba, is_btoa=False): 212 | super(CategoricalModule, self).__init__() 213 | self.n, self.k = tuple(sa.shape) 214 | 215 | sa = sa.clone().detach() if torch.is_tensor(sa) else torch.tensor(sa) 216 | sba = sba.clone().detach() if torch.is_tensor(sba) else torch.tensor(sba) 217 | self.sa = nn.Parameter(sa.to(torch.float32)) 218 | self.sba = nn.Parameter(sba.to(torch.float32)) 219 | self.BtoA = is_btoa 220 | 221 | def forward(self, a, b): 222 | """ 223 | :param a: n*m collection of m class in {1,..., k} observed 224 | for each of the n models 225 | :param b: n*m like a 226 | :return: the log-probability of observing a,b, 227 | where model 1 explains first row of a,b, 228 | model 2 explains row 2 and so forth. 229 | """ 230 | batch_size = a.shape[1] 231 | if self.BtoA: 232 | a, b = b, a 233 | rows = torch.arange(0, self.n).unsqueeze(1).repeat(1, batch_size) 234 | return self.to_joint()[rows.view(-1), a.view(-1), b.view(-1)].view(self.n, batch_size) 235 | 236 | def to_joint(self): 237 | return F.log_softmax(self.sba, dim=2) \ 238 | + F.log_softmax(self.sa, dim=1).unsqueeze(dim=2) 239 | 240 | def to_static(self): 241 | return CategoricalStatic( 242 | logit2proba(self.sa.detach().numpy()), 243 | logit2proba(self.sba.detach().numpy()) 244 | ) 245 | 246 | def kullback_leibler(self, other): 247 | joint = self.to_joint() 248 | return torch.sum((joint - other.to_joint()) * torch.exp(joint), 249 | dim=(1, 2)) 250 | 251 | def scoredist(self, other): 252 | return torch.sum((self.sa - other.sa) ** 2, dim=1) \ 253 | + torch.sum((self.sba - other.sba) ** 2, dim=(1, 2)) 254 | 255 | def __repr__(self): 256 | return f"CategoricalModule(joint={self.to_joint().detach()})" 257 | 258 | 259 | def test_CategoricalModule(n=7, k=5): 260 | print('test categorical module') 261 | references = sample_joint(k, n, 1) 262 | intervened = references.intervention(on='cause', concentration=1) 263 | 264 | modules = references.to_module() 265 | 266 | # test that reverse is numerically stable 267 | kls = references.reverse().reverse().to_module().kullback_leibler(modules) 268 | assert torch.allclose(torch.zeros(n), kls), kls 269 | 270 | # test optimization 271 | optimizer = optim.SGD(modules.parameters(), lr=1) 272 | aa, bb = intervened.sample(13, return_tensor=True) 273 | negativeloglikelihoods = -modules(aa, bb).mean() 274 | optimizer.zero_grad() 275 | negativeloglikelihoods.backward() 276 | optimizer.step() 277 | 278 | imodules = intervened.to_module() 279 | imodules.kullback_leibler(modules) 280 | imodules.scoredist(modules) 281 | 282 | 283 | class JointModule(nn.Module): 284 | 285 | def __init__(self, logits): 286 | super(JointModule, self).__init__() 287 | self.n, k2 = logits.shape # logits is flat 288 | 289 | self.k = int(np.sqrt(k2)) 290 | # if self.k ** 2 != k2: 291 | # raise ValueError('Logits matrix can not be reshaped to square.') 292 | 293 | # normalize to sum to 0 294 | logits = logits - logits.mean(dim=1, keepdim=True) 295 | self.logits = nn.Parameter(logits) 296 | 297 | @property 298 | def logpartition(self): 299 | return torch.logsumexp(self.logits, dim=1) 300 | 301 | def forward(self, a, b): 302 | batch_size = a.shape[1] 303 | rows = torch.arange(0, self.n).unsqueeze(1).repeat(1, batch_size).view(-1) 304 | index = (a * self.k + b).view(-1) 305 | return F.log_softmax(self.logits, dim=1)[rows, index].view(self.n, batch_size) 306 | 307 | def kullback_leibler(self, other): 308 | a = self.logpartition 309 | kl = torch.sum((self.logits - other.logits) * torch.exp(self.logits - a[:, None]), dim=1) 310 | return kl - a + other.logpartition 311 | 312 | def scoredist(self, other): 313 | return torch.sum((self.logits - other.logits) ** 2, dim=1) 314 | 315 | def __repr__(self): 316 | return f"CategoricalJoint(logits={self.logits.detach()})" 317 | 318 | 319 | class Counter: 320 | 321 | def __init__(self, counts): 322 | self.counts = counts 323 | self.n, self.k, self.k2 = counts.shape 324 | 325 | @property 326 | def total(self): 327 | return self.counts.sum(axis=(1, 2), keepdims=True) 328 | 329 | # @jit 330 | def update(self, a: np.ndarray, b: np.ndarray): 331 | for aaa, bbb in zip(a.T, b.T): 332 | self.counts[np.arange(self.n), aaa, bbb] += 1 333 | 334 | 335 | def test_Counter(): 336 | c = Counter(np.zeros([1, 2, 2])) 337 | c.update(np.array([[0, 0, 0, 1]]), np.array([[0, 0, 1, 1]])) 338 | assert c.total == 4 339 | assert np.allclose(c.counts / c.total, [[.5, .25], [0, .25]]) 340 | 341 | 342 | class JointMAP: 343 | 344 | def __init__(self, prior, counter): 345 | self.prior = prior 346 | self.n0 = self.prior.sum(axis=(1, 2), keepdims=True) 347 | self.counter = counter 348 | 349 | @property 350 | def frequencies(self): 351 | return ((self.prior + self.counter.counts) / 352 | (self.n0 + self.counter.total)) 353 | 354 | def to_joint(self): 355 | return np.log(self.frequencies) 356 | 357 | 358 | if __name__ == "__main__": 359 | print("hi") 360 | test_ConditionalStatic() 361 | test_CategoricalModule() 362 | test_Counter() 363 | -------------------------------------------------------------------------------- /8_IRM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from sklearn.linear_model import Ridge\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "plt.rcParams.update({'text.latex.preamble' : [r'\\usepackage{amsmath, amsfonts}']})\n", 14 | "plt.rc('text', usetex=True)\n", 15 | "plt.rc('font', size=12)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "# Least Square" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 34, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "def ls(x, y, reg=.1):\n", 32 | " return Ridge(alpha=reg, fit_intercept=False).fit(x, y).coef_\n", 33 | "\n", 34 | "\n", 35 | "def sample(n=100000, e=1):\n", 36 | " x = np.random.randn(n, 1) * e\n", 37 | " y = x + np.random.randn(n, 1) * e\n", 38 | " z = y + np.random.randn(n, 1)\n", 39 | " return np.hstack((x, z)), y\n", 40 | "\n", 41 | "def penalty_ls(x, y, phi, w, reg=.1):\n", 42 | " p = np.linalg.norm(ls(x @ phi, y, reg) - w) \n", 43 | " return p\n", 44 | "\n", 45 | "\n", 46 | "def penalty_g(x, y, phi, w):\n", 47 | " p = (phi.T @ x.T @ x @ phi @ w - phi.T @ x.T @ y) / x.shape[0]\n", 48 | " return np.linalg.norm(p) ** 2" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 36, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAEuCAYAAAB4RZ0yAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdeXzcV33v/9eZ0UiyZFtjyWvieBnHCYEsjiQnBBIwIBEgZSlITlva8gNiCdKWpY/WTii99La9DTa3/bHcFqRQoAsUO4LeshZLAZWGJYmtmCxOiCPZju3Eli15bEnWNjPn/jHfGc+MZqTRMtIs7+fjoYfm+53vcs58NaPPnPP5nmOstYiIiIjI1FwLXQARERGRXKHASURERCRNCpxERERE0qTASURERCRNCpxERERE0qTASURERCRNRQtdgHQtX77cbtiwIaPnGBoaory8PKPnyDaFWGcozHoXYp2hMOtdiHWGwqx3IdYZMl/vgwcPnrPWrkj2XM4EThs2bODAgQMZPUdnZyfbtm3L6DmyTSHWGQqz3oVYZyjMehdinaEw612IdYbM19sYczzVc+qqExEREUmTAicRERGRNClwEhEREUlTzuQ4iYhIfgiFQpw8eZKhoaE5PW5FRQXPPvvsnB4z2xVinWHu6l1eXs7atWtxudJvR1LgJCIi8+rcuXMYY7j22mun9Q9rKgMDAyxZsmTOjpcLCrHOMDf1DoVCnDp1inPnzrFy5cq091NXnYiIzCu/38+qVavmNGgSmS6Xy8WqVau4cOHC9PbLUHlERESSCgaDeDyehS6GCB6Ph0AgMK19FDg5DvUeYv+F/RzqPbTQRRERyXvGmIUuQsb19PTQ1dU16bLf75/0GG1tbXR0dNDa2orf78fv99PR0ZGxMiea7/PNt5n8HSpwIhw07di/g+/5v8eO/TsUPImIFIiOjg4aGxtpbGxkz549NDc3s2fPnikDmnSPXV1dnXQ5ct6enp6U+/f09NDT00NdXR1NTU3s2rULr9cbDaDSLUNNTQ2NjY1xQVu6pnu+QqDACThw5gBjwTEslvHQOAfOZHaEchERyQ51dXXU19dTX1/Pzp07aWlpwev1sm/fvlkFC21tbdTW1qZcrquri1tOdQyv1xtdjrT8NDQ08MADD6RVDr/fj8/nw+fz0d/fP50qRE3nfIVAgRNQu6qWYncxLlx4XB5qV03+xywiIvmrsrKSpqamWU3ztXfv3rjWpsTldPT19VFZWRldjg18JmupgnDAVFNTQ11dHT6fj61bt+Lz+aipqZl0v66uLvbs2UNHR8eEbkUJU+AEbFm5hQff/CB3ee/iwTc/yJaVWxa6SCIissBiW5z8fj+7du2K5hxlg8laxCJljG2x8vl8E/KsEo+3Y8cOdu7cSV1d3YRgSd11YRrHybFl5Rb8FX4FTSIi82zDfd/P6PGPffquae/j9/vjgo6IhoaGCet6enpoa2tj586dMypfKlVVVXGtTLGtT5GutxUrViTdt66ujgceeCAu2Onp6cHn86Vs+dq3bx8+ny+uSzDxfMlek0KjwGkSh3oPceDMAWpX1SqgEhEpEP39/ezbt4+mpia6urrYsWMHtbW1eL3eaN5RT08P3d3dVFVV0dfXR1dXF9XV1dTV1c34vInBWkNDAy0tLdHnYgMev98fDaSSBXler5eDBw/S1tYWbTmKrEt1vsrKSu6+++5oHWKfjz1foVPglELkTrux4BjF7mJ14YmIZMhMWoSSmclo0h0dHbS3t1NZWcmePXvo7u5m06ZNNDU1AdDS0sLBgwfp6emJBjFANGhqamqiv7+ftra2uKApMchIXG5ra+PAgQPRwKS6upqNGzdy9OjR6Dqfz8emTZui+UYPPvhgdP9I68/AwMCE/RJFAqfEMiTu19DQwJ49e+jq6qK/vz8ueV2tTZcpcEohcqddiFD0TjsFTiIi+aWurm7KVqJkuT27d++mo6ODHTt2sHv37gnP19fXR1uhki03NDRM6PaLbQ2KiARwiWX0+XyT7heR7DyT7ZequzH2fIVOyeEpRO60cxu37rQTESlQzc3NNDY2TgiOdu3axUMPPRS9W629vZ22trbo8w0NDXF35SUuJ9PT05NWq05bWxv333//tPebq/MVOrU4pRC50045TiIihau6upr29vYJ6xNbgFJtE9vKlLg81TGTieQdxQY8M82rmun5Cp0Cp0lsWblFAZOIiMxIYvfWXHR3eb3eWSWgZ/v5coG66qbhUO8hvvzUlzUli4iISIFSi1OadJediIiIqMUpTcnushMREZHCosApTbrLTkRERBQ4pSlyl90f3vyH6qYTEZFpiQws6ff7qa+vX+jizNp06tHT0zPrOs/0GH6/n+bm5lmdO5ECp2nYsnIL99xwj4ImEZE80dHRQWNjI42NjezZs4fm5mb27NkzpxPaRqZnSTUMQb6LjHM1Xa2trbM+htfrpb6+Pu5Ys6XAaZZ0p52ISO6qq6ujvr6e+vp6du7cSUtLC16vl3379s1Z8LRr1y527do1J8cqJLFT3MxG7Jx/c0GB0yxE7rT7QtcX2LF/h4InEZE8UFlZSVNT05Qjfaerq6srbgyn/v5+mpubqampiRttvLm5mfr6ehobG4FwN1OkNSyyrrGxMTr3XE9PD+985ztpbGykq6srui6ybaxI11pzc3O09SXxfInrdu3axaFDh6JdZH6/n5qamqTHji1n4rliu/Xa2tqi2xlj6OjoSFrPXbt2ResS2Sa2q+6jH/1oNOCNdIHW1NSwa9euCa8rhK9p5HWbLQ1HMAuaz05EJH/Ftjj5/X4eeOABtm7dOq1BISMjb8fq6emJThzc3NxMQ0MDra2teL1e2tvbaW1tpa2tjYaGBh566CEgHNB0dHTQ3NxMS0sLu3fvpqWlhfe///286lWv4oEHHuChhx6ipaUlZU5PR0cH3d3d+Hy+pOeLBHft7e3U19enPc2K1+uNK2dnZ2fcuWJfx8jcea2trVRWVkZfx8R6RuYCjKyPPUZraysbN27kK1/5Cn6/nze96U08/PDD9PT0sHv37rjXNcLn89HT0zMng5AqcJqFyJ1246Fx3WknIjIbf1GR+rnf+CzUvj/8+MBX4XsfS7rZEoC/uHB5RcvroPmnMypOsoAHSDphbk9PD21tbSknyK2srIxbrq2tja7v7+8HwhPuRlqi+vv7o60rra2tdHd3c+DAAerr62loaIh2+3V0dPDJT36SJUuW4Pf78fv9dHV1JZ10GMLTx0QCh2Tnq66ujpanv78fr9ebdndlbDlvu+22uHMl6urqoqWlJW6S4cR6Tqa9vZ2PfOQjANFr5Pf7k76uEdOpy1QUOM2C5rMTEck//f397Nu3j6amJrq6utixYwe1tbV4vV7a2trwer3RhO+qqir6+vqic9AltkRFtp1KTU0Nfr8/LviKdHPt3r07Lkeqrq6O1tbWaKAAl/OoknXTRcQGcMnO19PTE+02i21tigQhqeqRrJyJwWKsHTt2RFuSJqtnKvX19XR2dnLHHXdEA8Z05tKbq/n2FDjNUrL57A71HlIwJSIyHbEtRZOpff/l1qcEAwMD4VaniDRamzo6Omhvb6eyspI9e/bQ3d3Npk2baGpqAoi2jPT09MQlGEeCpqamJvr7+2lra5vVnG5NTU3R/CKv10tzczO1tbU0NjZGW2K2bt0arlZzM5s2bYprsamrq4t24830fD6fjwMHDuDz+WhpaaG/v5+mpiYqKytpbm5OGXgklvO6665Led7m5mb8fn80QJqsnrW1tdF8q+3bt8eV/QMf+EA03yo2CEulq6tr7oYlsNbmxE9NTY3NtJ/85CezPsYTZ56wtf9Sa2/82o229l9q7RNnnph9wTJoLuqciwqx3oVYZ2sLs97ZXufDhw9n5LgXL16c82M2NTXZ8+fP2+7ubrtz50770EMP2fb2dmutte3t7bahocF2d3fb3bt3pzxGZJtMiNT5/PnzdufOnbM6VlNTU7Ru58+ft9XV1bMuX6ZM91pPVpdkf4/AAZsiHlGL0xxTwriISP5obm6msbERn88X1+Kya9cu/H4/W7dujY4x5PP5kuZA3X///dFk7kxobW2lvb2dBx98cFbHibRYRVpw7r777rko3oJra2ub00EwFTjNMSWMi4jkj+rq6qQDLyZ2y002OGN1dTVer5eOjo5Zdeel0tTUFO1anI3q6uqMBXcLxe/3097ePqf1UuA0x5QwLiIiiVLd6SaZ5fV65zwYVOCUAckSxkVERCT3aeTweaKpWURERHKfWpzmQWRqlrHgGMXuYh5884NqkRIREclBanGaB8nutBMREZHco8BpHkTutHMbt+60ExERyWHqqpsHutNORCS/RaYrqa6uTrlcWVk55bQfsVO6REbLPnDgQEaGMUjG7/fP6/lyUUZanIwxTc7P7ph1DcaYOmPMzsnW5astK7dwzw33xAVNShgXEVlYHR0dNDY20tjYyJ49e2hubmbPnj3TnhC2o6MjGiQlLkfOMdWcdT09PfT09FBXV0dTUxO7du2KTk6bbnk6OjqoqamhsbGRrq6uadUBmPb5CtGcB07GmDqgw1rbCvicwKgawFrbAfiNMdXJ1s11WbJZJGH8C11fYMf+HQqeREQWQF1dHfX19dTX17Nz505aWlrwer3s27cv7eChra0tbsLdxOW6urq45cmOE9si1dHRAUBDQwMPPPBAWmXx+/34fD58Pl90ct7pms75ClEmWpx8QKSNr8dZvhvwx6yrS7GuYChhXEQkO1VWVtLU1MSBA+l9Lu/duzeutSlxOV19fX1UVlZGl2MDn6laq/x+PzU1NdTV1eHz+aJTwUQmwk2lq6uLPXv20NHREddCNdX5Ctmc5zg5LU0R1cBeoAaIDX2rAG+SdQVDU7OIiEzfod5D85Yv6vf7o61Afr8fr9e7oLk/kTIkE2mdin3e5/PR09NDV1dX0kDO7/ezY8cODh48CIRbvGK3m+x8hSxjyeFO11uXtbbLGDPTYzQBTQCrVq2is7Nz7gqYxODgYMbPEeveFfdyZOQIm0s34z/sp/Pw/J07Yr7rnC0Ksd6FWGcozHpne50rKioYGBiY9n5P9T3FRx75COPBcTxuD5+//fPcUHVD9PlgMDij446MjABE9x0eHubEiRMUFxdz4cIFiouLufPOO+O2iRUIBOLWJy4DjI+PMzQ0NGn5Fi9ezEsvvRTdxuv1Rh+vXbuWF198EbfbHbdPpM633noroVCIEydOMDY2xvDwME8++SQbNmxg8+bNSc/7z//8z6xbt47vfOc7ANx5551Tni9bzPRaJzMyMjKt90sm76qrs9buch77gUj7oxfocx4nWxfltF61AtTW1tpt27ZlrLAAnZ2dZPocsbYxf+dKZb7rnC0Ksd6FWGcozHpne52fffZZlixZMu39Dh87zHhwnBAhAqEAhwcO85oNr4k+PzAwMKPjlpaWAkT3HR4e5oc//CFNTU20tbVRVlbGj370o2jL0969e/H7/ezatYu6ujqKiorizpu4DODxeCgvL4+uT9aa8973vpeWlhaWLFmC3++ntrY2uv2lS5dYt25d9LnIvpE6L1myhCeeeIK2tjZOnjxJcXExa9as4YknnogeP/GcV1xxBe9973t5xzveMeH52PNlo5le62RKS0u5+eab094+Y3fVWWv3OI/rCHfX+ZynfUBHinUFT3faiYgkl4kx8To6Omhvb+fgwYPRu+r8fn90KIBk6uvraWlpob29HSAuLynZcltbGwcOHGDv3r3RPKKNGzdOSD73+Xxs2rSJjo4OWltbefDBB6PP9ff3R4OaZPvGitydl1iOxP0aGhqiXXmRrr5k55N4c97i5ARKu40xuwi3KDU63XW1znN+a22Xs+2EdYVMU7OIiKSWiTHx6urqpp235PP54oKS+vr6uDyixOWGhgYaGhrijhHJK0rU1NQULVfiOafaN9W5Jttv587kowHFnk/izXmLk7W2w1q7zFq7yfnd4axvdZ5rjdl2wrpCpjvtREQml2xMvIXW0NAQdwde4nIyPT09abfotLW1cf/9989o35mcM/F8Ek8jh2cR3WknIpJdUrXeAOzeHR3jmbq6urhWpsTlROm2ckXyjmIDnpne2ZfOfsnOJ/EUOGURTc0iIpKbEru25qqra76HQFjoIRdygQKnLLNl5ZYJAdN8jlsiIiIiqSlwynJKGBcREckeGRmOQOaOEsZFJB9Zaxe6CCIz+jtU4JTlMjFuiYjIQiotLaWvr0/Bkywoay19fX3RAVDTpa66LKeEcRHJN2vXruXkyZOcPXt2To87MjIy7X+Cua4Q6wxzV+/S0lLWrl07rX0UOOUAJYyLSD7xeDxs3Lhxzo/b2dk5rakz8kEh1hkWtt4KnHKQEsZFREQWhnKccpASxkVERBaGAqccpIRxERGRhaGuuhykhHEREZGFocApRyVLGBcREZHMUlddHjnUe4gvP/VlDvUeWuiiiIiI5CW1OOUJ3WknIiKSeWpxyhO6005ERCTzFDjlCd1pJyIiknnqqssTutNOREQk8xQ45RFNzSIiIpJZCpzymBLGRURE5pZynPKYEsZFRETmlgKnPKaEcRERkbmlrro8poRxERGRuaXAKc9NlTAuIiIi6VPgVGASE8bvXXEv29i20MUSERHJCcpxKjCJCeNHRo4sdJFERERyhgKnApOYML65dPNCF0lERCRnqKuuwCQmjPsP+xe6SCIiIjlDgVMBik0Y7zzcCWiEcRERkXQocBKNMC4iIpIm5TiJRhgXERFJkwIn0QjjIiIiaVJXnWiEcRERkTQpcBJg4gjjShYXERGZSIGTTKBkcRERkeSU4yQTKFlcREQkOQVOMoGSxUVERJJLq6vOGPNGa+2PM10YyQ6pksWV9yQiIoUu3RynGmPMfcB+oM1aeyxzRZJskCxZXHlPIiJS6NLqqrPWfsZa+2bgW8AeY8yPjDHvnmwfY0x1wvJu53dTzLoGY0ydMWbn9Isu80l5TyIiImkGTsaYDcaYB4BPA48D24Gjxpi9KbavAx5KWN1kjOkGepxtqgGstR2APzHQkuyivCcREZH0u+r2AF9KyHN6whjTkWxja22HMaYnYfUOa21bzPLdQLvzuAeoA7rSLI/MMw2SKSIikn7gtMNaeyGyYIzZYK09Zq19cBrn8jktUdXW2j2AF+iPeb5qGseSBZCY9wRKGBcRkcKSbuDUSriFKKIFuHM6J3KCJYwx9U4AJTlOCeMiIlJojLU29ZPGvIdwwFRHOLfJRJ5zksUn27fdWlvvPG4C+q21bU4iuB/YBLQ73XoNgC8SXMUcowloAli1alXNN7/5zRlUMX2Dg4MsXrw4o+fINrOp8/4L+/me/3tYLC5c3OW9izdXTPpnkTV0rQtHIda7EOsMhVnvQqwzZL7eb3jDGw5aa5Mm807a4mSt/RbwLWPMp621982iDAdwksIJB0wtzrpIoXzAhHwpa20r4dYuamtr7bZt22ZRhKl1dnaS6XNkm9nU2dvrpX1/O+OhcTwuD9tv254zLU661oWjEOtdiHWGwqx3IdYZFrbeKQMnY8ynARuz/EDs89ba+yfZtwGoNcY0WGvbrLVdxpgmY0w/0G2t7XK2q3W67fyRdZI7lDAuIiKFZrIWp6RDDaTDuXuuLWFda5LtJqyT3KKEcRERKSQpAydr7RPzWRDJD0oYFxGRfJbuAJj3GGMOGGP6jDEvGGOOZLpgkps0wriIiOSztAInoNHJLn/QWns18HAGyyQ5TCOMi4hIPkt3HKfI4Jd9zhx1b8pQeSTHpUoYV96TiIjkg3QDp10QnuzXGPOnhOeqE0kqMWFceU8iIpIv0u2q6zPGvNFpbeoBNmawTJJnlPckIiL5It0Wpx8TnpC3f6oNRRJF8p4iA2Uq70lERHJVuoHTgckGvBSZjAbKFBGRfJFu4OQzxjxOeJoUAKy1H85MkSQfaaBMERHJB9NKDgeWAheJmYpFZCaUMC4iIrko3eTwjcD9wCecEcXVbSezooRxERHJRekGTs3W2u3AUWd5WYbKIwVCA2WKiEguSnsATGPMPcAyZ0gCfwbLJAVAA2WKiEguSitwstZudwa+PA9UOq1PIrOigTJFRCTXpNvihLX2M5ksiEiyvCcFTiIikk3SynEyxmwwxnzaGPMjY8wXjTFLM10wKTzKexIRkWyXbotTO+EhCR4AtgIPO79F5ozynkREJNulGzg9bK39tvO4wxjTlKkCSWFT3pOIiGSzdAOnWmPMXsJz1VUSHkn8i6ARxCWzlPckIiLZJN3AaUdGSyGSgiYIFhGRbJLucARPZLogIslogmAREckmaQ9HILJQNEGwiIhkCwVOknOUMC4iIgsl3bnqAND4TZINNEGwiIgslHQHwHy3MWYf8JCzvDejpRKZhAbKFBGRhZJuV12ztfZOY8yXnOVlmSqQyFQ0UKaIiCyUdAOnC8aYe4Blxph3A/4MlklkShooU0REFkJaXXXW2u2EW5nOA5XOskjWUN6TiIjMh7RanIwxP7LW3hmzvNdae3fmiiUyPRooU0RE5kO6XXUmYVk5TpJVlPckIiLzId3AqccY8ydAB1CPcpwkCynvSUREMi3dHKcPAReAZuC8cpwkFyjvSURE5lq64zgtBboJtzj1O3fWiWQ1jfckIiJzLd2uuh8D7UB/BssiMqcmy3vaf2E/3l6vuu5ERGRa0g2cDlhr789oSUQyIFXe02hwlPb97cp7EhGRaUk3cPIZYx4Hokki1toPZ6ZIIpkTyXuy2GjekwInERFJV7qB066MlkJknkTynsaCY8p7EhGRaUsrcLLWPpHpgojMh0je075f7GP7bds13pOIiExLuiOH3wN8CNhIeNoVa63dnMmCiWTKlpVb8Ff444ImjfckIiLpSGs4AqDRWlsLPGitvRp4eKodjDHVCcsNxpg6Y8zOydaJzDeN9yQiIulKN3C64Pzuc8ZwetNkGxtj6oCHYparAay1HYDfGFOdbN10Cy8yFzTek4iIpGtayeHW2s8YY/4UaJxsY2tthzGmJ2bV3YTHgQLoAeqAqiTrutIsj8icSTbek3KeREQkmZSBkzHmizFDDjQbY2zkKcKB0KFpnMdL/OCZVSnWiSyI2PGelPMkIiKpTNbi1Bbz+GfAyQyXRSQrJMt5UuAkIiIwSeBkrY1NAP9d4B5r7cAMz+MHKp3HXqDPeZxsXZQxpgloAli1ahWdnZ0zPH16BgcHM36ObFOIdYbJ6+0edeM2brDgwoX7JTedfZ0cHT3KkZEjbC7dzMaSjfNb4Dmga104CrHOUJj1LsQ6w8LWO90cp/PAMWPMvsiKaY4cvheIZNz6CE8WTIp1UdbaVqAVoLa21m7btm0ap5y+zs5OMn2ObFOIdYbJ672Nbdzce/OEnKd/2P8POd19p2tdOAqxzlCY9S7EOsPC1jvdwKnF+UmLMaYBqDXGNFhr26y1XcaYWuduO7+1tsvZbsI6kWyQOMeduu9ERAQyNHK4tbaN+BypSOtR4nYT1olko8iQBeOhcQ1ZICJSwDRyuEgakg1ZAJqqRUSk0KTbVddora01xnzaWnufMeZLGS2VSBZK7L7TsAUiIoUnIyOHixQCTdUiIlJ40g2coiOHA5uA7RkrkUiO0FQtIiKFJ92uuo3AUYgGTyIFT3lPIiKFJ93AqcYYcx/QDbRYa6cz3YpI3lLek4hIYUmrq85a+xlr7ZuBPUC9MeZIZoslkpuU9yQikt/SHY5gKeGJfRsJT59yXyYLJZKrUo33pO47EZH8kG5X3ZeBbzqtTiKSQrK8J3XfiYjkj3QDpx2E85zeHVlhrf12Zookkts0XYuISP5KN3DqcH76M1gWkbyk7jsRkfyRbuB00Fp7f0ZLIpKn1H0nIpI/0g2cfMaYx4HoLULW2g9npkgi+UfddyIi+SHdwGlXRkshUmBSdd+JiEh2Sxk4GWO+GNOqdDdgEzZ5ImOlEslzGnVcRCQ3Tdbi1BbzeG+mCyJSaDTquIhI7kkZOFlrH455rNYlkQxT3pOISPZLa8oVEcm8SN6T27gnDFvw5ae+zKFeTREpIrLQ0k0OF5EM07AFIiLZT4GTSBbRsAUiItlNXXUiWUzddyIi2UUtTiJZTN13IiLZRYGTSJZT952ISPZQV51IjlH3nYjIwlGLk0iOUfediMjCUeAkkoPUfScisjAUOInkgVSTBsfOfSciIrOnwEkkD6TTfXfvinvZxraFLqqISE5T4CSSJ6bqvjsycmQBSycikh90V51Inkq8+25z6WYgP+++Gw0E+WzH8zx3+uKk250bHOXv9v+aMxdH5qlkIpJv1OIkkqcSu+/8h/15e/fd/mfO8NmOIzx/ZoB/eG9Nyu32Pn6Cz//4BUaDIe5/63XzWEIRyRdqcRLJY1tWbuGeG+6JBkfJ7r7LB0fODADQe3F00u3ODow62w9mvEwikp8UOIkUkGSDZ+ZD11332SEA+i+NTbpd/9CYs70CJxGZGXXViRSQxO47IC+67iKB0PmhyQOn805gdaL/EqOBICVF7oyXTUTyi1qcRApMbPddPnTdBUOWo+fCLU7+4XGCIZty20iLU8jC8b5L81I+EckvCpxEClg+zHv3kn+Y0UAIAGvhwvB4ym1jW6S6e9VdJyLTp646kQKWD/PevZCQr9Q/NEZleXHSbWNzoJTnJCIzocBJpMDl+rx3iS1H51MkiA+PBRkZD13ez0koFxGZDnXViUicXOu+SwyA+lMkiCcGVGpxEpGZUIuTiMTJte67SAC0emkppy+OpLyzLhJQRbbr7h3EWosxZt7KKiK5Ty1OIjJBLg2c2eMETls3VgKpx3KKtDhtWllOZXkxQ2NBzkwxYKaISKJ5C5yMMbud300x6xqMMXXGmJ3zVQ4Rmb5s7b7zXxrj3OAYZcVurluzBEg9llOkxWlZWTGbVpQD6q4Tkembz666JmNMA9AMYIypBrDWdhhjfMaYamtt1zyWR0TSlK3dd5H8po3Ly1leXgLA+UvJhyOIBFSV5cWUFxfx+LHzdJ8d5LVXL5+fwopIXpjPwGmHtbYtZvluoN153APUAQqcRLJUunffHeo9FBdgZVKkxWjTisUsc7tFHlgAACAASURBVIYgSNni5ARUy8qKKS8JjxiusZxEZLrmM3DyGWPqgGpr7R7AC/THPF81j2URkVmKdN+Nh8bj5r2bz1ao2MCpstwDTJLjFNPitHbZImd/DUkgItMzb4GTEyxhjKl3AqgpOflQTQCrVq2is7MzcwUEBgcHM36ObFOIdYbCrHcm6nzvins5MnKEzaWb8R/2s//CfkaDo1gsY8Ex9v1iH/4K/5yeM9Zjz44AMHr2OEdGTwDw0rkLcfWM1PvXx8Lbvnz8BUrPh9M7D5/sy8u/g0L8+4bCrHch1hkWtt7zEjg5AVC/01XXB/gAP1DpbOJ11sex1rYCrQC1tbV227ZtGS1nZ2cnmT5HtinEOkNh1jsTdd5G/PG8vV7a97dHW6G237Y9o913f3mwExji7dtuYU1FKff9dzvDIXdcPSP1bj3yS6CP22u38GpfJX/+8x/RPxJi6223U16SXyOzFOLfNxRmvQuxzrCw9Z6vT4sDhPOYADYBLc66WmedD+iYp7KISIbMZxL5eDDEi32XMCacHF7sduEycHEkwHgwhMcdf9Nw9K66cg9Fbhcblpfx/JlBjp4b4vorK2ZdHhEpDPMyHIFzt9x25666bmttV+QOOqfbzq876kTyQ7pjQM12KIPjfZcIhCxrly2i1OPG5TJ4y8IJ4v4kd9ZFxnFa5myzacViQEMSiMj0zGeOU2s660Qkv2QqiTw2MTxiWZmH/qExzl8aY8WSkuh6ay3nhy7fVRe7n+6sE5Hp0MjhIpJRke67P7z5D6MB0ly0QiULnCqdIQkS56sbGgsyFgyxyONmUXF4KIJNKyODYOrOOhFJX35lRIpIVkocA2ouWqG6e8MBT3yLU/KxnGKHIohQV52IzIQCJxGZd8mSyL/81JeTDqiZyuUWp/LoumiLU8JYTrGJ4RE+J3DqOTdEMGRxuzTZr4hMTYGTiCyIdFqhgKRDGVhrLwdOK2NanFKMHp6YGA6wuKSI1UtLOX1xhFPnh1lXVZaBWopIvlHgJCJZYTpDGZwdHGVgJEDFIg9VMd1vlWWRHKf4u+oigVNsVx2E85xOXxyh++ygAicRSYuSw0Uka6Q7lMH+Fx6juOonXLHqDMZc7mKLtjhN6KqLv6MuQnlOIjJdCpxEJGtFuu/cxh2XRP63T32c4hX7eWnRZ+PuwIvOV5dGcjgocBKR6VNXnYhkrVRJ5AEbwBiLJRBNIj/Ue4hHzv43rkVuzl+KHwk8kiy+LFXg1KshCUQkPQqcRCSrJUsiN9ZNCJt0KIOydW56z/8BcHt0n2iLU0JXnW9FZCwntTiJ5IpDvYfYf2E/3l7vnM59mS511YlITtmycgtl/X/A2Nk381ev/sKEATUxAQb4NXD5A/bkpWeB+OEIAFYvLaWs2E3f0Bj+hLwoEck+kS9J3/N/jx37d8x4yqbZUOAkIjlleCzImbOrCZ1/I2/edCsQnwuFLWJ4YAOPvdQV/YA9Vvx3uBYdn5Dj5HKZmFYnddeJZJvE2QQiX5IsNu6GkfmkrjoRySlHegewFtYvL8PjDn/3i82F+uJ/wuDwGn526tHoBywEKSrrobKseMK4UJtWLObpUxd57vRFatYvW9jKiUhUsuFIIl+SxoJjceO9zSe1OIlITvnGoy8CcKuvKm59ZCiD5UXXAOBbfBPF7mJcuLDWTeCSjxeHnmXH/h18oesL0Wb+WzeGj/Nvj72ItXZ+KyMiUalalxJnE3jwzQ9yl/euGU0OPhfU4iQiOePMxRG+3XUKY+Ce2zcm3cZbFs5jWuG5hgff/CD/+shevv3UVZSFNnHo3MEJH8Q3XlWNd81POdy/np+9cB23b14+n1USESZvXUqcTWDLyi34K/wLEjSBAicRySFfeeQoY8EQb7thdXSuuUSx89W95uot/Lqkn7bhYZZVeiZ8EFcUV/AHP24m5B2jbKmb//3TCm7f/N6k07yIyNxJfI8la12654Z7JgxHkg0UOIlITrhwaZx//eVxAD70+k0pt0ucr25gPNz9VllWPGFcqMuJpiEwlsP9T9D29Hp2H/r4hGleRGRuTLd1KdvefwqcRCQn/OujxxkaC/Laq6u4ca035XaJ89UNjoUDp0hAlfhBHPmwBjeXLvn4WtfDjNmJeRVqhRKZmVxuXUpGgZOIZL2R8SBf/dlRAD78+qsn3TZxvrqBscstToliW6A2Lb6R5uf7ePYoeH0egjaQdIBNtUKJpJYYJOV661IyCpxEJOs9dPAk5wbHuP7Kpbz26qpJt02cr24w3PA0YbqViNgP63dXP8k3Hw9RW3I/W687HzfNS7K7e9QKJXJZsiAp11uXklHgJCJZLRAM0frTbiDc2mSMmXT7ZWXxLU6RrrrEwS+TaXqdj70HTvDjXy3if955F6srSgGSfkNO1QqlYEoKRTpdcLneupSMAicRyWrff+plTvQPs6GqjLdcv3rK7aN31SUkhy9L0lWXyLdiMW+9fjU/eOo0X/nZUT7xtuuA1JMNJ/6TANSlJ3lppl1wyd47uU6Bk4hkrVDI8qX/6gGg6XWbcLsmb22CmBanocQWJ0/KfWJ96PWb+MFTp/n6L4/zB9uupsIZFyrZZMOJ/yRSDdinVijJZbPtgsvl1qVkFDiJSFay1vJX3z/Msy9fZMWSEt5dfWVa+8WO4wQwOI0WJ4Ab13p57dVV/OyFPu79xkH+8X1bKfW4J2yX6pu0uvQk1xVqF1y6FDiJSFb63MNH+OrPjlHsdvF3229KGrwkU1bsprjIxch4iOGx4OW76tLIcYr4X++6gYYv/YKfvdDHR/7tCf7hvdUUuSfOUJX4T2K2XXoKpmS+qQtu+hQ4iUjW+cdHjvLZjiO4DHz+t7dwx+YVae9rjKGyrJjTF0fovzQ2YRyndGxYXs6/fPAW7m75BfsPn2HXt57iMw034kqjq3CmXXqg/CjJrHSCpELugkuXAicRySr7Dpzgr753GIDd77mRt1y/ZtrHWFYeDpzODYxGhyPwLkovxyniujVL+er7b+F3v/wo3+o6yZLSIj719ldOeVdfonS79JQfJZmUbpBUyF1w6VLgJCJZ4z+ffpn7vvUkAP/jN15JY+1VMzpOJBH8WN8QFqhY5Ena1TaVmvXLaP39Gj74tQN87efHWFpaxMfrr5lR8DRVlx4oP0rmzkzzlAq5Cy5dCpxEZMENjQb42/3P87WfHyVk4WN1m/nA7RtnfLxIInj32SFgevlNie7YvILP//bN3Pv1g3z+xy/w3OkB/vKd10fHeJqpTOZHSeFJvP6zyVNS69LkFDiJyILq/HUvf/bvT3PKP4zbZfjYm67mo2/aPKtjXg6cBgHwlk2vmy7RW65fzed+62bu//ZT7D98hl9097Hrra/gd25Zl1beU7rmKj/q3hX3so1tapnKU4lBUmLL5Ns3vV15ShmkwElEFkT/0Bh/9b3D/PsTpwB41RVL2f2eG7n+yopZHzuSCN7dGw6cks1TN11vv+kKajcs48//79N0PNvLJ//v0/zHoVM88O4buXrl4lkfP5mZ5kcdGTmStJsPmHAsBVfZI9m1mCqh+94V9xI8E4y7/gajPKUMUuAkIvPqhd4B/unnx/l210mGxoKUFLn44/pr+ODtGyfPQ+o/CqEgFJWAZxGUVUGKXKNKp4Xp6LlwV92Ud9QN+2FsEAKj4R9jYOV1EzZbU7GIB3+/lh88dZpPfecZHj92njs/+1Pe8qrVvO81G9i6Ydm085+mMpP8qM2lmycEU9/t/i7f6f7OhEBKOVQLI5073GDi9UkWJG9ftT3u+r9909t5+6a36xpmiAInEcm4YMjyk+d6+adfHOO/j5yLrn/9NSv4y3e+ivVV5TA6CC92wctPwpmn4cJJeN93LwdHD/1/8PKhywe96lZ45z/A8qsnnC8SKI0GQsAkOU6X+uEHfwJPfyvhABvhozHn+uF9UF4Fq2/EXFnLXTeu4farl/PAD5/loYMn+f5TL/P9p17mlWuW8r7XrOedW65Me9ypmZgqmPIf9uNd5Y37Z2qxSbv4ZjPGlAKsiWbSapTqDjeYeH0Su283l26eNFdJ5p4CJxHJiJHxID974Rwdz57h4Wd76R0YBaDU4+I3b17L+16znleYk/CLT8KJx6H3GbCh+IMMnIalznAEFWth5AIEx2D4PJx4FL70WnjTp+DWD4HrcmtVYqCUdNTw4z+Hfe+DoV5weaB8BRQVQ1Fp+FwR48PwWEt82Sp9VKy9hU9fVcvHb30D//JMkH977EUOv3yRXd96ir/+3rO87toV1F23kjdcuxLvHHQVTiU2mOo83DnhnynAd7u/O6H7ZqZjTCVbl88B1kwDIpi61WiyYQCmSuj2H/YD6oabTwqcRGROBIIhnjs9QNeL5/np8+d45IWzjIyHMIS40pzj95a+xHtW93LNDbdQtvWt4Z2OnYcDXwk/dhXB6hvhympYdT2sviHcHRfxW1+//Hj4PPzn/fCrf4Mf3Q+rXgm+bdGnEwOlpPPULVkDY0Ow/rXwzr+HyhR38Vkbfv700+EWr1Nd0N8T/nnym6za/s/8yZ3v5A/feDWPdn6Prl/9iu/3reI/nxzl+0++jNtlqFm/jG3XrqB2fSU3rq3IaGtUrMR/pslaJWY6xhSk11qVbN10AqzprNt/YT/eXu+kgc1Mjj+bgCjZ6zSdO9ymSujuPNw5R38tki4FTiIybaGQ5cT5Szz78gDPvHSBg8fPc+iEn0tjQQDe5voln3Q9Q83il9hkX6Q4OARjwItA6WnY+t7wga7YAnX/E666BdZsgeKy9AqwaBn85pfgurfD0Z/GBU0wSYtTKHS5ZapyI3xwP6x8ZVxr1QTFZbDldy4vB8fhzDNw8nE48RhcGW4ZKPW4ef3Qj3j94Nf5eAkEXMWcdK+ja2QNh19cx0+ObWKPfQUet+GVV1RQs24ZN11VwStWL8W3ohzPDMaZmq5krRIzHWMq2bq5DrCmu240OEr7/vak2+3cupM9j++Z0fFnExAle52mMwyAWpKyjwInEUlpeCzI8f4hjp27xPG+IY72XuTsy8cYP3eUlcHTrDVn8Zkz1JmX+fDYx1hZtY7qdcv4iP8IG15+GALOgcpXwurr4Ypq2PDayycoLofbPzbzAr7irvBPxMgFAJaVxd/lVlleHA6a/uk3YMMd4XN6FoXLNF1uTzjgu2IL3LIj/rl1t4XLcPpJivwvsiH0AhvcL/BuNzy/5BY+WnQLvz59kRdOvEzT6b/g+KOr+aVdwctmJa5lG/Be4WPjymWsX17Ohqoy1leVUzHNEc/nQrrB1HwEWNNZZ7Ept+t4sWPGx59tQKRhAPKLAieRAjUyHuTswChnB0fpvTDMuXNnGew7xYj/ZYIXz+AaOsPzwxX8MHQrAJvNSX5QfD8eEwQX4Z8YP3jvSrw3vCG88Pz7oe8OWPUqWPkqWJz+XHMz9kIH/PuH4ZXvYNFdf8sij5vh8XAL2LLyYjj0r3D8Z+G7817zR5kpQ/XvhX8gHED1PhtunTrzNNesuI4f3noHAyPjvHDoEW7+z8fi9x2A0HOGs89VcM/Yn/CU9QFQt+h5Xlk+gFm8kpJlV1BeuYbKFWtYVVHGiiUlrFhSwuKSzH+Uz6a1aqYB1nTXjQXHUm5Xt66OrjNdMzr+bAMiBUn5RYGTSI6z1nJpLMjASIALw+NcGB5n4KKfQN9Rxob8jA32ExrqZ+DcSc49/k2Kxi7w6fHtnB4Jd1991vN/eKvrMUpMYMKxO4u28Oulb2R9VRnXVXjxPBlkbNEKXMs2UFS1AbzrYNkGqNqMd/UNl3e85k7gznmpf9SSK+BSXzhnqvp9VJYXc8o/DECV6xJ0/EV4uzf/FZRkZtylOKUVsO7V4Z/YYpZ6uPmGG2FRK5w/Bv7jBPuPEew7RtHQy6zCz81Xr2N8YAnH+y7xtkAH7x58BAaB0+FjBKyL8yzml6FX8kfjH6Gs2M3qxUXcy15CJV4ujLnoO/ksRYurKFm8jEVLvBQvu4KKxUuoKPOwpLSIxcVFczp4Z8RcB1jTWbfvF/vYftv2lNttXrZ5xsdXQCQRCpxE5om1ltFAiJGxACPjQUYClpFAkLGBfkIXTjE+eonAyCUCo0MERi8RGh0kMD7Oo967GBoNMDgWoP7MP+IdfZmi4DDFwSFKQpdYFLpEuRlhX2Abnwu+B4DXup7i68UPpCzL50ffiMd9JSsWl+ClmJLRACOuckZKqgiUrcAsXkVxxWpet6GaH1dvi1QA3n6aYs+ieXi1ZmDVK+GWJnj0i/DDnSwr+wSn/GCApY9+JhxUrX8tXP+ehS5peGiDm+6OLrqdH4LjMHCav1yyBtxFWGsZeOQY/UcrCQ304r7US8nIOcqCF1nBRZZ7xinFxaWxIP7+8zSU7oNLzkFfiD/lvWMf4QehcBD3u+523u/+Ty65yhh1lTHqLifgLiNYtIiR4kp+smYH5SVFlBW7ufFiJyVFBndxOe7Scjwl5XhKy/CUllO0ZAXF5V5KPW7nx0VpkTutgCzdQGQ66/wV/gkB2mTL01knErGggZMxpgHwA9XW2j0LWRbJLaGQJRAMEQyMEwwFCI6PEwyMEyheSiBkCYYsoQunsOPDBAMBgoEAoeAooWCA4PgYJ06d5SfP9TIeDGFGzlNx5jFCwXFsYAwb+R0ah+A4T1e9hYuuCsaDIa4518GagacxoTFMcBwTGscVGsMVGucl12q+UvYBxgIh7PgIXxj6E4rsOMWM4bEBihmjlDG8Zpw/HvsQ3w69DoD3u3/Ipzz/krSeo7aI5tFXRpf/qPi/eIXrRPxGTpfZVSWDXL14MRWLPFzt3sjpcxsY9ywmUFyBXbSM88MhVl11NcVLKvn2lrtYWrUm/A9upBbcxZR6FjHp7GvGhPOCstm2++DpNnjxF9xV+QhPcxPVnhdxPf5lMC54656Ug2ZmBbcHvJcnNjbGsPSOD8EdH4rfLjAGw/3cFgry7NIrGBwNcO7cWU4d/GPGB/s5f/o43uIQ7lE/nvEBioODrKq8kmuDS7gwPM660fNsMi+HjxVyfsbDiy/bSu596a3RUz1e8r9YYS4mLe7/P/6eaLD+JtdB/sHzOQYpZhQPYxQTMB7GTDFB4+HPF/8FY8VeiotcNF7ax/rAUUKuYqzbg3V5sO5irKuYvvKr+fWqt+Jxuyi3l3jV2e+DuxiXuwhTVIxxF+Mq8mDcHgaXb8GWL6fI5aJs+GX6Tz7Hs11u3EUeXG4P7qIiXEUeiopKoWINRS6D22UoCgzhdhfhLirCXeShyF2EyzDng5dKflqwwMkYUw1gre0wxviMMdXW2q6FKMu5wVFePnWci6cO89zjgfA3a2uxWKzz+OKqW8KrsSw++ySu8QGsBQiFt7UWbIiRsjUMVVyDxeIevUhF7+NYZ/wXay3GBrHWYi2cW/lqxjwVhKxlWf+vWDT4YnisGGudfcLnHvVUcGLlG8PlCVk2v7gXnMfhbULR/U4su4W+xdcQspaqi8+xrv9nMc/jjEUT3uenV32I48fH+PmlZ9n68jdYPHYWY0Ngg2BDGBvCWjhWdgNd3nqCIVg2eoo7z341/DwhZ/vwb0OIby77EKfdawhZy50X27hp+LHodi4ub3fCtZb/XfZxgiGLDQX4ytAfYQhv47JBXIRwO8ufsb/Pf9g7CIQs22nnU+6v4SZEsbFx1zFoDZtGL9+y/oPi+3ml63jSa/6rwJt4/1PLALjJvMB/lPyPlH8fDxxezjN2AwCfLnqY24s6kx8z5ONXfe8GwBBic+mxy08mfB5XLYJ1JWWUelyUh67gxPA6Aq4Sgq5SAu5SQkWLCBUtwnrK+LPrrqW8tJjyEjfjvX/KC2aU4tLFlJRXsGhxBYsWe/GULaVh0TIaSpbEnOXuuHN2dnZSs23bxIKXzn6Kk6yxyAt1fwH/8Qf8zsUv8/fs5pPur4X/7m9pmlkyeDYqKoYlq4Hwn9aSUg9L1l4Baz8FwLHOTm5OuNafil24dDPBgfsYHrjApUE/o0MXGB8ZYHx4gJGgi0+vuYHB0QDDY0FO/fqNnB3zYwLDFAWHcQdHKQqNUhwaoWjpSja4yhgZD1E1HqLEBighocvX+Zh6vneIC84b4SOex7nV/XTSqv0geAtffPYaANaaXh4pSf2d+vfHdvHT0E0A7Cz6JvcWfWdCSxvASbuc20c/H11+oqSJJWYwuhyyhkD404fPhrbzNd5BkcuwzXTxP/hS5JOJkAl/KtnwpxkfW7yHYdcSXC7Dh4dbuCZwBGsMFjchE97OGhfPl9zAd72/i8tARcjP7/d9DmsMGDfW2Q5jsMbNf1X9Fr2LfLiM4fqB/+aawcewxhUO/An/tsYw4vHyyyvfx8kXx3hs5Dle/dI/4QmNhb8YGONsb8Dl4iVvLb0VN+Ey4B0+zvq+Ry4/b8LnNs5+PWvfhXWX4jKw5uwjLBo9F/O8K/p7uGwN55fX4jJQFBhm5ZmfxmwXHm7DOMceWL6FQGkVGCi72MOioZPR58AV/mUMoaJShlbUhIsPLDn7BIYgBqdOGIwLAqUrOHspYcy3ebSQLU53A+3O4x6gDliQwOlHz5zmme98ib/x/CMcSb7NhpFvRB9/r/gTXO86lnS7bwTeyCcC9wBwg+nhuyWfTHne3xj9a552EkD/puhBfqfoJ0m3ezK0kd8biyTXWo6V/k3KY/7H+Af5t2D4D+p33D/mbs8/ptz2t7vrAANHe3h78Xe4IUWdenovsi9wXbROnyzZn/KYf37+bTxtw39Wbys6wg1Fh5JuNxYa47mLA9E6rSs9lfKYJjDMpWA4yTfkJpyc7AhYF0HcBI2LoHGzcrEHt9tNkdvgH13JS6FRgsZN0LgJURT+4DNuxjwref2aFXjchtXBIL/qey3WVYR1eQi5POA8xu3hXWtv4K1la/C4XSw738jB4Zrwt15P+NuvKSrGXVxKUdlyvnXlbZQUuSgpcnH6/H48JYvwlJRRXFpKcUkZruJF4C7hz1wu/ixai9cDn0hZ/xvilt6Xcjtx3PQ7cOCrVJw6wCeKvoF1FUFpFbwh9WtccMoqcZdVsngVJMv2uil24U3/lPIwf+T8AGDfCIH7sIFhRocvMT42wtjoMOMjw4yPXeIby65nLORiLBCi9NT9PDV4mlBwDBsYJTQ+hnUee0rX8adV1zIeDFE0UsmBF9+NCQYwNgDBcVx2HBMKYEIBNl3hw+VZQSBoKb24ll8P+PC4CX/5skHcNhwO9bsquaKilEDIEghZxoMeRqwHNyE8JojLWFwEgSA2FGTY+byxriFWFPsvVzj+uxo9ZweJtMVVeXp4hfu5pK/TiUtF/FfvWQDWmTN8tuSRlK/p587ezM9C4fzDjUW/4Nai7ybd7nhoJe9/4Y7wwtFumku+RoW5lHTbT4//Fl8KlgBwl+uX/H3x55NuB7Cjaz0Xnb+Kf/X8H25xP5N0u+8Hb+Ej4x+L1umnJR9PeczfHbufR0LhT7L7iv6ND6Wo04uhFbxj7HPR5SdL7mFpkjp9I/AGfrT6wzSmPGNmLWTg5AX6Y5arUm2YaVXlxSypXM3hwavDzcGAdSLc8PvEcMuqZeAyGOD0heswgYrwNwZnW+tE76HSa3n90hUYA6sD4xzqf7VzHKfJwbicx4bq9etZX7wmHF1fqOHAcGRQPOd4ToR9vng1v7XqKueLhOHRk++KHivybSFyzGsrb6F5qQ9jDFcMjfDL/nFwuYl+s+Dyt5E/XX8tx44d5epNmzh3+gM8On4+vI3LhTHhb0rGuFi5ZBO7l9+AMYbS8fUc6P0bjMsFxo1xuZ1vGW6M282frryFUIkXY2DJwDKeGX4/xuXCuNy4XC6MqwjjdlPqKecHVa/A7TK4XXDi/E9wud24nOfdRR5criJcbje7SpbyidKycDM79YR4AJe7CIyLImPi/ojj71N6Y8pr/lJnJx/YdkvMmnek3LY2bmlTyu0mWHVr+tvK3HG54G176Pv6PfzAfysPld/Ev3/wVeGxnyRzjAFPKcZTSumiZZN3+/p+Y9JD1cct3ZZyuxvjlm6ls/NOtiVpVV0D/DxuzbH4DULBcDd+MMDHcPFHxkMgZAmN3k7/yI7wc4EAwWAQGwoQCoWwoSBtFZsI4iIYsnj6P8fzIxcI2RDEbGNDIVaVVPJV73WEQhYzNsQTp78AoSDWhrChcCt/ZNvfWvEa3lmykqC1VPYHefTC9U4vSCjcE2FDmFCI0aJy7rviFbzQ3c3GjT6eOv5BikKj0d6HcG9I+PFVFbfRvMSHBVYMjvHLc3djnN6Sy9uHf961xseYKSFkLf6zr+Px0XXRngoTGT3fWoZLX8Fdy9ZgrWXJeDFdZ+/ARHpAwHlsMVh8a9bh9qzAAkUDV/PkcI3T++D8h3XOf95VyS1XVoY7RbB091/HIjvslA+c/7YEytaxfNHCdasqORx4y/VreMv1O+nsvCXpmw5gX9zS5G/k341b866U21ZPspTozXFLqb8Bxt+/8wogdSLsq4HOzlNse/0m4KOTnv+yq4B0uztWprkdsHLy+l82PyMuS467sob21/87j/z7M7yx1BW++08kFZcbl8uNy1NC3KhZiyqANLuyr9ya/vlu8qW54Tom+1L3OqDTnmDbtquBv065Xfz/heuA30y5bfzXvZtSbBX+QtkQt6Y++YYTjnJLiq3Cbo9b+nHSba4lnHqwUIy1duqtMnFiY3YD7U6OUwPgS0wQN8Y0AU0Aq1atqvnmN7+Z0TINDg6yePE83KacRQqxzlCY9S60Og8HLA8fH+eGijHWLy+cekPhXeuIQqx3IdYZMl/vN7zhDQettbXJnlvIFqe9XO4F8QEdiRtYa1uBVoDa2lqbqjVornR2dqZsccpXhVhnKMx6F2Kd30ph1rsQ6wyFWe9CrDMsbL0zPzlSCpE76IwxdYB/oe6oExEREUnXguY4OS1KIiIiIjlhwVqcRERERHKNAicRERGRNClwEhERPY7LNwAACLVJREFUEUmTAicRERGRNClwEhEREUmTAicRERGRNClwEhEREUnTgk25Ml3GmLPA8QyfZjlwLsPnyDaFWGcozHoXYp2hMOtdiHWGwqx3IdYZMl/v9dbaFcmeyJnAaT4YYw6kmpsmXxVinaEw612IdYbCrHch1hkKs96FWGdY2Hqrq05EREQkTQqcRERERNKkwCleIc6dV4h1hsKsdyHWGQqz3oVYZyjMehdinWEB660cJxEREck6xphqa21XiucaAD9Qba3dk2pdJhR8i5MxpnqS5xqMMXXGmJ2TrRORhTPVezIf38dp1LnJ+dkds2535Ln5KudcS6PeE+qYz9faGFNtjLHGmG7np8VZnw/Xug54KMVz1QDW2g7A77wOE9ZlqmwFHThl84XJlAJ+E+oDd+LzOf/Pdar3ZD6+j9Oocx3QYa1tBXzOMkCTMaYb6JnXAs+RNK9bXB3z/VoDldZaY63dBDQCkfdyTl9riNY5VfnvJtyyhLNNXYp1GVHQgVM2X5hMKNQ3oT5w8/qf61Tvybx7HzN1+X0x63qcZYAd1tpNzt9ELkrnuiXWMa+vdcK1rLXWRt63uX6tp+IF+mOWq1Ksy4iCDpymsKAXJkMK9U2oD9z8/ec61XsyH9/Hk5bfWtvqBMQA1cAB57Evl1tQSe+6JdYxr691hPPFZ1/Mqly/1llNgVNhKdQ3oT5wC+efqzicVsauSHKttXaPExBXxbQw5pVCqGMK9dbayBelQngd/ECl89gL9KVYlxFFmTpwNkiRp9GT5rfpVBdhXi7MAquPfY1i7lioN8bU5VBrRNoS67jQ5Zkvyf65Outz5VpP9WGZj+/jdP9B1Flrd0H0s7DfWtvmbO9LsU82m7TeKeo4b/9MMyTd8ke74vPkWidljPE6AeJeIDJquA+IfE4lWzfn8jpwivlGnbZsuTAzNUWwmLdvwtnUO1c/cOfoWuf6P9ek79Ncfx9PYao6Y4xpigmC6wi3KEa63jcBLfNa4rkxVb2T1fFAsn1ySDrXOvF9mg/XOjK0QK0xpsH5TAJ4GKix1nYZY2qdv21/5ItfsnWZkNeB01Sy+cLM1BTBYt6+CWdZ75z8wJ2Da53z/1xTvU/J8ffxZKaqs7N+tzFmF+HgudHZp8kY0w9051qdIe1rPaGO+XytYzbtSdgnp681gPM/uS1hXU3M4wmffzNpLJkJDYBZYJxWhR7AF/kjM8YcjPxBOoHTLmttc8I+/c4+GRtULJPSqPeEOibbJ5dMVmdzeSiOfi7/c+3Ih2stIpJJCpxERERE0qS76kRERETSpMBJREREJE0KnERERETSpMBJREREJE0KnERERETSpMBJREREJE0KnERERETSpMBJREREJE0KnERERETSpMBJZI4YY6qNMbszdOw6Y0zSuePm8rzOHFcPOfM4Lpjp1mmy12c+pDp/Jv8mpjr3Qsi16yYyEwqcRGZhvgIMa21H7PyB6Zhh2eqttY0xk15POKYx5rzz2zuD46c0m9dyJq/PXJYh9vzzHXTOVd3nS+zrk2tlFwEFTiKzdfdCF2ASMylbymAoJlDqB/yEJweeS9nwWs5FGbKhHtlMr4/kNAVOIkk4XQgNk7UeOF0SdU7XViSoqDbGtBhjDkbWOc+3JzuWMabd+e01xnQ7j6uNMbtj94vtAolZ32KM2ZnsvCnKlnjuyHEecpZ3ArXJulqcYzwUaYmy1nYAu40x1WkctzqmvAeTlWe2r2XMazbhXM66g866bmOMz1nXEqm38xpP+ppNca0iv9OqR8JrFanbTufvzhdTz6ZIWZ11PzTG+BK2T1n3hNcs9u8lsW4tzjbtCfvEXsNIGd4R8zi2rInXJK4eya5zir/rtP9uRBaCAieRBM4HdLMTJGyN/KNKZK3dBRxwurb8zupKp+uhg3AQshNot9bWk/ybdrsxpg7YDnQ5564DypLt5/xzehxodM61J9l5U5Qtto47gb3O8fcaY5qcY/U4+ybaDvQkrHs8sU7JjpusfIkHn6PXkmT7Oet6nHW7gKRdQ1O9ZqS+VnunU4+EY+51jgfhbtIOa22Ps399TFm3AzuAv41ZV+8EsCnrPsnfS5RzjbqttfXW2vpJrmGkDCdjHjeQ4pokq0eq13imfzciC0GBk8hETUA7hD/orbWJAcNkDji/+wh3e20C6p3WDW+Sb81tQD1QAzxA+B9SPVAcux+w1Nk+tiz9k5x3KlsJ/zMC6HLOOZl9QGIAuZWYoGGK4063fMn2meq1TLUfXH6tOoAJrWRpSnqtrLVdU+yXsu5OcN7oBOfRa+u0Ju3m8mu+z1rrdwIln7N9exrnSvX3Eqvm/7Vzr7cJw1AYhj9LHaDqCHSDlo7ACIgRGIERKkYonaCXTSgjtCNUiAXSHz4mIdiJUwL08j5/SqXY+CblxD5B5ZxJ6Tl8rgQ74fPOnKhcp6l+pPS5boCjInAC4t7aL8mykn+SntoT/c5OhgVlA0lruwHfyucP7ZSTtKkUG0l6lHTI21pL+d0S2d9l08XW7nE4jrGdl1kkaOhUb0eNY9ki5GMN5YOJatnrnAoa5upQa/mdm+3Rocqdv1iws5Sf+2gCf0TbelmpnLNQf+4cJtdpRj+qjrlugF4ROAH7XuXzMG4sQNjmt8RYHkbqOG8haVLN3Yj4lPRun6/kjz6ayg3lb+Dz1Pe2tc2ObCbOuZX8rsneEU6kzPboSf7pf+9m+J16c9prdeeMZcrAyszlA74P+aOsB9V2Q5raoMhcde1HxJP80XAIQkP+WOo1/YWkyw47oY3rxcb1LoxrlzlsmZNkP+rjc+i6AU7JFUVx7jYAP55zbp7I/Tl1O15kN34L6sa8zt3M+QT26V8ZJ+vPIPWTEbVrWS9Azy7O3QDgl6jn8pzLvaSZcy78z48H/iOWND0qimKcWYT1AvSMHScAAIBM5DgBAABkInACAADIROAEAACQicAJAAAgE4ETAABAJgInAACATF86psfc9qr95gAAAABJRU5ErkJggg==\n", 59 | "text/plain": [ 60 | "
" 61 | ] 62 | }, 63 | "metadata": { 64 | "needs_background": "light" 65 | }, 66 | "output_type": "display_data" 67 | } 68 | ], 69 | "source": [ 70 | "x1, y1 = sample(e=1)\n", 71 | "x2, y2 = sample(e=2)\n", 72 | "\n", 73 | "tt = np.linspace(-1, 1, 100 + 1)\n", 74 | "plot_y_ls = []\n", 75 | "plot_y_ls_reg = []\n", 76 | "plot_y_1 = []\n", 77 | "\n", 78 | "for t in tt:\n", 79 | " phi = np.diag([1, t])\n", 80 | " w = np.array([1, 0])\n", 81 | " # phi = np.array([[1],[t]])\n", 82 | " # w = np.ones(1)\n", 83 | " plot_y_ls.append(penalty_ls(x1, y1, phi, w) + penalty_ls(x2,y2,phi, w))\n", 84 | " plot_y_ls_reg.append(penalty_ls(x1, y1, phi, w, reg=1000) + penalty_ls(x2,y2,phi, w, reg=1000))\n", 85 | " plot_y_1.append(penalty_g(x1, y1, phi, w) + penalty_g(x2, y2, phi, w))\n", 86 | "\n", 87 | "plt.figure(figsize=(8, 4))\n", 88 | "plt.plot(tt, plot_y_ls, lw=2, label=r'$\\mathbb{D}_{\\text{dist}}((1, 0), \\Phi, e)$')\n", 89 | "plt.plot(tt, plot_y_ls_reg, ls=\"--\", lw=2, label=r'$\\mathbb{D}_{\\text{dist}}$ (heavy regularization)')\n", 90 | "plt.plot(tt, plot_y_1, '.', lw=2, label=r'$\\mathbb{D}_{\\text{lin}}((1, 0), \\Phi, e)$')\n", 91 | "# plt.ylim(-1, 12)\n", 92 | "plt.grid()\n", 93 | "plt.xlabel(r'$c$, the weight of $\\Phi$ on the input with varying correlation', labelpad=10)\n", 94 | "plt.ylabel(r'invariance penalty')\n", 95 | "plt.tight_layout(0, 0, 0)\n", 96 | "plt.legend(prop={'size': 11}, loc=\"upper right\")\n", 97 | "plt.show()\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 37, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "\n", 107 | "# SOMETHING is wrong. check dimensionality of w and phi\n", 108 | "\n", 109 | "# Need to work tomorrow on getting my report done for the predoc." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "deep (3.6)", 123 | "language": "python", 124 | "name": "deep" 125 | }, 126 | "language_info": { 127 | "codemirror_mode": { 128 | "name": "ipython", 129 | "version": 3 130 | }, 131 | "file_extension": ".py", 132 | "mimetype": "text/x-python", 133 | "name": "python", 134 | "nbconvert_exporter": "python", 135 | "pygments_lexer": "ipython3", 136 | "version": "3.6.9" 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 2 141 | } 142 | --------------------------------------------------------------------------------