├── exps ├── .gitignore ├── cmnist_label_noise_sweep.sh └── cmnist_with_specified_label_noise.sh ├── opt_env ├── .gitignore ├── cmnist_results │ ├── .gitignore │ └── acc_table.py ├── utils │ ├── .gitignore │ ├── model_utils.py │ ├── env_utils.py │ └── opt_utils.py └── irm_cmnist.py ├── InvariantRiskMinimization ├── code │ ├── colored_mnist │ │ ├── .gitignore │ │ ├── optimize_envs.sh │ │ └── main_optenv.py │ ├── experiment_synthetic │ │ ├── .gitignore │ │ ├── synthetic_results.pt │ │ ├── run_sems.sh │ │ ├── sem.py │ │ ├── plot.py │ │ ├── main.py │ │ └── models.py │ └── figure_1 │ │ └── penalties.py ├── CODE_OF_CONDUCT.md ├── README.md ├── CONTRIBUTING.md └── LICENSE ├── .gitignore ├── requirements.txt ├── README.md ├── LICENSE └── notebooks └── sem_results.ipynb /exps/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/* 2 | -------------------------------------------------------------------------------- /opt_env/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | __pycache__/* 3 | -------------------------------------------------------------------------------- /opt_env/cmnist_results/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | -------------------------------------------------------------------------------- /opt_env/utils/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | __pycache__/* 3 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/colored_mnist/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | __pycache__/* 3 | plots/* 4 | results/* 5 | slurm_output/* 6 | old_scripts/* 7 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | slurm_output/* 3 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/synthetic_results.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecreager/eiil/HEAD/InvariantRiskMinimization/code/experiment_synthetic/synthetic_results.pt -------------------------------------------------------------------------------- /InvariantRiskMinimization/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please read the [full text](https://code.fb.com/codeofconduct/)) so that you can understand what actions will and will not be tolerated. 4 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/colored_mnist/optimize_envs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "EIIL + IRM" 4 | python -u main_optenv.py \ 5 | --hidden_dim=390 \ 6 | --l2_regularizer_weight=0.00110794568 \ 7 | --lr=0.0004898536566546834 \ 8 | --penalty_anneal_iters=190 \ 9 | --penalty_weight=191257.18613115903 \ 10 | --steps=501 \ 11 | --n_restarts=10 \ 12 | --eiil \ 13 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/README.md: -------------------------------------------------------------------------------- 1 | # Code repository for Invariant Risk Minimization 2 | 3 | Source code for [the paper](https://arxiv.org/abs/1907.02893v1): 4 | 5 | ``` 6 | @article{InvariantRiskMinimization, 7 | title={Invariant Risk Minimization}, 8 | author={Arjovsky, Martin and Bottou, L{\'e}on and Gulrajani, Ishaan and Lopez-Paz, David}, 9 | journal={arXiv}, 10 | year={2019} 11 | } 12 | ``` 13 | 14 | Repository licensed under [LICENSE](LICENSE). 15 | -------------------------------------------------------------------------------- /exps/cmnist_label_noise_sweep.sh: -------------------------------------------------------------------------------- 1 | ./exps/cmnist_with_specified_label_noise.sh 0.00 2 | ./exps/cmnist_with_specified_label_noise.sh 0.05 3 | ./exps/cmnist_with_specified_label_noise.sh 0.10 4 | ./exps/cmnist_with_specified_label_noise.sh 0.15 5 | ./exps/cmnist_with_specified_label_noise.sh 0.20 6 | ./exps/cmnist_with_specified_label_noise.sh 0.25 7 | ./exps/cmnist_with_specified_label_noise.sh 0.30 8 | ./exps/cmnist_with_specified_label_noise.sh 0.35 9 | ./exps/cmnist_with_specified_label_noise.sh 0.40 10 | ./exps/cmnist_with_specified_label_noise.sh 0.45 11 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/run_sems.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo RUNNING ALPHA SWEEP EXPER 3 | RESULTS_DIR=/scratch/gobi1/creager/opt_env/run_sems_alpha_sweep/$RANDOM$RANDOM 4 | echo $RESULTS_DIR 5 | SETUP_HETERO=2 6 | N_REPS=5 7 | mkdir -p $RESULTS_DIR 8 | echo results found here 9 | echo $RESULTS_DIR 10 | echo RUNNING ALL METHODS FOR $N_REPS RESTARTS IN HETEROSKEDASTIC SETTING 11 | python -u main.py --verbose 1 --methods "EIIL,ERM,ICP,IRM" --setup_hetero $SETUP_HETERO --results_dir $RESULTS_DIR --n_reps $N_REPS 12 | for alpha in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 13 | do 14 | echo RUNNING EIIL WITH $alpha SPURIOUS REFERENCE CLASSIFIER FOR $N_REPS RESTARTS IN HETEROSKEDASTIC SETTING 15 | python -u main.py --verbose 1 --methods "EIIL" --setup_hetero $SETUP_HETERO --results_dir $RESULTS_DIR --eiil_ref_alpha $alpha --n_reps $N_REPS 16 | done 17 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to InvariantRiskMinimization 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Facebook's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## License 31 | By contributing to GradientEpisodicMemory, you agree that your contributions 32 | will be licensed under the LICENSE file in the root directory of this source 33 | tree. 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argon2-cffi==20.1.0 2 | attrs==20.1.0 3 | backcall==0.2.0 4 | bleach==3.1.5 5 | certifi==2020.6.20 6 | cffi==1.14.2 7 | cycler==0.10.0 8 | decorator==4.4.2 9 | defusedxml==0.6.0 10 | entrypoints==0.3 11 | future==0.18.2 12 | importlib-metadata==1.7.0 13 | ipykernel==5.3.4 14 | ipython==7.16.1 15 | ipython-genutils==0.2.0 16 | ipywidgets==7.5.1 17 | jedi==0.17.2 18 | Jinja2==2.11.2 19 | joblib==0.16.0 20 | jsonschema==3.2.0 21 | jupyter==1.0.0 22 | jupyter-client==6.1.7 23 | jupyter-console==6.1.0 24 | jupyter-core==4.6.3 25 | jupyter-http-over-ws==0.0.8 26 | kiwisolver==1.2.0 27 | MarkupSafe==1.1.1 28 | matplotlib==3.3.1 29 | mistune==0.8.4 30 | nbconvert==5.6.1 31 | nbformat==5.0.7 32 | notebook==6.1.3 33 | numpy==1.19.1 34 | packaging==20.4 35 | pandas==1.1.1 36 | pandocfilters==1.4.2 37 | parso==0.7.1 38 | pexpect==4.8.0 39 | pickleshare==0.7.5 40 | Pillow==7.2.0 41 | prometheus-client==0.8.0 42 | prompt-toolkit==3.0.6 43 | ptyprocess==0.6.0 44 | pycparser==2.20 45 | Pygments==2.6.1 46 | pyparsing==2.4.7 47 | pyrsistent==0.16.0 48 | python-dateutil==2.8.1 49 | pytz==2020.1 50 | pyzmq==19.0.2 51 | qtconsole==4.7.6 52 | QtPy==1.9.0 53 | scikit-learn==0.23.2 54 | scipy==1.5.2 55 | seaborn==0.10.1 56 | Send2Trash==1.5.0 57 | six==1.15.0 58 | terminado==0.8.3 59 | testpath==0.4.4 60 | threadpoolctl==2.1.0 61 | torch==1.6.0 62 | torchvision==0.7.0 63 | tornado==6.0.4 64 | tqdm==4.48.2 65 | traitlets==4.3.3 66 | wcwidth==0.2.5 67 | webencodings==0.5.1 68 | widgetsnbextension==3.5.1 69 | zipp==3.1.0 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Environment Inference for Invariant Learning 2 | This code accompanies the paper [Environment Inference for Invariant Learning](https://arxiv.org/abs/2010.07249), which appears at ICML 2021. 3 | 4 | Thanks to my wonderful co-authors [Jörn-Henrik Jacobsen](https://github.com/jhjacobsen/) and [Richard Zemel](https://www.cs.toronto.edu/~zemel/inquiry/home.php). 5 | 6 | The InvariantRiskMinimization subdirectory is modified from https://github.com/facebookresearch/InvariantRiskMinimization, and has its own license. 7 | 8 | ## Reproducing paper results 9 | 10 | ### Sythetic data 11 | To produce results 12 | ``` 13 | cd InvariantRiskMinimization/code/experiment_synthetic/ 14 | ./run_sems.sh 15 | ``` 16 | To analyze results 17 | ``` 18 | noteooks/sem_results.ipynb 19 | ``` 20 | 21 | ### Color MNIST 22 | To produce results 23 | ``` 24 | ./exps/cmnist_label_noise_sweep.sh 25 | ``` 26 | To analyze results 27 | ``` 28 | notebooks/plot_cmnist_label_noise_sweep.ipynb 29 | ``` 30 | As an alternative, `InvariantRiskMinimization/code/colored_mnist/optimize_envs.sh` also runs EIIL+IRM on CMNIST with 25% label noise (the default from the IRM paper). 31 | 32 | ## Citing this work 33 | If you find this code to your research useful please consider citing our workshop paper using the following bibtex entry 34 | ``` 35 | @inproceedings{creager21environment, 36 | title={Environment Inference for Invariant Learning}, 37 | author={Creager, Elliot and Jacobsen, J{\"o}rn-Henrik and Zemel, Richard}, 38 | booktitle={International Conference on Machine Learning}, 39 | year={2021}, 40 | } 41 | 42 | ``` 43 | -------------------------------------------------------------------------------- /opt_env/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """Model utils.""" 2 | import os 3 | import pickle 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class ColorBasedClassifier(nn.Module): 10 | LOG_CONFIDENCE = .5 11 | def forward(self, input): 12 | # estimate color based on arg max over the two channels 13 | input = input.sum((-1, -2)) # add pixels to get per-channel sums 14 | if len(input) == 0: # hack to handle size zero batches 15 | hard_prediction = torch.zeros(0, 1).to(input.device) 16 | else: 17 | hard_prediction = torch.argmax(input, -1, keepdim=True) # in {0, 1} 18 | hard_prediction = hard_prediction.float() 19 | hard_prediction = hard_prediction * 2 - 1. # in {-1, 1} 20 | return hard_prediction * self.LOG_CONFIDENCE 21 | 22 | 23 | def load_mlp(results_dir=None, flags=None, basename='params.p'): 24 | if flags is None: 25 | assert results_dir is not None, "flags and results_dir cannot both be None." 26 | flags = pickle.load(open(os.path.join(results_dir, 'flags.p'), 'rb')) 27 | 28 | class MLP(nn.Module): 29 | def __init__(self): 30 | super(MLP, self).__init__() 31 | if flags.grayscale_model: 32 | lin1 = nn.Linear(14 * 14, flags.hidden_dim) 33 | else: 34 | lin1 = nn.Linear(2 * 14 * 14, flags.hidden_dim) 35 | lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim) 36 | lin3 = nn.Linear(flags.hidden_dim, 1) 37 | for lin in [lin1, lin2, lin3]: 38 | nn.init.xavier_uniform_(lin.weight) 39 | nn.init.zeros_(lin.bias) 40 | self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3) 41 | def forward(self, input): 42 | if flags.grayscale_model: 43 | out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1) 44 | else: 45 | out = input.view(input.shape[0], 2 * 14 * 14) 46 | out = self._main(out) 47 | return out 48 | 49 | mlp = MLP() 50 | if torch.cuda.is_available(): 51 | mlp = mlp.cuda() 52 | 53 | if results_dir is not None: 54 | mlp.load_state_dict(torch.load(os.path.join(results_dir, basename))) 55 | print('Model params loaded from %s.' % results_dir) 56 | else: 57 | print('Model built with randomly initialized parameters.') 58 | mlp.eval() 59 | 60 | return mlp 61 | 62 | 63 | -------------------------------------------------------------------------------- /opt_env/cmnist_results/acc_table.py: -------------------------------------------------------------------------------- 1 | """Build results table for CMNIST experiment.""" 2 | import argparse 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | def load_results(dirname): 11 | return pickle.load(open(os.path.join(dirname, 'metrics.p'), 'rb')) 12 | 13 | def main(flags): 14 | # load results from disk 15 | results = dict( 16 | erm=load_results(flags.erm_results_dir), 17 | irm=load_results(flags.irm_results_dir), 18 | eiil=load_results(flags.eiil_results_dir), 19 | eiil_cb=load_results(flags.eiil_cb_results_dir), 20 | cb=load_results(flags.cb_results_dir), 21 | gray=load_results(flags.gray_results_dir), 22 | ) 23 | num_methods = len(results) 24 | def mean_plus_minus_std(x): 25 | """Format list as its mean plus minus one std dev.""" 26 | return r"""%.1f $\pm$ %.1f""" % (100. * np.mean(x), 100. * np.std(x)) 27 | results = pd.DataFrame.from_dict(results) 28 | results = results.T 29 | results_tex = results.to_latex( 30 | formatters=[mean_plus_minus_std, ] * 2, 31 | escape=False 32 | ) 33 | print(results_tex) 34 | if not os.path.exists(flags.results_dir): 35 | os.makedirs(flags.results_dir) 36 | print(results_tex, file=open(os.path.join(flags.results_dir, 'results.tex'), 'w')) 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser(description='Build results table for CMNIST experiment.') 40 | parser.add_argument('--erm_results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist/erm') 41 | parser.add_argument('--irm_results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist/irm') 42 | parser.add_argument('--eiil_results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist/eiil') 43 | parser.add_argument('--eiil_cb_results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist/eiil_cb') 44 | parser.add_argument('--cb_results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist/cb') 45 | parser.add_argument('--gray_results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist/gray') 46 | parser.add_argument('--results_dir', type=str, default='/scratch/gobi1/creager/opt_env/cmnist_results/acc_table', 47 | help='where tex tables should be saved') 48 | flags = parser.parse_args() 49 | main(flags) -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/figure_1/penalties.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from sklearn.linear_model import Ridge 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | 13 | def ls(x, y, reg=.1): 14 | return Ridge(alpha=reg, fit_intercept=False).fit(x, y).coef_ 15 | 16 | 17 | def sample(n=100000, e=1): 18 | x = np.random.randn(n, 1) * e 19 | y = x + np.random.randn(n, 1) * e 20 | z = y + np.random.randn(n, 1) 21 | return np.hstack((x, z)), y 22 | 23 | 24 | def penalty_ls(x1, y1, x2, y2, t=1, reg=.1): 25 | phi = np.diag([1, t]) 26 | w = np.array([1, 0]).reshape(1, 2) 27 | p1 = np.linalg.norm(ls(x1 @ phi, y1, reg) - w) 28 | p2 = np.linalg.norm(ls(x2 @ phi, y2, reg) - w) 29 | return (p1 + p2) / 2 30 | 31 | 32 | def penalty_g(x1, y1, x2, y2, t=1): 33 | phi = np.diag([1, t]) 34 | w = np.array([1, 0]).reshape(2, 1) 35 | p1 = (phi.T @ x1.T @ x1 @ phi @ w - phi.T @ x1.T @ y1) / x1.shape[0] 36 | p2 = (phi.T @ x2.T @ x2 @ phi @ w - phi.T @ x2.T @ y2) / x2.shape[0] 37 | return np.linalg.norm(p1) ** 2 + np.linalg.norm(p2) ** 2 38 | 39 | 40 | if __name__ == "__main__": 41 | x1, y1 = sample(e=1) 42 | x2, y2 = sample(e=2) 43 | 44 | plot_x = np.linspace(-1, 1, 100 + 1) 45 | plot_y_ls = [] 46 | plot_y_ls_reg = [] 47 | plot_y_1 = [] 48 | 49 | for t in plot_x: 50 | plot_y_ls.append(penalty_ls(x1, y1, x2, y2, t)) 51 | plot_y_ls_reg.append(penalty_ls(x1, y1, x2, y2, t, reg=1000)) 52 | plot_y_1.append(penalty_g(x1, y1, x2, y2, t)) 53 | 54 | plt.rcParams.update({'text.latex.preamble' : [r'\usepackage{amsmath, amsfonts}']}) 55 | plt.rcParams["font.family"] = "Times New Roman" 56 | plt.rc('text', usetex=True) 57 | plt.rc('font', size=12) 58 | 59 | plt.figure(figsize=(8, 4)) 60 | plt.plot(plot_x, plot_y_ls, lw=2, label=r'$\mathbb{D}_{\text{dist}}((1, 0), \Phi, e)$') 61 | plt.plot(plot_x, plot_y_ls_reg, ls="--", lw=2, label=r'$\mathbb{D}_{\text{dist}}$ (heavy regularization)') 62 | plt.plot(plot_x, plot_y_1, '.', lw=2, label=r'$\mathbb{D}_{\text{lin}}((1, 0), \Phi, e)$') 63 | plt.ylim(-1, 12) 64 | plt.xlabel( 65 | r'$c$, the weight of $\Phi$ on the input with varying correlation', labelpad=10) 66 | plt.ylabel(r'invariance penalty') 67 | plt.tight_layout(0, 0, 0) 68 | plt.legend(prop={'size': 11}, loc="upper right") 69 | plt.savefig("different_penalties.pdf") 70 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/sem.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class ChainEquationModel(object): 13 | def __init__(self, dim, scramble=False, hetero=True, hidden=False): 14 | self.hetero = hetero 15 | self.hidden = hidden 16 | self.dim = dim // 2 17 | ones = True 18 | 19 | if ones: 20 | self.wxy = torch.eye(self.dim) 21 | self.wyz = torch.eye(self.dim) 22 | else: 23 | self.wxy = torch.randn(self.dim, self.dim) / dim 24 | self.wyz = torch.randn(self.dim, self.dim) / dim 25 | 26 | if scramble: 27 | self.scramble, _ = torch.qr(torch.randn(dim, dim)) 28 | else: 29 | self.scramble = torch.eye(dim) 30 | 31 | if hidden: 32 | self.whx = torch.randn(self.dim, self.dim) / dim 33 | self.why = torch.randn(self.dim, self.dim) / dim 34 | self.whz = torch.randn(self.dim, self.dim) / dim 35 | else: 36 | self.whx = torch.eye(self.dim, self.dim) 37 | self.why = torch.zeros(self.dim, self.dim) 38 | self.whz = torch.zeros(self.dim, self.dim) 39 | 40 | def solution(self): 41 | w = torch.cat((self.wxy.sum(1), torch.zeros(self.dim))).view(-1, 1) 42 | return self.scramble.t() @ w 43 | 44 | def __call__(self, n, env): 45 | h = torch.randn(n, self.dim) * env 46 | 47 | if self.hetero == 2: 48 | x = torch.randn(n, self.dim) * 5. 49 | y = x @ self.wxy + torch.randn(n, self.dim) * env 50 | z = y @ self.wyz + torch.randn(n, self.dim) 51 | elif self.hetero == 1: 52 | x = h @ self.whx + torch.randn(n, self.dim) * env 53 | y = x @ self.wxy + h @ self.why + torch.randn(n, self.dim) * env 54 | z = y @ self.wyz + h @ self.whz + torch.randn(n, self.dim) 55 | else: 56 | x = h @ self.whx + torch.randn(n, self.dim) * env 57 | y = x @ self.wxy + h @ self.why + torch.randn(n, self.dim) 58 | z = y @ self.wyz + h @ self.whz + torch.randn(n, self.dim) * env 59 | 60 | variances = dict( 61 | h=h.var().item(), 62 | x=x.var().item(), 63 | y=y.var().item(), 64 | z=z.var().item(), 65 | e=(torch.randn(n, self.dim) * env).var().item() # any env dependent noise we might add 66 | ) 67 | from pprint import pprint 68 | print('in setting %d data in env %d have following variances' % (self.hetero, env)) 69 | pprint(variances) 70 | return torch.cat((x, z), 1) @ self.scramble, y.sum(1, keepdim=True) 71 | -------------------------------------------------------------------------------- /exps/cmnist_with_specified_label_noise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # CMNIST Experiment. 3 | 4 | # Hyperparameters 5 | N_RESTARTS=10 6 | HIDDEN_DIM=390 7 | L2_REGULARIZER_WEIGHT=0.00110794568 8 | LR=0.0004898536566546834 9 | LABEL_NOISE=${1-0.05} 10 | PENALTY_ANNEAL_ITERS=190 11 | PENALTY_WEIGHT=191257.18613115903 12 | STEPS=501 13 | ROOT=${2-/scratch/gobi1/creager/opt_env/cmnist} 14 | # RNUM=$(printf "%05d" $(($RANDOM$RANDOM$RANDOM % 100000))) 15 | TAG=$(date +'%Y-%m-%d')--$LABEL_NOISE 16 | ROOT=$ROOT/label_noise_sweep/$TAG 17 | 18 | # ERM 19 | python -u -m opt_env.irm_cmnist \ 20 | --results_dir $ROOT/erm \ 21 | --n_restarts $N_RESTARTS \ 22 | --hidden_dim $HIDDEN_DIM \ 23 | --l2_regularizer_weight $L2_REGULARIZER_WEIGHT \ 24 | --lr $LR \ 25 | --label_noise $LABEL_NOISE \ 26 | --penalty_anneal_iters 0 \ 27 | --penalty_weight 0.0 \ 28 | --steps $STEPS 29 | 30 | # IRM 31 | python -u -m opt_env.irm_cmnist \ 32 | --results_dir $ROOT/irm \ 33 | --n_restarts $N_RESTARTS \ 34 | --hidden_dim $HIDDEN_DIM \ 35 | --l2_regularizer_weight $L2_REGULARIZER_WEIGHT \ 36 | --lr $LR \ 37 | --label_noise $LABEL_NOISE \ 38 | --penalty_anneal_iters $PENALTY_ANNEAL_ITERS \ 39 | --penalty_weight $PENALTY_WEIGHT \ 40 | --steps $STEPS 41 | 42 | # EIIL 43 | python -u -m opt_env.irm_cmnist \ 44 | --results_dir $ROOT/eiil \ 45 | --n_restarts $N_RESTARTS \ 46 | --hidden_dim $HIDDEN_DIM \ 47 | --l2_regularizer_weight $L2_REGULARIZER_WEIGHT \ 48 | --lr $LR \ 49 | --label_noise $LABEL_NOISE \ 50 | --penalty_anneal_iters $PENALTY_ANNEAL_ITERS \ 51 | --penalty_weight $PENALTY_WEIGHT \ 52 | --steps $STEPS \ 53 | --eiil 54 | 55 | # EIIL with color-based reference classifier 56 | python -u -m opt_env.irm_cmnist \ 57 | --results_dir $ROOT/eiil_cb \ 58 | --n_restarts $N_RESTARTS \ 59 | --hidden_dim $HIDDEN_DIM \ 60 | --l2_regularizer_weight $L2_REGULARIZER_WEIGHT \ 61 | --lr $LR \ 62 | --label_noise $LABEL_NOISE \ 63 | --penalty_anneal_iters $PENALTY_ANNEAL_ITERS \ 64 | --penalty_weight $PENALTY_WEIGHT \ 65 | --steps $STEPS \ 66 | --eiil \ 67 | --color_based 68 | 69 | # Evaluate color-based classifier on its own 70 | python -u -m opt_env.irm_cmnist \ 71 | --results_dir $ROOT/cb \ 72 | --n_restarts $N_RESTARTS \ 73 | --hidden_dim $HIDDEN_DIM \ 74 | --l2_regularizer_weight $L2_REGULARIZER_WEIGHT \ 75 | --lr $LR \ 76 | --label_noise $LABEL_NOISE \ 77 | --penalty_anneal_iters $PENALTY_ANNEAL_ITERS \ 78 | --penalty_weight $PENALTY_WEIGHT \ 79 | --steps $STEPS \ 80 | --color_based_eval \ 81 | 82 | # Grayscale baseline 83 | python -u -m opt_env.irm_cmnist \ 84 | --results_dir $ROOT/gray \ 85 | --n_restarts $N_RESTARTS \ 86 | --hidden_dim $HIDDEN_DIM \ 87 | --l2_regularizer_weight $L2_REGULARIZER_WEIGHT \ 88 | --lr $LR \ 89 | --label_noise $LABEL_NOISE \ 90 | --penalty_anneal_iters $PENALTY_ANNEAL_ITERS \ 91 | --penalty_weight 0.0 \ 92 | --steps $STEPS \ 93 | --grayscale_model 94 | 95 | # Build latex tables 96 | # accuracy 97 | python -u -m opt_env.cmnist_results.acc_table \ 98 | --erm_results_dir $ROOT/erm \ 99 | --irm_results_dir $ROOT/irm \ 100 | --eiil_results_dir $ROOT/eiil \ 101 | --eiil_cb_results_dir $ROOT/eiil_cb \ 102 | --cb_results_dir $ROOT/cb \ 103 | --gray_results_dir $ROOT/gray \ 104 | --results_dir $ROOT/acc_table 105 | -------------------------------------------------------------------------------- /opt_env/utils/env_utils.py: -------------------------------------------------------------------------------- 1 | """Build environments.""" 2 | import attr 3 | import numpy as np 4 | import torch 5 | from torchvision import datasets 6 | 7 | 8 | def get_envs(cuda=True, flags=None): 9 | 10 | if flags is None: # configure data generation like in original IRM paper 11 | @attr.s 12 | class DefaultFlags(object): 13 | """Specify spurrious correlations as original IRM paper.""" 14 | train_env_1__color_noise = attr.ib(default=0.2) 15 | train_env_2__color_noise = attr.ib(default=0.1) 16 | test_env__color_noise = attr.ib(default=0.9) 17 | label_noise = attr.ib(default=0.25) 18 | flags = DefaultFlags() 19 | 20 | def _make_environment(images, labels, e): 21 | 22 | # NOTE: low e indicates a spurious correlation from color to (noisy) label 23 | 24 | def torch_bernoulli(p, size): 25 | return (torch.rand(size) < p).float() 26 | 27 | def torch_xor(a, b): 28 | return (a-b).abs() # Assumes both inputs are either 0 or 1 29 | 30 | samples = dict() 31 | # 2x subsample for computational convenience 32 | images = images.reshape((-1, 28, 28))[:, ::2, ::2] 33 | # Assign a binary label based on the digit; flip label with probability 0.25 34 | labels = (labels < 5).float() 35 | samples.update(preliminary_labels=labels) 36 | label_noise = torch_bernoulli(flags.label_noise, len(labels)) 37 | labels = torch_xor(labels, label_noise) 38 | samples.update(final_labels=labels) 39 | samples.update(label_noise=label_noise) 40 | # Assign a color based on the label; flip the color with probability e 41 | color_noise = torch_bernoulli(e, len(labels)) 42 | colors = torch_xor(labels, color_noise) 43 | samples.update(colors=colors) 44 | samples.update(color_noise=color_noise) 45 | # Apply the color to the image by zeroing out the other color channel 46 | images = torch.stack([images, images], dim=1) 47 | images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0 48 | images = (images.float() / 255.) 49 | labels = labels[:, None] 50 | if cuda and torch.cuda.is_available(): 51 | images = images.cuda() 52 | labels = labels.cuda() 53 | samples.update(images=images, labels=labels) 54 | return samples 55 | 56 | mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True) 57 | mnist_train = (mnist.data[:50000], mnist.targets[:50000]) 58 | mnist_val = (mnist.data[50000:], mnist.targets[50000:]) 59 | 60 | rng_state = np.random.get_state() 61 | np.random.shuffle(mnist_train[0].numpy()) 62 | np.random.set_state(rng_state) 63 | np.random.shuffle(mnist_train[1].numpy()) 64 | 65 | envs = [ 66 | _make_environment(mnist_train[0][::2], mnist_train[1][::2], flags.train_env_1__color_noise), 67 | _make_environment(mnist_train[0][1::2], mnist_train[1][1::2], flags.train_env_2__color_noise), 68 | _make_environment(mnist_val[0], mnist_val[1], flags.test_env__color_noise) 69 | ] 70 | return envs 71 | 72 | 73 | def get_envs_with_indices(): 74 | """Return IRM envs but with indices and environment indicators.""" 75 | envs = get_envs() 76 | examples_so_far = 0 77 | for i, env in enumerate(envs): 78 | num_examples = len(env['images']) 79 | env['idx'] = idx = torch.tensor( 80 | np.arange(examples_so_far, examples_so_far + num_examples), 81 | dtype=torch.int32 82 | ) 83 | examples_so_far += num_examples 84 | # here "env" is a label indicating which env each example belongs to 85 | env['env'] = torch.tensor(i * np.ones_like(env['idx']), dtype=torch.uint8) 86 | return envs 87 | 88 | 89 | def split_by_noise(env, noise_var='label'): 90 | assert noise_var in ('label', 'color'), 'Unexpected noise variable.' 91 | noise_name = '%s_noise' % noise_var 92 | clean_idx = (env[noise_name] == 0.) 93 | noisy_idx = (env[noise_name] == 1.) 94 | from copy import deepcopy 95 | clean_env, noisy_env = deepcopy(env), deepcopy(env) 96 | for k, v in clean_env.items(): 97 | if v.numel() > 1: 98 | clean_env[k] = v[clean_idx] 99 | for k, v in noisy_env.items(): 100 | if v.numel() > 1: 101 | noisy_env[k] = v[noisy_idx] 102 | return clean_env, noisy_env 103 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import matplotlib.pyplot as plt 9 | import matplotlib.ticker as mticker 10 | import numpy as np 11 | import torch 12 | import math 13 | import sys 14 | 15 | from matplotlib.patches import Patch 16 | 17 | 18 | def parse_title(title): 19 | result = "" 20 | fields = title.split("_") 21 | 22 | if fields[1].split("=")[1] == "1": 23 | result += "P" 24 | else: 25 | result += "F" 26 | 27 | if fields[2].split("=")[1] == "1": 28 | result += "E" 29 | else: 30 | result += "O" 31 | 32 | if fields[3].split("=")[1] == "1": 33 | result += "S" 34 | else: 35 | result += "U" 36 | 37 | return result 38 | 39 | 40 | def plot_bars(results, category, which, sep=1.1): 41 | models = list(results[list(results.keys())[0]].keys()) 42 | 43 | if "SEM" in models: 44 | models.remove("SEM") 45 | models.sort() 46 | 47 | setups = list(results.keys()) 48 | setups.sort() 49 | 50 | if which == "causal": 51 | hatch = None 52 | offset = 0 53 | idx = 0 54 | else: 55 | hatch = "//" 56 | offset = 4 57 | idx = 1 58 | 59 | counter = 1 60 | for s, setup in enumerate(setups): 61 | title = parse_title(setup) 62 | 63 | if category not in title: 64 | continue 65 | 66 | boxes = [] 67 | boxes_means = [] 68 | boxes_colors = [] 69 | boxes_vars = [] 70 | ax = plt.subplot(2, 4, counter + offset) 71 | counter += 1 72 | 73 | for m, model in enumerate(models): 74 | boxes.append(np.array(results[setup][model])[:, idx]) 75 | boxes_means.append( 76 | np.mean(np.array(results[setup][model])[:, idx])) 77 | boxes_vars.append(np.std(np.array(results[setup][model])[:, idx])) 78 | boxes_colors.append("C" + str(m)) 79 | 80 | plt.bar([0, 1, 2], 81 | boxes_means, 82 | yerr=np.array(boxes_vars), 83 | color=boxes_colors, 84 | hatch=hatch, 85 | alpha=0.7, 86 | log=True) 87 | 88 | if which == "causal": 89 | plt.xticks([1], [title]) 90 | else: 91 | ax.xaxis.set_ticks_position('top') 92 | plt.xticks([1], [""]) 93 | 94 | if (counter + offset) == 2 or (counter + offset) == 6: 95 | if which == "causal": 96 | plt.ylabel("causal error") 97 | else: 98 | plt.ylabel("non-causal error") 99 | 100 | if title == "PES" and which != "causal": 101 | legends = [] 102 | for m, model in enumerate(models): 103 | legends.append( 104 | Patch(facecolor="C" + str(m), alpha=0.7, label=model)) 105 | plt.legend(handles=legends, loc="lower center") 106 | 107 | if title == "POU" and which != "causal": 108 | plt.minorticks_off() 109 | ax.set_yticks([0.1, 0.01]) 110 | 111 | def get_results(all_solutions): 112 | results = {} 113 | 114 | for line in all_solutions: 115 | words = line.split(" ") 116 | setup = str(words[0]) 117 | model = str(words[1]) 118 | err_causal = float(words[-2]) 119 | err_noncausal = float(words[-1]) 120 | 121 | if setup not in results: 122 | results[setup] = {} 123 | 124 | if model not in results[setup]: 125 | results[setup][model] = [] 126 | 127 | results[setup][model].append([err_causal, err_noncausal]) 128 | return results 129 | 130 | 131 | 132 | def plot_experiment(all_solutions, fname): 133 | plt.rcParams["font.family"] = "serif" 134 | plt.rc('text', usetex=True) 135 | plt.rc('font', size=10) 136 | 137 | results = get_results(all_solutions) 138 | 139 | plt.figure(figsize=(7, 2)) 140 | plot_bars(results, category, "causal") 141 | plot_bars(results, category, "noncausal") 142 | plt.tight_layout(0, 0, 0.5) 143 | 144 | if fname is None: 145 | plt.show() 146 | else: 147 | plt.savefig(fname) 148 | 149 | 150 | if __name__ == "__main__": 151 | if len(sys.argv) == 1: 152 | fname = "synthetic_results.pt" 153 | else: 154 | fname = sys.argv[1] 155 | lines = torch.load(fname) 156 | plot_experiment(lines, "F", "results_f.pdf") 157 | plot_experiment(lines, "P", "results_p.pdf") 158 | -------------------------------------------------------------------------------- /opt_env/utils/opt_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import autograd 4 | from torch import nn 5 | from torch import optim 6 | from tqdm import tqdm 7 | 8 | def nll(logits, y, reduction='mean'): 9 | return nn.functional.binary_cross_entropy_with_logits(logits, y, reduction=reduction) 10 | 11 | def mean_accuracy(logits, y): 12 | preds = (logits > 0.).float() 13 | return ((preds - y).abs() < 1e-2).float().mean() 14 | 15 | def penalty(logits, y): 16 | scale = torch.tensor(1.).cuda().requires_grad_() 17 | loss = nll(logits * scale, y) 18 | grad = autograd.grad(loss, [scale], create_graph=True)[0] 19 | return torch.sum(grad**2) 20 | 21 | def split_data_opt(envs, model, n_steps=10000, n_samples=-1, lr=0.001, 22 | batch_size=None, join=True, no_tqdm=False): 23 | """Learn soft environment assignment.""" 24 | 25 | if join: # assumes first two entries in envs list are the train sets to joined 26 | print('pooling envs') 27 | # pool all training envs (defined as each env in envs[:-1]) 28 | joined_train_envs = dict() 29 | for k in envs[0].keys(): 30 | if envs[0][k].numel() > 1: # omit scalars previously stored during training 31 | joined_values = torch.cat((envs[0][k][:n_samples], 32 | envs[1][k][:n_samples]), 33 | 0) 34 | joined_train_envs[k] = joined_values 35 | print('size of pooled envs: %d' % len(joined_train_envs['images'])) 36 | else: 37 | if not isinstance(envs, dict): 38 | raise ValueError(('When join=False, first argument should be a dict' 39 | ' corresponding to the only environment.' 40 | )) 41 | print('splitting data from single env of size %d' % len(envs['images'])) 42 | joined_train_envs = envs 43 | 44 | scale = torch.tensor(1.).cuda().requires_grad_() 45 | if batch_size: 46 | logits = [] 47 | i = 0 48 | num_examples = len(joined_train_envs['images']) 49 | while i < num_examples: 50 | images = joined_train_envs['images'][i:i+64] 51 | images = images.cuda() 52 | logits.append(model(images).detach()) 53 | i += 64 54 | logits = torch.cat(logits) 55 | else: 56 | logits = model(joined_train_envs['images']) 57 | logits = logits.detach() 58 | 59 | loss = nll(logits * scale, joined_train_envs['labels'].cuda(), reduction='none') 60 | 61 | env_w = torch.randn(len(logits)).cuda().requires_grad_() 62 | optimizer = optim.Adam([env_w], lr=lr) 63 | 64 | with tqdm(total=n_steps, position=1, bar_format='{desc}', desc='AED Loss: ', disable=no_tqdm) as desc: 65 | for i in tqdm(range(n_steps), disable=no_tqdm): 66 | # penalty for env a 67 | lossa = (loss.squeeze() * env_w.sigmoid()).mean() 68 | grada = autograd.grad(lossa, [scale], create_graph=True)[0] 69 | penaltya = torch.sum(grada**2) 70 | # penalty for env b 71 | lossb = (loss.squeeze() * (1-env_w.sigmoid())).mean() 72 | gradb = autograd.grad(lossb, [scale], create_graph=True)[0] 73 | penaltyb = torch.sum(gradb**2) 74 | # negate 75 | npenalty = - torch.stack([penaltya, penaltyb]).mean() 76 | # step 77 | optimizer.zero_grad() 78 | npenalty.backward(retain_graph=True) 79 | optimizer.step() 80 | desc.set_description('AED Loss: %.8f' % npenalty.cpu().item()) 81 | 82 | print('Final AED Loss: %.8f' % npenalty.cpu().item()) 83 | 84 | # split envs based on env_w threshold 85 | new_envs = [] 86 | idx0 = (env_w.sigmoid()>.5) 87 | idx1 = (env_w.sigmoid()<=.5) 88 | # train envs 89 | # NOTE: envs include original data indices for qualitative investigation 90 | for _idx in (idx0, idx1): 91 | new_env = dict() 92 | for k, v in joined_train_envs.items(): 93 | if k == 'paths': # paths is formatted as a list of str, not ndarray or tensor 94 | v_ = np.array(v) 95 | v_ = v_[_idx.cpu().numpy()] 96 | v_ = list(v_) 97 | new_env[k] = v_ 98 | else: 99 | new_env[k] = v[_idx] 100 | new_envs.append(new_env) 101 | print('size of env0: %d' % len(new_envs[0]['images'])) 102 | print('size of env1: %d' % len(new_envs[1]['images'])) 103 | 104 | if join: #NOTE: assume the user includes test set as part of arguments only if join=True 105 | new_envs.append(envs[-1]) 106 | print('size of env2: %d' % len(new_envs[2]['images'])) 107 | return new_envs 108 | 109 | 110 | def train_irm_batch(model, envs, flags): 111 | """Batch version of the IRM algo for CMNIST expers.""" 112 | def _pretty_print(*values): 113 | col_width = 13 114 | def format_val(v): 115 | if not isinstance(v, str): 116 | v = np.array2string(v, precision=5, floatmode='fixed') 117 | return v.ljust(col_width) 118 | str_values = [format_val(v) for v in values] 119 | print(" ".join(str_values)) 120 | 121 | if flags.color_based_eval: # skip IRM and evaluate color-based model 122 | from opt_env.utils.model_utils import ColorBasedClassifier 123 | model = ColorBasedClassifier() 124 | if not flags.color_based_eval: 125 | optimizer = optim.Adam(model.parameters(), lr=flags.lr) 126 | for step in range(flags.steps): 127 | for env in envs: 128 | logits = model(env['images']) 129 | env['nll'] = nll(logits, env['labels']) 130 | env['acc'] = mean_accuracy(logits, env['labels']) 131 | env['penalty'] = penalty(logits, env['labels']) 132 | 133 | train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean() 134 | train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean() 135 | train_penalty = torch.stack([envs[0]['penalty'], envs[1]['penalty']]).mean() 136 | 137 | weight_norm = torch.tensor(0.).cuda() 138 | for w in model.parameters(): 139 | weight_norm += w.norm().pow(2) 140 | loss = train_nll.clone() 141 | loss += flags.l2_regularizer_weight * weight_norm 142 | penalty_weight = (flags.penalty_weight 143 | if step >= flags.penalty_anneal_iters else 1.0) 144 | loss += penalty_weight * train_penalty 145 | if penalty_weight > 1.0: 146 | # Rescale the entire loss to keep gradients in a reasonable range 147 | loss /= penalty_weight 148 | 149 | if not flags.color_based_eval: 150 | optimizer.zero_grad() 151 | loss.backward() 152 | optimizer.step() 153 | 154 | test_acc = envs[2]['acc'] 155 | if step % 100 == 0: 156 | _pretty_print( 157 | np.int32(step), 158 | train_nll.detach().cpu().numpy(), 159 | train_acc.detach().cpu().numpy(), 160 | train_penalty.detach().cpu().numpy(), 161 | test_acc.detach().cpu().numpy() 162 | ) 163 | 164 | final_train_acc = train_acc.detach().cpu().numpy() 165 | final_test_acc = test_acc.detach().cpu().numpy() 166 | return model, final_train_acc, final_test_acc 167 | 168 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import pickle 10 | import sys 11 | 12 | from sem import ChainEquationModel 13 | from models import * 14 | 15 | import argparse 16 | import torch 17 | import numpy 18 | 19 | 20 | def pretty(vector): 21 | vlist = vector.view(-1).tolist() 22 | return "[" + ", ".join("{:+.3f}".format(vi) for vi in vlist) + "]" 23 | 24 | 25 | def errors(w, w_hat): 26 | w = w.view(-1) 27 | w_hat = w_hat.view(-1) 28 | 29 | i_causal = (w != 0).nonzero().view(-1) 30 | i_noncausal = (w == 0).nonzero().view(-1) 31 | 32 | if len(i_causal): 33 | error_causal = (w[i_causal] - w_hat[i_causal]).pow(2).mean() 34 | error_causal = error_causal.item() 35 | else: 36 | error_causal = 0 37 | 38 | if len(i_noncausal): 39 | error_noncausal = (w[i_noncausal] - w_hat[i_noncausal]).pow(2).mean() 40 | error_noncausal = error_noncausal.item() 41 | else: 42 | error_noncausal = 0 43 | 44 | return error_causal, error_noncausal 45 | 46 | 47 | def run_experiment(args): 48 | if args["seed"] >= 0: 49 | torch.manual_seed(args["seed"]) 50 | numpy.random.seed(args["seed"]) 51 | torch.set_num_threads(1) 52 | 53 | if args["setup_sem"] == "chain": 54 | setup_str = "chain_hidden={}_hetero={}_scramble={}".format( 55 | args["setup_hidden"], 56 | args["setup_hetero"], 57 | args["setup_scramble"]) 58 | elif args["setup_sem"] == "icp": 59 | setup_str = "sem_icp" 60 | else: 61 | raise NotImplementedError 62 | 63 | args['results_dir'] = os.path.join(args['results_dir'], setup_str) 64 | if args['eiil_ref_alpha'] >= 0 and args['eiil_ref_alpha'] <= 1: 65 | args['results_dir'] = '{results_dir}_alpha_{eiil_ref_alpha:.1f}'.format(**args) 66 | 67 | if not os.path.exists(args['results_dir']): 68 | os.makedirs(args['results_dir']) 69 | pickle.dump(args, open(os.path.join(args['results_dir'], 'flags.p'), 'wb')) 70 | for f in sys.stdout, open(os.path.join(args['results_dir'], 'flags.txt'), 'w'): 71 | print('Flags:', file=f) 72 | for k,v in sorted(args.items()): 73 | print("\t{}: {}".format(k, v), file=f) 74 | print('results will be found here:') 75 | print(args['results_dir']) 76 | 77 | all_methods = { 78 | "ERM": EmpiricalRiskMinimizer, 79 | "ICP": InvariantCausalPrediction, 80 | "IRM": InvariantRiskMinimization, 81 | "EIIL": LearnedEnvInvariantRiskMinimization 82 | } 83 | 84 | if args["methods"] == "all": 85 | methods = all_methods 86 | else: 87 | methods = {m: all_methods[m] for m in args["methods"].split(',')} 88 | 89 | all_sems = [] 90 | all_solutions = [] 91 | all_environments = [] 92 | from collections import defaultdict 93 | all_err_causal = defaultdict(list) 94 | all_err_noncausal = defaultdict(list) 95 | 96 | for rep_i in range(args["n_reps"]): 97 | if args["setup_sem"] == "chain": 98 | sem = ChainEquationModel(args["dim"], 99 | hidden=args["setup_hidden"], 100 | scramble=args["setup_scramble"], 101 | hetero=args["setup_hetero"]) 102 | environments = [sem(args["n_samples"], .2), 103 | sem(args["n_samples"], 2.), 104 | sem(args["n_samples"], 5.)] 105 | else: 106 | raise NotImplementedError 107 | 108 | all_sems.append(sem) 109 | all_environments.append(environments) 110 | 111 | for sem, environments in zip(all_sems, all_environments): 112 | soln = sem.solution() 113 | solutions = [ 114 | "{} {:<5} {} {:.5f} {:.5f}".format(setup_str, 115 | "SEM", 116 | pretty(sem.solution()), 0, 0) 117 | ] 118 | 119 | 120 | for method_name, method_constructor in methods.items(): 121 | method = method_constructor(environments, args) 122 | msolution = method.solution() 123 | err_causal, err_noncausal = errors(sem.solution(), msolution) 124 | all_err_causal[method_name].append(err_causal) 125 | all_err_noncausal[method_name].append(err_noncausal) 126 | 127 | solutions.append("{} {:<5} {} {:.5f} {:.5f}".format(setup_str, 128 | method_name, 129 | pretty(msolution), 130 | err_causal, 131 | err_noncausal)) 132 | 133 | all_solutions += solutions 134 | 135 | # save results 136 | results = dict() 137 | results.update(setup_str=setup_str) 138 | results.update(all_sems=all_sems) 139 | results.update(all_solutions=all_solutions) 140 | results.update(all_environments=all_environments) 141 | results.update(all_environments=all_environments) 142 | results.update(all_err_causal=all_err_causal) 143 | results.update(all_err_noncausal=all_err_noncausal) 144 | with open(os.path.join(args['results_dir'], 'results.p'), 'wb') as f: 145 | pickle.dump(results, f) 146 | 147 | return all_solutions 148 | 149 | 150 | if __name__ == '__main__': 151 | parser = argparse.ArgumentParser(description='Invariant regression') 152 | parser.add_argument('--dim', type=int, default=10) 153 | parser.add_argument('--n_samples', type=int, default=1000) 154 | parser.add_argument('--n_reps', type=int, default=1) 155 | parser.add_argument('--skip_reps', type=int, default=0) 156 | parser.add_argument('--seed', type=int, default=0) # Negative is random 157 | parser.add_argument('--print_vectors', type=int, default=1) 158 | parser.add_argument('--n_iterations', type=int, default=10000) 159 | parser.add_argument('--lr', type=float, default=1e-3) 160 | parser.add_argument('--verbose', type=int, default=0) 161 | parser.add_argument('--methods', type=str, default="EIIL,IRM,ERM") 162 | parser.add_argument('--alpha', type=float, default=0.05) 163 | parser.add_argument('--setup_sem', type=str, default="chain") 164 | parser.add_argument('--setup_hidden', type=int, default=0) 165 | parser.add_argument('--setup_hetero', type=int, default=1) 166 | parser.add_argument('--setup_scramble', type=int, default=0) 167 | parser.add_argument('--results_dir', type=str, default="/tmp/experiment_synthetic") 168 | parser.add_argument('--eiil_ref_alpha', type=float, default=-1, 169 | help=('Value between zero and one to hard code the reference ' 170 | 'classifier propensity to use the spurious feature. Set ' 171 | 'to value outside zero one interval to disable.')) 172 | args = dict(vars(parser.parse_args())) 173 | 174 | all_solutions = run_experiment(args) 175 | print("\n".join(all_solutions)) 176 | print("\n".join(all_solutions), file=open( 177 | os.path.join(args['results_dir'], 'all_solutions.txt'), 'w') 178 | ) 179 | -------------------------------------------------------------------------------- /opt_env/irm_cmnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | import pickle 5 | import sys 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision import datasets 10 | from torch import nn, optim, autograd 11 | from tqdm import tqdm 12 | 13 | from opt_env.utils.env_utils import get_envs 14 | from opt_env.utils.opt_utils import split_data_opt 15 | from opt_env.utils.opt_utils import train_irm_batch 16 | from opt_env.utils.model_utils import load_mlp 17 | from opt_env.utils.model_utils import ColorBasedClassifier 18 | 19 | def main(flags): 20 | 21 | if not os.path.exists(flags.results_dir): 22 | os.makedirs(flags.results_dir) 23 | 24 | # save this file and command for reproducibility 25 | if flags.results_dir != '.': 26 | with open(__file__, 'r') as f: 27 | this_file = f.readlines() 28 | with open(os.path.join(flags.results_dir, 'irm_cmnist.py'), 'w') as f: 29 | f.write(''.join(this_file)) 30 | cmd = 'python ' + ' '.join(sys.argv) 31 | with open(os.path.join(flags.results_dir, 'command.sh'), 'w') as f: 32 | f.write(cmd) 33 | # save params for later 34 | if not os.path.exists(flags.results_dir): 35 | os.makedirs(flags.results_dir) 36 | pickle.dump(flags, open(os.path.join(flags.results_dir, 'flags.p'), 'wb')) 37 | for f in sys.stdout, open(os.path.join(flags.results_dir, 'flags.txt'), 'w'): 38 | print('Flags:', file=f) 39 | for k,v in sorted(vars(flags).items()): 40 | print("\t{}: {}".format(k, v), file=f) 41 | 42 | print('results will be found here:') 43 | print(flags.results_dir) 44 | 45 | final_train_accs = [] 46 | final_test_accs = [] 47 | for restart in range(flags.n_restarts): 48 | print("Restart", restart) 49 | 50 | rng_state = np.random.get_state() 51 | np.random.set_state(rng_state) 52 | 53 | # Build environments 54 | envs = get_envs(flags=flags) 55 | 56 | # Define and instantiate the model 57 | if flags.color_based: # use color-based reference classifier without trainable params 58 | mlp_pre = ColorBasedClassifier() 59 | else: 60 | mlp_pre = load_mlp(results_dir=None, flags=flags).cuda() # reference classifier 61 | mlp = load_mlp(results_dir=None, flags=flags).cuda() # invariant representation learner 62 | mlp_pre.train() 63 | mlp.train() 64 | 65 | # Define loss function helpers 66 | 67 | def nll(logits, y, reduction='mean'): 68 | return nn.functional.binary_cross_entropy_with_logits(logits, y, reduction=reduction) 69 | 70 | def mean_accuracy(logits, y): 71 | preds = (logits > 0.).float() 72 | return ((preds - y).abs() < 1e-2).float().mean() 73 | 74 | def penalty(logits, y): 75 | scale = torch.tensor(1.).cuda().requires_grad_() 76 | loss = nll(logits * scale, y) 77 | grad = autograd.grad(loss, [scale], create_graph=True)[0] 78 | return torch.sum(grad**2) 79 | 80 | # Train loop 81 | 82 | def pretty_print(*values): 83 | col_width = 13 84 | def format_val(v): 85 | if not isinstance(v, str): 86 | v = np.array2string(v, precision=5, floatmode='fixed') 87 | return v.ljust(col_width) 88 | str_values = [format_val(v) for v in values] 89 | print(" ".join(str_values)) 90 | 91 | 92 | pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc') 93 | 94 | if flags.eiil: 95 | if flags.color_based: 96 | print('Color-based refernece classifier was specified, to skipping pre-training.') 97 | else: 98 | optimizer_pre = optim.Adam(mlp_pre.parameters(), lr=flags.lr) 99 | for step in range(flags.steps): 100 | for env in envs: 101 | logits = mlp_pre(env['images']) 102 | env['nll'] = nll(logits, env['labels']) 103 | env['acc'] = mean_accuracy(logits, env['labels']) 104 | env['penalty'] = penalty(logits, env['labels']) 105 | 106 | train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean() 107 | train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean() 108 | train_penalty = torch.stack([envs[0]['penalty'], envs[1]['penalty']]).mean() 109 | 110 | weight_norm = torch.tensor(0.).cuda() 111 | for w in mlp_pre.parameters(): 112 | weight_norm += w.norm().pow(2) 113 | 114 | loss = train_nll.clone() 115 | loss += flags.l2_regularizer_weight * weight_norm 116 | 117 | optimizer_pre.zero_grad() 118 | loss.backward() 119 | optimizer_pre.step() 120 | 121 | test_acc = envs[2]['acc'] 122 | if step % 100 == 0: 123 | pretty_print( 124 | np.int32(step), 125 | train_nll.detach().cpu().numpy(), 126 | train_acc.detach().cpu().numpy(), 127 | train_penalty.detach().cpu().numpy(), 128 | test_acc.detach().cpu().numpy() 129 | ) 130 | torch.save(mlp_pre.state_dict(), 131 | os.path.join(flags.results_dir, 'mlp_pre.%d.p' % restart)) 132 | envs = split_data_opt(envs, mlp_pre) 133 | mlp, final_train_acc, final_test_acc = train_irm_batch(mlp, envs, flags) 134 | final_train_accs.append(final_train_acc) 135 | final_test_accs.append(final_test_acc) 136 | print('Final train acc (mean/std across restarts so far):') 137 | print(np.mean(final_train_accs), np.std(final_train_accs)) 138 | print('Final test acc (mean/std across restarts so far):') 139 | print(np.mean(final_test_accs), np.std(final_test_accs)) 140 | print('done with restart %d' % restart) 141 | torch.save(mlp.state_dict(), 142 | os.path.join(flags.results_dir, 'mlp.%s.p' % restart)) 143 | 144 | print('done with all restarts') 145 | final_train_accs = [t.item() for t in final_train_accs] 146 | final_test_accs = [t.item() for t in final_test_accs] 147 | metrics = {'Train accs': final_train_accs, 148 | 'Test accs': final_test_accs} 149 | with open(os.path.join(flags.results_dir, 'metrics.p'), 'wb') as f: 150 | pickle.dump(metrics, f) 151 | 152 | print('results are here:') 153 | print(flags.results_dir) 154 | 155 | 156 | if __name__ == '__main__': 157 | parser = argparse.ArgumentParser(description='IRM Colored MNIST') 158 | parser.add_argument('--hidden_dim', type=int, default=256) 159 | parser.add_argument('--l2_regularizer_weight', type=float,default=0.001) 160 | parser.add_argument('--lr', type=float, default=0.001) 161 | parser.add_argument('--n_restarts', type=int, default=1) 162 | parser.add_argument('--penalty_anneal_iters', type=int, default=100) 163 | parser.add_argument('--penalty_weight', type=float, default=10000.0) 164 | parser.add_argument('--steps', type=int, default=5001) 165 | parser.add_argument('--grayscale_model', action='store_true') 166 | parser.add_argument('--eiil', action='store_true') 167 | parser.add_argument('--results_dir', type=str, default='/tmp/opt_env/irm_cmnist', 168 | help='Directory where results should be saved.') 169 | parser.add_argument('--label_noise', type=float, default=0.25) 170 | parser.add_argument('--train_env_1__color_noise', type=float, default=0.2) 171 | parser.add_argument('--train_env_2__color_noise', type=float, default=0.1) 172 | parser.add_argument('--test_env__color_noise', type=float, default=0.9) 173 | parser.add_argument('--color_based', action='store_true') # use color-based reference classifier without trainable params 174 | parser.add_argument('--color_based_eval', action='store_true') # skip IRM phase and evaluate color-based classifier 175 | flags = parser.parse_args() 176 | torch.cuda.set_device(0) 177 | main(flags) 178 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/colored_mnist/main_optenv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | from torchvision import datasets 5 | from torch import nn, optim, autograd 6 | import pdb 7 | from tqdm import tqdm 8 | 9 | parser = argparse.ArgumentParser(description='Colored MNIST') 10 | parser.add_argument('--hidden_dim', type=int, default=256) 11 | parser.add_argument('--l2_regularizer_weight', type=float,default=0.001) 12 | parser.add_argument('--lr', type=float, default=0.001) 13 | parser.add_argument('--n_restarts', type=int, default=1) 14 | parser.add_argument('--penalty_anneal_iters', type=int, default=100) 15 | parser.add_argument('--penalty_weight', type=float, default=10000.0) 16 | parser.add_argument('--steps', type=int, default=5001) 17 | parser.add_argument('--grayscale_model', action='store_true') 18 | parser.add_argument('--eiil', action='store_true') 19 | flags = parser.parse_args() 20 | torch.cuda.set_device(0) 21 | 22 | print('Flags:') 23 | for k,v in sorted(vars(flags).items()): 24 | print("\t{}: {}".format(k, v)) 25 | 26 | def split_data_opt(envs, model, n_steps=10000, n_samples=-1): 27 | """Learn soft environment assignment.""" 28 | images = torch.cat((envs[0]['images'][:n_samples],envs[1]['images'][:n_samples]),0) 29 | labels = torch.cat((envs[0]['labels'][:n_samples],envs[1]['labels'][:n_samples]),0) 30 | print('size of pooled envs: '+str(len(images))) 31 | 32 | scale = torch.tensor(1.).cuda().requires_grad_() 33 | logits = model(images) 34 | loss = nll(logits * scale, labels, reduction='none') 35 | 36 | env_w = torch.randn(len(logits)).cuda().requires_grad_() 37 | optimizer = optim.Adam([env_w], lr=0.001) 38 | 39 | print('learning soft environment assignments') 40 | for i in tqdm(range(n_steps)): 41 | # penalty for env a 42 | lossa = (loss.squeeze() * env_w.sigmoid()).mean() 43 | grada = autograd.grad(lossa, [scale], create_graph=True)[0] 44 | penaltya = torch.sum(grada**2) 45 | # penalty for env b 46 | lossb = (loss.squeeze() * (1-env_w.sigmoid())).mean() 47 | gradb = autograd.grad(lossb, [scale], create_graph=True)[0] 48 | penaltyb = torch.sum(gradb**2) 49 | # negate 50 | npenalty = - torch.stack([penaltya, penaltyb]).mean() 51 | 52 | optimizer.zero_grad() 53 | npenalty.backward(retain_graph=True) 54 | optimizer.step() 55 | 56 | # split envs based on env_w threshold 57 | new_envs = [] 58 | idx0 = (env_w.sigmoid()>.5) 59 | idx1 = (env_w.sigmoid()<=.5) 60 | # train envs 61 | for idx in (idx0, idx1): 62 | new_envs.append(dict(images=images[idx], labels=labels[idx])) 63 | # test env 64 | new_envs.append(dict(images=envs[-1]['images'], 65 | labels=envs[-1]['labels'])) 66 | print('size of env0: '+str(len(new_envs[0]['images']))) 67 | print('size of env1: '+str(len(new_envs[1]['images']))) 68 | print('size of env2: '+str(len(new_envs[2]['images']))) 69 | return new_envs 70 | 71 | final_train_accs = [] 72 | final_test_accs = [] 73 | for restart in range(flags.n_restarts): 74 | print("Restart", restart) 75 | 76 | # Load MNIST, make train/val splits, and shuffle train set examples 77 | 78 | mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True) 79 | mnist_train = (mnist.data[:50000], mnist.targets[:50000]) 80 | mnist_val = (mnist.data[50000:], mnist.targets[50000:]) 81 | 82 | rng_state = np.random.get_state() 83 | np.random.shuffle(mnist_train[0].numpy()) 84 | np.random.set_state(rng_state) 85 | np.random.shuffle(mnist_train[1].numpy()) 86 | 87 | # Build environments 88 | 89 | def make_environment(images, labels, e): 90 | def torch_bernoulli(p, size): 91 | return (torch.rand(size) < p).float() 92 | def torch_xor(a, b): 93 | return (a-b).abs() # Assumes both inputs are either 0 or 1 94 | # 2x subsample for computational convenience 95 | images = images.reshape((-1, 28, 28))[:, ::2, ::2] 96 | # Assign a binary label based on the digit; flip label with probability 0.25 97 | labels = (labels < 5).float() 98 | labels = torch_xor(labels, torch_bernoulli(0.25, len(labels))) 99 | # Assign a color based on the label; flip the color with probability e 100 | colors = torch_xor(labels, torch_bernoulli(e, len(labels))) 101 | # Apply the color to the image by zeroing out the other color channel 102 | images = torch.stack([images, images], dim=1) 103 | images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0 104 | return { 105 | 'images': (images.float() / 255.).cuda(), 106 | 'labels': labels[:, None].cuda() 107 | } 108 | envs = [ 109 | make_environment(mnist_train[0][::2], mnist_train[1][::2], 0.2), 110 | make_environment(mnist_train[0][1::2], mnist_train[1][1::2], 0.1), 111 | make_environment(mnist_val[0], mnist_val[1], 0.9) 112 | ] 113 | 114 | # Define and instantiate the model 115 | 116 | class MLP(nn.Module): 117 | def __init__(self): 118 | super(MLP, self).__init__() 119 | if flags.grayscale_model: 120 | lin1 = nn.Linear(14 * 14, flags.hidden_dim) 121 | else: 122 | lin1 = nn.Linear(2 * 14 * 14, flags.hidden_dim) 123 | lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim) 124 | lin3 = nn.Linear(flags.hidden_dim, 1) 125 | for lin in [lin1, lin2, lin3]: 126 | nn.init.xavier_uniform_(lin.weight) 127 | nn.init.zeros_(lin.bias) 128 | self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3) 129 | def forward(self, input): 130 | if flags.grayscale_model: 131 | out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1) 132 | else: 133 | out = input.view(input.shape[0], 2 * 14 * 14) 134 | out = self._main(out) 135 | return out 136 | 137 | mlp_pre = MLP().cuda() 138 | mlp = MLP().cuda() 139 | 140 | # Define loss function helpers 141 | 142 | def nll(logits, y, reduction='mean'): 143 | return nn.functional.binary_cross_entropy_with_logits(logits, y, reduction=reduction) 144 | 145 | def mean_accuracy(logits, y): 146 | preds = (logits > 0.).float() 147 | return ((preds - y).abs() < 1e-2).float().mean() 148 | 149 | def penalty(logits, y): 150 | scale = torch.tensor(1.).cuda().requires_grad_() 151 | loss = nll(logits * scale, y) 152 | grad = autograd.grad(loss, [scale], create_graph=True)[0] 153 | return torch.sum(grad**2) 154 | 155 | # Train loop 156 | 157 | def pretty_print(*values): 158 | col_width = 13 159 | def format_val(v): 160 | if not isinstance(v, str): 161 | v = np.array2string(v, precision=5, floatmode='fixed') 162 | return v.ljust(col_width) 163 | str_values = [format_val(v) for v in values] 164 | print(" ".join(str_values)) 165 | 166 | optimizer_pre = optim.Adam(mlp_pre.parameters(), lr=flags.lr) 167 | optimizer = optim.Adam(mlp.parameters(), lr=flags.lr) 168 | 169 | pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc') 170 | 171 | if flags.eiil: 172 | # Pre-train reference model 173 | for step in range(flags.steps): 174 | for env in envs: 175 | logits = mlp_pre(env['images']) 176 | env['nll'] = nll(logits, env['labels']) 177 | env['acc'] = mean_accuracy(logits, env['labels']) 178 | env['penalty'] = penalty(logits, env['labels']) 179 | 180 | train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean() 181 | train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean() 182 | train_penalty = torch.stack([envs[0]['penalty'], envs[1]['penalty']]).mean() 183 | 184 | weight_norm = torch.tensor(0.).cuda() 185 | for w in mlp_pre.parameters(): 186 | weight_norm += w.norm().pow(2) 187 | 188 | loss = train_nll.clone() 189 | loss += flags.l2_regularizer_weight * weight_norm 190 | # NOTE: IRM penalties not used in pre-training 191 | #penalty_weight = (flags.penalty_weight 192 | # if step >= flags.penalty_anneal_iters else 1.0) 193 | #loss += penalty_weight * train_penalty 194 | #if penalty_weight > 1.0: 195 | # # Rescale the entire loss to keep gradients in a reasonable range 196 | # loss /= penalty_weight 197 | 198 | optimizer_pre.zero_grad() 199 | loss.backward() 200 | optimizer_pre.step() 201 | 202 | test_acc = envs[2]['acc'] 203 | if step % 100 == 0: 204 | pretty_print( 205 | np.int32(step), 206 | train_nll.detach().cpu().numpy(), 207 | train_acc.detach().cpu().numpy(), 208 | train_penalty.detach().cpu().numpy(), 209 | test_acc.detach().cpu().numpy() 210 | ) 211 | 212 | envs = split_data_opt(envs, mlp_pre) 213 | 214 | for step in range(flags.steps): 215 | for env in envs: 216 | logits = mlp(env['images']) 217 | env['nll'] = nll(logits, env['labels']) 218 | env['acc'] = mean_accuracy(logits, env['labels']) 219 | env['penalty'] = penalty(logits, env['labels']) 220 | 221 | train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean() 222 | train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean() 223 | train_penalty = torch.stack([envs[0]['penalty'], envs[1]['penalty']]).mean() 224 | 225 | weight_norm = torch.tensor(0.).cuda() 226 | for w in mlp.parameters(): 227 | weight_norm += w.norm().pow(2) 228 | 229 | loss = train_nll.clone() 230 | loss += flags.l2_regularizer_weight * weight_norm 231 | penalty_weight = (flags.penalty_weight 232 | if step >= flags.penalty_anneal_iters else 1.0) 233 | loss += penalty_weight * train_penalty 234 | if penalty_weight > 1.0: 235 | # Rescale the entire loss to keep gradients in a reasonable range 236 | loss /= penalty_weight 237 | 238 | optimizer.zero_grad() 239 | loss.backward() 240 | optimizer.step() 241 | 242 | test_acc = envs[2]['acc'] 243 | if step % 100 == 0: 244 | pretty_print( 245 | np.int32(step), 246 | train_nll.detach().cpu().numpy(), 247 | train_acc.detach().cpu().numpy(), 248 | train_penalty.detach().cpu().numpy(), 249 | test_acc.detach().cpu().numpy() 250 | ) 251 | 252 | final_train_accs.append(train_acc.detach().cpu().numpy()) 253 | final_test_accs.append(test_acc.detach().cpu().numpy()) 254 | print('Final train acc (mean/std across restarts so far):') 255 | print(np.mean(final_train_accs), np.std(final_train_accs)) 256 | print('Final test acc (mean/std across restarts so far):') 257 | print(np.mean(final_test_accs), np.std(final_test_accs)) 258 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/code/experiment_synthetic/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | import math 11 | 12 | from sklearn.linear_model import LinearRegression 13 | from itertools import chain, combinations 14 | from scipy.stats import f as fdist 15 | from scipy.stats import ttest_ind 16 | 17 | from torch.autograd import grad 18 | 19 | import scipy.optimize 20 | 21 | import matplotlib 22 | import matplotlib.pyplot as plt 23 | from tqdm import tqdm 24 | import pdb 25 | 26 | 27 | def pretty(vector): 28 | vlist = vector.view(-1).tolist() 29 | return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]" 30 | 31 | 32 | 33 | class InvariantRiskMinimization(object): 34 | def __init__(self, environments, args): 35 | best_reg = 0 36 | best_err = 1e6 37 | 38 | x_val = environments[-1][0] 39 | y_val = environments[-1][1] 40 | 41 | for reg in [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]: 42 | reg = 1. - reg # change of variables for consistency with old codebase 43 | self.train(environments[:-1], args, reg=reg) 44 | err = (x_val @ self.solution() - y_val).pow(2).mean().item() 45 | 46 | if args["verbose"]: 47 | print("IRM (reg={:.6f}) has {:.3f} validation error.".format( 48 | reg, err)) 49 | 50 | if err < best_err: 51 | best_err = err 52 | best_reg = reg 53 | best_phi = self.phi.clone() 54 | self.phi = best_phi 55 | 56 | def train(self, environments, args, reg=0): 57 | print('learning representation with', self, 'and reg', reg) 58 | dim_x = environments[0][0].size(1) 59 | 60 | self.phi = torch.nn.Parameter(torch.eye(dim_x, dim_x)) 61 | self.w = torch.ones(dim_x, 1) 62 | self.w.requires_grad = True 63 | 64 | opt = torch.optim.Adam([self.phi], lr=args["lr"]) 65 | loss = torch.nn.MSELoss() 66 | 67 | for iteration in range(args["n_iterations"]): 68 | penalty = 0 69 | error = 0 70 | for x_e, y_e in environments: 71 | error_e = loss(x_e @ self.phi @ self.w, y_e) 72 | penalty += grad(error_e, self.w, 73 | create_graph=True)[0].pow(2).mean() 74 | error += error_e 75 | 76 | opt.zero_grad() 77 | # (reg * error + (1 - reg) * penalty).backward() # dumb; zero reg means regularize 100% 78 | ((1 - reg) * error + reg * penalty).backward() # good 79 | opt.step() 80 | 81 | if args["verbose"] and iteration % 1000 == 0: 82 | w_str = pretty(self.solution()) 83 | print("{:05d} | {:.5f} | {:.5f} | {:.5f} | {}".format(iteration, 84 | reg, 85 | error, 86 | penalty, 87 | w_str)) 88 | 89 | def solution(self): 90 | return self.phi @ self.w 91 | 92 | 93 | class LearnedEnvInvariantRiskMinimization(InvariantRiskMinimization): 94 | def __init__(self, environments, args, pretrain=False): 95 | best_reg = 0 96 | best_err = 1e6 97 | 98 | x_val = environments[-1][0] 99 | y_val = environments[-1][1] 100 | 101 | if args['eiil_ref_alpha'] >= 0 and args['eiil_ref_alpha'] <= 1: 102 | print('Using hard-coded reference classifier with alpha={:.2f}'.format( 103 | args['eiil_ref_alpha'] 104 | )) 105 | alpha = args['eiil_ref_alpha'] 106 | w_causal = (1. - alpha) * np.ones((1, 5)) 107 | w_noncausal = alpha * np.ones((1, 5)) # spurious contribution to prediction 108 | w_ref = np.hstack((w_causal, w_noncausal)) 109 | w_ref = torch.tensor(w_ref, dtype=torch.float32) 110 | else: 111 | print('Using ERM soln as reference classifier.') 112 | w_ref = EmpiricalRiskMinimizer(environments, args).solution() 113 | 114 | self.phi = torch.nn.Parameter(torch.diag(w_ref.squeeze())) 115 | dim_x = environments[0][0].size(1) 116 | self.w = torch.ones(dim_x, 1) 117 | self.w.requires_grad = True 118 | err = (x_val @ self.solution() - y_val).pow(2).mean().item() 119 | 120 | if args["verbose"]: 121 | print("EIIL's reference classifier has {:.3f} validation error.".format( 122 | err)) 123 | print("EIIL's reference classifier has the following solution:\n.", 124 | pretty(self.solution())) 125 | 126 | self.phi = self.phi.clone() 127 | 128 | environments = self.split(environments, args) 129 | if args["verbose"]: 130 | print("EIIL+ERM ref clf still has the following solution after AED (sanity check):\n.", pretty(self.solution())) 131 | best_reg = 0 132 | best_err = 1e6 133 | 134 | for reg in [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]: 135 | reg = 1. - reg # change of variables for consistency with old codebase 136 | self.train(environments[:-1], args, reg=reg) 137 | err = (x_val @ self.solution() - y_val).pow(2).mean().item() 138 | 139 | if args["verbose"]: 140 | print("EIIL+IRM (reg={:.6f}) has {:.3f} validation error.".format( 141 | reg, err)) 142 | 143 | if err < best_err: 144 | best_err = err 145 | best_reg = reg 146 | best_phi = self.phi.clone() 147 | self.phi = best_phi 148 | 149 | 150 | def split(self, environments, args, n_samples=-1): 151 | """Learn soft environment assignment.""" 152 | envs = environments 153 | test_env = envs[-1] 154 | x = torch.cat((envs[0][0][:n_samples],envs[1][0][:n_samples]),0) 155 | y = torch.cat((envs[0][1][:n_samples],envs[1][1][:n_samples]),0) 156 | print('size of pooled envs: '+str(len(x))) 157 | 158 | loss = torch.nn.MSELoss(reduction='none') 159 | error = loss(x @ self.phi @ self.w, y) 160 | 161 | env_w = torch.randn(len(error)).requires_grad_() 162 | optimizer = torch.optim.Adam([env_w], lr=0.001) 163 | 164 | print('learning soft environment assignments') 165 | with tqdm(total=args['n_iterations'], 166 | position=1, 167 | bar_format='{desc}', 168 | desc='negative penalty: ', 169 | ) as desc: 170 | for i in tqdm(range(args["n_iterations"])): 171 | # penalty for env a 172 | error_a = (error.squeeze() * env_w.sigmoid()).mean() 173 | penalty_a = grad(error_a, self.w, create_graph=True)[0].pow(2).mean() 174 | # penalty for env b 175 | error_b = (error.squeeze() * (1-env_w.sigmoid())).mean() 176 | penalty_b = grad(error_b, self.w, create_graph=True)[0].pow(2).mean() 177 | # negate 178 | npenalty = - torch.stack([penalty_a, penalty_b]).mean() 179 | desc.set_description('negative penalty: '+ str(npenalty)) 180 | 181 | optimizer.zero_grad() 182 | npenalty.backward(retain_graph=True) 183 | optimizer.step() 184 | 185 | envs = [] 186 | idx0 = (env_w.sigmoid()>.5) 187 | idx1 = (env_w.sigmoid()<=.5) 188 | envs.append((x[idx0],y[idx0])) 189 | print('size of env 0: '+str(len(x[idx0]))) 190 | envs.append((x[idx1],y[idx1])) 191 | print('size of env 1: '+str(len(x[idx1]))) 192 | print('weights: '+str(env_w.sigmoid())) 193 | envs.append(test_env) 194 | return envs 195 | 196 | 197 | class InvariantCausalPrediction(object): 198 | def __init__(self, environments, args): 199 | self.coefficients = None 200 | self.alpha = args["alpha"] 201 | 202 | x_all = [] 203 | y_all = [] 204 | e_all = [] 205 | 206 | for e, (x, y) in enumerate(environments): 207 | x_all.append(x.numpy()) 208 | y_all.append(y.numpy()) 209 | e_all.append(np.full(x.shape[0], e)) 210 | 211 | x_all = np.vstack(x_all) 212 | y_all = np.vstack(y_all) 213 | e_all = np.hstack(e_all) 214 | 215 | dim = x_all.shape[1] 216 | 217 | accepted_subsets = [] 218 | for subset in self.powerset(range(dim)): 219 | if len(subset) == 0: 220 | continue 221 | 222 | x_s = x_all[:, subset] 223 | reg = LinearRegression(fit_intercept=False).fit(x_s, y_all) 224 | 225 | p_values = [] 226 | for e in range(len(environments)): 227 | e_in = np.where(e_all == e)[0] 228 | e_out = np.where(e_all != e)[0] 229 | 230 | res_in = (y_all[e_in] - reg.predict(x_s[e_in, :])).ravel() 231 | res_out = (y_all[e_out] - reg.predict(x_s[e_out, :])).ravel() 232 | 233 | p_values.append(self.mean_var_test(res_in, res_out)) 234 | 235 | # TODO: Jonas uses "min(p_values) * len(environments) - 1" 236 | p_value = min(p_values) * len(environments) 237 | 238 | if p_value > self.alpha: 239 | accepted_subsets.append(set(subset)) 240 | if args["verbose"]: 241 | print("Accepted subset:", subset) 242 | 243 | if len(accepted_subsets): 244 | accepted_features = list(set.intersection(*accepted_subsets)) 245 | if args["verbose"]: 246 | print("Intersection:", accepted_features) 247 | self.coefficients = np.zeros(dim) 248 | 249 | if len(accepted_features): 250 | x_s = x_all[:, list(accepted_features)] 251 | reg = LinearRegression(fit_intercept=False).fit(x_s, y_all) 252 | self.coefficients[list(accepted_features)] = reg.coef_ 253 | 254 | self.coefficients = torch.Tensor(self.coefficients) 255 | else: 256 | self.coefficients = torch.zeros(dim) 257 | 258 | def mean_var_test(self, x, y): 259 | pvalue_mean = ttest_ind(x, y, equal_var=False).pvalue 260 | pvalue_var1 = 1 - fdist.cdf(np.var(x, ddof=1) / np.var(y, ddof=1), 261 | x.shape[0] - 1, 262 | y.shape[0] - 1) 263 | 264 | pvalue_var2 = 2 * min(pvalue_var1, 1 - pvalue_var1) 265 | 266 | return 2 * min(pvalue_mean, pvalue_var2) 267 | 268 | def powerset(self, s): 269 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) 270 | 271 | def solution(self): 272 | return self.coefficients 273 | 274 | 275 | class EmpiricalRiskMinimizer(object): 276 | def __init__(self, environments, args): 277 | x_all = torch.cat([x for (x, y) in environments]).numpy() 278 | y_all = torch.cat([y for (x, y) in environments]).numpy() 279 | 280 | w = LinearRegression(fit_intercept=False).fit(x_all, y_all).coef_ 281 | self.w = torch.Tensor(w) 282 | if args['verbose']: 283 | print('Done training ERM.') 284 | err = np.mean((x_all.dot(self.solution().T) - y_all) ** 2.).item() 285 | print("ERM has {:.3f} train error.".format(err)) 286 | print("ERM has the following solution:\n ", pretty(self.solution())) 287 | 288 | def solution(self): 289 | return self.w 290 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | 401 | -------------------------------------------------------------------------------- /InvariantRiskMinimization/LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /notebooks/sem_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "sem_results.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "metadata": { 30 | "id": "Uhz1a5KgB5hg", 31 | "colab_type": "code", 32 | "colab": { 33 | "base_uri": "https://localhost:8080/", 34 | "height": 1000 35 | }, 36 | "outputId": "19285df6-b2bc-42b9-d38e-2213aa76552a" 37 | }, 38 | "source": [ 39 | "\"\"\"Load results from disk. See the run_sems.sh experiment producing results.\"\"\"\n", 40 | "from glob import glob\n", 41 | "import os\n", 42 | "import pickle\n", 43 | "from pprint import pprint\n", 44 | "\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "\n", 47 | "def load_results(dirname):\n", 48 | " return pickle.load(open(os.path.join(dirname, 'metrics.p'), 'rb'))\n", 49 | " \n", 50 | "def load_flags(dirname):\n", 51 | " return pickle.load(open(os.path.join(dirname, 'flags.p'), 'rb'))\n", 52 | "\n", 53 | "# ROOT = '/scratch/gobi1/creager/opt_env/run_sems_alpha_sweep/43878915'\n", 54 | "ROOT = '/PATH/TO/SWEEP/DIR'\n", 55 | "\n", 56 | "# baselines include HandCrafted->ERM, ICP, IRM, and ERM->EIIL->IRM\n", 57 | "BASELINE_GLOB_PATTERN = '{}/chain_hidden=0_hetero=2_scramble=0/all_solutions.txt'.format(ROOT) # TODO(): change\n", 58 | "\n", 59 | "ALPHAS = (\n", 60 | " 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0\n", 61 | ") # NOTE: UDL workshop paper missed two x-tick marks on this plot\n", 62 | "from collections import OrderedDict\n", 63 | "ALPHA_SWEEP_GLOB_PATTERNS = OrderedDict()\n", 64 | "for alpha in ALPHAS:\n", 65 | " ALPHA_SWEEP_GLOB_PATTERNS[alpha] = \\\n", 66 | " '{}/chain_hidden=0_hetero=2_scramble=0_alpha_{}/all_solutions.txt'.format(ROOT, str(alpha)) # TODO(): change\n", 67 | "\n", 68 | "def read_lines(results_glob):\n", 69 | " all_lines = []\n", 70 | " for fn in glob(results_glob):\n", 71 | " lines = open(fn, 'r').read().splitlines()\n", 72 | " all_lines.extend(lines)\n", 73 | " return all_lines\n", 74 | "\n", 75 | "def get_results(all_solutions):\n", 76 | " results = {}\n", 77 | "\n", 78 | " for line in all_solutions:\n", 79 | " words = line.split(\" \")\n", 80 | " setup = str(words[0])\n", 81 | " model = str(words[1])\n", 82 | " err_causal = float(words[-2])\n", 83 | " err_noncausal = float(words[-1])\n", 84 | "\n", 85 | " if setup not in results:\n", 86 | " results[setup] = {}\n", 87 | "\n", 88 | " if model not in results[setup]:\n", 89 | " results[setup][model] = []\n", 90 | "\n", 91 | " results[setup][model].append([err_causal, err_noncausal])\n", 92 | " return results\n", 93 | "\n", 94 | "results_lines = dict()\n", 95 | "results_lines['baselines'] = read_lines(BASELINE_GLOB_PATTERN)\n", 96 | "results_lines['alpha_sweep'] = OrderedDict()\n", 97 | "for a, agp in ALPHA_SWEEP_GLOB_PATTERNS.items():\n", 98 | " results_lines['alpha_sweep'][a] = read_lines(agp)\n", 99 | "\n", 100 | "SETTING = 'chain_hidden=0_hetero=2_scramble=0' # apply filter to yield only this data setting\n", 101 | "# results = {k: get_results(rl)[SETTING] for k, rl in results_lines.items()}\n", 102 | "results = dict()\n", 103 | "results['baselines'] = get_results(results_lines['baselines'])[SETTING]\n", 104 | "results['alpha_sweep'] = OrderedDict()\n", 105 | "for a in ALPHAS:\n", 106 | " results['alpha_sweep'][a] = get_results(results_lines['alpha_sweep'][a])[SETTING]\n", 107 | "\n", 108 | "from pprint import pprint\n", 109 | "pprint(results)\n", 110 | "for a in ALPHAS:\n", 111 | " print(a, results['alpha_sweep'][a])" 112 | ], 113 | "execution_count": null, 114 | "outputs": [ 115 | { 116 | "output_type": "stream", 117 | "text": [ 118 | "{'alpha_sweep': OrderedDict([(0.1,\n", 119 | " {'EIIL': [[1.20987, 1.08879],\n", 120 | " [1.03908, 0.92657],\n", 121 | " [1.27709, 1.12844],\n", 122 | " [1.56064, 1.35345],\n", 123 | " [0.58332, 0.56332]],\n", 124 | " 'SEM': [[0.0, 0.0],\n", 125 | " [0.0, 0.0],\n", 126 | " [0.0, 0.0],\n", 127 | " [0.0, 0.0],\n", 128 | " [0.0, 0.0]]}),\n", 129 | " (0.2,\n", 130 | " {'EIIL': [[0.58675, 0.56536],\n", 131 | " [0.42208, 0.41769],\n", 132 | " [0.47998, 0.46702],\n", 133 | " [0.49813, 0.49398],\n", 134 | " [0.47181, 0.47596]],\n", 135 | " 'SEM': [[0.0, 0.0],\n", 136 | " [0.0, 0.0],\n", 137 | " [0.0, 0.0],\n", 138 | " [0.0, 0.0],\n", 139 | " [0.0, 0.0]]}),\n", 140 | " (0.3,\n", 141 | " {'EIIL': [[1.13695, 1.00236],\n", 142 | " [0.57304, 0.54475],\n", 143 | " [0.46865, 0.45732],\n", 144 | " [0.4946, 0.48862],\n", 145 | " [0.4725, 0.47559]],\n", 146 | " 'SEM': [[0.0, 0.0],\n", 147 | " [0.0, 0.0],\n", 148 | " [0.0, 0.0],\n", 149 | " [0.0, 0.0],\n", 150 | " [0.0, 0.0]]}),\n", 151 | " (0.4,\n", 152 | " {'EIIL': [[1.07029, 0.94224],\n", 153 | " [0.61972, 0.56385],\n", 154 | " [0.46864, 0.45667],\n", 155 | " [0.50413, 0.49556],\n", 156 | " [0.72768, 0.6986]],\n", 157 | " 'SEM': [[0.0, 0.0],\n", 158 | " [0.0, 0.0],\n", 159 | " [0.0, 0.0],\n", 160 | " [0.0, 0.0],\n", 161 | " [0.0, 0.0]]}),\n", 162 | " (0.5,\n", 163 | " {'EIIL': [[1.20125, 1.06908],\n", 164 | " [0.66164, 0.58418],\n", 165 | " [0.47249, 0.46022],\n", 166 | " [0.49825, 0.48423],\n", 167 | " [1.28127, 1.16425]],\n", 168 | " 'SEM': [[0.0, 0.0],\n", 169 | " [0.0, 0.0],\n", 170 | " [0.0, 0.0],\n", 171 | " [0.0, 0.0],\n", 172 | " [0.0, 0.0]]}),\n", 173 | " (0.6,\n", 174 | " {'EIIL': [[1.2228, 1.05965],\n", 175 | " [1.44376, 1.32061],\n", 176 | " [0.46628, 0.45464],\n", 177 | " [0.47986, 0.47615],\n", 178 | " [0.83696, 0.75114]],\n", 179 | " 'SEM': [[0.0, 0.0],\n", 180 | " [0.0, 0.0],\n", 181 | " [0.0, 0.0],\n", 182 | " [0.0, 0.0],\n", 183 | " [0.0, 0.0]]}),\n", 184 | " (0.7,\n", 185 | " {'EIIL': [[0.56202, 0.57157],\n", 186 | " [0.12392, 0.14817],\n", 187 | " [0.45657, 0.44559],\n", 188 | " [0.46264, 0.45731],\n", 189 | " [0.46125, 0.46519]],\n", 190 | " 'SEM': [[0.0, 0.0],\n", 191 | " [0.0, 0.0],\n", 192 | " [0.0, 0.0],\n", 193 | " [0.0, 0.0],\n", 194 | " [0.0, 0.0]]}),\n", 195 | " (0.8,\n", 196 | " {'EIIL': [[0.61086, 0.49274],\n", 197 | " [0.06665, 0.06252],\n", 198 | " [0.50138, 0.483],\n", 199 | " [0.70804, 0.63456],\n", 200 | " [0.56308, 0.55354]],\n", 201 | " 'SEM': [[0.0, 0.0],\n", 202 | " [0.0, 0.0],\n", 203 | " [0.0, 0.0],\n", 204 | " [0.0, 0.0],\n", 205 | " [0.0, 0.0]]}),\n", 206 | " (0.9,\n", 207 | " {'EIIL': [[0.18227, 0.1929],\n", 208 | " [0.03049, 0.01854],\n", 209 | " [0.51175, 0.49186],\n", 210 | " [0.0655, 0.06096],\n", 211 | " [0.08848, 0.10323]],\n", 212 | " 'SEM': [[0.0, 0.0],\n", 213 | " [0.0, 0.0],\n", 214 | " [0.0, 0.0],\n", 215 | " [0.0, 0.0],\n", 216 | " [0.0, 0.0]]}),\n", 217 | " (1.0,\n", 218 | " {'EIIL': [[0.08395, 0.0869],\n", 219 | " [0.03107, 0.03012],\n", 220 | " [0.4985, 0.48285],\n", 221 | " [0.03681, 0.03995],\n", 222 | " [0.07054, 0.08129]],\n", 223 | " 'SEM': [[0.0, 0.0],\n", 224 | " [0.0, 0.0],\n", 225 | " [0.0, 0.0],\n", 226 | " [0.0, 0.0],\n", 227 | " [0.0, 0.0]]})]),\n", 228 | " 'baselines': {'EIIL': [[0.09301, 0.09893],\n", 229 | " [0.02745, 0.0234],\n", 230 | " [0.51578, 0.49595],\n", 231 | " [0.06457, 0.05908],\n", 232 | " [0.03721, 0.04549]],\n", 233 | " 'ERM': [[0.85369, 0.84856],\n", 234 | " [0.82197, 0.81206],\n", 235 | " [0.82215, 0.81623],\n", 236 | " [0.82626, 0.82619],\n", 237 | " [0.81081, 0.81841]],\n", 238 | " 'ICP': [[1.0, 0.95135],\n", 239 | " [1.0, 0.0],\n", 240 | " [1.0, 0.93675],\n", 241 | " [1.0, 0.943],\n", 242 | " [1.0, 0.95139]],\n", 243 | " 'IRM': [[0.79471, 0.74555],\n", 244 | " [0.56653, 0.55343],\n", 245 | " [0.65777, 0.62865],\n", 246 | " [0.66595, 0.65153],\n", 247 | " [0.64409, 0.64043]],\n", 248 | " 'SEM': [[0.0, 0.0],\n", 249 | " [0.0, 0.0],\n", 250 | " [0.0, 0.0],\n", 251 | " [0.0, 0.0],\n", 252 | " [0.0, 0.0]]}}\n", 253 | "0.1 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[1.20987, 1.08879], [1.03908, 0.92657], [1.27709, 1.12844], [1.56064, 1.35345], [0.58332, 0.56332]]}\n", 254 | "0.2 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[0.58675, 0.56536], [0.42208, 0.41769], [0.47998, 0.46702], [0.49813, 0.49398], [0.47181, 0.47596]]}\n", 255 | "0.3 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[1.13695, 1.00236], [0.57304, 0.54475], [0.46865, 0.45732], [0.4946, 0.48862], [0.4725, 0.47559]]}\n", 256 | "0.4 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[1.07029, 0.94224], [0.61972, 0.56385], [0.46864, 0.45667], [0.50413, 0.49556], [0.72768, 0.6986]]}\n", 257 | "0.5 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[1.20125, 1.06908], [0.66164, 0.58418], [0.47249, 0.46022], [0.49825, 0.48423], [1.28127, 1.16425]]}\n", 258 | "0.6 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[1.2228, 1.05965], [1.44376, 1.32061], [0.46628, 0.45464], [0.47986, 0.47615], [0.83696, 0.75114]]}\n", 259 | "0.7 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[0.56202, 0.57157], [0.12392, 0.14817], [0.45657, 0.44559], [0.46264, 0.45731], [0.46125, 0.46519]]}\n", 260 | "0.8 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[0.61086, 0.49274], [0.06665, 0.06252], [0.50138, 0.483], [0.70804, 0.63456], [0.56308, 0.55354]]}\n", 261 | "0.9 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[0.18227, 0.1929], [0.03049, 0.01854], [0.51175, 0.49186], [0.0655, 0.06096], [0.08848, 0.10323]]}\n", 262 | "1.0 {'SEM': [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 'EIIL': [[0.08395, 0.0869], [0.03107, 0.03012], [0.4985, 0.48285], [0.03681, 0.03995], [0.07054, 0.08129]]}\n" 263 | ], 264 | "name": "stdout" 265 | } 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "metadata": { 271 | "id": "ZgDWX6D63qqd", 272 | "colab_type": "code", 273 | "colab": { 274 | "base_uri": "https://localhost:8080/", 275 | "height": 301 276 | }, 277 | "outputId": "e4915737-eb2f-4895-bb77-58bfa4d2eb75" 278 | }, 279 | "source": [ 280 | "\"\"\"Produce table from baseline errors.\"\"\"\n", 281 | "from collections import defaultdict\n", 282 | "import numpy as np\n", 283 | "import pandas as pd\n", 284 | "\n", 285 | "LATEX_SYMBOLS = dict(\n", 286 | " ERM=r'\\textsc{ERM}',\n", 287 | " ICP=r'\\textsc{ICP}',\n", 288 | " IRM=r'\\textsc{IRM}($e_{\\textsc{HC}}$)',\n", 289 | " EIIL=r'\\textsc{IRM}($e_{\\textsc{EIIL}}| \\tilde \\Phi = \\Phi_{ERM}$)',\n", 290 | " alpha=r'\\textsc{IRM}($e_{\\textsc{EIIL}}| \\tilde \\Phi = \\Phi_{\\alpha-Spurious}$)',\n", 291 | ")\n", 292 | "\n", 293 | "baseline_errors = defaultdict(dict)\n", 294 | "for method, v in results['baselines'].items():\n", 295 | " if method == 'SEM':\n", 296 | " continue\n", 297 | " results_ = np.stack(v)\n", 298 | " mean = np.mean(results_, 0)\n", 299 | " std = np.std(results_, 0)\n", 300 | " baseline_errors['Causal err.'][LATEX_SYMBOLS[method]] = (mean[0], std[0])\n", 301 | " baseline_errors['Noncausal err.'][LATEX_SYMBOLS[method]] = (mean[1], std[1])\n", 302 | "\n", 303 | "def f(xy):\n", 304 | " \"\"\"Format mean plus/minus std dev.\"\"\"\n", 305 | " x, y = xy\n", 306 | " return r\"\"\"%.3f $\\pm$ %.3f\"\"\" % (x, y)\n", 307 | "\n", 308 | "re = pd.DataFrame.from_dict(baseline_errors)\n", 309 | "results_tex = re.to_latex(\n", 310 | " formatters=[f, ] * len(re.T),\n", 311 | " escape=False,\n", 312 | " caption=(\n", 313 | " 'Results on synthetic data. MSE on causal and non-causal portions of the '\n", 314 | " 'ground truth regression coefficient are reported, plus/minus a standard '\n", 315 | " 'deviation across ten runs. '\n", 316 | " )\n", 317 | ")\n", 318 | "print(results_tex)" 319 | ], 320 | "execution_count": null, 321 | "outputs": [ 322 | { 323 | "output_type": "stream", 324 | "text": [ 325 | "\\begin{table}\n", 326 | "\\centering\n", 327 | "\\caption{Results on synthetic data. MSE on causal and non-causal portions of the ground truth regression coefficient are reported, plus/minus a standard deviation across ten runs. }\n", 328 | "\\begin{tabular}{lll}\n", 329 | "\\toprule\n", 330 | "{} & Causal err. & Noncausal err. \\\\\n", 331 | "\\midrule\n", 332 | "\\textsc{IRM}($e_{\\textsc{EIIL}}| \\tilde \\Phi = ... & 0.148 $\\pm$ 0.186 & 0.145 $\\pm$ 0.177 \\\\\n", 333 | "\\textsc{ERM} & 0.827 $\\pm$ 0.014 & 0.824 $\\pm$ 0.013 \\\\\n", 334 | "\\textsc{ICP} & 1.000 $\\pm$ 0.000 & 0.756 $\\pm$ 0.378 \\\\\n", 335 | "\\textsc{IRM}($e_{\\textsc{HC}}$) & 0.666 $\\pm$ 0.073 & 0.644 $\\pm$ 0.061 \\\\\n", 336 | "\\bottomrule\n", 337 | "\\end{tabular}\n", 338 | "\\end{table}\n", 339 | "\n" 340 | ], 341 | "name": "stdout" 342 | } 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "metadata": { 348 | "id": "3VfKtY2HJ5xw", 349 | "colab_type": "code", 350 | "colab": { 351 | "base_uri": "https://localhost:8080/", 352 | "height": 1000 353 | }, 354 | "outputId": "ac275104-516c-4b90-98ed-07186d795b4f" 355 | }, 356 | "source": [ 357 | "\"\"\"Aggregate error measurements.\"\"\"\n", 358 | "\n", 359 | "alpha_sweep_errors = dict()\n", 360 | "for k in ('Causal err.', 'Noncausal err.'):\n", 361 | " alpha_sweep_errors[k] = defaultdict(list)\n", 362 | "\n", 363 | "for a, v in results['alpha_sweep'].items():\n", 364 | " v = v['EIIL']\n", 365 | " results_ = np.stack(v)\n", 366 | " mean = np.mean(results_, 0)\n", 367 | " std = np.std(results_, 0)\n", 368 | " alpha_sweep_errors['Causal err.']['mean'].append(mean[0])\n", 369 | " alpha_sweep_errors['Causal err.']['upper'].append(mean[0] + std[0])\n", 370 | " alpha_sweep_errors['Causal err.']['lower'].append(mean[0] - std[0])\n", 371 | " alpha_sweep_errors['Noncausal err.']['mean'].append(mean[1])\n", 372 | " alpha_sweep_errors['Noncausal err.']['upper'].append(mean[1] + std[1])\n", 373 | " alpha_sweep_errors['Noncausal err.']['lower'].append(mean[1] - std[1])\n", 374 | "\n", 375 | "pprint(alpha_sweep_errors)\n" 376 | ], 377 | "execution_count": null, 378 | "outputs": [ 379 | { 380 | "output_type": "stream", 381 | "text": [ 382 | "{'Causal err.': defaultdict(,\n", 383 | " {'lower': [0.8113512859470844,\n", 384 | " 0.4379805852737823,\n", 385 | " 0.3724721939410729,\n", 386 | " 0.4618428616988267,\n", 387 | " 0.4744283322088388,\n", 388 | " 0.4980206186699855,\n", 389 | " 0.263303767349623,\n", 390 | " 0.26781109357941757,\n", 391 | " 0.0003069626691261096,\n", 392 | " -0.03410407140531893],\n", 393 | " 'mean': [1.134,\n", 394 | " 0.49175,\n", 395 | " 0.629148,\n", 396 | " 0.678092,\n", 397 | " 0.8229799999999999,\n", 398 | " 0.889932,\n", 399 | " 0.41328,\n", 400 | " 0.49000199999999994,\n", 401 | " 0.175698,\n", 402 | " 0.144174],\n", 403 | " 'upper': [1.4566487140529154,\n", 404 | " 0.5455194147262178,\n", 405 | " 0.8858238060589272,\n", 406 | " 0.8943411383011733,\n", 407 | " 1.1715316677911611,\n", 408 | " 1.2818433813300145,\n", 409 | " 0.5632562326503769,\n", 410 | " 0.7121929064205823,\n", 411 | " 0.3510890373308739,\n", 412 | " 0.32245207140531895]}),\n", 413 | " 'Noncausal err.': defaultdict(,\n", 414 | " {'lower': [0.7495668451801467,\n", 415 | " 0.4361083092255358,\n", 416 | " 0.38733498796713106,\n", 417 | " 0.45549572562106355,\n", 418 | " 0.4505679074295096,\n", 419 | " 0.47651567949122525,\n", 420 | " 0.27549248922476793,\n", 421 | " 0.24641143916402125,\n", 422 | " 0.004178671109291238,\n", 423 | " -0.026545143842133745],\n", 424 | " 'mean': [1.012114,\n", 425 | " 0.484002,\n", 426 | " 0.593728,\n", 427 | " 0.631384,\n", 428 | " 0.7523920000000001,\n", 429 | " 0.812438,\n", 430 | " 0.41756600000000005,\n", 431 | " 0.445272,\n", 432 | " 0.173498,\n", 433 | " 0.14422200000000002],\n", 434 | " 'upper': [1.2746611548198532,\n", 435 | " 0.5318956907744642,\n", 436 | " 0.800121012032869,\n", 437 | " 0.8072722743789363,\n", 438 | " 1.0542160925704906,\n", 439 | " 1.1483603205087747,\n", 440 | " 0.5596395107752321,\n", 441 | " 0.6441325608359787,\n", 442 | " 0.3428173288907088,\n", 443 | " 0.3149891438421338]})}\n" 444 | ], 445 | "name": "stdout" 446 | } 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "metadata": { 452 | "id": "-28z06pUENzh", 453 | "colab_type": "code", 454 | "colab": { 455 | "base_uri": "https://localhost:8080/", 456 | "height": 755 457 | }, 458 | "outputId": "60708d07-2422-4870-e8c4-973b6904f4af" 459 | }, 460 | "source": [ 461 | "from matplotlib import rc\n", 462 | "rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})\n", 463 | "rc('text', usetex=True)\n", 464 | "\n", 465 | "import matplotlib.pyplot as plt\n", 466 | "import matplotlib.patches as mpatches\n", 467 | "from matplotlib.colors import colorConverter as cc\n", 468 | "import numpy as np\n", 469 | "import seaborn as sns\n", 470 | "\n", 471 | "TITLE_FONTSIZE = 16\n", 472 | "AXIS_FONTSIZE = 16\n", 473 | "LEGEND_FONTSIZE = 10\n", 474 | "FIGSIZE = (6, 4)\n", 475 | "SHADING_OPACITY = .5 # how opaque should the uncertainty fills be\n", 476 | "\n", 477 | "\n", 478 | "class LegendObject(object):\n", 479 | " def __init__(self, facecolor='red', edgecolor='white', dashed=False):\n", 480 | " self.facecolor = facecolor\n", 481 | " self.edgecolor = edgecolor\n", 482 | " self.dashed = dashed\n", 483 | " \n", 484 | " def legend_artist(self, legend, orig_handle, fontsize, handlebox):\n", 485 | " x0, y0 = handlebox.xdescent, handlebox.ydescent\n", 486 | " width, height = handlebox.width, handlebox.height\n", 487 | " patch = mpatches.Rectangle(\n", 488 | " # create a rectangle that is filled with color\n", 489 | " [x0, y0], width, height, facecolor=self.facecolor,\n", 490 | " # and whose edges are the faded color\n", 491 | " edgecolor=self.edgecolor, lw=3)\n", 492 | " handlebox.add_artist(patch)\n", 493 | " \n", 494 | " # if we're creating the legend for a dashed line,\n", 495 | " # manually add the dash in to our rectangle\n", 496 | " if self.dashed:\n", 497 | " patch1 = mpatches.Rectangle(\n", 498 | " [x0 + 2*width/5, y0], width/5, height, facecolor=self.edgecolor,\n", 499 | " transform=handlebox.get_transform())\n", 500 | " handlebox.add_artist(patch1)\n", 501 | " \n", 502 | " return patch\n", 503 | " \n", 504 | " \n", 505 | "def plot_mean_and_CI(x_axis_vec, mean, lb, ub, color_mean=None, color_shading=None):\n", 506 | " # plot the shaded range of the confidence intervals\n", 507 | " plt.fill_between(x_axis_vec, ub, lb,\n", 508 | " color=color_shading, alpha=.5)\n", 509 | " # plt.fill_between(range(mean.shape[0]), ub, lb,\n", 510 | " # color=color_shading, alpha=.5)\n", 511 | " # plot the mean on top\n", 512 | " plt.plot(x_axis_vec, mean, color=color_mean)\n", 513 | " \n", 514 | "# plot the data\n", 515 | "a = lambda x: np.array(x)\n", 516 | "bg = np.array([1, 1, 1]) # background of the legend is white\n", 517 | "for metric in ('Causal err.', 'Noncausal err.'):\n", 518 | " from itertools import cycle\n", 519 | " handler_map = dict()\n", 520 | " colors = iter('mgkb')\n", 521 | " legend_names = []\n", 522 | " plt.figure(figsize=FIGSIZE)\n", 523 | " color = next(colors)\n", 524 | " color_faded = (np.array(cc.to_rgb(color)) + bg) * SHADING_OPACITY\n", 525 | " handler_map[len(legend_names)] = LegendObject(color, color_faded)\n", 526 | " print(handler_map)\n", 527 | " legend_names.append(LATEX_SYMBOLS['alpha'])\n", 528 | " a = lambda x: np.array(x)\n", 529 | " plot_mean_and_CI(ALPHAS,\n", 530 | " a(alpha_sweep_errors[metric]['mean']),\n", 531 | " a(alpha_sweep_errors[metric]['lower']),\n", 532 | " a(alpha_sweep_errors[metric]['upper']),\n", 533 | " color_mean=color, color_shading=color_faded)\n", 534 | " \n", 535 | " for m, v in baseline_errors[metric].items():\n", 536 | " if m == LATEX_SYMBOLS['ICP']: # don't plot this baseline b/c it doesn't work anyways\n", 537 | " continue\n", 538 | " color = next(colors)\n", 539 | " color_faded = (np.array(cc.to_rgb(color)) + bg) * SHADING_OPACITY\n", 540 | " handler_map[len(legend_names)] = LegendObject(color, color_faded)\n", 541 | " print(handler_map)\n", 542 | " legend_names.append(m)\n", 543 | " mean = v[0] * np.ones(len(ALPHAS))\n", 544 | " upper = (v[0] + v[1]) * np.ones(len(ALPHAS))\n", 545 | " lower = (v[0] - v[1]) * np.ones(len(ALPHAS))\n", 546 | " plot_mean_and_CI(ALPHAS,\n", 547 | " mean,\n", 548 | " lower,\n", 549 | " upper,\n", 550 | " color_mean=color, color_shading=color_faded)\n", 551 | " plt.legend(range(len(legend_names)), legend_names, handler_map=handler_map,\n", 552 | " loc='upper right', fontsize=LEGEND_FONTSIZE)\n", 553 | " plt.xlabel(r'$\\alpha$', fontsize=AXIS_FONTSIZE)\n", 554 | " if metric == 'Causal err.':\n", 555 | " plt.ylabel('MSE', fontsize=AXIS_FONTSIZE)\n", 556 | " # plt.title(metric, fontsize=TITLE_FONTSIZE)\n", 557 | " plt.tight_layout()\n", 558 | " plt.grid()\n", 559 | " plt.show()" 560 | ], 561 | "execution_count": null, 562 | "outputs": [ 563 | { 564 | "output_type": "stream", 565 | "text": [ 566 | "{0: <__main__.LegendObject object at 0x7f59bab72f98>}\n", 567 | "{0: <__main__.LegendObject object at 0x7f59bab72f98>, 1: <__main__.LegendObject object at 0x7f59bab51898>}\n", 568 | "{0: <__main__.LegendObject object at 0x7f59bab72f98>, 1: <__main__.LegendObject object at 0x7f59bab51898>, 2: <__main__.LegendObject object at 0x7f59bab51ef0>}\n", 569 | "{0: <__main__.LegendObject object at 0x7f59bab72f98>, 1: <__main__.LegendObject object at 0x7f59bab51898>, 2: <__main__.LegendObject object at 0x7f59bab51ef0>, 3: <__main__.LegendObject object at 0x7f59b8aeb550>}\n" 570 | ], 571 | "name": "stdout" 572 | }, 573 | { 574 | "output_type": "stream", 575 | "text": [ 576 | "findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n" 577 | ], 578 | "name": "stderr" 579 | }, 580 | { 581 | "output_type": "display_data", 582 | "data": { 583 | "text/plain": [ 584 | "
" 585 | ], 586 | "image/png": "\n" 587 | }, 588 | "metadata": { 589 | "tags": [], 590 | "needs_background": "light" 591 | } 592 | }, 593 | { 594 | "output_type": "stream", 595 | "text": [ 596 | "{0: <__main__.LegendObject object at 0x7f5a0dfa3fd0>}\n", 597 | "{0: <__main__.LegendObject object at 0x7f5a0dfa3fd0>, 1: <__main__.LegendObject object at 0x7f5a0cebfa58>}\n", 598 | "{0: <__main__.LegendObject object at 0x7f5a0dfa3fd0>, 1: <__main__.LegendObject object at 0x7f5a0cebfa58>, 2: <__main__.LegendObject object at 0x7f5a0cecd080>}\n", 599 | "{0: <__main__.LegendObject object at 0x7f5a0dfa3fd0>, 1: <__main__.LegendObject object at 0x7f5a0cebfa58>, 2: <__main__.LegendObject object at 0x7f5a0cecd080>, 3: <__main__.LegendObject object at 0x7f5a0cecd668>}\n" 600 | ], 601 | "name": "stdout" 602 | }, 603 | { 604 | "output_type": "display_data", 605 | "data": { 606 | "text/plain": [ 607 | "
" 608 | ], 609 | "image/png": "\n" 610 | }, 611 | "metadata": { 612 | "tags": [], 613 | "needs_background": "light" 614 | } 615 | } 616 | ] 617 | } 618 | ] 619 | } --------------------------------------------------------------------------------