├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data └── results_processed.tar.gz ├── reproduce_plots.sh ├── reproduce_results.sh ├── requirements.txt └── scripts ├── __pycache__ ├── collect_results.cpython-38.pyc ├── datasets.cpython-37.pyc ├── datasets.cpython-38.pyc ├── main.cpython-37.pyc ├── main.cpython-38.pyc ├── models.cpython-38.pyc ├── plot_results.cpython-38.pyc └── utils.cpython-38.pyc ├── collect_results.py ├── datasets.py ├── main.py ├── models.py ├── plot_results.py ├── sweep.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to InvarianceUnitTests 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to GradientEpisodicMemory, you agree that your contributions 31 | will be licensed under the LICENSE file in the root directory of this source 32 | tree. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear unit-tests for invariance discovery - Code 2 | 3 | Official code for the paper [Linear unit-tests for invariance discovery](https://arxiv.org/abs/2102.10867), presented as a spotlight talk at the [NeurIPS 2020 Workshop Causal Discovery & Causality-Inspired Machine Learning](https://www.cmu.edu/dietrich/causality/neurips20ws/). 4 | 5 | ### Installing requirements 6 | 7 | ```bash 8 | conda create -n invariance python=3.8 9 | conda activate invariance 10 | python3.8 -m pip install -U -r requirements.txt 11 | ``` 12 | 13 | ### Running a single experiment 14 | 15 | ```bash 16 | python3.8 scripts/main.py \ 17 | --model ERM --dataset Example1 --n_envs 3 \ 18 | --num_iterations 10000 --dim_inv 5 --dim_spu 5 \ 19 | --hparams '{"lr":1e-3, "wd":1e-4}' --output_dir results/ 20 | ``` 21 | 22 | ### Running the experiments and printing results 23 | 24 | ```bash 25 | python3.8 scripts/sweep.py --num_iterations 10000 --num_data_seeds 1 --num_model_seed 1 --output_dir results/ 26 | python3.8 scripts/collect_results.py results/COMMIT 27 | ``` 28 | 29 | ### Reproducing the figures 30 | 31 | ```bash 32 | bash reproduce_plots.sh 33 | ``` 34 | 35 | ### Reproducing the results (requires a cluster) 36 | 37 | Be careful, this script launches 630 000 jobs for the hyper-parameter search. 38 | 39 | ```bash 40 | bash reproduce_results.sh test 41 | ``` 42 | 43 | ### Deactivating and removing the env 44 | 45 | ```bash 46 | conda deactivate 47 | conda remove --name invariance --all 48 | ``` 49 | 50 | ## License 51 | 52 | This source code is released under the MIT license, included [here](LICENSE). 53 | 54 | ## Reference 55 | 56 | If you make use of our suite of tasks in your research, please cite the following in your manuscript: 57 | 58 | ``` 59 | @article{aubin2021linear, 60 | title={Linear unit-tests for invariance discovery}, 61 | author={Aubin, Benjamin and S{\l}owik, Agnieszka and Arjovsky, Martin and Bottou, Leon and Lopez-Paz, David}, 62 | journal={arXiv preprint arXiv:2102.10867}, 63 | year={2021} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /data/results_processed.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/data/results_processed.tar.gz -------------------------------------------------------------------------------- /reproduce_plots.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #!/bin/bash 4 | 5 | # Untar results.tar.gz 6 | if [ ! -d "results_processed" ]; then 7 | tar -xvzf data/results_processed.tar.gz 8 | echo "unzip" 9 | fi 10 | 11 | ## Plot figures 1, 2 ## 12 | echo "Plot figures 1 & 2" 13 | python3.8 scripts/plot_results.py -dirname results/ -commit e717c2ff36 --load 14 | -------------------------------------------------------------------------------- /reproduce_results.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #!/bin/bash 4 | 5 | if [ $1 == 'test' ] 6 | then 7 | num_data_seeds=1 8 | num_model_seeds=1 9 | num_iterations=1 10 | dir_results='test_results' 11 | else 12 | num_data_seeds=50 13 | num_model_seeds=20 14 | num_iterations=10000 15 | dir_results='test_results' 16 | fi 17 | 18 | ### Default experiment ### 19 | dim_inv=5 20 | dim_spu=5 21 | n_envs=3 22 | echo "Default experiment: dim_inv=${dim_inv} dim_spu=${dim_spu} n_envs=${n_envs}" 23 | 24 | python3.8 scripts/sweep.py \ 25 | --models ERM IRMv1 ANDMask IGA Oracle \ 26 | --num_iterations $num_iterations \ 27 | --datasets Example1 Example1s Example2 Example2s Example3 Example3s \ 28 | --dim_inv $dim_inv --dim_spu $dim_spu \ 29 | --n_envs $n_envs \ 30 | --num_data_seeds $num_data_seeds --num_model_seeds $num_model_seeds \ 31 | --output_dir ${dir_results}/default/sweep_linear_nenvs=${n_envs}_dinv=${dim_inv}_dspu=${dim_spu} \ 32 | --cluster \ 33 | --jobs_cluster 200 34 | 35 | ### Varying the number of environments ### 36 | dim_inv=5 37 | dim_spu=5 38 | echo "Varying number of environments: n_envs" 39 | 40 | for n_envs in 2 3 4 5 6 7 8 9 10 41 | do 42 | echo "dim_inv=${dim_inv} dim_spu=${dim_spu} n_envs=${n_envs}" 43 | python3.8 scripts/sweep.py \ 44 | --models ERM IRMv1 ANDMask IGA Oracle \ 45 | --num_iterations $num_iterations \ 46 | --datasets Example1 Example1s Example2 Example2s Example3 Example3s \ 47 | --dim_inv $dim_inv --dim_spu $dim_spu \ 48 | --n_envs $n_envs \ 49 | --num_data_seeds $num_data_seeds --num_model_seeds $num_model_seeds \ 50 | --output_dir ${dir_results}/nenvs/sweep_linear_nenvs=${n_envs}_dinv=${dim_inv}_dspu=${dim_spu} \ 51 | --cluster \ 52 | --jobs_cluster 200 53 | done 54 | 55 | ### Varying the spurious dimensions ### 56 | dim_inv=5 57 | n_envs=3 58 | echo "Varying spurious dimensions: dim_spu" 59 | for dim_spu in 0 1 2 3 4 5 6 7 8 9 10 60 | do 61 | echo "dim_inv=${dim_inv} dim_spu=${dim_spu} n_envs=${n_envs}" 62 | python3.8 scripts/sweep.py \ 63 | --models ERM IRMv1 ANDMask IGA Oracle \ 64 | --num_iterations $num_iterations \ 65 | --datasets Example1 Example1s Example2 Example2s Example3 Example3s \ 66 | --dim_inv $dim_inv --dim_spu $dim_spu \ 67 | --n_envs $n_envs \ 68 | --num_data_seeds $num_data_seeds --num_model_seeds $num_model_seeds \ 69 | --output_dir ${dir_results}/dimspu/sweep_linear_nenvs=${n_envs}_dinv=${dim_inv}_dspu=${dim_spu} \ 70 | --cluster \ 71 | --jobs_cluster 200 72 | done 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | gitpython 4 | submitit 5 | pandas 6 | matplotlib 7 | python-git -------------------------------------------------------------------------------- /scripts/__pycache__/collect_results.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/collect_results.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/main.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/main.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/main.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/main.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/plot_results.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/plot_results.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/InvarianceUnitTests/80bcd8b71827eae8e2983d405298cec2157d1676/scripts/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import pandas as pd 4 | import glob 5 | import os 6 | import json 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import plot_results 10 | 11 | def print_row(row, col_width=15, latex=False): 12 | sep = " & " if latex else " " 13 | end_ = "\\\\" if latex else "" 14 | print(sep.join([x.ljust(col_width) for x in row]), end_) 15 | 16 | 17 | def print_table(table, col_width=15, latex=False): 18 | col_names = sorted(table[next(iter(table))].keys()) 19 | 20 | print("\n") 21 | if latex: 22 | print("\\documentclass{article}") 23 | print("\\usepackage{booktabs}") 24 | print("\\usepackage{adjustbox}") 25 | print("\\begin{document}") 26 | print("\\begin{table}") 27 | print("\\begin{center}") 28 | print("\\adjustbox{max width=\\textwidth}{%") 29 | print("\\begin{tabular}{l" + "c" * len(col_names) + "}") 30 | print("\\toprule") 31 | 32 | print_row([""] + col_names, col_width, latex) 33 | 34 | if latex: 35 | print("\\midrule") 36 | 37 | for row_k, row_v in sorted(table.items()): 38 | row_values = [row_k] 39 | for col_k, col_v in sorted(row_v.items()): 40 | row_values.append(col_v) 41 | print_row(row_values, col_width, latex) 42 | 43 | if latex: 44 | print("\\bottomrule") 45 | print("\\end{tabular}}") 46 | print("\\end{center}") 47 | print("\\label{main_results}") 48 | print("\\caption{Main results.}") 49 | print("\\end{table}") 50 | print("\\end{document}") 51 | 52 | 53 | def print_table_hparams(table, col_width=15, latex=False): 54 | print("\n") 55 | for dataset in table.keys(): 56 | print(dataset, "\n") 57 | for model in table[dataset].keys(): 58 | print(model, table[dataset][model]) 59 | print("\n") 60 | 61 | 62 | def build_table(dirname, models=None, n_envs=None, num_dim=None, latex=False, standard_error=False): 63 | records = [] 64 | for fname in glob.glob(os.path.join(dirname, "*.jsonl")): 65 | with open(fname, "r") as f: 66 | if os.path.getsize(fname) != 0: 67 | records.append(f.readline().strip()) 68 | 69 | df = pd.read_json("\n".join(records), lines=True) 70 | if models is not None: 71 | df = df.query(f"model in {models}") 72 | if n_envs is not None: 73 | df = df.query(f"n_envs=={n_envs}") 74 | if num_dim is not None: 75 | df = df.query(f"num_dim=={num_dim}") 76 | 77 | print(f'{len(df)} records.') 78 | pm = "$\\pm$" if latex else "+-" 79 | 80 | table = {} 81 | table_avg = {} 82 | table_val = {} 83 | table_val_avg = { 84 | "data" : {}, 85 | "n_envs": 0, 86 | "dim_inv": 0, 87 | "dim_spu": 0 88 | } 89 | table_hparams = {} 90 | 91 | for dataset in df["dataset"].unique(): 92 | # filtered by dataset 93 | df_d = df[df["dataset"] == dataset] 94 | envs = sorted(list(set( 95 | [c[-1] for c in df_d.filter(regex="error_").columns]))) 96 | if n_envs: 97 | envs = envs[:n_envs] 98 | 99 | table_hparams[dataset] = {} 100 | table_val[dataset] = {} 101 | for key in ["n_envs", "dim_inv", "dim_spu"]: 102 | table_val_avg[key] = int(df[key].iloc[0]) 103 | table_val_avg["data"][dataset] = {} 104 | 105 | for model in df["model"].unique(): 106 | # filtered by model 107 | df_d_m = df_d[df_d["model"] == model] 108 | 109 | best_model_seed = df_d_m.groupby("model_seed").mean().filter( 110 | regex='error_validation').sum(1).idxmin() 111 | 112 | # filtered by hparams 113 | df_d_m_s = df_d_m[df_d_m["model_seed"] == best_model_seed].filter( 114 | regex="error_test") 115 | 116 | # store the best hparams 117 | df_d_m_s_h = df_d_m[df_d_m["model_seed"] == best_model_seed].filter( 118 | regex="hparams") 119 | table_hparams[dataset][model] = json.dumps( 120 | df_d_m_s_h['hparams'].iloc[0]) 121 | 122 | table_val[dataset][model] = {} 123 | for env in range(len(envs)): 124 | errors = df_d_m_s[["error_test_E" + str(env)]] 125 | std = float(errors.std(ddof=0)) 126 | se = std / len(errors) 127 | fmt_str = "{:.2f} {} {:.2f}".format( 128 | float(errors.mean()), pm, std) 129 | if standard_error: 130 | fmt_str += " {} {:.1f}".format( 131 | float('/', se)) 132 | 133 | dataset_env = dataset + ".E" + str(env) 134 | if dataset_env not in table: 135 | table[dataset_env] = {} 136 | 137 | table[dataset_env][model] = fmt_str 138 | table_val[dataset][model][env] = { 139 | "mean": float(errors.mean()), 140 | "std": float(errors.std(ddof=0)) 141 | } 142 | 143 | # Avg 144 | if dataset not in table_avg: 145 | table_avg[dataset] = {} 146 | table_test_errors = df_d_m_s[["error_test_E" + 147 | str(env) for env in range(len(envs))]] 148 | mean = table_test_errors.mean(axis=0).mean(axis=0) 149 | std = table_test_errors.std(axis=0,ddof=0).mean(axis=0) 150 | table_avg[dataset][model] = f"{float(mean):.2f} {pm} {float(std):.2f}" 151 | table_val_avg["data"][dataset][model] = { 152 | "mean": float(mean), 153 | "std":float(std), 154 | "hparams": table_hparams[dataset][model] 155 | } 156 | 157 | return table, table_avg, table_hparams, table_val, table_val_avg, df 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument("dirname") 163 | parser.add_argument("--latex", action="store_true") 164 | parser.add_argument('--models', nargs='+', default=None) 165 | parser.add_argument('--num_dim', type=int, default=None) 166 | parser.add_argument('--n_envs', type=int, default=None) 167 | args = parser.parse_args() 168 | 169 | table, table_avg, table_hparams, table_val, table_val_avg, df = build_table( 170 | args.dirname, args.models, args.n_envs, args.num_dim, args.latex) 171 | 172 | # Print table and averaged table 173 | print_table(table, latex=args.latex) 174 | print_table(table_avg, latex=args.latex) 175 | 176 | # Print best hparams 177 | print_table_hparams(table_hparams) 178 | 179 | # Plot results 180 | commit = args.dirname.split('/')[-2] 181 | plot_results.plot_table( 182 | table=table_val, 183 | dirname=args.dirname, 184 | file_name='results_' + commit) 185 | plot_results.plot_table_avg( 186 | table=table_val_avg, 187 | dirname=args.dirname, 188 | file_name='results_avg_' + commit) 189 | -------------------------------------------------------------------------------- /scripts/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import numpy as np 4 | import torch 5 | import math 6 | 7 | 8 | class Example1: 9 | """ 10 | Cause and effect of a target with heteroskedastic noise 11 | """ 12 | 13 | def __init__(self, dim_inv, dim_spu, n_envs): 14 | self.scramble = torch.eye(dim_inv + dim_spu) 15 | self.dim_inv = dim_inv 16 | self.dim_spu = dim_spu 17 | self.dim = dim_inv + dim_spu 18 | 19 | self.task = "regression" 20 | self.envs = {} 21 | 22 | if n_envs >= 2: 23 | self.envs = {'E0': 0.1, 'E1': 1.5} 24 | if n_envs >= 3: 25 | self.envs["E2"] = 2 26 | if n_envs > 3: 27 | for env in range(3, n_envs): 28 | var = 10 ** torch.zeros(1).uniform_(-2, 1).item() 29 | self.envs["E" + str(env)] = var 30 | 31 | self.wxy = torch.randn(self.dim_inv, self.dim_inv) / self.dim_inv 32 | self.wyz = torch.randn(self.dim_inv, self.dim_spu) / self.dim_spu 33 | 34 | def sample(self, n=1000, env="E0", split="train"): 35 | sdv = self.envs[env] 36 | x = torch.randn(n, self.dim_inv) * sdv 37 | y = x @ self.wxy + torch.randn(n, self.dim_inv) * sdv 38 | z = y @ self.wyz + torch.randn(n, self.dim_spu) 39 | 40 | if split == "test": 41 | z = z[torch.randperm(len(z))] 42 | 43 | inputs = torch.cat((x, z), -1) @ self.scramble 44 | outputs = y.sum(1, keepdim=True) 45 | 46 | return inputs, outputs 47 | 48 | 49 | class Example2: 50 | """ 51 | Cows and camels 52 | """ 53 | 54 | def __init__(self, dim_inv, dim_spu, n_envs): 55 | self.scramble = torch.eye(dim_inv + dim_spu) 56 | self.dim_inv = dim_inv 57 | self.dim_spu = dim_spu 58 | self.dim = dim_inv + dim_spu 59 | 60 | self.task = "classification" 61 | self.envs = {} 62 | 63 | if n_envs >= 2: 64 | self.envs = { 65 | 'E0': {"p": 0.95, "s": 0.3}, 66 | 'E1': {"p": 0.97, "s": 0.5} 67 | } 68 | if n_envs >= 3: 69 | self.envs["E2"] = {"p": 0.99, "s": 0.7} 70 | if n_envs > 3: 71 | for env in range(3, n_envs): 72 | self.envs["E" + str(env)] = { 73 | "p": torch.zeros(1).uniform_(0.9, 1).item(), 74 | "s": torch.zeros(1).uniform_(0.3, 0.7).item() 75 | } 76 | 77 | # foreground is 100x noisier than background 78 | self.snr_fg = 1e-2 79 | self.snr_bg = 1 80 | 81 | # foreground (fg) denotes animal (cow / camel) 82 | cow = torch.ones(1, self.dim_inv) 83 | self.avg_fg = torch.cat((cow, cow, -cow, -cow)) 84 | 85 | # background (bg) denotes context (grass / sand) 86 | grass = torch.ones(1, self.dim_spu) 87 | self.avg_bg = torch.cat((grass, -grass, -grass, grass)) 88 | 89 | def sample(self, n=1000, env="E0", split="train"): 90 | p = self.envs[env]["p"] 91 | s = self.envs[env]["s"] 92 | w = torch.Tensor([p, 1 - p] * 2) * torch.Tensor([s] * 2 + [1 - s] * 2) 93 | i = torch.multinomial(w, n, True) 94 | x = torch.cat(( 95 | (torch.randn(n, self.dim_inv) / 96 | math.sqrt(10) + self.avg_fg[i]) * self.snr_fg, 97 | (torch.randn(n, self.dim_spu) / 98 | math.sqrt(10) + self.avg_bg[i]) * self.snr_bg), -1) 99 | 100 | if split == "test": 101 | x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:] 102 | 103 | inputs = x @ self.scramble 104 | outputs = x[:, :self.dim_inv].sum(1, keepdim=True).gt(0).float() 105 | 106 | return inputs, outputs 107 | 108 | 109 | class Example3: 110 | """ 111 | Small invariant margin versus large spurious margin 112 | """ 113 | 114 | def __init__(self, dim_inv, dim_spu, n_envs): 115 | self.scramble = torch.eye(dim_inv + dim_spu) 116 | self.dim_inv = dim_inv 117 | self.dim_spu = dim_spu 118 | self.dim = dim_inv + dim_spu 119 | 120 | self.task = "classification" 121 | self.envs = {} 122 | 123 | for env in range(n_envs): 124 | self.envs["E" + str(env)] = torch.randn(1, dim_spu) 125 | 126 | def sample(self, n=1000, env="E0", split="train"): 127 | m = n // 2 128 | sep = .1 129 | 130 | invariant_0 = torch.randn(m, self.dim_inv) * .1 + \ 131 | torch.Tensor([[sep] * self.dim_inv]) 132 | invariant_1 = torch.randn(m, self.dim_inv) * .1 - \ 133 | torch.Tensor([[sep] * self.dim_inv]) 134 | 135 | shortcuts_0 = torch.randn(m, self.dim_spu) * .1 + self.envs[env] 136 | shortcuts_1 = torch.randn(m, self.dim_spu) * .1 - self.envs[env] 137 | 138 | x = torch.cat((torch.cat((invariant_0, shortcuts_0), -1), 139 | torch.cat((invariant_1, shortcuts_1), -1))) 140 | 141 | if split == "test": 142 | x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:] 143 | 144 | inputs = x @ self.scramble 145 | outputs = torch.cat((torch.zeros(m, 1), torch.ones(m, 1))) 146 | 147 | return inputs, outputs 148 | 149 | 150 | class Example1s(Example1): 151 | def __init__(self, dim_inv, dim_spu, n_envs): 152 | super().__init__(dim_inv, dim_spu, n_envs) 153 | 154 | self.scramble, _ = torch.qr(torch.randn(self.dim, self.dim)) 155 | 156 | 157 | class Example2s(Example2): 158 | def __init__(self, dim_inv, dim_spu, n_envs): 159 | super().__init__(dim_inv, dim_spu, n_envs) 160 | 161 | self.scramble, _ = torch.qr(torch.randn(self.dim, self.dim)) 162 | 163 | 164 | class Example3s(Example3): 165 | def __init__(self, dim_inv, dim_spu, n_envs): 166 | super().__init__(dim_inv, dim_spu, n_envs) 167 | 168 | self.scramble, _ = torch.qr(torch.randn(self.dim, self.dim)) 169 | 170 | 171 | DATASETS = { 172 | "Example1": Example1, 173 | "Example2": Example2, 174 | "Example3": Example3, 175 | "Example1s": Example1s, 176 | "Example2s": Example2s, 177 | "Example3s": Example3s 178 | } 179 | -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import hashlib 5 | import pprint 6 | import json 7 | import git 8 | import os 9 | import datasets 10 | import models 11 | import utils 12 | 13 | 14 | def run_experiment(args): 15 | # build directory name 16 | commit = git.Repo(search_parent_directories=True).head.object.hexsha[:10] 17 | results_dirname = os.path.join(args["output_dir"], commit + "/") 18 | os.makedirs(results_dirname, exist_ok=True) 19 | 20 | # build file name 21 | md5_fname = hashlib.md5(str(args).encode('utf-8')).hexdigest() 22 | results_fname = os.path.join(results_dirname, md5_fname + ".jsonl") 23 | results_file = open(results_fname, "w") 24 | 25 | utils.set_seed(args["data_seed"]) 26 | dataset = datasets.DATASETS[args["dataset"]]( 27 | dim_inv=args["dim_inv"], 28 | dim_spu=args["dim_spu"], 29 | n_envs=args["n_envs"] 30 | ) 31 | 32 | # Oracle trained on test mode (scrambled) 33 | train_split = "train" if args["model"] != "Oracle" else "test" 34 | 35 | # sample the envs 36 | envs = {} 37 | for key_split, split in zip(("train", "validation", "test"), 38 | (train_split, train_split, "test")): 39 | envs[key_split] = {"keys": [], "envs": []} 40 | for env in dataset.envs: 41 | envs[key_split]["envs"].append(dataset.sample( 42 | n=args["num_samples"], 43 | env=env, 44 | split=split) 45 | ) 46 | envs[key_split]["keys"].append(env) 47 | 48 | # offsetting model seed to avoid overlap with data_seed 49 | utils.set_seed(args["model_seed"] + 1000) 50 | 51 | # selecting model 52 | args["num_dim"] = args["dim_inv"] + args["dim_spu"] 53 | model = models.MODELS[args["model"]]( 54 | in_features=args["num_dim"], 55 | out_features=1, 56 | task=dataset.task, 57 | hparams=args["hparams"] 58 | ) 59 | 60 | # update this field for printing purposes 61 | args["hparams"] = model.hparams 62 | 63 | # fit the dataset 64 | model.fit( 65 | envs=envs, 66 | num_iterations=args["num_iterations"], 67 | callback=args["callback"]) 68 | 69 | # compute the train, validation and test errors 70 | for split in ("train", "validation", "test"): 71 | key = "error_" + split 72 | for k_env, env in zip(envs[split]["keys"], envs[split]["envs"]): 73 | args[key + "_" + 74 | k_env] = utils.compute_error(model, *env) 75 | 76 | # write results 77 | results_file.write(json.dumps(args)) 78 | results_file.close() 79 | return args 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser(description='Synthetic invariances') 84 | parser.add_argument('--model', type=str, default="ERM") 85 | parser.add_argument('--num_iterations', type=int, default=10000) 86 | parser.add_argument('--hparams', type=str, default="default") 87 | parser.add_argument('--dataset', type=str, default="Example1") 88 | parser.add_argument('--dim_inv', type=int, default=5) 89 | parser.add_argument('--dim_spu', type=int, default=5) 90 | parser.add_argument('--n_envs', type=int, default=3) 91 | parser.add_argument('--num_samples', type=int, default=10000) 92 | parser.add_argument('--data_seed', type=int, default=0) 93 | parser.add_argument('--model_seed', type=int, default=0) 94 | parser.add_argument('--output_dir', type=str, default="results") 95 | parser.add_argument('--callback', action='store_true') 96 | args = parser.parse_args() 97 | 98 | pprint.pprint(run_experiment(vars(args))) 99 | -------------------------------------------------------------------------------- /scripts/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import json 5 | import random 6 | import numpy as np 7 | import utils 8 | from torch.autograd import grad 9 | 10 | 11 | class Model(torch.nn.Module): 12 | def __init__(self, in_features, out_features, task, hparams="default"): 13 | super().__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.task = task 17 | 18 | # network architecture 19 | self.network = torch.nn.Linear(in_features, out_features) 20 | 21 | # loss 22 | if self.task == "regression": 23 | self.loss = torch.nn.MSELoss() 24 | else: 25 | self.loss = torch.nn.BCEWithLogitsLoss() 26 | 27 | # hyper-parameters 28 | if hparams == "default": 29 | self.hparams = {k: v[0] for k, v in self.HPARAMS.items()} 30 | elif hparams == "random": 31 | self.hparams = {k: v[1] for k, v in self.HPARAMS.items()} 32 | else: 33 | self.hparams = json.loads(hparams) 34 | 35 | # callbacks 36 | self.callbacks = {} 37 | for key in ["errors"]: 38 | self.callbacks[key] = { 39 | "train": [], 40 | "validation": [], 41 | "test": [] 42 | } 43 | 44 | 45 | class ERM(Model): 46 | def __init__(self, in_features, out_features, task, hparams="default"): 47 | self.HPARAMS = {} 48 | self.HPARAMS["lr"] = (1e-3, 10**random.uniform(-4, -2)) 49 | self.HPARAMS['wd'] = (0., 10**random.uniform(-6, -2)) 50 | 51 | super().__init__(in_features, out_features, task, hparams) 52 | 53 | self.optimizer = torch.optim.Adam( 54 | self.network.parameters(), 55 | lr=self.hparams["lr"], 56 | weight_decay=self.hparams["wd"]) 57 | 58 | def fit(self, envs, num_iterations, callback=False): 59 | x = torch.cat([xe for xe, ye in envs["train"]["envs"]]) 60 | y = torch.cat([ye for xe, ye in envs["train"]["envs"]]) 61 | 62 | for epoch in range(num_iterations): 63 | self.optimizer.zero_grad() 64 | self.loss(self.network(x), y).backward() 65 | self.optimizer.step() 66 | 67 | if callback: 68 | # compute errors 69 | utils.compute_errors(self, envs) 70 | 71 | def predict(self, x): 72 | return self.network(x) 73 | 74 | 75 | class IRM(Model): 76 | """ 77 | Abstract class for IRM 78 | """ 79 | 80 | def __init__( 81 | self, in_features, out_features, task, hparams="default", version=1): 82 | self.HPARAMS = {} 83 | self.HPARAMS["lr"] = (1e-3, 10**random.uniform(-4, -2)) 84 | self.HPARAMS['wd'] = (0., 10**random.uniform(-6, -2)) 85 | self.HPARAMS['irm_lambda'] = (0.9, 1 - 10**random.uniform(-3, -.3)) 86 | 87 | super().__init__(in_features, out_features, task, hparams) 88 | self.version = version 89 | 90 | self.network = self.IRMLayer(self.network) 91 | self.net_parameters, self.net_dummies = self.find_parameters( 92 | self.network) 93 | 94 | self.optimizer = torch.optim.Adam( 95 | self.net_parameters, 96 | lr=self.hparams["lr"], 97 | weight_decay=self.hparams["wd"]) 98 | 99 | def find_parameters(self, network): 100 | """ 101 | Alternative to network.parameters() to separate real parameters 102 | from dummmies. 103 | """ 104 | parameters = [] 105 | dummies = [] 106 | 107 | for name, param in network.named_parameters(): 108 | if "dummy" in name: 109 | dummies.append(param) 110 | else: 111 | parameters.append(param) 112 | return parameters, dummies 113 | 114 | class IRMLayer(torch.nn.Module): 115 | """ 116 | Add a "multiply by one and sum zero" dummy operation to 117 | any layer. Then you can take gradients with respect these 118 | dummies. Often applied to Linear and Conv2d layers. 119 | """ 120 | 121 | def __init__(self, layer): 122 | super().__init__() 123 | self.layer = layer 124 | self.dummy_mul = torch.nn.Parameter(torch.Tensor([1.0])) 125 | self.dummy_sum = torch.nn.Parameter(torch.Tensor([0.0])) 126 | 127 | def forward(self, x): 128 | return self.layer(x) * self.dummy_mul + self.dummy_sum 129 | 130 | def fit(self, envs, num_iterations, callback=False): 131 | for epoch in range(num_iterations): 132 | losses_env = [] 133 | gradients_env = [] 134 | for x, y in envs["train"]["envs"]: 135 | losses_env.append(self.loss(self.network(x), y)) 136 | gradients_env.append(grad( 137 | losses_env[-1], self.net_dummies, create_graph=True)) 138 | 139 | # Average loss across envs 140 | losses_avg = sum(losses_env) / len(losses_env) 141 | gradients_avg = grad( 142 | losses_avg, self.net_dummies, create_graph=True) 143 | 144 | penalty = 0 145 | for gradients_this_env in gradients_env: 146 | for g_env, g_avg in zip(gradients_this_env, gradients_avg): 147 | if self.version == 1: 148 | penalty += g_env.pow(2).sum() 149 | else: 150 | raise NotImplementedError 151 | 152 | obj = (1 - self.hparams["irm_lambda"]) * losses_avg 153 | obj += self.hparams["irm_lambda"] * penalty 154 | 155 | self.optimizer.zero_grad() 156 | obj.backward() 157 | self.optimizer.step() 158 | 159 | if callback: 160 | # compute errors 161 | utils.compute_errors(self, envs) 162 | 163 | def predict(self, x): 164 | return self.network(x) 165 | 166 | 167 | class IRMv1(IRM): 168 | """ 169 | IRMv1 with penalty \sum_e \| \nabla_{w|w=1} \mR_e (\Phi \circ \vec{w}) \|_2^2 170 | From https://arxiv.org/abs/1907.02893v1 171 | """ 172 | 173 | def __init__(self, in_features, out_features, task, hparams="default"): 174 | super().__init__(in_features, out_features, task, hparams, version=1) 175 | 176 | 177 | class AndMask(Model): 178 | """ 179 | AndMask: Masks the grqdients features for which 180 | the gradients signs across envs disagree more than 'tau' 181 | From https://arxiv.org/abs/2009.00329 182 | """ 183 | 184 | def __init__(self, in_features, out_features, task, hparams="default"): 185 | self.HPARAMS = {} 186 | self.HPARAMS["lr"] = (1e-3, 10**random.uniform(-4, 0)) 187 | self.HPARAMS['wd'] = (0., 10**random.uniform(-5, 0)) 188 | self.HPARAMS["tau"] = (0.9, random.uniform(0.8, 1)) 189 | super().__init__(in_features, out_features, task, hparams) 190 | 191 | def fit(self, envs, num_iterations, callback=False): 192 | for epoch in range(num_iterations): 193 | losses = [self.loss(self.network(x), y) 194 | for x, y in envs["train"]["envs"]] 195 | self.mask_step( 196 | losses, list(self.parameters()), 197 | tau=self.hparams["tau"], 198 | wd=self.hparams["wd"], 199 | lr=self.hparams["lr"] 200 | ) 201 | 202 | if callback: 203 | # compute errors 204 | utils.compute_errors(self, envs) 205 | 206 | def predict(self, x): 207 | return self.network(x) 208 | 209 | def mask_step(self, losses, parameters, tau=0.9, wd=0.1, lr=1e-3): 210 | with torch.no_grad(): 211 | gradients = [] 212 | for loss in losses: 213 | gradients.append(list(torch.autograd.grad(loss, parameters))) 214 | gradients[-1][0] = gradients[-1][0] / gradients[-1][0].norm() 215 | 216 | for ge_all, parameter in zip(zip(*gradients), parameters): 217 | # environment-wise gradients (num_environments x num_parameters) 218 | ge_cat = torch.cat(ge_all) 219 | 220 | # treat scalar parameters also as matrices 221 | if ge_cat.dim() == 1: 222 | ge_cat = ge_cat.view(len(losses), -1) 223 | 224 | # creates a mask with zeros on weak features 225 | mask = (torch.abs(torch.sign(ge_cat).sum(0)) 226 | > len(losses) * tau).int() 227 | 228 | # mean gradient (1 x num_parameters) 229 | g_mean = ge_cat.mean(0, keepdim=True) 230 | 231 | # apply the mask 232 | g_masked = mask * g_mean 233 | 234 | # update 235 | parameter.data = parameter.data - lr * g_masked \ 236 | - lr * wd * parameter.data 237 | 238 | 239 | class IGA(Model): 240 | """ 241 | Inter-environmental Gradient Alignment 242 | From https://arxiv.org/abs/2008.01883v2 243 | """ 244 | 245 | def __init__(self, in_features, out_features, task, hparams="default"): 246 | self.HPARAMS = {} 247 | self.HPARAMS["lr"] = (1e-3, 10**random.uniform(-4, -2)) 248 | self.HPARAMS['wd'] = (0., 10**random.uniform(-6, -2)) 249 | self.HPARAMS['penalty'] = (1000, 10**random.uniform(1, 5)) 250 | super().__init__(in_features, out_features, task, hparams) 251 | 252 | self.optimizer = torch.optim.Adam( 253 | self.parameters(), 254 | lr=self.hparams["lr"], 255 | weight_decay=self.hparams["wd"]) 256 | 257 | def fit(self, envs, num_iterations, callback=False): 258 | for epoch in range(num_iterations): 259 | losses = [self.loss(self.network(x), y) 260 | for x, y in envs["train"]["envs"]] 261 | gradients = [ 262 | grad(loss, self.parameters(), create_graph=True) 263 | for loss in losses 264 | ] 265 | # average loss and gradients 266 | avg_loss = sum(losses) / len(losses) 267 | avg_gradient = grad(avg_loss, self.parameters(), create_graph=True) 268 | 269 | # compute trace penalty 270 | penalty_value = 0 271 | for gradient in gradients: 272 | for gradient_i, avg_grad_i in zip(gradient, avg_gradient): 273 | penalty_value += (gradient_i - avg_grad_i).pow(2).sum() 274 | 275 | self.optimizer.zero_grad() 276 | (avg_loss + self.hparams['penalty'] * penalty_value).backward() 277 | self.optimizer.step() 278 | 279 | if callback: 280 | # compute errors 281 | utils.compute_errors(self, envs) 282 | 283 | def predict(self, x): 284 | return self.network(x) 285 | 286 | 287 | MODELS = { 288 | "ERM": ERM, 289 | "IRMv1": IRMv1, 290 | "ANDMask": AndMask, 291 | "IGA": IGA, 292 | "Oracle": ERM 293 | } 294 | -------------------------------------------------------------------------------- /scripts/plot_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import pandas as pd 4 | import glob 5 | import os 6 | import json 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import collect_results 10 | import numpy as np 11 | import torch 12 | 13 | plt.rc('text', usetex=True) 14 | plt.rc('text.latex', preamble=r'\usepackage{times,amsmath}') 15 | plt.rc('font', family='serif') 16 | plt.rc('font', size=14) 17 | 18 | 19 | def plot_table(table, dirname, file_name, save=True, block=False, fontsize=12): 20 | fig, axs = plt.subplots(1, 6, figsize=(13, 2.1)) 21 | axs = axs.flat 22 | width = None 23 | for id_d, dataset in zip(range(len(table.keys())), sorted(table.keys())): 24 | models = table[dataset] 25 | envs = models[list(models.keys())[0]].keys() 26 | if not width: 27 | width = 1 / (len(envs) + 1) 28 | legends = [] 29 | for id_e, env in zip(range(len(envs)), envs): 30 | labels = sorted(models.keys()) 31 | pos = np.arange(len(labels)) 32 | model_means = [models[model][env]['mean'] 33 | for model in sorted(models.keys())] 34 | model_stds = [models[model][env]['std'] 35 | for model in sorted(models.keys())] 36 | l = axs[id_d].bar(pos + id_e * width, model_means, 37 | width=width, color=f'C{id_e}', label=f'E{env}', 38 | align='center', ecolor=f'black', capsize=3, yerr=model_stds, 39 | ) 40 | legends.append(l) 41 | 42 | axs[id_d].set_title(dataset) 43 | axs[id_d].set_xticks(pos + width * (len(envs) / 2 - 0.5)) 44 | axs[id_d].set_xticklabels(labels, fontsize=7) 45 | axs[id_d].set_ylim(bottom=0) 46 | 47 | 48 | axs[0].set_ylabel('Test error') 49 | plt.tight_layout(pad=0) 50 | plt.subplots_adjust(wspace=0.3, hspace=0.3) 51 | plt.legend(handles=legends, 52 | ncol=6, 53 | loc="lower center", 54 | bbox_to_anchor=(-2.8, -0.4)) 55 | 56 | if save: 57 | fig_dirname = "figs/" 58 | os.makedirs(fig_dirname, exist_ok=True) 59 | models = '_'.join(sorted(models.keys())) 60 | plt.savefig(fig_dirname + file_name + '_' + models +'.pdf', 61 | format='pdf', bbox_inches='tight') 62 | 63 | if block: 64 | plt.show(block=False) 65 | input('Press to close') 66 | plt.close('all') 67 | 68 | 69 | def plot_table_avg(table, dirname, file_name, save=True, block=False, fontsize=12): 70 | table = table["data"] 71 | 72 | fig, axs = plt.subplots(1, 6, figsize=(13, 2.1)) 73 | axs = axs.flat 74 | width = 0.5 75 | for id_d, dataset in zip(range(len(table.keys())), sorted(table.keys())): 76 | models = table[dataset] 77 | labels = sorted(models.keys()) 78 | pos = np.arange(len(labels)) 79 | model_means = [models[model]['mean'] 80 | for model in sorted(models.keys())] 81 | model_stds = [models[model]['std'] 82 | for model in sorted(models.keys())] 83 | legends = [] 84 | for id_m in range(len(pos)): 85 | l, = axs[id_d].bar(pos[id_m], model_means[id_m], 86 | width=width, color=f'C{id_m}', 87 | align='center', ecolor='black', 88 | capsize=7, yerr=model_stds[id_m], linewidth=0.1 89 | ) 90 | legends.append(labels[id_m]) 91 | 92 | axs[id_d].set_title(dataset) 93 | axs[id_d].set_xticks(pos) 94 | axs[id_d].set_ylim(bottom=0) 95 | 96 | axs[0].set_ylabel('Test error') 97 | plt.tight_layout(pad=0) 98 | plt.subplots_adjust(wspace=0.3, hspace=0.3) 99 | plt.legend(legends, 100 | ncol=6, 101 | loc="lower center", 102 | bbox_to_anchor=(-2.8, -0.5)) 103 | 104 | if save: 105 | fig_dirname = "figs/" 106 | os.makedirs(fig_dirname, exist_ok=True) 107 | models = '_'.join(sorted(models.keys())) 108 | plt.savefig(fig_dirname + file_name + '_' + models +'.pdf', 109 | format='pdf', bbox_inches='tight') 110 | if block: 111 | plt.show(block=False) 112 | input('Press to close') 113 | plt.close('all') 114 | 115 | 116 | def plot_nenvs_dimspu(df_nenvs, df_dimspu, dirname, file_name='', save=True, block=False): 117 | fig, axs = plt.subplots(2, 6, figsize=(13, 5)) 118 | axs = axs.flat 119 | counter = 0 120 | # Top: nenvs 121 | datasets = df_nenvs["dataset"].unique() 122 | for id_d, dataset in zip(range(len(datasets)), sorted(datasets)): 123 | df_d = df_nenvs[df_nenvs["dataset"] == dataset] 124 | models = df_d["model"].unique() 125 | legends = [] 126 | for id_m, model in zip(range(len(models)), sorted(models)): 127 | df_d_m = df_d[df_d["model"] == model].sort_values(by="n_envs") 128 | legend, = axs[id_d].plot(df_d_m["n_envs"]/5, df_d_m["mean"], 129 | color=f'C{id_m}', 130 | label=model, 131 | linewidth=2) 132 | top = (df_d_m["mean"]+df_d_m["std"]/2).to_numpy() 133 | bottom = (df_d_m["mean"]-df_d_m["std"]/2).to_numpy() 134 | xs = np.arange(2, 11) / 5 135 | axs[id_d].fill_between(xs, bottom, top, facecolor=f'C{id_m}', alpha=0.2) 136 | legends.append(legend) 137 | 138 | axs[id_d].set_xlabel(r'$\delta_{\rm env}$') 139 | axs[id_d].set_title(dataset) 140 | axs[id_d].set_ylim(bottom=-0.005) 141 | axs[id_d].set_xlim(left=0.4, right=2) 142 | counter += 1 143 | 144 | # Bottom: dimspu 145 | datasets = df_dimspu["dataset"].unique() 146 | for id_d, dataset in zip(range(counter, counter+len(datasets)), sorted(datasets)): 147 | df_d = df_dimspu[df_dimspu["dataset"] == dataset] 148 | models = df_d["model"].unique() 149 | legends = [] 150 | for id_m, model in zip(range(len(models)), sorted(models)): 151 | df_d_m = df_d[df_d["model"] == model].sort_values(by="dim_spu") 152 | legend, = axs[id_d].plot(df_d_m["dim_spu"]/5, df_d_m["mean"], 153 | color=f'C{id_m}', 154 | label=model, 155 | linewidth=2) 156 | top = (df_d_m["mean"]+df_d_m["std"]/2).to_numpy() 157 | bottom = (df_d_m["mean"]-df_d_m["std"]/2).to_numpy() 158 | xs = np.arange(0, 11) / 5 159 | axs[id_d].fill_between(xs, bottom, top, facecolor=f'C{id_m}', alpha=0.2) 160 | legends.append(legend) 161 | 162 | axs[id_d].set_xlabel(r'$\delta_{\rm spu}$') 163 | axs[id_d].set_title(dataset) 164 | axs[id_d].set_ylim(bottom=-0.005) 165 | axs[id_d].set_xlim(left=0, right=2) 166 | 167 | 168 | axs[0].set_ylabel("Test error") 169 | axs[6].set_ylabel("Test error") 170 | plt.tight_layout(pad=0) 171 | plt.legend(handles=legends, 172 | ncol=6, 173 | loc="lower center", 174 | bbox_to_anchor=(-2.8, -0.7)) 175 | 176 | if save: 177 | fig_dirname = "figs/" + dirname 178 | os.makedirs(fig_dirname, exist_ok=True) 179 | models = '_'.join(models) 180 | plt.savefig(fig_dirname + file_name + '.pdf', 181 | format='pdf', bbox_inches='tight') 182 | if block: 183 | plt.show(block=False) 184 | input('Press to close') 185 | plt.close('all') 186 | 187 | 188 | def build_df(dirname): 189 | df = pd.DataFrame(columns=['n_envs', 'dim_inv', 'dim_spu', 'dataset', 'model', 'mean', 'std']) 190 | for filename in glob.glob(os.path.join(dirname, "*.jsonl")): 191 | with open(filename) as f: 192 | dic = json.load(f) 193 | n_envs = dic["n_envs"] 194 | dim_inv = dic["dim_inv"] 195 | dim_spu = dic["dim_spu"] 196 | for dataset in dic["data"].keys(): 197 | single_dic = {} 198 | for model in dic["data"][dataset].keys(): 199 | mean = dic["data"][dataset][model]["mean"] 200 | std = dic["data"][dataset][model]["std"] 201 | single_dic = dict( 202 | n_envs=n_envs, 203 | dim_inv=dim_inv, 204 | dim_spu=dim_spu, 205 | dataset=dataset, 206 | model=model, 207 | mean=mean, 208 | std=std 209 | ) 210 | # print(single_dic) 211 | df = df.append(single_dic, ignore_index=True) 212 | 213 | return df 214 | 215 | 216 | def process_results(dirname, commit, save_dirname): 217 | subdirs = [os.path.join(dirname, subdir, commit + '/') for subdir in os.listdir(dirname) if os.path.isdir(os.path.join(dirname, subdir))] 218 | for subdir in subdirs: 219 | print(subdir) 220 | table, table_avg, table_hparams, table_val, table_val_avg, df = collect_results.build_table(subdir) 221 | 222 | # plot table_val 223 | plot_table( 224 | table=table_val, 225 | dirname=subdir, 226 | file_name='_'.join(subdir.split('/')[-3:-1]), 227 | save=True, block=False) 228 | # save table_val 229 | save_dirname_single = save_dirname + "single/" 230 | os.makedirs(save_dirname_single, exist_ok=True) 231 | results_filename = os.path.join(save_dirname_single, 'single_' + '_'.join(subdir.split('/')[-4:-1]) + ".jsonl") 232 | results_file = open(results_filename, "w") 233 | results_file.write(json.dumps(table_val)) 234 | results_file.close() 235 | 236 | # plot table_val_avg 237 | plot_table_avg( 238 | table=table_val_avg, 239 | dirname=subdir, 240 | file_name='avg_' + '_'.join(subdir.split('/')[-3:-1]), 241 | save=True, block=False) 242 | # save table_val_avg 243 | save_dirname_avg = save_dirname + "avg/" 244 | os.makedirs(save_dirname_avg, exist_ok=True) 245 | results_filename = os.path.join(save_dirname_avg, 'avg_' + '_'.join(subdir.split('/')[-4:-1]) + ".jsonl") 246 | results_file = open(results_filename, "w") 247 | results_file.write(json.dumps(table_val_avg)) 248 | results_file.close() 249 | 250 | 251 | if __name__ == "__main__": 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument("-dirname") 254 | parser.add_argument("-commit") 255 | parser.add_argument('--load', action='store_true') 256 | args = parser.parse_args() 257 | 258 | dirname_nenvs = "results_processed/nenvs/" + args.commit + "/" 259 | dirname_dimspu = "results_processed/dimspu/" + args.commit + "/" 260 | 261 | # construct averaged data 262 | if not args.load: 263 | process_results(dirname=args.dirname + "nenvs/", commit=args.commit, save_dirname=dirname_nenvs) 264 | process_results(dirname=args.dirname + "dimspu/", commit=args.commit, save_dirname=dirname_dimspu) 265 | 266 | # plot results for different number of envs 267 | df_nenvs = build_df(dirname_nenvs + "avg/") 268 | df_dimspu = build_df(dirname_dimspu + "avg/") 269 | 270 | plot_nenvs_dimspu( 271 | df_nenvs=df_nenvs, 272 | df_dimspu=df_dimspu, 273 | dirname= args.dirname.split('/')[-1], 274 | file_name= 'results_nenvs_dimspu_' + args.commit, 275 | save=True, block=False) 276 | 277 | dirname = dirname_nenvs + "avg/" 278 | file_name = "avg_nenvs_final_sweep_linear_nenvs=3_dinv=5_dspu=5_e717c2ff36" 279 | results_filename = os.path.join(dirname, file_name + ".jsonl") 280 | table_avg = json.load(open(results_filename, "r")) 281 | plot_table_avg( 282 | table=table_avg, 283 | dirname='', 284 | file_name=file_name, 285 | save=True, block=False) 286 | 287 | dirname = dirname_nenvs + "single/" 288 | file_name = "single_nenvs_final_sweep_linear_nenvs=3_dinv=5_dspu=5_e717c2ff36" 289 | results_filename = os.path.join(dirname, file_name + ".jsonl") 290 | table = json.load(open(results_filename, "r")) 291 | plot_table( 292 | table=table, 293 | dirname='', 294 | file_name=file_name, 295 | save=True, block=False) 296 | 297 | 298 | 299 | -------------------------------------------------------------------------------- /scripts/sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import main 4 | import random 5 | import models 6 | import datasets 7 | import argparse 8 | import getpass 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser(description='Synthetic invariances') 13 | parser.add_argument('--models', nargs='+', default=[]) 14 | parser.add_argument('--num_iterations', type=int, default=10000) 15 | parser.add_argument('--hparams', type=str, default="default") 16 | parser.add_argument('--datasets', nargs='+', default=[]) 17 | parser.add_argument('--dim_inv', type=int, default=5) 18 | parser.add_argument('--dim_spu', type=int, default=5) 19 | parser.add_argument('--n_envs', type=int, default=3) 20 | parser.add_argument('--num_samples', type=int, default=10000) 21 | parser.add_argument('--num_data_seeds', type=int, default=50) 22 | parser.add_argument('--num_model_seeds', type=int, default=20) 23 | parser.add_argument('--output_dir', type=str, default="results") 24 | parser.add_argument('--callback', action='store_true') 25 | parser.add_argument('--cluster', action="store_true") 26 | parser.add_argument('--jobs_cluster', type=int, default=512) 27 | args = vars(parser.parse_args()) 28 | 29 | try: 30 | import submitit 31 | except: 32 | args["cluster"] = False 33 | pass 34 | 35 | all_jobs = [] 36 | if len(args["models"]): 37 | model_lists = args["models"] 38 | else: 39 | model_lists = models.MODELS.keys() 40 | if len(args["datasets"]): 41 | dataset_lists = args["datasets"] 42 | else: 43 | dataset_lists = datasets.DATASETS.keys() 44 | 45 | for model in model_lists: 46 | for dataset in dataset_lists: 47 | for data_seed in range(args["num_data_seeds"]): 48 | for model_seed in range(args["num_model_seeds"]): 49 | train_args = { 50 | "model": model, 51 | "num_iterations": args["num_iterations"], 52 | "hparams": "random" if model_seed else "default", 53 | "dataset": dataset, 54 | "dim_inv": args["dim_inv"], 55 | "dim_spu": args["dim_spu"], 56 | "n_envs": args["n_envs"], 57 | "num_samples": args["num_samples"], 58 | "data_seed": data_seed, 59 | "model_seed": model_seed, 60 | "output_dir": args["output_dir"], 61 | "callback": args["callback"] 62 | } 63 | 64 | all_jobs.append(train_args) 65 | 66 | random.shuffle(all_jobs) 67 | 68 | print("Launching {} jobs...".format(len(all_jobs))) 69 | 70 | if args["cluster"]: 71 | executor = submitit.SlurmExecutor( 72 | folder=f"/checkpoint/{getpass.getuser()}/submitit/") 73 | executor.update_parameters( 74 | time=3*24*60, 75 | gpus_per_node=0, 76 | array_parallelism=args["jobs_cluster"], 77 | cpus_per_task=1, 78 | comment="", 79 | partition="learnfair") 80 | 81 | executor.map_array(main.run_experiment, all_jobs) 82 | else: 83 | for job in all_jobs: 84 | print(main.run_experiment(job)) 85 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | 8 | def set_seed(seed): 9 | torch.manual_seed(seed) 10 | np.random.seed(seed) 11 | random.seed(seed) 12 | 13 | 14 | def compute_error(algorithm, x, y): 15 | with torch.no_grad(): 16 | if len(y.unique()) == 2: 17 | return algorithm.predict(x).gt(0).ne(y).float().mean().item() 18 | else: 19 | return (algorithm.predict(x) - y).pow(2).mean().item() 20 | 21 | 22 | def compute_errors(model, envs): 23 | for split in envs.keys(): 24 | if not bool(model.callbacks["errors"][split]): 25 | model.callbacks["errors"][split] = { 26 | key: [] for key in envs[split]["keys"]} 27 | 28 | for k, env in zip(envs[split]["keys"], envs[split]["envs"]): 29 | model.callbacks["errors"][split][k].append( 30 | compute_error(model, *env)) 31 | --------------------------------------------------------------------------------