├── .gitignore ├── LICENSE ├── README.md ├── figures ├── isr-mean.png └── real_datasets.png ├── linear_unit_tests ├── README.md ├── collect_results.py ├── config.py ├── datasets.py ├── isr.py ├── launch_exp.py ├── main.py ├── models.py ├── plot_results.py ├── sweep.py └── utils.py ├── real_datasets ├── README.md ├── configs │ ├── README.md │ ├── __init__.py │ ├── eval_config.py │ ├── model_config.py │ ├── parse_config.py │ └── train_config.py ├── data │ ├── __init__.py │ ├── celebA_dataset.py │ ├── confounder_dataset.py │ ├── confounder_utils.py │ ├── cub_dataset.py │ ├── dro_dataset.py │ ├── label_shift_utils.py │ ├── multinli_dataset.py │ ├── torchvision_datasets.py │ └── utils.py ├── dataset_metadata │ └── multinli │ │ ├── metadata_preset.csv │ │ └── metadata_random.csv ├── dataset_scripts │ ├── dataset_utils.py │ ├── generate_multinli.py │ └── generate_waterbirds.py ├── eval.py ├── isr.py ├── launch_parse.py ├── launch_train.py ├── parse_features.py ├── run_expt.py ├── train.py ├── utils │ ├── eval_utils.py │ ├── loss_utils.py │ └── train_utils.py └── utils_glue.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .idea/ 3 | .DS_Store/ 4 | logs*/ 5 | old*/ 6 | results*/ 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | figs/ 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Haoxiang Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Invariant-feature Subspace Recovery (ISR) 2 | Code for the paper **[Provable Domain Generalization via Invariant-Feature Subspace Recovery](https://arxiv.org/abs/2201.12919)** (ICML 2022) by Haoxiang Wang, Haozhe Si, Bo Li, and Han Zhao from UIUC. 3 | 4 | If you find this repo useful for your research, please consider citing our paper 5 | 6 | ``` 7 | @inproceedings{ISR, 8 | title = {Provable Domain Generalization via Invariant-Feature Subspace Recovery}, 9 | author = {Wang, Haoxiang and Si, Haozhe and Li, Bo and Zhao, Han}, 10 | booktitle = {International Conference on Machine Learning}, 11 | pages = {23018--23033}, 12 | year = {2022}, 13 | publisher = {PMLR}, 14 | url = {https://proceedings.mlr.press/v162/wang22x.html}, 15 | } 16 | ``` 17 | ![ISR-Mean illustration](figures/isr-mean.png) 18 | 19 | ## Installation 20 | This repo was tested with Ubuntu 20.04, Python 3.8/3.9 ([Anaconda](https://www.anaconda.com/products/individual) version), Pytorch 1.9/1.10/1.11 with CUDA 11. The experiments on the real datasets are tested on a single GPU of 16GB memory, but 11GB may also suffice. 21 | 22 | Our code is devided into two parts, `linear_unit_tests/` and `real_datasets/`. 23 | 24 | + `linear_unit_tests/`: This part is built on the [released code](https://github.com/facebookresearch/InvarianceUnitTests) of [Linear Unit-Tests](https://arxiv.org/abs/2102.10867). Please install necessary packages following their [installation requirements](https://github.com/facebookresearch/InvarianceUnitTests). 25 | + `real_datasets/`: This part is built on the [released code](https://github.com/kohpangwei/group_DRO) of [GroupDRO](https://arxiv.org/abs/1911.08731). Please also follow their [installation requirements](https://github.com/kohpangwei/group_DRO#prerequisites). 26 | 27 | ## Datasets 28 | 29 | The synthetic datasets in `linear_unit_tests/` are generated by the code, and the three real datasets (Waterbirds, CelebA and MultiNLI) used in `real_datasets/` should be downloaded in advance following [this instruction](https://github.com/kohpangwei/group_DRO#datasets-and-code). 30 | 31 | ![Two image datasets and one text dataset.](figures/real_datasets.png "Real Datasets") 32 | 33 | ## Code 34 | 35 | ### `linear_unit_tests/` 36 | 37 | Run `python launch_exp.py` to reproduce experiments in the paper. Use `python plot_results.py` to plot the results. The experiments on run on CPU (parallel on all CPU cores by default). 38 | 39 | ### `real_datasets/` 40 | 41 | Run experiments on the three real-world datasets: 42 | 43 | + `"CUB"`: The [Waterbirds](https://github.com/kohpangwei/group_DRO#waterbirds) dataset (a bird image dataset), formed from [Caltech-UCSD Birds 200](http://www.vision.caltech.edu/visipedia/CUB-200.html) + [Places](http://places2.csail.mit.edu/). 44 | + `"CelebA"`: The [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset (a face image dataset). 45 | + `"MultiNLI"`: The [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset (a text dataset). 46 | 47 | Please see `real_datasets/README.md` for detailed instructions on running experiments. 48 | 49 | Notably, we implement our ISR algorithms in a *sklearn*-style classifier class, which can be easily used as follows 50 | 51 | ```python 52 | from isr import ISRClassifier 53 | classifier = ISRClassifier(version="mean", # "mean": ISR-Mean. "cov": ISR-Cov. 54 | d_spu=1, # the number of spurious features to remove 55 | ) 56 | # xs: training samples 57 | # ys: class labels 58 | # es: environment labels 59 | classifier.fit(xs,ys,es, 60 | chosen_class=0, # need to condition on a class 61 | ) 62 | predictions = classifier.predict(test_xs) # test_xs: test samples 63 | ``` 64 | 65 | 66 | 67 | ## Acknowledgement 68 | In this repo, we adopt some code from the following codebases, and we sincerely thank their authors: 69 | + [facebookresearch/InvarianceUnitTests](https://github.com/facebookresearch/InvarianceUnitTests): Our `/linear_unit_tests` is built upon this repo. 70 | + [kohpangwei/group_DRO](https://github.com/kohpangwei/group_DRO): Our `/real_datasets` is built on this repo. -------------------------------------------------------------------------------- /figures/isr-mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/ISR/728487c5f96c5864acf1b0dc6665954f5aef8997/figures/isr-mean.png -------------------------------------------------------------------------------- /figures/real_datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/ISR/728487c5f96c5864acf1b0dc6665954f5aef8997/figures/real_datasets.png -------------------------------------------------------------------------------- /linear_unit_tests/README.md: -------------------------------------------------------------------------------- 1 | # Experiments on Linear Unit Tests 2 | 3 | This folder is built on the [released code](https://github.com/facebookresearch/InvarianceUnitTests) 4 | of [Linear Unit-Tests](https://arxiv.org/abs/2102.10867), a work of Facebook Research. 5 | 6 | ## Preparation 7 | 8 | Before runing the experiments, please put a folder path (for saving the experiment results) in `config.py` (i.e. define the `RESULT_FOLDER` variable). 9 | 10 | ## Run Experiments 11 | 12 | Run the script 13 | 14 | ```shell 15 | python launch_exp.py 16 | ``` 17 | 18 | which trains and evaluates ISR-Mean/Cov and baseline algorithms on 6 linear benchmarks as reported in our paper. Notice that "Example3_Modified" and "Example3s_Modified" are variants we implemented (reported as Example-3' and Example-3s' in our paper). 19 | 20 | ## Plot Results 21 | 22 | Run the following script to plot the results: 23 | 24 | ```shell 25 | python plot_results.py 26 | ``` -------------------------------------------------------------------------------- /linear_unit_tests/collect_results.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import glob 3 | import os 4 | import json 5 | import argparse 6 | import plot_results 7 | 8 | 9 | def print_row(row, col_width=15, latex=False): 10 | sep = " & " if latex else " " 11 | end_ = "\\\\" if latex else "" 12 | print(sep.join([x.ljust(col_width) for x in row]), end_) 13 | 14 | 15 | def print_table(table, col_width=15, latex=False): 16 | col_names = sorted(table[next(iter(table))].keys()) 17 | 18 | print("\n") 19 | if latex: 20 | print("\\documentclass{article}") 21 | print("\\usepackage{booktabs}") 22 | print("\\usepackage{adjustbox}") 23 | print("\\begin{document}") 24 | print("\\begin{table}") 25 | print("\\begin{center}") 26 | print("\\adjustbox{max width=\\textwidth}{%") 27 | print("\\begin{tabular}{l" + "c" * len(col_names) + "}") 28 | print("\\toprule") 29 | 30 | print_row([""] + col_names, col_width, latex) 31 | 32 | if latex: 33 | print("\\midrule") 34 | 35 | for row_k, row_v in sorted(table.items()): 36 | row_values = [row_k] 37 | for col_k, col_v in sorted(row_v.items()): 38 | row_values.append(col_v) 39 | print_row(row_values, col_width, latex) 40 | 41 | if latex: 42 | print("\\bottomrule") 43 | print("\\end{tabular}}") 44 | print("\\end{center}") 45 | print("\\label{main_results}") 46 | print("\\caption{Main results.}") 47 | print("\\end{table}") 48 | print("\\end{document}") 49 | 50 | 51 | def print_table_hparams(table, col_width=15, latex=False): 52 | print("\n") 53 | for dataset in table.keys(): 54 | print(dataset, "\n") 55 | for model in table[dataset].keys(): 56 | print(model, table[dataset][model]) 57 | print("\n") 58 | 59 | 60 | def build_table(dirname, models=None, n_envs=None, num_dim=None, latex=False, standard_error=False): 61 | records = [] 62 | for fname in glob.glob(os.path.join(dirname, "*.jsonl")): 63 | with open(fname, "r") as f: 64 | if os.path.getsize(fname) != 0: 65 | records.append(f.readline().strip()) 66 | 67 | df = pd.read_json("\n".join(records), lines=True) 68 | if models is not None: 69 | df = df.query(f"model in {models}") 70 | if n_envs is not None: 71 | df = df.query(f"n_envs=={n_envs}") 72 | if num_dim is not None: 73 | df = df.query(f"num_dim=={num_dim}") 74 | 75 | print(f'{len(df)} records.') 76 | pm = "$\\pm$" if latex else "+-" 77 | 78 | table = {} 79 | table_avg = {} 80 | table_val = {} 81 | table_val_avg = { 82 | "data": {}, 83 | "n_envs": 0, 84 | "dim_inv": 0, 85 | "dim_spu": 0 86 | } 87 | table_hparams = {} 88 | 89 | for dataset in df["dataset"].unique(): 90 | # filtered by dataset 91 | df_d = df[df["dataset"] == dataset] 92 | envs = sorted(list(set( 93 | [c[-1] for c in df_d.filter(regex="error_").columns]))) 94 | if n_envs: 95 | envs = envs[:n_envs] 96 | 97 | table_hparams[dataset] = {} 98 | table_val[dataset] = {} 99 | for key in ["n_envs", "dim_inv", "dim_spu"]: 100 | table_val_avg[key] = int(df[key].iloc[0]) 101 | table_val_avg["data"][dataset] = {} 102 | 103 | for model in df["model"].unique(): 104 | # filtered by model 105 | df_d_m = df_d[df_d["model"] == model] 106 | 107 | best_model_seed = df_d_m.groupby("model_seed").mean().filter( 108 | regex='error_validation').sum(1).idxmin() 109 | 110 | # filtered by hparams 111 | df_d_m_s = df_d_m[df_d_m["model_seed"] == best_model_seed].filter( 112 | regex="error_test") 113 | 114 | # store the best hparams 115 | df_d_m_s_h = df_d_m[df_d_m["model_seed"] == best_model_seed].filter( 116 | regex="hparams") 117 | table_hparams[dataset][model] = json.dumps( 118 | df_d_m_s_h['hparams'].iloc[0]) 119 | 120 | table_val[dataset][model] = {} 121 | for env in range(len(envs)): 122 | errors = df_d_m_s[["error_test_E" + str(env)]] 123 | std = float(errors.std(ddof=0)) 124 | se = std / len(errors) 125 | fmt_str = "{:.2f} {} {:.2f}".format( 126 | float(errors.mean()), pm, std) 127 | if standard_error: 128 | fmt_str += " {} {:.1f}".format( 129 | float('/', se)) 130 | 131 | dataset_env = dataset + ".E" + str(env) 132 | if dataset_env not in table: 133 | table[dataset_env] = {} 134 | 135 | table[dataset_env][model] = fmt_str 136 | table_val[dataset][model][env] = { 137 | "mean": float(errors.mean()), 138 | "std": float(errors.std(ddof=0)) 139 | } 140 | 141 | # Avg 142 | if dataset not in table_avg: 143 | table_avg[dataset] = {} 144 | table_test_errors = df_d_m_s[["error_test_E" + 145 | str(env) for env in range(len(envs))]] 146 | mean = table_test_errors.mean(axis=0).mean(axis=0) 147 | std = table_test_errors.std(axis=0, ddof=0).mean(axis=0) 148 | table_avg[dataset][model] = f"{float(mean):.2f} {pm} {float(std):.2f}" 149 | table_val_avg["data"][dataset][model] = { 150 | "mean": float(mean), 151 | "std": float(std), 152 | "hparams": table_hparams[dataset][model] 153 | } 154 | 155 | return table, table_avg, table_hparams, table_val, table_val_avg, df 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("dirname") 161 | parser.add_argument("--latex", action="store_true") 162 | parser.add_argument('--models', nargs='+', default=None) 163 | parser.add_argument('--num_dim', type=int, default=None) 164 | parser.add_argument('--n_envs', type=int, default=None) 165 | args = parser.parse_args() 166 | 167 | table, table_avg, table_hparams, table_val, table_val_avg, df = build_table( 168 | args.dirname, args.models, args.n_envs, args.num_dim, args.latex) 169 | 170 | # Print table and averaged table 171 | print_table(table, latex=args.latex) 172 | print_table(table_avg, latex=args.latex) 173 | 174 | # Print best hparams 175 | print_table_hparams(table_hparams) 176 | 177 | # Plot results 178 | exp_name = args.dirname.split('/')[-2] 179 | plot_results.plot_table( 180 | table=table_val, 181 | dirname=args.dirname, 182 | file_name='results_' + exp_name) 183 | plot_results.plot_table_avg( 184 | table=table_val_avg, 185 | dirname=args.dirname, 186 | file_name='results_avg_' + exp_name) 187 | -------------------------------------------------------------------------------- /linear_unit_tests/config.py: -------------------------------------------------------------------------------- 1 | RESULT_FOLDER = '/data/common/ISR/linear_unit_test/test_results/' 2 | -------------------------------------------------------------------------------- /linear_unit_tests/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | 5 | 6 | class Example1: 7 | """ 8 | Cause and effect of a target with heteroskedastic noise 9 | """ 10 | 11 | def __init__(self, dim_inv, dim_spu, n_envs, inv_std=0.1): 12 | self.scramble = torch.eye(dim_inv + dim_spu) 13 | self.dim_inv = dim_inv 14 | self.dim_spu = dim_spu 15 | self.dim = dim_inv + dim_spu 16 | 17 | self.task = "regression" 18 | self.envs = {} 19 | 20 | if n_envs >= 2: 21 | self.envs = {'E0': 0.1, 'E1': 1.5} 22 | if n_envs >= 3: 23 | self.envs["E2"] = 2 24 | if n_envs > 3: 25 | assert inv_std <= 1, "inv_std must be <= 1" 26 | for env in range(3, n_envs): 27 | var = 10 ** torch.zeros(1).uniform_(math.log10(inv_std * 1.1), 1).item() 28 | self.envs["E" + str(env)] = var 29 | 30 | self.wxy = torch.randn(self.dim_inv, self.dim_inv) / self.dim_inv 31 | self.wyz = torch.randn(self.dim_inv, self.dim_spu) / self.dim_spu 32 | self.inv_std = inv_std 33 | 34 | def sample(self, n=1000, env="E0", split="train"): 35 | sdv = self.envs[env] 36 | x = torch.randn(n, self.dim_inv) * self.inv_std 37 | y = x @ self.wxy + torch.randn(n, self.dim_inv) * sdv 38 | z = y @ self.wyz + torch.randn(n, self.dim_spu) 39 | 40 | if split == "test": 41 | z = z[torch.randperm(len(z))] 42 | 43 | inputs = torch.cat((x, z), -1) @ self.scramble 44 | outputs = y.sum(1, keepdim=True) 45 | 46 | return inputs, outputs 47 | 48 | 49 | class Example2: 50 | """ 51 | Cows and camels 52 | """ 53 | 54 | def __init__(self, dim_inv, dim_spu, n_envs): 55 | self.scramble = torch.eye(dim_inv + dim_spu) 56 | self.dim_inv = dim_inv 57 | self.dim_spu = dim_spu 58 | self.dim = dim_inv + dim_spu 59 | 60 | self.task = "classification" 61 | self.envs = {} 62 | 63 | if n_envs >= 2: 64 | self.envs = { 65 | 'E0': {"p": 0.95, "s": 0.3}, 66 | 'E1': {"p": 0.97, "s": 0.5} 67 | } 68 | if n_envs >= 3: 69 | self.envs["E2"] = {"p": 0.99, "s": 0.7} 70 | if n_envs > 3: 71 | for env in range(3, n_envs): 72 | self.envs["E" + str(env)] = { 73 | "p": torch.zeros(1).uniform_(0.9, 1).item(), 74 | "s": torch.zeros(1).uniform_(0.3, 0.7).item() 75 | } 76 | 77 | # foreground is 100x noisier than background 78 | self.snr_fg = 1e-2 79 | self.snr_bg = 1 80 | 81 | # foreground (fg) denotes animal (cow / camel) 82 | cow = torch.ones(1, self.dim_inv) 83 | self.avg_fg = torch.cat((cow, cow, -cow, -cow)) 84 | 85 | # background (bg) denotes context (grass / sand) 86 | grass = torch.ones(1, self.dim_spu) 87 | self.avg_bg = torch.cat((grass, -grass, -grass, grass)) 88 | 89 | def sample(self, n=1000, env="E0", split="train"): 90 | p = self.envs[env]["p"] 91 | s = self.envs[env]["s"] 92 | w = torch.Tensor([p, 1 - p] * 2) * torch.Tensor([s] * 2 + [1 - s] * 2) 93 | i = torch.multinomial(w, n, True) 94 | x = torch.cat(( 95 | (torch.randn(n, self.dim_inv) / 96 | math.sqrt(10) + self.avg_fg[i]) * self.snr_fg, 97 | (torch.randn(n, self.dim_spu) / 98 | math.sqrt(10) + self.avg_bg[i]) * self.snr_bg), -1) 99 | 100 | if split == "test": 101 | x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:] 102 | 103 | inputs = x @ self.scramble 104 | outputs = x[:, :self.dim_inv].sum(1, keepdim=True).gt(0).float() 105 | 106 | return inputs, outputs 107 | 108 | 109 | class Example3: 110 | """ 111 | Small invariant margin versus large spurious margin 112 | """ 113 | 114 | def __init__(self, dim_inv, dim_spu, n_envs, rand_env_std=False, env_std_prior=0.2, inv_std=0.1): 115 | self.scramble = torch.eye(dim_inv + dim_spu) 116 | self.dim_inv = dim_inv 117 | self.dim_spu = dim_spu 118 | self.dim = dim_inv + dim_spu 119 | 120 | self.task = "classification" 121 | self.envs = {} 122 | 123 | for env in range(n_envs): 124 | self.envs["E" + str(env)] = torch.randn(1, dim_spu) 125 | self.rand_env_std = rand_env_std 126 | self.env_std_prior = env_std_prior 127 | self.inv_std = inv_std 128 | 129 | def sample(self, n=1000, env="E0", split="train", ): 130 | m = n // 2 131 | sep = .1 132 | 133 | invariant_0 = torch.randn(m, self.dim_inv) * self.inv_std + \ 134 | torch.Tensor([[sep] * self.dim_inv]) 135 | invariant_1 = torch.randn(m, self.dim_inv) * self.inv_std - \ 136 | torch.Tensor([[sep] * self.dim_inv]) 137 | 138 | if self.rand_env_std: 139 | std_env = self.env_std_prior * np.random.uniform(.5, 1.5) 140 | else: 141 | std_env = .1 # original 142 | shortcuts_0 = torch.randn(m, self.dim_spu) * std_env + self.envs[env] 143 | shortcuts_1 = torch.randn(m, self.dim_spu) * std_env - self.envs[env] 144 | 145 | x = torch.cat((torch.cat((invariant_0, shortcuts_0), -1), 146 | torch.cat((invariant_1, shortcuts_1), -1))) 147 | 148 | if split == "test": 149 | x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:] 150 | 151 | inputs = x @ self.scramble 152 | outputs = torch.cat((torch.zeros(m, 1), torch.ones(m, 1))) 153 | 154 | return inputs, outputs 155 | 156 | 157 | class Example1s(Example1): 158 | def __init__(self, dim_inv, dim_spu, n_envs, orthonormal=True, **kwargs): 159 | super().__init__(dim_inv, dim_spu, n_envs, **kwargs) 160 | 161 | if orthonormal: 162 | self.scramble, _ = torch.linalg.qr(torch.randn(self.dim, self.dim)) 163 | else: 164 | self.scramble = 10 * torch.randn(self.dim, self.dim) 165 | 166 | 167 | class Example2s(Example2): 168 | def __init__(self, dim_inv, dim_spu, n_envs, orthonormal=True): 169 | super().__init__(dim_inv, dim_spu, n_envs) 170 | if orthonormal: 171 | self.scramble, _ = torch.linalg.qr(torch.randn(self.dim, self.dim)) 172 | else: 173 | self.scramble = 10 * torch.randn(self.dim, self.dim) 174 | 175 | 176 | class Example3s(Example3): 177 | def __init__(self, dim_inv, dim_spu, n_envs, orthonormal=True, **kwargs): 178 | super().__init__(dim_inv, dim_spu, n_envs, **kwargs) 179 | if orthonormal: 180 | self.scramble, _ = torch.linalg.qr(torch.randn(self.dim, self.dim)) 181 | else: 182 | self.scramble = 10 * torch.randn(self.dim, self.dim) 183 | 184 | 185 | DATASETS = { 186 | "Example1": Example1, 187 | "Example2": Example2, 188 | "Example3": Example3, 189 | "Example1s": Example1s, 190 | "Example2s": Example2s, 191 | "Example3s": Example3s 192 | } 193 | -------------------------------------------------------------------------------- /linear_unit_tests/isr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.linear_model import LogisticRegression, Ridge 4 | import scipy 5 | import random 6 | import json 7 | 8 | 9 | class ISR(): 10 | def __init__(self, dim_inv, fit_method='cov', l2_reg=0.01, 11 | verbose=False, regression=False, spu_proj=False, 12 | logistic_regression=False, 13 | hparams=None, num_iterations=1000, 14 | ): 15 | self.dim_inv = dim_inv 16 | self.fit_method = fit_method 17 | self.l2_reg = l2_reg 18 | self.verbose = verbose 19 | self.regression = regression 20 | self.spu_proj = spu_proj 21 | self.logistic_regression = logistic_regression 22 | self.hparams = hparams 23 | self.num_iterations = num_iterations 24 | assert not regression, 'Regression is not supported yet for ISR' 25 | 26 | def extract_class_data(self, envs, label, ): 27 | data = [] 28 | for env_idx, (env_data, env_label) in enumerate(envs): 29 | class_data = env_data[env_label == label] if label >= 0 else env_data 30 | data.append(np.array(class_data)) 31 | 32 | return data 33 | 34 | def cal_Mu(self, envs_data, env_idxes=None): 35 | # calculate the mean of each env 36 | n_env = len(envs_data) 37 | n_dim = envs_data[0].shape[-1] 38 | Mu = np.zeros((n_env, n_dim)) 39 | env_idxes = np.arange(n_env) if env_idxes is None else env_idxes 40 | for env_idx in env_idxes: 41 | env_data = envs_data[env_idx] 42 | Mu[env_idx] = np.mean(env_data, axis=0) 43 | return Mu 44 | 45 | def cal_Cov(self, envs_data, n_selected_envs=None, env_idxes=None): 46 | # calculate the covariance of each env 47 | n_env = len(envs_data) 48 | n_dim = envs_data[0].shape[-1] 49 | if env_idxes is None: 50 | if n_selected_envs is not None: 51 | env_idxes = np.random.choice(n_env, size=n_selected_envs, replace=False) 52 | else: 53 | env_idxes = np.arange(n_env) 54 | 55 | envs_data = [envs_data[i] for i in env_idxes] 56 | n_env = len(envs_data) 57 | covs = np.zeros((n_env, n_dim, n_dim)) 58 | for env_idx, env_data in enumerate(envs_data): 59 | covs[env_idx] = np.cov(env_data.T) 60 | return covs 61 | 62 | def fit(self, envs, fit_method=None, n_env=None, env_idxes=None, 63 | extracted_class=1, fit_clf=True, regression=None, spu_proj=None, 64 | return_proj_mat=False, dim_spu=None): 65 | regression = regression if regression is not None else self.regression 66 | extracted_class = -1 if regression else extracted_class 67 | label_space = [-1] if regression else [0, 1] 68 | spu_proj = self.spu_proj if spu_proj is None else spu_proj 69 | data = self.extract_class_data(envs, label=extracted_class) 70 | n_env = len(envs) if n_env is None else n_env 71 | n_dim = data[0].shape[-1] 72 | self.dim = n_dim 73 | self.dim_spu = n_dim - self.dim_inv if dim_spu is None else dim_spu 74 | 75 | if fit_method is None: 76 | fit_method = self.fit_method 77 | if fit_method == 'mean': 78 | Mu = self.cal_Mu(data, env_idxes=env_idxes) 79 | P = self.fit_mean(Mu) 80 | elif fit_method == 'cov': 81 | Cov = self.cal_Cov(data, env_idxes=env_idxes) 82 | P = self.fit_cov(Cov) 83 | elif fit_method == 'cov-flag': 84 | cov_projs = [] 85 | for i in range(n_env): 86 | for label in label_space: 87 | for j in range(i + 1, n_env): 88 | proj_mat = self.fit(envs, fit_method='cov', n_env=None, env_idxes=[i, j], 89 | extracted_class=label, return_proj_mat=True, fit_clf=False) 90 | cov_projs.append(proj_mat) 91 | concat_projs = np.concatenate([proj for proj in cov_projs], axis=1) 92 | P = np.linalg.svd(concat_projs, full_matrices=True)[0] # compute the flag-mean 93 | 94 | 95 | elif fit_method == 'mean-flag': 96 | # should focus on the spurious dims, and find the rest dimensions at the end 97 | mean_projs = [] 98 | for label in [0, 1]: 99 | proj_mat = self.fit(envs, fit_method='mean', n_env=None, extracted_class=label, return_proj_mat=True, 100 | fit_clf=False, spu_proj=True, dim_spu=min(n_env - 1, n_dim - self.dim_inv)) 101 | mean_projs.append(proj_mat) 102 | # print("proj_mat.shape", proj_mat.shape) 103 | concat_projs = np.concatenate([proj for proj in mean_projs], axis=1) 104 | # print("concat_projs.shape", concat_projs.shape) 105 | P = np.linalg.svd(concat_projs, full_matrices=True)[0] # compute the flag-mean 106 | # print("P.shape", P.shape) 107 | P = P[:, np.flip(np.arange(P.shape[1]))] 108 | # print("P.shape", P.shape) 109 | self.dim_spu = n_dim - self.dim_inv 110 | 111 | else: 112 | raise ValueError(f'fit_method = {fit_method} is not supported') 113 | 114 | self.P = P # projection matrix 115 | if spu_proj: 116 | proj_mat = P[:, -self.dim_spu:] 117 | else: 118 | proj_mat = P[:, :self.dim_inv] 119 | self.proj_mat = proj_mat 120 | 121 | if fit_clf: 122 | self.clf = self.fit_subspace_clf(envs, proj_mat, spu_proj=spu_proj) 123 | if return_proj_mat: 124 | return proj_mat 125 | 126 | def fit_cov(self, covs): 127 | 128 | d = len(covs[0]) 129 | E = len(covs) 130 | 131 | pos_coefs = np.ones(E // 2) 132 | neg_coefs = -1 * np.ones(E - E // 2) 133 | coefs = np.concatenate([pos_coefs, neg_coefs]) 134 | coefs -= np.mean(coefs) 135 | np.random.shuffle(coefs) 136 | 137 | Cov = np.zeros((d, d)) 138 | for i in range(len(coefs)): 139 | Cov += coefs[i] * covs[i] 140 | eigenvals, P = np.linalg.eigh(Cov) 141 | inv_order = np.argsort(np.abs(eigenvals)) 142 | P[:, :] = P[:, inv_order] 143 | k = self.dim_inv 144 | self.Cov = Cov 145 | self.P = P 146 | return P 147 | 148 | def fit_mean(self, Mu): 149 | B, n_env, n_dim = Mu, len(Mu), len(Mu[0]) 150 | E_b = np.mean(B, axis=0) 151 | B_zm = B - E_b 152 | Cov = B_zm.T @ B_zm / (len(B_zm) - 1) 153 | _, P = np.linalg.eigh(Cov) 154 | k = self.dim_inv 155 | self.Mu = Mu 156 | self.P = P 157 | 158 | return P 159 | 160 | def fit_subspace_clf(self, envs, proj_mat, spu_proj=False): 161 | features = [] 162 | labels = [] 163 | for feature, label in envs: 164 | features.append(feature) 165 | labels.append(label) 166 | zs = np.concatenate(features, axis=0) 167 | ys = np.concatenate(labels, axis=0) 168 | if self.logistic_regression: 169 | if self.regression: 170 | clf = Ridge(alpha=self.l2_reg, max_iter=self.num_iterations) 171 | else: 172 | clf = LogisticRegression(C=1 / self.l2_reg, max_iter=self.num_iterations) 173 | else: 174 | task = 'regression' if self.regression else '' 175 | clf = ERM(self.dim_inv, 1, task, regression=self.regression, hparams=self.hparams, 176 | num_iterations=self.num_iterations) 177 | self.hparams = clf.hparams 178 | 179 | if spu_proj: 180 | print(proj_mat.shape) 181 | proj_mat = scipy.linalg.null_space(proj_mat.T) 182 | self.proj_mat = proj_mat 183 | zs_proj = zs @ (proj_mat) 184 | clf.fit(zs_proj, ys) 185 | return clf 186 | 187 | def predict(self, x): 188 | return self.clf.predict(x @ self.proj_mat) 189 | 190 | def score(self, x, y): 191 | return self.clf.score(x @ self.proj_mat, y) 192 | 193 | 194 | class Model(torch.nn.Module): 195 | def __init__(self, in_features, out_features, task, hparams="default", num_iterations=10000): 196 | super().__init__() 197 | self.in_features = in_features 198 | self.out_features = out_features 199 | self.task = task 200 | self.num_iterations = num_iterations 201 | 202 | # network architecture 203 | self.network = torch.nn.Linear(in_features, out_features) 204 | 205 | # loss 206 | if self.task == "regression": 207 | self.loss = torch.nn.MSELoss() 208 | else: 209 | self.loss = torch.nn.BCEWithLogitsLoss() 210 | 211 | # hyper-parameters 212 | if hparams == "default": 213 | self.hparams = {k: v[0] for k, v in self.HPARAMS.items()} 214 | elif hparams == "random": 215 | self.hparams = {k: v[1] for k, v in self.HPARAMS.items()} 216 | else: 217 | self.hparams = json.loads(hparams) 218 | 219 | # callbacks 220 | self.callbacks = {} 221 | for key in ["errors"]: 222 | self.callbacks[key] = { 223 | "train": [], 224 | "validation": [], 225 | "test": [] 226 | } 227 | 228 | 229 | class ERM(Model): 230 | def __init__(self, in_features, out_features, task, hparams="default", regression=False, num_iterations=10000): 231 | self.regression = regression 232 | self.HPARAMS = {} 233 | self.HPARAMS["lr"] = (1e-3, 10 ** random.uniform(-4, -2)) 234 | self.HPARAMS['wd'] = (0., 10 ** random.uniform(-6, -2)) 235 | self.num_iterations = num_iterations 236 | 237 | super().__init__(in_features, out_features, task, hparams) 238 | 239 | self.optimizer = torch.optim.Adam( 240 | self.network.parameters(), 241 | lr=self.hparams["lr"], 242 | weight_decay=self.hparams["wd"]) 243 | 244 | def fit(self, x, y): 245 | x = torch.Tensor(x) 246 | y = torch.Tensor(y) 247 | 248 | for epoch in range(self.num_iterations): 249 | self.optimizer.zero_grad() 250 | loss = self.loss(self.network(x).squeeze(), y) 251 | loss.backward() 252 | self.optimizer.step() 253 | 254 | def predict(self, x): 255 | return self.network(x.float()) 256 | 257 | def score(self, x, y): 258 | return 1 - self.network(x.float()).gt(0).float().squeeze(1).ne(y).float().mean().item() 259 | -------------------------------------------------------------------------------- /linear_unit_tests/launch_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from config import * 5 | 6 | num_data_seeds = 50 # we used 50 in the paper 7 | num_model_seeds = 1 8 | num_iterations = 10000 # we used 10000 in the paper 9 | dim_inv = 5 10 | dim_spu = 5 11 | list_n_envs = [2, 3, 4, 5, 6, 7, 8, 9, 10] 12 | n_threads = 60 # -1 means to use all available CPU cores/threads; otherwise, use the specified number of threads 13 | datasets = ["Example2", "Example2s", "Example3", "Example3s", "Example3_Modified", "Example3s_Modified"] 14 | models = ["ISR_mean", "ISR_cov-flag", "ERM", "IRMv1", "IGA", "Oracle"] 15 | for n_envs in list_n_envs: 16 | print('n_envs: {}'.format(n_envs)) 17 | command = f"python sweep.py \ 18 | --models {' '.join(models)}\ 19 | --num_iterations {num_iterations} \ 20 | --datasets {' '.join(datasets)} \ 21 | --dim_inv {dim_inv} --dim_spu {dim_spu} \ 22 | --n_envs {n_envs} \ 23 | --num_data_seeds {num_data_seeds} --num_model_seeds {num_model_seeds} \ 24 | --output_dir {RESULT_FOLDER}/nenvs/sweep_linear_nenvs={n_envs}_dinv={dim_inv}_dspu={dim_spu} \ 25 | --n_threads {n_threads}" 26 | print('Command:', command) 27 | os.system(command) 28 | -------------------------------------------------------------------------------- /linear_unit_tests/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import hashlib 5 | import pprint 6 | import json 7 | import os 8 | import datasets 9 | import models 10 | import utils 11 | from glob import glob 12 | 13 | def run_experiment(args): 14 | # build directory name 15 | 16 | results_dirname = args['result_dir'] 17 | 18 | # build file name 19 | md5_fname = hashlib.md5(str(args).encode('utf-8')).hexdigest() 20 | results_fname = os.path.join(results_dirname, md5_fname + ".jsonl") 21 | results_file = open(results_fname, "w") 22 | 23 | utils.set_seed(args["data_seed"]) 24 | try: 25 | dataset = datasets.DATASETS[args["dataset"]]( 26 | dim_inv=args["dim_inv"], 27 | dim_spu=args["dim_spu"], 28 | n_envs=args["n_envs"] 29 | ) 30 | except: 31 | dataset_name, _ = args["dataset"].split('_') 32 | dataset = datasets.DATASETS[dataset_name]( 33 | dim_inv=args["dim_inv"], 34 | dim_spu=args["dim_spu"], 35 | n_envs=args["n_envs"], 36 | rand_env_std=True 37 | ) 38 | 39 | # Oracle trained on test mode (scrambled) 40 | train_split = "train" if args["model"] != "Oracle" else "test" 41 | 42 | # sample the envs 43 | if "ISR" in args["model"]: 44 | envs = {} 45 | for key_split, split in zip(("train", "validation", "test"), 46 | (train_split, train_split, "test")): 47 | envs[key_split] = [] 48 | for env in dataset.envs: 49 | data = dataset.sample( 50 | n=args["num_samples"], 51 | env=env, 52 | split=split) 53 | envs[key_split].append((data[0], data[1].flatten())) 54 | else: 55 | envs = {} 56 | for key_split, split in zip(("train", "validation", "test"), 57 | (train_split, train_split, "test")): 58 | envs[key_split] = {"keys": [], "envs": []} 59 | for env in dataset.envs: 60 | envs[key_split]["envs"].append(dataset.sample( 61 | n=args["num_samples"], 62 | env=env, 63 | split=split) 64 | ) 65 | envs[key_split]["keys"].append(env) 66 | 67 | # offsetting model seed to avoid overlap with data_seed 68 | utils.set_seed(args["model_seed"] + 1000) 69 | 70 | if "Example1" in args["dataset"]: 71 | regression = True 72 | else: 73 | regression = False 74 | # selecting model 75 | args["num_dim"] = args["dim_inv"] + args["dim_spu"] 76 | if "ISR" not in args["model"]: 77 | if args["model"] in ['ERM', 'Oracle']: 78 | model = models.MODELS[args["model"]]( 79 | in_features=args["num_dim"], 80 | out_features=1, 81 | task=dataset.task, 82 | hparams=args["hparams"], 83 | regression=regression 84 | ) 85 | else: 86 | model = models.MODELS[args["model"]]( 87 | in_features=args["num_dim"], 88 | out_features=1, 89 | task=dataset.task, 90 | hparams=args["hparams"], 91 | ) 92 | # update this field for printing purposes 93 | args["hparams"] = model.hparams 94 | else: 95 | model_n, model_m = args["model"].split('_') 96 | model = models.MODELS[model_n]( 97 | dim_inv=max(1, args["dim_inv"]), 98 | fit_method=model_m, 99 | regression=regression, 100 | hparams=args["hparams"], 101 | num_iterations=args["num_iterations"] 102 | ) 103 | 104 | # fit the dataset 105 | if "ISR" in args["model"]: 106 | model.fit(envs['train']) 107 | args["hparams"] = model.hparams 108 | else: 109 | model.fit( 110 | envs=envs, 111 | num_iterations=args["num_iterations"], 112 | callback=args["callback"]) 113 | 114 | # compute the train, validation and test errors 115 | for split in ("train", "validation", "test"): 116 | key = "error_" + split 117 | if "ISR" in args["model"]: 118 | for k_env, env in enumerate(envs[split]): 119 | env = env 120 | args[key + "_E" + 121 | str(k_env)] = utils.compute_error(model, *env) 122 | else: 123 | for k_env, env in zip(envs[split]["keys"], envs[split]["envs"]): 124 | env = env 125 | args[key + "_" + 126 | k_env] = utils.compute_error(model, *env) 127 | 128 | # write results 129 | results_file.write(json.dumps(args)) 130 | results_file.close() 131 | return args 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description='Synthetic invariances') 136 | parser.add_argument('--model', type=str, default="ERM") 137 | parser.add_argument('--num_iterations', type=int, default=10000) 138 | parser.add_argument('--hparams', type=str, default="default") 139 | parser.add_argument('--dataset', type=str, default="Example1") 140 | parser.add_argument('--dim_inv', type=int, default=5) 141 | parser.add_argument('--dim_spu', type=int, default=5) 142 | parser.add_argument('--n_envs', type=int, default=3) 143 | parser.add_argument('--num_samples', type=int, default=10000) 144 | parser.add_argument('--data_seed', type=int, default=0) 145 | parser.add_argument('--model_seed', type=int, default=0) 146 | parser.add_argument('--output_dir', type=str, default="results") 147 | parser.add_argument('--callback', action='store_true') 148 | parser.add_argument('--exp_name', type=str, default="default") 149 | parser.add_argument('--result_dir', type=str, default=None) 150 | args = parser.parse_args() 151 | 152 | pprint.pprint(run_experiment(vars(args))) 153 | -------------------------------------------------------------------------------- /linear_unit_tests/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torch.autograd import grad 4 | from isr import ISR 5 | import utils 6 | import torch 7 | import random 8 | import json 9 | 10 | 11 | class Model(torch.nn.Module): 12 | def __init__(self, in_features, out_features, task, hparams="default"): 13 | super().__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.task = task 17 | 18 | # network architecture 19 | self.network = torch.nn.Linear(in_features, out_features) 20 | 21 | # loss 22 | if self.task == "regression": 23 | self.loss = torch.nn.MSELoss() 24 | else: 25 | self.loss = torch.nn.BCEWithLogitsLoss() 26 | 27 | # hyper-parameters 28 | if hparams == "default": 29 | self.hparams = {k: v[0] for k, v in self.HPARAMS.items()} 30 | elif hparams == "random": 31 | self.hparams = {k: v[1] for k, v in self.HPARAMS.items()} 32 | else: 33 | self.hparams = json.loads(hparams) 34 | 35 | # callbacks 36 | self.callbacks = {} 37 | for key in ["errors"]: 38 | self.callbacks[key] = { 39 | "train": [], 40 | "validation": [], 41 | "test": [] 42 | } 43 | 44 | 45 | class ERM(Model): 46 | def __init__(self, in_features, out_features, task, hparams="default", regression=False): 47 | self.regression = regression 48 | self.HPARAMS = {} 49 | self.HPARAMS["lr"] = (1e-3, 10 ** random.uniform(-4, -2)) 50 | self.HPARAMS['wd'] = (0., 10 ** random.uniform(-6, -2)) 51 | 52 | super().__init__(in_features, out_features, task, hparams) 53 | 54 | self.optimizer = torch.optim.Adam( 55 | self.network.parameters(), 56 | lr=self.hparams["lr"], 57 | weight_decay=self.hparams["wd"]) 58 | 59 | def fit(self, envs, num_iterations, callback=False): 60 | x = torch.cat([xe for xe, ye in envs["train"]["envs"]]) 61 | y = torch.cat([ye for xe, ye in envs["train"]["envs"]]) 62 | 63 | for epoch in range(num_iterations): 64 | self.optimizer.zero_grad() 65 | self.loss(self.network(x), y).backward() 66 | self.optimizer.step() 67 | 68 | if callback: 69 | # compute errors 70 | utils.compute_errors(self, envs) 71 | 72 | def predict(self, x): 73 | return self.network(x) 74 | 75 | 76 | class IRM(Model): 77 | """ 78 | Abstract class for IRM 79 | """ 80 | 81 | def __init__( 82 | self, in_features, out_features, task, hparams="default", version=1): 83 | self.HPARAMS = {} 84 | self.HPARAMS["lr"] = (1e-3, 10 ** random.uniform(-4, -2)) 85 | self.HPARAMS['wd'] = (0., 10 ** random.uniform(-6, -2)) 86 | self.HPARAMS['irm_lambda'] = (0.9, 1 - 10 ** random.uniform(-3, -.3)) 87 | 88 | super().__init__(in_features, out_features, task, hparams) 89 | self.version = version 90 | 91 | self.network = self.IRMLayer(self.network) 92 | self.net_parameters, self.net_dummies = self.find_parameters( 93 | self.network) 94 | 95 | self.optimizer = torch.optim.Adam( 96 | self.net_parameters, 97 | lr=self.hparams["lr"], 98 | weight_decay=self.hparams["wd"]) 99 | 100 | def find_parameters(self, network): 101 | """ 102 | Alternative to network.parameters() to separate real parameters 103 | from dummmies. 104 | """ 105 | parameters = [] 106 | dummies = [] 107 | 108 | for name, param in network.named_parameters(): 109 | if "dummy" in name: 110 | dummies.append(param) 111 | else: 112 | parameters.append(param) 113 | return parameters, dummies 114 | 115 | class IRMLayer(torch.nn.Module): 116 | """ 117 | Add a "multiply by one and sum zero" dummy operation to 118 | any layer. Then you can take gradients with respect these 119 | dummies. Often applied to Linear and Conv2d layers. 120 | """ 121 | 122 | def __init__(self, layer): 123 | super().__init__() 124 | self.layer = layer 125 | self.dummy_mul = torch.nn.Parameter(torch.Tensor([1.0])) 126 | self.dummy_sum = torch.nn.Parameter(torch.Tensor([0.0])) 127 | 128 | def forward(self, x): 129 | return self.layer(x) * self.dummy_mul + self.dummy_sum 130 | 131 | def fit(self, envs, num_iterations, callback=False): 132 | for epoch in range(num_iterations): 133 | losses_env = [] 134 | gradients_env = [] 135 | for x, y in envs["train"]["envs"]: 136 | losses_env.append(self.loss(self.network(x), y)) 137 | gradients_env.append(grad( 138 | losses_env[-1], self.net_dummies, create_graph=True)) 139 | 140 | # Average loss across envs 141 | losses_avg = sum(losses_env) / len(losses_env) 142 | gradients_avg = grad( 143 | losses_avg, self.net_dummies, create_graph=True) 144 | 145 | penalty = 0 146 | for gradients_this_env in gradients_env: 147 | for g_env, g_avg in zip(gradients_this_env, gradients_avg): 148 | if self.version == 1: 149 | penalty += g_env.pow(2).sum() 150 | else: 151 | raise NotImplementedError 152 | 153 | obj = (1 - self.hparams["irm_lambda"]) * losses_avg 154 | obj += self.hparams["irm_lambda"] * penalty 155 | 156 | self.optimizer.zero_grad() 157 | obj.backward() 158 | self.optimizer.step() 159 | 160 | if callback: 161 | # compute errors 162 | utils.compute_errors(self, envs) 163 | 164 | def predict(self, x): 165 | return self.network(x) 166 | 167 | 168 | class IRMv1(IRM): 169 | """ 170 | IRMv1 with penalty \sum_e \| \nabla_{w|w=1} \mR_e (\Phi \circ \vec{w}) \|_2^2 171 | From https://arxiv.org/abs/1907.02893v1 172 | """ 173 | 174 | def __init__(self, in_features, out_features, task, hparams="default"): 175 | super().__init__(in_features, out_features, task, hparams, version=1) 176 | 177 | 178 | class AndMask(Model): 179 | """ 180 | AndMask: Masks the grqdients features for which 181 | the gradients signs across envs disagree more than 'tau' 182 | From https://arxiv.org/abs/2009.00329 183 | """ 184 | 185 | def __init__(self, in_features, out_features, task, hparams="default"): 186 | self.HPARAMS = {} 187 | self.HPARAMS["lr"] = (1e-3, 10 ** random.uniform(-4, 0)) 188 | self.HPARAMS['wd'] = (0., 10 ** random.uniform(-5, 0)) 189 | self.HPARAMS["tau"] = (0.9, random.uniform(0.8, 1)) 190 | super().__init__(in_features, out_features, task, hparams) 191 | 192 | def fit(self, envs, num_iterations, callback=False): 193 | for epoch in range(num_iterations): 194 | losses = [self.loss(self.network(x), y) 195 | for x, y in envs["train"]["envs"]] 196 | self.mask_step( 197 | losses, list(self.parameters()), 198 | tau=self.hparams["tau"], 199 | wd=self.hparams["wd"], 200 | lr=self.hparams["lr"] 201 | ) 202 | 203 | if callback: 204 | # compute errors 205 | utils.compute_errors(self, envs) 206 | 207 | def predict(self, x): 208 | return self.network(x) 209 | 210 | def mask_step(self, losses, parameters, tau=0.9, wd=0.1, lr=1e-3): 211 | with torch.no_grad(): 212 | gradients = [] 213 | for loss in losses: 214 | gradients.append(list(torch.autograd.grad(loss, parameters))) 215 | gradients[-1][0] = gradients[-1][0] / gradients[-1][0].norm() 216 | 217 | for ge_all, parameter in zip(zip(*gradients), parameters): 218 | # environment-wise gradients (num_environments x num_parameters) 219 | ge_cat = torch.cat(ge_all) 220 | 221 | # treat scalar parameters also as matrices 222 | if ge_cat.dim() == 1: 223 | ge_cat = ge_cat.view(len(losses), -1) 224 | 225 | # creates a mask with zeros on weak features 226 | mask = (torch.abs(torch.sign(ge_cat).sum(0)) 227 | > len(losses) * tau).int() 228 | 229 | # mean gradient (1 x num_parameters) 230 | g_mean = ge_cat.mean(0, keepdim=True) 231 | 232 | # apply the mask 233 | g_masked = mask * g_mean 234 | 235 | # update 236 | parameter.data = parameter.data - lr * g_masked \ 237 | - lr * wd * parameter.data 238 | 239 | 240 | class IGA(Model): 241 | """ 242 | Inter-environmental Gradient Alignment 243 | From https://arxiv.org/abs/2008.01883v2 244 | """ 245 | 246 | def __init__(self, in_features, out_features, task, hparams="default"): 247 | self.HPARAMS = {} 248 | self.HPARAMS["lr"] = (1e-3, 10 ** random.uniform(-4, -2)) 249 | self.HPARAMS['wd'] = (0., 10 ** random.uniform(-6, -2)) 250 | self.HPARAMS['penalty'] = (1000, 10 ** random.uniform(1, 5)) 251 | super().__init__(in_features, out_features, task, hparams) 252 | 253 | self.optimizer = torch.optim.Adam( 254 | self.parameters(), 255 | lr=self.hparams["lr"], 256 | weight_decay=self.hparams["wd"]) 257 | 258 | def fit(self, envs, num_iterations, callback=False): 259 | for epoch in range(num_iterations): 260 | losses = [self.loss(self.network(x), y) 261 | for x, y in envs["train"]["envs"]] 262 | gradients = [ 263 | grad(loss, self.parameters(), create_graph=True) 264 | for loss in losses 265 | ] 266 | # average loss and gradients 267 | avg_loss = sum(losses) / len(losses) 268 | avg_gradient = grad(avg_loss, self.parameters(), create_graph=True) 269 | 270 | # compute trace penalty 271 | penalty_value = 0 272 | for gradient in gradients: 273 | for gradient_i, avg_grad_i in zip(gradient, avg_gradient): 274 | penalty_value += (gradient_i - avg_grad_i).pow(2).sum() 275 | 276 | self.optimizer.zero_grad() 277 | (avg_loss + self.hparams['penalty'] * penalty_value).backward() 278 | self.optimizer.step() 279 | 280 | if callback: 281 | # compute errors 282 | utils.compute_errors(self, envs) 283 | 284 | def predict(self, x): 285 | return self.network(x) 286 | 287 | 288 | MODELS = { 289 | "ERM": ERM, 290 | "IRMv1": IRMv1, 291 | "ANDMask": AndMask, 292 | "IGA": IGA, 293 | "Oracle": ERM, 294 | "ISR": ISR 295 | } 296 | -------------------------------------------------------------------------------- /linear_unit_tests/plot_results.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from collect_results import * 7 | from config import * 8 | 9 | warnings.filterwarnings("ignore") # ignore warnings for df.append 10 | 11 | 12 | def plot_nenvs(df_data, dirname, file_name='', save=True, block=False): 13 | plt.rc('xtick', labelsize=13) 14 | plt.rc('ytick', labelsize=13) 15 | 16 | fig, axs = plt.subplots(1, 6, figsize=(18, 3.1)) 17 | 18 | axs = axs.flat 19 | # Top: nenvs 20 | datasets = df_data["dataset"].unique() 21 | for id_d, dataset in zip(range(len(datasets)), datasets): 22 | df_d = df_data[df_data["dataset"] == dataset] 23 | models = df_d["model"].unique() 24 | legends = [] 25 | for id_m, model in zip(range(len(models)), models): 26 | if id_m == 0: 27 | marker = 'o' 28 | elif id_m == 1: 29 | marker = 's' 30 | else: 31 | marker = '' 32 | df_d_m = df_d[df_d["model"] == model].sort_values(by="n_envs") 33 | legend, = axs[id_d].plot(df_d_m["n_envs"], df_d_m["mean"], 34 | color=f'C{id_m}', 35 | label=model, marker=marker, 36 | linewidth=2) 37 | top = (df_d_m["mean"] + df_d_m["std"] / 2).to_numpy().astype(float) 38 | bottom = (df_d_m["mean"] - df_d_m["std"] / 2).to_numpy().astype(float) 39 | xs = np.arange(2, 11).astype(float) 40 | axs[id_d].fill_between(xs, bottom, top, facecolor=f'C{id_m}', alpha=0.2) 41 | legends.append(legend) 42 | 43 | axs[id_d].set_xlabel('n_env', fontsize=13, labelpad=-1) 44 | axs[id_d].set_title(dataset, fontsize=13) 45 | axs[id_d].set_ylim(bottom=-0.05, top=0.55) 46 | axs[id_d].set_yticks([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) 47 | axs[id_d].set_xticks([2, 4, 6, 8, 10]) 48 | if id_d != 0: 49 | axs[id_d].set_yticklabels([]) 50 | axs[id_d].set_xlim(left=1.5, right=10.5) 51 | fig.tight_layout() 52 | axs[0].set_ylabel("Mean Error", fontsize=14) 53 | 54 | plt.legend(handles=legends, 55 | ncol=6, 56 | loc="lower center", 57 | bbox_to_anchor=(-2.3, -0.5), prop={'size': 13}) 58 | 59 | if save: 60 | fig_dirname = "figs/" + dirname + '/' 61 | os.makedirs(fig_dirname, exist_ok=True) 62 | models = '_'.join(models) 63 | plt.savefig(fig_dirname + file_name + '.pdf', 64 | format='pdf', bbox_inches='tight') 65 | if block: 66 | plt.show(block=False) 67 | input('Press to close') 68 | plt.close('all') 69 | 70 | 71 | def build_df(dirname): 72 | print(dirname) 73 | df = pd.DataFrame(columns=['n_envs', 'dim_inv', 'dim_spu', 'dataset', 'model', 'mean', 'std']) 74 | for filename in glob.glob(os.path.join(dirname, "*.jsonl")): 75 | with open(filename) as f: 76 | dic = json.load(f) 77 | n_envs = dic["n_envs"] 78 | dim_inv = dic["dim_inv"] 79 | dim_spu = dic["dim_spu"] 80 | for dataset in dic["data"].keys(): 81 | single_dic = {} 82 | for model in dic["data"][dataset].keys(): 83 | mean = dic["data"][dataset][model]["mean"] 84 | std = dic["data"][dataset][model]["std"] 85 | single_dic = dict( 86 | n_envs=n_envs, 87 | dim_inv=dim_inv, 88 | dim_spu=dim_spu, 89 | dataset=dataset, 90 | model=model, 91 | mean=mean, 92 | std=std 93 | ) 94 | # print(single_dic) 95 | df = df.append(single_dic, ignore_index=True) 96 | return df 97 | 98 | 99 | def process_results(dirname, exp_name, save_dirname): 100 | subdirs = [os.path.join(dirname, subdir, exp_name + '/') for subdir in os.listdir(dirname) if 101 | os.path.isdir(os.path.join(dirname, subdir))] 102 | for subdir in subdirs: 103 | print(subdir) 104 | table, table_avg, table_hparams, table_val, table_val_avg, df = build_table(subdir) 105 | 106 | # save table_val_avg 107 | save_dirname_avg = save_dirname + "avg/" 108 | os.makedirs(save_dirname_avg, exist_ok=True) 109 | results_filename = os.path.join(save_dirname_avg, 'avg_' + '_'.join(subdir.split('/')[-4:-1]) + ".jsonl") 110 | results_file = open(results_filename, "w") 111 | results_file.write(json.dumps(table_val_avg)) 112 | results_file.close() 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--dirname", type=str, default=RESULT_FOLDER) 118 | parser.add_argument("--exp_name", type=str, default="default") 119 | parser.add_argument('--load', action='store_true') 120 | args = parser.parse_args() 121 | 122 | dirname_nenvs = "results_processed/nenvs/" + args.exp_name + "/" 123 | 124 | # construct averaged data 125 | if not args.load: 126 | process_results(dirname=args.dirname + "nenvs/", exp_name=args.exp_name, save_dirname=dirname_nenvs) 127 | 128 | # plot results for different number of envs 129 | df_nenvs = build_df(dirname_nenvs + "avg/") 130 | 131 | plot_nenvs(df_nenvs, dirname=args.dirname.split('/')[-1], 132 | file_name='results_nenvs_dimspu_' + args.exp_name, 133 | save=True, block=False) 134 | -------------------------------------------------------------------------------- /linear_unit_tests/sweep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from glob import glob 5 | 6 | from joblib import Parallel, delayed 7 | from tqdm.auto import tqdm 8 | 9 | import datasets 10 | import main 11 | import models 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description='Synthetic invariances') 15 | parser.add_argument('--models', nargs='+', default=[]) 16 | parser.add_argument('--num_iterations', type=int, default=10000) 17 | parser.add_argument('--hparams', type=str, default="default") 18 | parser.add_argument('--datasets', nargs='+', default=[]) 19 | parser.add_argument('--dim_inv', type=int, default=5) 20 | parser.add_argument('--dim_spu', type=int, default=5) 21 | parser.add_argument('--n_envs', type=int, default=3) 22 | parser.add_argument('--num_samples', type=int, default=10000) 23 | parser.add_argument('--num_data_seeds', type=int, default=50) 24 | parser.add_argument('--num_model_seeds', type=int, default=20) 25 | parser.add_argument('--output_dir', type=str, default="results") 26 | parser.add_argument('--callback', action='store_true') 27 | parser.add_argument('--n_threads', type=int, default=-1) 28 | parser.add_argument('--exp_name', type=str, default="default") 29 | args = vars(parser.parse_args()) 30 | 31 | all_jobs = [] 32 | if len(args["models"]) > 0: 33 | model_lists = args["models"] 34 | else: 35 | model_lists = models.MODELS.keys() 36 | if len(args["datasets"]): 37 | dataset_lists = args["datasets"] 38 | else: 39 | dataset_lists = datasets.DATASETS.keys() 40 | 41 | results_dirname = os.path.join(args["output_dir"], args["exp_name"] + "/") 42 | os.makedirs(results_dirname, exist_ok=True) 43 | for f in glob(f'{results_dirname}/*'): 44 | # remove all previous experiment results 45 | os.remove(f) 46 | 47 | for model in model_lists: 48 | for dataset in dataset_lists: 49 | for data_seed in range(args["num_data_seeds"]): 50 | for model_seed in range(args["num_model_seeds"]): 51 | train_args = { 52 | "model": model, 53 | "num_iterations": args["num_iterations"], 54 | "hparams": "random" if model_seed else "default", 55 | "dataset": dataset, 56 | "dim_inv": args["dim_inv"], 57 | "dim_spu": args["dim_spu"], 58 | "n_envs": args["n_envs"], 59 | "num_samples": args["num_samples"], 60 | "data_seed": data_seed, 61 | "model_seed": model_seed, 62 | "output_dir": args["output_dir"], 63 | "callback": args["callback"], 64 | "exp_name": args["exp_name"], 65 | "result_dir": results_dirname, 66 | } 67 | 68 | all_jobs.append(train_args) 69 | 70 | random.shuffle(all_jobs) 71 | 72 | print("Launching {} jobs...".format(len(all_jobs))) 73 | 74 | iterator = tqdm(all_jobs, desc="Jobs") 75 | if args["n_threads"] == 1: 76 | for job in iterator: 77 | main.run_experiment(job) 78 | else: 79 | Parallel(n_jobs=args["n_threads"])(delayed(main.run_experiment)(job) for job in iterator) 80 | -------------------------------------------------------------------------------- /linear_unit_tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def set_seed(seed): 10 | torch.manual_seed(seed) 11 | np.random.seed(seed) 12 | random.seed(seed) 13 | 14 | 15 | def compute_error(algorithm, x, y): 16 | if hasattr(algorithm, 'score'): 17 | if len(y.unique()) == 2: 18 | return 1 - algorithm.score(x, y) 19 | else: 20 | return np.mean((algorithm.predict(x) - y.squeeze().numpy()) ** 2) 21 | with torch.no_grad(): 22 | if len(y.unique()) == 2: 23 | return algorithm.predict(x).gt(0).ne(y).float().mean().item() 24 | else: 25 | return (algorithm.predict(x) - y).pow(2).mean().item() 26 | 27 | 28 | def compute_errors(model, envs): 29 | for split in envs.keys(): 30 | if not bool(model.callbacks["errors"][split]): 31 | model.callbacks["errors"][split] = { 32 | key: [] for key in envs[split]["keys"]} 33 | 34 | for k, env in zip(envs[split]["keys"], envs[split]["envs"]): 35 | model.callbacks["errors"][split][k].append( 36 | compute_error(model, *env)) 37 | -------------------------------------------------------------------------------- /real_datasets/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Datasets 4 | 5 | In this code repo, we use the following keywords to represent three real-world datasets: 6 | 7 | + `"CUB"`: The [Waterbirds](https://github.com/kohpangwei/group_DRO#waterbirds) dataset (a bird image dataset), formed from [Caltech-UCSD Birds 200](http://www.vision.caltech.edu/visipedia/CUB-200.html) + [Places](http://places2.csail.mit.edu/). 8 | + `"CelebA"`: The [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset (a face image dataset). 9 | + `"MultiNLI"`: The [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset (a text dataset). 10 | 11 | ![Two image datasets and one text dataset.](../figures/real_datasets.png "Real Datasets") 12 | 13 | ## Preparation 14 | 15 | + Please download the datasets in to a folder and define the folder path as `DATA_FOLDER` in `configs/__init__.py` . 16 | - We expect the folder structure to be like: 17 | 18 | ``` 19 | DATA_FOLDER 20 | └───waterbirds 21 | └───celebA_v1.0 22 | └───multinli 23 | ``` 24 | 25 | + Please define `LOG_FOLDER` in ``configs/__init__.py`, which is the folder to save all training/evaluation results. 26 | 27 | ## Procedures 28 | 29 | There are three procedures to run ISR-Mean/Cov on the real datasets. 30 | 31 | #### 1. Train a neural network with ERM/Reweight/GroupDRO. 32 | 33 | Can launch the training experiments with `python launch_train.py`. 34 | 35 | #### 2. Parse learned features over training/validation/test data. 36 | 37 | Parse learned features with `python launch_parse.py`, which save the parsed features locally. 38 | 39 | #### 3. Post-process learned features & classifier with ISR-Mean/Cov 40 | 41 | Run `eval.py` to evaluate with the vanilla trained classifier or the ISR classifiers. The evaluation results are saved in CSV files for analyses. [To-Do] You can also run `launch_eval.py` (automatially load the hyperparameters we used) to reproduce our experiment results reported in the paper. 42 | 43 | ## To-Do 44 | - [ ] Implement `launch_eval.py` and fill in optimal hyperparameters (To reproduce Table 1 of our [paper](https://arxiv.org/pdf/2201.12919.pdf)) 45 | - [ ] Partial environment labels (To reproduce Figure 6 of our [paper](https://arxiv.org/pdf/2201.12919.pdf)) 46 | - [ ] ISR classifiers on CLIP backbones (To reproduce Table 2 of our [paper](https://arxiv.org/pdf/2201.12919.pdf)) 47 | - [ ] Provide trained models in a Dropbox folder -------------------------------------------------------------------------------- /real_datasets/configs/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/ISR/728487c5f96c5864acf1b0dc6665954f5aef8997/real_datasets/configs/README.md -------------------------------------------------------------------------------- /real_datasets/configs/__init__.py: -------------------------------------------------------------------------------- 1 | DATA_FOLDER = '/data/common/ISR/datasets/' 2 | LOG_FOLDER = '/data/common/ISR/logs/' 3 | 4 | from .model_config import model_attributes 5 | from .train_config import get_train_command 6 | from .parse_config import get_parse_command 7 | -------------------------------------------------------------------------------- /real_datasets/configs/eval_config.py: -------------------------------------------------------------------------------- 1 | #TODO: to fill in the configs 2 | ISR_configs = dict( 3 | CUB={}, 4 | MultiNLI={}, 5 | CelebA={}, 6 | ) 7 | 8 | ISR_CLIP_configs = dict( 9 | CUB={} 10 | ) 11 | 12 | def get_ISR_config(dataset:str,algo:str,ISR_version:str,use_CLIP=False): 13 | #TODO: to implement this function 14 | return -------------------------------------------------------------------------------- /real_datasets/configs/model_config.py: -------------------------------------------------------------------------------- 1 | model_attributes = { 2 | 'bert': { 3 | 'feature_type': 'text' 4 | }, 5 | 'inception_v3': { 6 | 'feature_type': 'image', 7 | 'target_resolution': (299, 299), 8 | 'flatten': False 9 | }, 10 | 'wideresnet50': { 11 | 'feature_type': 'image', 12 | 'target_resolution': (224, 224), 13 | 'flatten': False 14 | }, 15 | 'resnet50': { 16 | 'feature_type': 'image', 17 | 'target_resolution': (224, 224), 18 | 'flatten': False 19 | }, 20 | 'resnet34': { 21 | 'feature_type': 'image', 22 | 'target_resolution': None, 23 | 'flatten': False 24 | }, 25 | 'raw_logistic_regression': { 26 | 'feature_type': 'image', 27 | 'target_resolution': None, 28 | 'flatten': True, 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /real_datasets/configs/parse_config.py: -------------------------------------------------------------------------------- 1 | PARSE_COMMANDS = dict( 2 | MultiNLI=['-s', 'confounder', '-d', 'MultiNLI', '-t', 'gold_label_random', 3 | '-c', 'sentence2_has_negation', '--batch_size', '32', '--model', 'bert', 4 | '--n_epochs', '3', ], 5 | CelebA=['-d', 'CelebA', '-t', 'Blond_Hair', '-c', 'Male', '--model', 'resnet50', 6 | '--weight_decay', '0.01', '--lr', '0.0001', 7 | "--batch_size", '128', '--n_epochs', '50'], 8 | CUB=['-d', 'CUB', '-t', 'waterbird_complete95', '-c', 'forest2water2', 9 | '--model', 'resnet50', '--weight_decay', '0.1', '--lr', '0.0001', 10 | '--batch_size', '128', '--n_epochs', '300'] 11 | ) 12 | 13 | 14 | def get_parse_command(dataset, algos, train_log_seeds, model_selects, 15 | log_dir: str, gpu_idx=None, parse_script='parse_features.py', 16 | ): 17 | prefix = f'CUDA_VISIBLE_DEVICES={gpu_idx} ' if gpu_idx is not None else '' 18 | main_args = ' '.join(PARSE_COMMANDS[dataset]) 19 | parse_algos = ' '.join(algos) 20 | parse_seeds = ' '.join(map(str, train_log_seeds)) 21 | parse_model_selects = ' '.join(model_selects) 22 | parse_args = f" --parse_dir {log_dir} --parse_algos {parse_algos} --parse_seeds {parse_seeds} --parse_model_selects {parse_model_selects}" 23 | command = f'{prefix} python {parse_script} {main_args} {parse_args}' 24 | return command 25 | -------------------------------------------------------------------------------- /real_datasets/configs/train_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TRAIN_COMMANDS = dict( 5 | CelebA= 6 | { 7 | "ERM": "-s confounder -d CelebA -t Blond_Hair -c Male --model resnet50 --weight_decay 0.01 --lr 0.0001 " 8 | "--batch_size 128 --n_epochs 50", # ERM 9 | "groupDRO": "-s confounder -d CelebA -t Blond_Hair -c Male --model resnet50 --weight_decay 0.1 --lr 1e-05 " 10 | "--batch_size 128 --n_epochs 50 --reweight_groups --robust --alpha 0.01 --gamma 0.1 --generalization_adjustment 3", 11 | # groupDRO 12 | "reweight": "-s confounder -d CelebA -t Blond_Hair -c Male --model resnet50 --weight_decay 0.1 --lr 1e-05 " 13 | "--batch_size 128 --n_epochs 50 --reweight_groups" # reweight, 14 | } 15 | , 16 | CUB={ 17 | "ERM": '-s confounder -d CUB -t waterbird_complete95 -c forest2water2 --model resnet50 --weight_decay 0.1 --lr 0.0001 ' 18 | '--batch_size 128 --n_epochs 300', # ERM 19 | "reweight": '-s confounder -d CUB -t waterbird_complete95 -c forest2water2 --model resnet50 --weight_decay 1 --lr 1e-05 ' 20 | '--batch_size 128 --n_epochs 300 --reweight_groups', # reweight 21 | "groupDRO": ' -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --model resnet50 --weight_decay 1 --lr 1e-05 ' 22 | '--batch_size 128 --n_epochs 300 --reweight_groups --robust --alpha 0.01 --gamma 0.1 --generalization_adjustment 2', 23 | # groupDRO 24 | }, 25 | MultiNLI={ 26 | "ERM": '-s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --model bert --weight_decay 0 --lr ' 27 | f'2e-05 --batch_size 32 --n_epochs 3', # ERM 28 | "groupDRO": '-s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --model bert --weight_decay 0 --lr ' 29 | f'2e-05 --batch_size 32 --n_epochs 3 --reweight_groups --robust --alpha 0.01 --gamma 0.1 ' 30 | '--generalization_adjustment 0', # groupDRO 31 | "reweight": '-s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --model bert --weight_decay 0 --lr ' 32 | f'2e-05 --batch_size 32 --n_epochs 3 --reweight_groups', # reweight 33 | } 34 | ) 35 | 36 | 37 | def get_train_command(dataset: str, algo: str, gpu_idx: int = None, train_script: str = 'run_expt.py', 38 | algo_suffix: str = '',seed:int=None,save_best:bool=True,save_last:bool=True): 39 | prefix = f'CUDA_VISIBLE_DEVICES={gpu_idx}' if gpu_idx is not None else '' 40 | suffix = f' --algo_suffix {algo_suffix}' if algo_suffix else '' 41 | if save_best: 42 | suffix += ' --save_best' 43 | if save_last: 44 | suffix += ' --save_last' 45 | seed = 0 if seed is None else seed 46 | args_command = TRAIN_COMMANDS[dataset][algo] 47 | command = f"{prefix} python {train_script} {args_command} --seed {seed} {suffix}" 48 | return command -------------------------------------------------------------------------------- /real_datasets/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .confounder_utils import prepare_confounder_data 3 | from .label_shift_utils import prepare_label_shift_data 4 | from configs import DATA_FOLDER 5 | root_dir = DATA_FOLDER 6 | 7 | dataset_attributes = { 8 | 'CelebA': { 9 | 'root_dir': 'celebA_v1.0' 10 | }, 11 | 'CUB': { 12 | 'root_dir': 'waterbirds' 13 | }, 14 | 'CIFAR10': { 15 | 'root_dir': 'CIFAR10/data' 16 | }, 17 | 'MultiNLI': { 18 | 'root_dir': 'multinli' 19 | } 20 | } 21 | 22 | for dataset in dataset_attributes: 23 | dataset_attributes[dataset]['root_dir'] = os.path.join(root_dir, dataset_attributes[dataset]['root_dir']) 24 | 25 | shift_types = ['confounder', 'label_shift_step'] 26 | 27 | 28 | def prepare_data(args, train, return_full_dataset=False, train_transform=None, eval_transform=None): 29 | # Set root_dir to defaults if necessary 30 | if args.root_dir is None: 31 | args.root_dir = dataset_attributes[args.dataset]['root_dir'] 32 | if args.shift_type == 'confounder': 33 | return prepare_confounder_data(args, train, return_full_dataset, train_transform=train_transform, 34 | eval_transform=eval_transform) 35 | elif args.shift_type.startswith('label_shift'): 36 | assert not return_full_dataset 37 | return prepare_label_shift_data(args, train) 38 | 39 | 40 | def log_data(data, logger): 41 | logger.write('Training Data...\n') 42 | for group_idx in range(data['train_data'].n_groups): 43 | logger.write( 44 | f' {data["train_data"].group_str(group_idx)}: n = {data["train_data"].group_counts()[group_idx]:.0f}\n') 45 | logger.write('Validation Data...\n') 46 | for group_idx in range(data['val_data'].n_groups): 47 | logger.write( 48 | f' {data["val_data"].group_str(group_idx)}: n = {data["val_data"].group_counts()[group_idx]:.0f}\n') 49 | if data['test_data'] is not None: 50 | logger.write('Test Data...\n') 51 | for group_idx in range(data['test_data'].n_groups): 52 | logger.write( 53 | f' {data["test_data"].group_str(group_idx)}: n = {data["test_data"].group_counts()[group_idx]:.0f}\n') 54 | -------------------------------------------------------------------------------- /real_datasets/data/celebA_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torchvision.transforms as transforms 7 | 8 | from data.confounder_dataset import ConfounderDataset 9 | from configs import model_attributes 10 | 11 | 12 | class CelebADataset(ConfounderDataset): 13 | """ 14 | CelebA dataset (already cropped and centered). 15 | Note: idx and filenames are off by one. 16 | """ 17 | 18 | def __init__(self, root_dir, target_name, confounder_names, 19 | model_type, augment_data, train_transform=None, eval_transform=None, ): 20 | self.root_dir = root_dir 21 | self.target_name = target_name 22 | self.confounder_names = confounder_names 23 | self.augment_data = augment_data 24 | self.model_type = model_type 25 | 26 | # Read in attributes 27 | self.attrs_df = pd.read_csv( 28 | os.path.join(root_dir, 'data', 'list_attr_celeba.csv')) 29 | 30 | # Split out filenames and attribute names 31 | self.data_dir = os.path.join(self.root_dir, 'data', 'img_align_celeba') 32 | self.filename_array = self.attrs_df['image_id'].values 33 | self.attrs_df = self.attrs_df.drop(labels='image_id', axis='columns') 34 | self.attr_names = self.attrs_df.columns.copy() 35 | 36 | # Then cast attributes to numpy array and set them to 0 and 1 37 | # (originally, they're -1 and 1) 38 | self.attrs_df = self.attrs_df.values 39 | self.attrs_df[self.attrs_df == -1] = 0 40 | 41 | # Get the y values 42 | target_idx = self.attr_idx(self.target_name) 43 | self.y_array = self.attrs_df[:, target_idx] 44 | self.n_classes = 2 45 | 46 | # Map the confounder attributes to a number 0,...,2^|confounder_idx|-1 47 | self.confounder_idx = [self.attr_idx(a) for a in self.confounder_names] 48 | self.n_confounders = len(self.confounder_idx) 49 | confounders = self.attrs_df[:, self.confounder_idx] 50 | confounder_id = confounders @ np.power(2, np.arange(len(self.confounder_idx))) 51 | self.confounder_array = confounder_id 52 | 53 | # Map to groups 54 | self.n_groups = self.n_classes * pow(2, len(self.confounder_idx)) 55 | self.group_array = (self.y_array * (self.n_groups / 2) + self.confounder_array).astype('int') 56 | 57 | # Read in train/val/test splits 58 | self.split_df = pd.read_csv( 59 | os.path.join(root_dir, 'data', 'list_eval_partition.csv')) 60 | self.split_array = self.split_df['partition'].values 61 | self.split_dict = { 62 | 'train': 0, 63 | 'val': 1, 64 | 'test': 2 65 | } 66 | 67 | if model_attributes[self.model_type]['feature_type'] == 'precomputed': 68 | self.features_mat = torch.from_numpy(np.load( 69 | os.path.join(root_dir, 'features', model_attributes[self.model_type]['feature_filename']))).float() 70 | self.train_transform = train_transform 71 | self.eval_transform = eval_transform 72 | else: 73 | self.features_mat = None 74 | self.train_transform = get_transform_celebA(self.model_type, train=True, augment_data=augment_data) \ 75 | if train_transform is None else train_transform 76 | self.eval_transform = get_transform_celebA(self.model_type, train=False, augment_data=augment_data) \ 77 | if eval_transform is None else eval_transform 78 | 79 | def attr_idx(self, attr_name): 80 | return self.attr_names.get_loc(attr_name) 81 | 82 | 83 | def get_transform_celebA(model_type, train, augment_data): 84 | orig_w = 178 85 | orig_h = 218 86 | orig_min_dim = min(orig_w, orig_h) 87 | if model_attributes[model_type]['target_resolution'] is not None: 88 | target_resolution = model_attributes[model_type]['target_resolution'] 89 | else: 90 | target_resolution = (orig_w, orig_h) 91 | 92 | if (not train) or (not augment_data): 93 | transform = transforms.Compose([ 94 | transforms.CenterCrop(orig_min_dim), 95 | transforms.Resize(target_resolution), 96 | transforms.ToTensor(), 97 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 98 | ]) 99 | else: 100 | # Orig aspect ratio is 0.81, so we don't squish it in that direction any more 101 | transform = transforms.Compose([ 102 | transforms.RandomResizedCrop( 103 | target_resolution, 104 | scale=(0.7, 1.0), 105 | ratio=(1.0, 1.3333333333333333), 106 | interpolation=2), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | ]) 111 | return transform 112 | -------------------------------------------------------------------------------- /real_datasets/data/confounder_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset, Subset 6 | 7 | from configs import model_attributes 8 | 9 | 10 | class ConfounderDataset(Dataset): 11 | def __init__(self, root_dir, 12 | target_name, confounder_names, 13 | model_type=None, augment_data=None): 14 | raise NotImplementedError 15 | 16 | def __len__(self): 17 | return len(self.filename_array) 18 | 19 | def __getitem__(self, idx): 20 | y = self.y_array[idx] 21 | g = self.group_array[idx] 22 | 23 | if model_attributes[self.model_type]['feature_type'] == 'precomputed': 24 | x = self.features_mat[idx, :] 25 | else: 26 | img_filename = os.path.join( 27 | self.data_dir, 28 | self.filename_array[idx]) 29 | img = Image.open(img_filename).convert('RGB') 30 | # Figure out split and transform accordingly 31 | if self.split_array[idx] == self.split_dict['train'] and self.train_transform: 32 | img = self.train_transform(img) 33 | elif (self.split_array[idx] in [self.split_dict['val'], self.split_dict['test']] and 34 | self.eval_transform): 35 | img = self.eval_transform(img) 36 | # Flatten if needed 37 | if model_attributes[self.model_type]['flatten']: 38 | assert img.dim() == 3 39 | img = img.view(-1) 40 | x = img 41 | 42 | return x, y, g 43 | 44 | def get_splits(self, splits, train_frac=1.0): 45 | subsets = {} 46 | for split in splits: 47 | assert split in ('train', 'val', 'test'), split + ' is not a valid split' 48 | mask = self.split_array == self.split_dict[split] 49 | num_split = np.sum(mask) 50 | indices = np.where(mask)[0] 51 | if train_frac < 1 and split == 'train': 52 | num_to_retain = int(np.round(float(len(indices)) * train_frac)) 53 | indices = np.sort(np.random.permutation(indices)[:num_to_retain]) 54 | subsets[split] = Subset(self, indices) 55 | return subsets 56 | 57 | def group_str(self, group_idx): 58 | y = group_idx // (self.n_groups / self.n_classes) 59 | c = group_idx % (self.n_groups // self.n_classes) 60 | 61 | group_name = f'{self.target_name} = {int(y)}' 62 | bin_str = format(int(c), f'0{self.n_confounders}b')[::-1] 63 | for attr_idx, attr_name in enumerate(self.confounder_names): 64 | group_name += f', {attr_name} = {bin_str[attr_idx]}' 65 | return group_name 66 | -------------------------------------------------------------------------------- /real_datasets/data/confounder_utils.py: -------------------------------------------------------------------------------- 1 | from data.celebA_dataset import CelebADataset 2 | from data.cub_dataset import CUBDataset 3 | from data.dro_dataset import DRODataset 4 | from data.multinli_dataset import MultiNLIDataset 5 | 6 | ################ 7 | ### SETTINGS ### 8 | ################ 9 | 10 | confounder_settings = { 11 | 'CelebA': { 12 | 'constructor': CelebADataset 13 | }, 14 | 'CUB': { 15 | 'constructor': CUBDataset 16 | }, 17 | 'MultiNLI': { 18 | 'constructor': MultiNLIDataset 19 | } 20 | } 21 | 22 | 23 | ######################## 24 | ### DATA PREPARATION ### 25 | ######################## 26 | def prepare_confounder_data(args, train, return_full_dataset=False, train_transform=None, eval_transform=None): 27 | full_dataset = confounder_settings[args.dataset]['constructor']( 28 | root_dir=args.root_dir, 29 | target_name=args.target_name, 30 | confounder_names=args.confounder_names, 31 | model_type=args.model, 32 | augment_data=args.augment_data, 33 | train_transform=train_transform, 34 | eval_transform=eval_transform 35 | ) 36 | if return_full_dataset: 37 | return DRODataset( 38 | full_dataset, 39 | process_item_fn=None, 40 | n_groups=full_dataset.n_groups, 41 | n_classes=full_dataset.n_classes, 42 | group_str_fn=full_dataset.group_str) 43 | if train: 44 | splits = ['train', 'val', 'test'] 45 | else: 46 | splits = ['test'] 47 | subsets = full_dataset.get_splits(splits, train_frac=args.fraction) 48 | dro_subsets = [DRODataset(subsets[split], process_item_fn=None, n_groups=full_dataset.n_groups, 49 | n_classes=full_dataset.n_classes, group_str_fn=full_dataset.group_str) \ 50 | for split in splits] 51 | return dro_subsets 52 | -------------------------------------------------------------------------------- /real_datasets/data/cub_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torchvision.transforms as transforms 7 | 8 | from data.confounder_dataset import ConfounderDataset 9 | from configs import model_attributes 10 | 11 | 12 | class CUBDataset(ConfounderDataset): 13 | """ 14 | CUB dataset (already cropped and centered). 15 | Note: metadata_df is one-indexed. 16 | """ 17 | 18 | def __init__(self, root_dir, 19 | target_name, confounder_names, 20 | augment_data=False, 21 | model_type=None, 22 | train_transform=None, 23 | eval_transform=None, ): 24 | self.root_dir = root_dir 25 | self.target_name = target_name 26 | self.confounder_names = confounder_names 27 | self.model_type = model_type 28 | self.augment_data = augment_data 29 | 30 | self.data_dir = os.path.join( 31 | self.root_dir, 32 | 'data', 33 | '_'.join([self.target_name] + self.confounder_names)) 34 | 35 | if not os.path.exists(self.data_dir): 36 | raise ValueError( 37 | f'{self.data_dir} does not exist yet. Please generate the dataset first.') 38 | 39 | # Read in metadata 40 | self.metadata_df = pd.read_csv( 41 | os.path.join(self.data_dir, 'metadata.csv')) 42 | 43 | # Get the y values 44 | self.y_array = self.metadata_df['y'].values 45 | self.n_classes = 2 46 | 47 | # We only support one confounder for CUB for now 48 | self.confounder_array = self.metadata_df['place'].values 49 | self.n_confounders = 1 50 | # Map to groups 51 | self.n_groups = pow(2, 2) 52 | self.group_array = (self.y_array * (self.n_groups / 2) + self.confounder_array).astype('int') 53 | 54 | # Extract filenames and splits 55 | self.filename_array = self.metadata_df['img_filename'].values 56 | self.split_array = self.metadata_df['split'].values 57 | self.split_dict = { 58 | 'train': 0, 59 | 'val': 1, 60 | 'test': 2 61 | } 62 | 63 | # Set transform 64 | if model_attributes[self.model_type]['feature_type'] == 'precomputed': 65 | self.features_mat = torch.from_numpy(np.load( 66 | os.path.join(root_dir, 'features', model_attributes[self.model_type]['feature_filename']))).float() 67 | self.train_transform = train_transform 68 | self.eval_transform = eval_transform 69 | else: 70 | self.features_mat = None 71 | self.train_transform = get_transform_cub( 72 | self.model_type, 73 | train=True, 74 | augment_data=augment_data) if train_transform is None else train_transform 75 | self.eval_transform = get_transform_cub( 76 | self.model_type, 77 | train=False, 78 | augment_data=augment_data) if eval_transform is None else eval_transform 79 | 80 | 81 | def get_transform_cub(model_type, train, augment_data): 82 | scale = 256.0 / 224.0 83 | target_resolution = model_attributes[model_type]['target_resolution'] 84 | assert target_resolution is not None 85 | 86 | if (not train) or (not augment_data): 87 | # Resizes the image to a slightly larger square then crops the center. 88 | transform = transforms.Compose([ 89 | transforms.Resize((int(target_resolution[0] * scale), int(target_resolution[1] * scale))), 90 | transforms.CenterCrop(target_resolution), 91 | transforms.ToTensor(), 92 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 93 | ]) 94 | else: 95 | transform = transforms.Compose([ 96 | transforms.RandomResizedCrop( 97 | target_resolution, 98 | scale=(0.7, 1.0), 99 | ratio=(0.75, 1.3333333333333333), 100 | interpolation=2), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.ToTensor(), 103 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 104 | ]) 105 | return transform 106 | -------------------------------------------------------------------------------- /real_datasets/data/dro_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader, Subset 3 | from torch.utils.data.sampler import WeightedRandomSampler 4 | 5 | 6 | class DRODataset(Dataset): 7 | def __init__(self, dataset, process_item_fn, n_groups, n_classes, group_str_fn): 8 | self.dataset = dataset 9 | self.process_item = process_item_fn 10 | self.n_groups = n_groups 11 | self.n_classes = n_classes 12 | self.group_str = group_str_fn 13 | if isinstance(dataset, Subset): 14 | full_dataset = dataset.dataset 15 | assert not isinstance(full_dataset, Subset) 16 | indices = dataset.indices # subset indices 17 | group_array = full_dataset.group_array[indices] 18 | y_array = full_dataset.y_array[indices] 19 | else: 20 | group_array = dataset.group_array 21 | y_array = dataset.y_array 22 | self._group_array = torch.LongTensor(group_array) 23 | self._y_array = torch.LongTensor(y_array) 24 | self._group_counts = (torch.arange(self.n_groups).unsqueeze(1) == self._group_array).sum(1).float() 25 | self._y_counts = (torch.arange(self.n_classes).unsqueeze(1) == self._y_array).sum(1).float() 26 | 27 | def __getitem__(self, idx): 28 | if self.process_item is None: 29 | return self.dataset[idx] 30 | else: 31 | return self.process_item(self.dataset[idx]) 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def group_counts(self): 37 | return self._group_counts 38 | 39 | def class_counts(self): 40 | return self._y_counts 41 | 42 | def input_size(self): 43 | for x, y, g in self: 44 | return x.size() 45 | 46 | def get_loader(self, train, reweight_groups, **kwargs): 47 | if not train: # Validation or testing 48 | assert reweight_groups is None 49 | shuffle = False 50 | sampler = None 51 | elif not reweight_groups: # Training but not reweighting 52 | shuffle = True 53 | sampler = None 54 | else: # Training and reweighting 55 | # When the --robust flag is not set, reweighting changes the loss function 56 | # from the normal ERM (average loss over each training example) 57 | # to a reweighted ERM (weighted average where each (y,c) group has equal weight) . 58 | # When the --robust flag is set, reweighting does not change the loss function 59 | # since the minibatch is only used for mean gradient estimation for each group separately 60 | group_weights = len(self) / self._group_counts 61 | weights = group_weights[self._group_array] 62 | 63 | # Replacement needs to be set to True, otherwise we'll run out of minority samples 64 | sampler = WeightedRandomSampler(weights, len(self), replacement=True) 65 | shuffle = False 66 | 67 | loader = DataLoader( 68 | self, 69 | shuffle=shuffle, 70 | sampler=sampler, 71 | **kwargs) 72 | return loader 73 | -------------------------------------------------------------------------------- /real_datasets/data/label_shift_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.dro_dataset import DRODataset 3 | from data.utils import subsample, Subset 4 | from data.torchvision_datasets import load_CIFAR10 5 | import numpy as np 6 | 7 | 8 | ######################## 9 | ### DATA PREPARATION ### 10 | ######################## 11 | 12 | def prepare_label_shift_data(args, train): 13 | settings = label_shift_settings[args.dataset] 14 | data = settings['load_fn'](args, train) 15 | n_classes = settings['n_classes'] 16 | if train: 17 | train_data, val_data = data 18 | if args.fraction < 1: 19 | train_data = subsample(train_data, args.fraction) 20 | train_data = apply_label_shift(train_data, n_classes, args.shift_type, args.minority_fraction, 21 | args.imbalance_ratio) 22 | data = [train_data, val_data] 23 | dro_data = [DRODataset(subset, process_item_fn=settings['process_fn'], n_groups=n_classes, 24 | n_classes=n_classes, group_str_fn=settings['group_str_fn']) \ 25 | for subset in data] 26 | return dro_data 27 | 28 | 29 | ############## 30 | ### SHIFTS ### 31 | ############## 32 | 33 | def apply_label_shift(dataset, n_classes, shift_type, minority_frac, imbalance_ratio): 34 | assert shift_type.startswith('label_shift') 35 | if shift_type == 'label_shift_step': 36 | return step_shift(dataset, n_classes, minority_frac, imbalance_ratio) 37 | 38 | 39 | def step_shift(dataset, n_classes, minority_frac, imbalance_ratio): 40 | # get y info 41 | y_array = [] 42 | for x, y in dataset: 43 | y_array.append(y) 44 | y_array = torch.LongTensor(y_array) 45 | y_counts = ((torch.arange(n_classes).unsqueeze(1) == y_array).sum(1)).float() 46 | # figure out sample size for each class 47 | is_major = (torch.arange(n_classes) < (1 - minority_frac) * n_classes).float() 48 | major_count = int(torch.min(is_major * y_counts + (1 - is_major) * y_counts * imbalance_ratio).item()) 49 | minor_count = int(np.floor(major_count / imbalance_ratio)) 50 | print(y_counts, major_count, minor_count) 51 | # subsample 52 | sampled_indices = [] 53 | for y in np.arange(n_classes): 54 | indices, = np.where(y_array == y) 55 | np.random.shuffle(indices) 56 | if is_major[y]: 57 | sample_size = major_count 58 | else: 59 | sample_size = minor_count 60 | sampled_indices.append(indices[:sample_size]) 61 | sampled_indices = torch.from_numpy(np.concatenate(sampled_indices)) 62 | return Subset(dataset, sampled_indices) 63 | 64 | 65 | ################### 66 | ### PROCESS FNS ### 67 | ################### 68 | 69 | def xy_to_xyy(data): 70 | x, y = data 71 | return x, y, y 72 | 73 | 74 | ##################### 75 | ### GROUP STR FNS ### 76 | ##################### 77 | 78 | def group_str_CIFAR10(group_idx): 79 | classes = ['plane', 'car', 'bird', 'cat', 80 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 81 | return f'Y = {group_idx} ({classes[group_idx]})' 82 | 83 | 84 | ################ 85 | ### SETTINGS ### 86 | ################ 87 | 88 | label_shift_settings = { 89 | 'CIFAR10': { 90 | 'load_fn': load_CIFAR10, 91 | 'group_str_fn': group_str_CIFAR10, 92 | 'process_fn': xy_to_xyy, 93 | 'n_classes': 10 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /real_datasets/data/multinli_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | from data.confounder_dataset import ConfounderDataset 8 | 9 | 10 | class MultiNLIDataset(ConfounderDataset): 11 | """ 12 | MultiNLI dataset. 13 | label_dict = { 14 | 'contradiction': 0, 15 | 'entailment': 1, 16 | 'neutral': 2 17 | } 18 | # Negation words taken from https://arxiv.org/pdf/1803.02324.pdf 19 | negation_words = ['nobody', 'no', 'never', 'nothing'] 20 | """ 21 | 22 | def __init__(self, root_dir, 23 | target_name, confounder_names, 24 | augment_data=False, 25 | model_type=None, 26 | train_transform=None, 27 | eval_transform=None, 28 | ): 29 | self.root_dir = root_dir 30 | self.target_name = target_name 31 | self.confounder_names = confounder_names 32 | self.model_type = model_type 33 | self.augment_data = augment_data 34 | 35 | assert len(confounder_names) == 1 36 | assert confounder_names[0] == 'sentence2_has_negation' 37 | assert target_name in ['gold_label_preset', 'gold_label_random'] 38 | assert augment_data == False 39 | assert model_type == 'bert' 40 | 41 | self.data_dir = os.path.join( 42 | self.root_dir, 43 | 'data') 44 | self.glue_dir = os.path.join( 45 | self.root_dir, 46 | 'glue_data', 47 | 'MNLI') 48 | if not os.path.exists(self.data_dir): 49 | raise ValueError( 50 | f'{self.data_dir} does not exist yet. Please generate the dataset first.') 51 | if not os.path.exists(self.glue_dir): 52 | raise ValueError( 53 | f'{self.glue_dir} does not exist yet. Please generate the dataset first.') 54 | 55 | # Read in metadata 56 | type_of_split = target_name.split('_')[-1] 57 | self.metadata_df = pd.read_csv( 58 | os.path.join( 59 | self.data_dir, 60 | f'metadata_{type_of_split}.csv'), 61 | index_col=0) 62 | 63 | # Get the y values 64 | # gold_label is hardcoded 65 | self.y_array = self.metadata_df['gold_label'].values 66 | self.n_classes = len(np.unique(self.y_array)) 67 | 68 | self.confounder_array = self.metadata_df[confounder_names[0]].values 69 | self.n_confounders = len(confounder_names) 70 | 71 | # Map to groups 72 | self.n_groups = len(np.unique(self.confounder_array)) * self.n_classes 73 | self.group_array = (self.y_array * (self.n_groups / self.n_classes) + self.confounder_array).astype('int') 74 | 75 | # Extract splits 76 | self.split_array = self.metadata_df['split'].values 77 | self.split_dict = { 78 | 'train': 0, 79 | 'val': 1, 80 | 'test': 2 81 | } 82 | 83 | # Load features 84 | self.features_array = [] 85 | for feature_file in [ 86 | 'cached_train_bert-base-uncased_128_mnli', 87 | 'cached_dev_bert-base-uncased_128_mnli', 88 | 'cached_dev_bert-base-uncased_128_mnli-mm' 89 | ]: 90 | features = torch.load( 91 | os.path.join( 92 | self.glue_dir, 93 | feature_file)) 94 | 95 | self.features_array += features 96 | 97 | self.all_input_ids = torch.tensor([f.input_ids for f in self.features_array], dtype=torch.long) 98 | self.all_input_masks = torch.tensor([f.input_mask for f in self.features_array], dtype=torch.long) 99 | self.all_segment_ids = torch.tensor([f.segment_ids for f in self.features_array], dtype=torch.long) 100 | self.all_label_ids = torch.tensor([f.label_id for f in self.features_array], dtype=torch.long) 101 | 102 | self.x_array = torch.stack(( 103 | self.all_input_ids, 104 | self.all_input_masks, 105 | self.all_segment_ids), dim=2) 106 | 107 | assert np.all(np.array(self.all_label_ids) == self.y_array) 108 | 109 | def __len__(self): 110 | return len(self.y_array) 111 | 112 | def __getitem__(self, idx): 113 | y = self.y_array[idx] 114 | g = self.group_array[idx] 115 | x = self.x_array[idx, ...] 116 | return x, y, g 117 | 118 | def group_str(self, group_idx): 119 | y = group_idx // (self.n_groups / self.n_classes) 120 | c = group_idx % (self.n_groups // self.n_classes) 121 | 122 | attr_name = self.confounder_names[0] 123 | group_name = f'{self.target_name} = {int(y)}, {attr_name} = {int(c)}' 124 | return group_name 125 | -------------------------------------------------------------------------------- /real_datasets/data/torchvision_datasets.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.transforms as transforms 3 | 4 | from data.utils import train_val_split 5 | from configs import model_attributes 6 | 7 | 8 | ### CIFAR10 ### 9 | def load_CIFAR10(args, train): 10 | transform = get_transform_CIFAR10(args, train) 11 | dataset = torchvision.datasets.CIFAR10(args.root_dir, train, transform=transform, download=True) 12 | if train: 13 | subsets = train_val_split(dataset, args.val_fraction) 14 | else: 15 | subsets = [dataset, ] 16 | return subsets 17 | 18 | 19 | def get_transform_CIFAR10(args, train): 20 | transform_list = [] 21 | # resize if needed 22 | target_resolution = model_attributes[args.model]['target_resolution'] 23 | if target_resolution is not None: 24 | transform_list.append(transforms.Resize(target_resolution)) 25 | transform_list += [transforms.ToTensor(), 26 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 27 | composed_transform = transforms.Compose(transform_list) 28 | return composed_transform 29 | -------------------------------------------------------------------------------- /real_datasets/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Subset 3 | 4 | 5 | # Train val split 6 | def train_val_split(dataset, val_frac): 7 | # split into train and val 8 | indices = np.arange(len(dataset)) 9 | np.random.shuffle(indices) 10 | val_size = int(np.round(len(dataset) * val_frac)) 11 | train_indices, val_indices = indices[val_size:], indices[:val_size] 12 | train_data, val_data = Subset(dataset, train_indices), Subset(dataset, val_indices) 13 | return train_data, val_data 14 | 15 | 16 | # Subsample a fraction for smaller training data 17 | def subsample(dataset, fraction): 18 | indices = np.arange(len(dataset)) 19 | num_to_retain = int(np.round(float(len(dataset)) * fraction)) 20 | np.random.shuffle(indices) 21 | return Subset(dataset, indices[:num_to_retain]) 22 | -------------------------------------------------------------------------------- /real_datasets/dataset_scripts/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | def crop_and_resize(source_img, target_img): 5 | """ 6 | Make source_img exactly the same as target_img by expanding/shrinking and 7 | cropping appropriately. 8 | 9 | If source_img's dimensions are strictly greater than or equal to the 10 | corresponding target img dimensions, we crop left/right or top/bottom 11 | depending on aspect ratio, then shrink down. 12 | 13 | If any of source img's dimensions are smaller than target img's dimensions, 14 | we expand the source img and then crop accordingly 15 | 16 | Modified from 17 | https://stackoverflow.com/questions/4744372/reducing-the-width-height-of-an-image-to-fit-a-given-aspect-ratio-how-python 18 | """ 19 | source_width = source_img.size[0] 20 | source_height = source_img.size[1] 21 | 22 | target_width = target_img.size[0] 23 | target_height = target_img.size[1] 24 | 25 | # Check if source does not completely cover target 26 | if (source_width < target_width) or (source_height < target_height): 27 | # Try matching width 28 | width_resize = (target_width, int((target_width / source_width) * source_height)) 29 | if (width_resize[0] >= target_width) and (width_resize[1] >= target_height): 30 | source_resized = source_img.resize(width_resize, Image.ANTIALIAS) 31 | else: 32 | height_resize = (int((target_height / source_height) * source_width), target_height) 33 | assert (height_resize[0] >= target_width) and (height_resize[1] >= target_height) 34 | source_resized = source_img.resize(height_resize, Image.ANTIALIAS) 35 | # Rerun the cropping 36 | return crop_and_resize(source_resized, target_img) 37 | 38 | source_aspect = source_width / source_height 39 | target_aspect = target_width / target_height 40 | 41 | if source_aspect > target_aspect: 42 | # Crop left/right 43 | new_source_width = int(target_aspect * source_height) 44 | offset = (source_width - new_source_width) // 2 45 | resize = (offset, 0, source_width - offset, source_height) 46 | else: 47 | # Crop top/bottom 48 | new_source_height = int(source_width / target_aspect) 49 | offset = (source_height - new_source_height) // 2 50 | resize = (0, offset, source_width, source_height - offset) 51 | 52 | source_resized = source_img.crop(resize).resize((target_width, target_height), Image.ANTIALIAS) 53 | return source_resized 54 | 55 | 56 | def combine_and_mask(img_new, mask, img_black): 57 | """ 58 | Combine img_new, mask, and image_black based on the mask 59 | 60 | img_new: new (unmasked image) 61 | mask: binary mask of bird image 62 | img_black: already-masked bird image (bird only) 63 | """ 64 | # Warp new img to match black img 65 | img_resized = crop_and_resize(img_new, img_black) 66 | img_resized_np = np.asarray(img_resized) 67 | 68 | # Mask new img 69 | img_masked_np = np.around(img_resized_np * (1 - mask)).astype(np.uint8) 70 | 71 | # Combine 72 | img_combined_np = np.asarray(img_black) + img_masked_np 73 | img_combined = Image.fromarray(img_combined_np) 74 | 75 | return img_combined 76 | -------------------------------------------------------------------------------- /real_datasets/dataset_scripts/generate_multinli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import random 5 | import string 6 | import pandas as pd 7 | from configs import DATA_FOLDER 8 | ################ Paths and other configs - Set these ################################# 9 | 10 | data_dir = DATA_FOLDER + '/multinli/data' 11 | # glue_dir = '/u/scr/nlp/dro/multinli/glue_data/MNLI' 12 | 13 | type_of_split = 'random' 14 | assert type_of_split in ['preset', 'random'] 15 | # If 'preset', use the official train/val/test MultiNLI split 16 | # If 'random', randomly split 50%/20%/30% of the data to train/val/test 17 | 18 | ###################################################################################### 19 | 20 | def tokenize(s): 21 | s = s.translate(str.maketrans('', '', string.punctuation)) 22 | s = s.lower() 23 | s = s.split(' ') 24 | return s 25 | 26 | ### Read in data and assign train/val/test splits 27 | train_df = pd.read_json( 28 | os.path.join( 29 | data_dir, 30 | 'multinli_1.0_train.jsonl'), 31 | lines=True) 32 | 33 | val_df = pd.read_json( 34 | os.path.join( 35 | data_dir, 36 | 'multinli_1.0_dev_matched.jsonl'), 37 | lines=True) 38 | 39 | test_df = pd.read_json( 40 | os.path.join( 41 | data_dir, 42 | 'multinli_1.0_dev_mismatched.jsonl'), 43 | lines=True) 44 | 45 | split_dict = { 46 | 'train': 0, 47 | 'val': 1, 48 | 'test': 2 49 | } 50 | 51 | if type_of_split == 'preset': 52 | train_df['split'] = split_dict['train'] 53 | val_df['split'] = split_dict['val'] 54 | test_df['split'] = split_dict['test'] 55 | df = pd.concat([train_df, val_df, test_df], ignore_index=True) 56 | 57 | elif type_of_split == 'random': 58 | val_frac = 0.2 59 | test_frac = 0.3 60 | 61 | df = pd.concat([train_df, val_df, test_df], ignore_index=True) 62 | n = len(df) 63 | n_val = int(val_frac * n) 64 | n_test = int(test_frac * n) 65 | n_train = n - n_val - n_test 66 | splits = np.array([split_dict['train']] * n_train + [split_dict['val']] * n_val + [split_dict['test']] * n_test) 67 | np.random.shuffle(splits) 68 | df['split'] = splits 69 | 70 | ### Assign labels 71 | df = df.loc[df['gold_label'] != '-', :] 72 | print(f'Total number of examples: {len(df)}') 73 | for k, v in split_dict.items(): 74 | print(k, np.mean(df['split'] == v)) 75 | 76 | label_dict = { 77 | 'contradiction': 0, 78 | 'entailment': 1, 79 | 'neutral': 2 80 | } 81 | for k, v in label_dict.items(): 82 | idx = df.loc[:, 'gold_label'] == k 83 | df.loc[idx, 'gold_label'] = v 84 | 85 | ### Assign spurious attribute (negation words) 86 | negation_words = ['nobody', 'no', 'never', 'nothing'] # Taken from https://arxiv.org/pdf/1803.02324.pdf 87 | 88 | df['sentence2_has_negation'] = [False] * len(df) 89 | 90 | for negation_word in negation_words: 91 | df['sentence2_has_negation'] |= [negation_word in tokenize(sentence) for sentence in df['sentence2']] 92 | 93 | df['sentence2_has_negation'] = df['sentence2_has_negation'].astype(int) 94 | 95 | ## Write to disk 96 | df = df[['gold_label', 'sentence2_has_negation', 'split']] 97 | df.to_csv(os.path.join(data_dir, f'metadata_{type_of_split}.csv')) 98 | -------------------------------------------------------------------------------- /real_datasets/dataset_scripts/generate_waterbirds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import pandas as pd 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from dataset_utils import crop_and_resize, combine_and_mask 8 | 9 | ################ Paths and other configs - Set these ################################# 10 | cub_dir = '/u/scr/nlp/CUB_200_2011' 11 | places_dir = '/u/scr/nlp/places365' 12 | output_dir = '/u/scr/nlp/dro/cub/data' 13 | dataset_name = 'waterbird_complete95_forest2water2' 14 | 15 | target_places = [ 16 | ['bamboo_forest', 'forest/broadleaf'], # Land backgrounds 17 | ['ocean', 'lake/natural']] # Water backgrounds 18 | 19 | val_frac = 0.2 # What fraction of the training data to use as validation 20 | confounder_strength = 0.95 # Determines relative size of majority vs. minority groups 21 | ###################################################################################### 22 | 23 | images_path = os.path.join(cub_dir, 'images.txt') 24 | 25 | df = pd.read_csv( 26 | images_path, 27 | sep=" ", 28 | header=None, 29 | names=['img_id', 'img_filename'], 30 | index_col='img_id') 31 | 32 | ### Set up labels of waterbirds vs. landbirds 33 | # We consider water birds = seabirds and waterfowl. 34 | species = np.unique([img_filename.split('/')[0].split('.')[1].lower() for img_filename in df['img_filename']]) 35 | water_birds_list = [ 36 | 'Albatross', # Seabirds 37 | 'Auklet', 38 | 'Cormorant', 39 | 'Frigatebird', 40 | 'Fulmar', 41 | 'Gull', 42 | 'Jaeger', 43 | 'Kittiwake', 44 | 'Pelican', 45 | 'Puffin', 46 | 'Tern', 47 | 'Gadwall', # Waterfowl 48 | 'Grebe', 49 | 'Mallard', 50 | 'Merganser', 51 | 'Guillemot', 52 | 'Pacific_Loon' 53 | ] 54 | 55 | water_birds = {} 56 | for species_name in species: 57 | water_birds[species_name] = 0 58 | for water_bird in water_birds_list: 59 | if water_bird.lower() in species_name: 60 | water_birds[species_name] = 1 61 | species_list = [img_filename.split('/')[0].split('.')[1].lower() for img_filename in df['img_filename']] 62 | df['y'] = [water_birds[species] for species in species_list] 63 | 64 | ### Assign train/tesst/valid splits 65 | # In the original CUB dataset split, split = 0 is test and split = 1 is train 66 | # We want to change it to 67 | # split = 0 is train, 68 | # split = 1 is val, 69 | # split = 2 is test 70 | 71 | train_test_df = pd.read_csv( 72 | os.path.join(cub_dir, 'train_test_split.txt'), 73 | sep=" ", 74 | header=None, 75 | names=['img_id', 'split'], 76 | index_col='img_id') 77 | 78 | df = df.join(train_test_df, on='img_id') 79 | test_ids = df.loc[df['split'] == 0].index 80 | train_ids = np.array(df.loc[df['split'] == 1].index) 81 | val_ids = np.random.choice( 82 | train_ids, 83 | size=int(np.round(val_frac * len(train_ids))), 84 | replace=False) 85 | 86 | df.loc[train_ids, 'split'] = 0 87 | df.loc[val_ids, 'split'] = 1 88 | df.loc[test_ids, 'split'] = 2 89 | 90 | ### Assign confounders (place categories) 91 | 92 | # Confounders are set up as the following: 93 | # Y = 0, C = 0: confounder_strength 94 | # Y = 0, C = 1: 1 - confounder_strength 95 | # Y = 1, C = 0: 1 - confounder_strength 96 | # Y = 1, C = 1: confounder_strength 97 | 98 | df['place'] = 0 99 | train_ids = np.array(df.loc[df['split'] == 0].index) 100 | val_ids = np.array(df.loc[df['split'] == 1].index) 101 | test_ids = np.array(df.loc[df['split'] == 2].index) 102 | for split_idx, ids in enumerate([train_ids, val_ids, test_ids]): 103 | for y in (0, 1): 104 | if split_idx == 0: # train 105 | if y == 0: 106 | pos_fraction = 1 - confounder_strength 107 | else: 108 | pos_fraction = confounder_strength 109 | else: 110 | pos_fraction = 0.5 111 | subset_df = df.loc[ids, :] 112 | y_ids = np.array((subset_df.loc[subset_df['y'] == y]).index) 113 | pos_place_ids = np.random.choice( 114 | y_ids, 115 | size=int(np.round(pos_fraction * len(y_ids))), 116 | replace=False) 117 | df.loc[pos_place_ids, 'place'] = 1 118 | 119 | for split, split_label in [(0, 'train'), (1, 'val'), (2, 'test')]: 120 | print(f"{split_label}:") 121 | split_df = df.loc[df['split'] == split, :] 122 | print(f"waterbirds are {np.mean(split_df['y']):.3f} of the examples") 123 | print(f"y = 0, c = 0: {np.mean(split_df.loc[split_df['y'] == 0, 'place'] == 0):.3f}, n = {np.sum((split_df['y'] == 0) & (split_df['place'] == 0))}") 124 | print(f"y = 0, c = 1: {np.mean(split_df.loc[split_df['y'] == 0, 'place'] == 1):.3f}, n = {np.sum((split_df['y'] == 0) & (split_df['place'] == 1))}") 125 | print(f"y = 1, c = 0: {np.mean(split_df.loc[split_df['y'] == 1, 'place'] == 0):.3f}, n = {np.sum((split_df['y'] == 1) & (split_df['place'] == 0))}") 126 | print(f"y = 1, c = 1: {np.mean(split_df.loc[split_df['y'] == 1, 'place'] == 1):.3f}, n = {np.sum((split_df['y'] == 1) & (split_df['place'] == 1))}") 127 | 128 | ### Assign places to train, val, and test set 129 | place_ids_df = pd.read_csv( 130 | os.path.join(places_dir, 'categories_places365.txt'), 131 | sep=" ", 132 | header=None, 133 | names=['place_name', 'place_id'], 134 | index_col='place_id') 135 | 136 | target_place_ids = [] 137 | 138 | for idx, target_places in enumerate(target_places): 139 | place_filenames = [] 140 | 141 | for target_place in target_places: 142 | target_place_full = f'/{target_place[0]}/{target_place}' 143 | assert (np.sum(place_ids_df['place_name'] == target_place_full) == 1) 144 | target_place_ids.append(place_ids_df.index[place_ids_df['place_name'] == target_place_full][0]) 145 | print(f'train category {idx} {target_place_full} has id {target_place_ids[idx]}') 146 | 147 | # Read place filenames associated with target_place 148 | place_filenames += [ 149 | f'/{target_place[0]}/{target_place}/{filename}' for filename in os.listdir( 150 | os.path.join(places_dir, 'data_large', target_place[0], target_place)) 151 | if filename.endswith('.jpg')] 152 | 153 | random.shuffle(place_filenames) 154 | 155 | # Assign each filename to an image 156 | indices = (df.loc[:, 'place'] == idx) 157 | assert len(place_filenames) >= np.sum(indices),\ 158 | f"Not enough places ({len(place_filenames)}) to fit the dataset ({np.sum(df.loc[:, 'place'] == idx)})" 159 | df.loc[indices, 'place_filename'] = place_filenames[:np.sum(indices)] 160 | 161 | ### Write dataset to disk 162 | output_subfolder = os.path.join(output_dir, dataset_name) 163 | os.makedirs(output_subfolder, exist_ok=True) 164 | 165 | df.to_csv(os.path.join(output_subfolder, 'metadata.csv')) 166 | 167 | for i in tqdm(df.index): 168 | # Load bird image and segmentation 169 | img_path = os.path.join(cub_dir, 'images', df.loc[i, 'img_filename']) 170 | seg_path = os.path.join(cub_dir, 'segmentations', df.loc[i, 'img_filename'].replace('.jpg','.png')) 171 | img_np = np.asarray(Image.open(img_path).convert('RGB')) 172 | seg_np = np.asarray(Image.open(seg_path).convert('RGB')) / 255 173 | 174 | # Load place background 175 | # Skip front / 176 | place_path = os.path.join(places_dir, 'data_large', df.loc[i, 'place_filename'][1:]) 177 | place = Image.open(place_path).convert('RGB') 178 | 179 | img_black = Image.fromarray(np.around(img_np * seg_np).astype(np.uint8)) 180 | combined_img = combine_and_mask(place, seg_np, img_black) 181 | 182 | output_path = os.path.join(output_subfolder, df.loc[i, 'img_filename']) 183 | os.makedirs('/'.join(output_path.split('/')[:-1]), exist_ok=True) 184 | 185 | combined_img.save(output_path) 186 | -------------------------------------------------------------------------------- /real_datasets/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import warnings 5 | from itertools import product 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm.auto import tqdm 11 | 12 | from configs import DATA_FOLDER 13 | from isr import ISRClassifier, check_clf 14 | from utils.eval_utils import extract_data, save_df, measure_group_accs, load_data, group2env 15 | 16 | warnings.filterwarnings('ignore') # filter out Pandas append warnings 17 | 18 | 19 | def eval_ISR(args, train_data=None, val_data=None, test_data=None, log_dir=None): 20 | if (train_data is None) or (val_data is None) or (test_data is None) or (log_dir is None): 21 | train_data, val_data, test_data, log_dir = load_data(args) 22 | train_gs = train_data['group'] 23 | n_train = len(train_gs) 24 | groups, counts = np.unique(train_data['group'], return_counts=True, axis=0) 25 | n_groups = len(groups) 26 | n_classes = len(np.unique(train_data['label'])) 27 | # we do this because the original group is defined by (class * attribute) 28 | n_spu_attr = n_groups // n_classes 29 | assert n_spu_attr >= 2 30 | assert n_groups % n_classes == 0 31 | 32 | zs, ys, gs, preds = extract_data(train_data) 33 | 34 | test_zs, test_ys, test_gs, test_preds = extract_data( 35 | test_data, ) 36 | val_zs, val_ys, val_gs, val_preds = extract_data( 37 | val_data, ) 38 | 39 | if args.algo == 'ERM' or args.no_reweight: 40 | # no_reweight: do not use reweightning in the ISR classifier even if the args.algo is 'reweight' or 'groupDRO' 41 | sample_weight = None 42 | else: 43 | sample_weight = np.ones(n_train) 44 | for group, count in zip(groups, counts): 45 | sample_weight[train_gs == group] = n_train / n_groups / count 46 | if args.verbose: 47 | print('Computed non-uniform sample weight') 48 | 49 | df = pd.DataFrame( 50 | columns=['dataset', 'algo', 'seed', 'ckpt', 'split', 'method', 'clf_type', 'C', 'pca_dim', 'd_spu', 'ISR_class', 51 | 'ISR_scale', 'env_label_ratio'] + 52 | [f'acc-{g}' for g in groups] + ['worst_group', 'avg_acc', 'worst_acc', ]) 53 | base_row = {'dataset': args.dataset, 'algo': args.algo, 54 | 'seed': args.seed, 'ckpt': args.model_select, } 55 | 56 | # Need to convert group labels to env labels (i.e., spurious-attribute labels) 57 | es, val_es, test_es = group2env(gs, n_spu_attr), group2env(val_gs, n_spu_attr), group2env(test_gs, n_spu_attr) 58 | 59 | # eval_groups = np.array([0] + list(range(n_groups))) 60 | method = f'ISR-{args.ISR_version.capitalize()}' 61 | if args.no_reweight and (not args.use_orig_clf) and args.algo != 'ERM': 62 | method += '_noRW' 63 | if args.use_orig_clf: 64 | ckpt = pickle.load(open(log_dir + f'/{args.model_select}_clf.p', 'rb')) 65 | orig_clf = check_clf(ckpt, n_classes=n_classes) 66 | # Record original val accuracy: 67 | for (split, eval_zs, eval_ys, eval_gs) in [('val', val_zs, val_ys, val_gs), 68 | ('test', test_zs, test_ys, test_gs)]: 69 | eval_group_accs, eval_worst_acc, eval_worst_group = measure_group_accs(orig_clf, eval_zs, eval_ys, eval_gs, 70 | include_avg_acc=True) 71 | row = {**base_row, 'split': split, 'method': 'orig', **eval_group_accs, 'clf_type': 'orig', 72 | 'worst_acc': eval_worst_acc, 'worst_group': eval_worst_group} 73 | df = df.append(row, ignore_index=True) 74 | args.n_components = -1 75 | given_clf = orig_clf 76 | clf_type = 'orig' 77 | else: 78 | given_clf = None 79 | clf_type = 'logistic' 80 | 81 | if args.env_label_ratio < 1: 82 | rng = np.random.default_rng() 83 | # take a subset of training data 84 | idxes = rng.choice(len(zs), size=int( 85 | len(zs) * args.env_label_ratio), replace=False) 86 | zs, ys, gs, es = zs[idxes], ys[idxes], gs[idxes], es[idxes] 87 | 88 | np.random.seed(args.seed) 89 | # Start ISR 90 | ISR_classes = np.arange( 91 | n_classes) if args.ISR_class is None else [args.ISR_class] 92 | 93 | clf_kwargs = dict(C=args.C, max_iter=args.max_iter, random_state=args.seed) 94 | if args.ISR_version == 'mean': args.d_spu = n_spu_attr - 1 95 | 96 | isr_clf = ISRClassifier(version=args.ISR_version, pca_dim=args.n_components, d_spu=args.d_spu, 97 | clf_type='LogisticRegression', clf_kwargs=clf_kwargs, ) 98 | 99 | isr_clf.fit_data(zs, ys, es, n_classes=n_classes, n_envs=n_spu_attr) 100 | 101 | for ISR_class, ISR_scale in tqdm(list(product(ISR_classes, args.ISR_scales)), desc='ISR iter', leave=False): 102 | 103 | isr_clf.set_params(chosen_class=ISR_class, spu_scale=ISR_scale) 104 | 105 | if args.ISR_version == 'mean': 106 | isr_clf.fit_isr_mean(chosen_class=ISR_class, ) 107 | elif args.ISR_version == 'cov': 108 | isr_clf.fit_isr_cov(chosen_class=ISR_class, ) 109 | else: 110 | raise ValueError('Unknown ISR version') 111 | 112 | isr_clf.fit_clf(zs, ys, given_clf=given_clf, sample_weight=sample_weight) 113 | for (split, eval_zs, eval_ys, eval_gs) in [('val', val_zs, val_ys, val_gs), 114 | ('test', test_zs, test_ys, test_gs)]: 115 | group_accs, worst_acc, worst_group = measure_group_accs( 116 | isr_clf, eval_zs, eval_ys, eval_gs, include_avg_acc=True) 117 | row = {**base_row, 'split': split, 'method': method, 'clf_type': clf_type, 'ISR_class': ISR_class, 118 | 'ISR_scale': ISR_scale, 'd_spu': args.d_spu, **group_accs, 'worst_group': worst_group, 119 | 'worst_acc': worst_acc, 'env_label_ratio': args.env_label_ratio} 120 | if not args.use_orig_clf: 121 | row.update({'C': args.C, 'pca_dim': args.n_components, }) 122 | df = df.append(row, ignore_index=True) 123 | 124 | if args.verbose: 125 | print('Evaluation result') 126 | print(df) 127 | if not args.no_save: 128 | Path(args.save_dir).mkdir(parents=True, 129 | exist_ok=True) # make dir if not exists 130 | save_df(df, os.path.join(args.save_dir, 131 | f'{args.dataset}_results{args.file_suffix}.csv'), subset=None, verbose=args.verbose) 132 | return df 133 | 134 | 135 | def parse_args(args: list = None, specs: dict = None): 136 | argparser = argparse.ArgumentParser() 137 | argparser.add_argument('--root_dir', type=str, 138 | default=DATA_FOLDER) 139 | argparser.add_argument('--algo', type=str, default='ERM', 140 | choices=['ERM', 'groupDRO', 'reweight']) 141 | argparser.add_argument( 142 | '--dataset', type=str, default='CelebA', choices=['CelebA', 'MultiNLI', 'CUB']) 143 | argparser.add_argument('--model_select', type=str, 144 | default='best', choices=['best', 'best_avg_acc', 'last']) 145 | 146 | argparser.add_argument('--seed', type=int, default=0) 147 | argparser.add_argument('--n_components', type=int, default=100) 148 | argparser.add_argument('--C', type=float, default=1) 149 | argparser.add_argument('--ISR_version', type=str, default='mean', choices=['mean', 'cov']) 150 | argparser.add_argument('--ISR_class', type=int, default=None, 151 | help='None means enumerating over all classes.') 152 | argparser.add_argument('--ISR_scales', type=float, 153 | nargs='+', default=[0, 0.5]) 154 | argparser.add_argument('--d_spu', type=int, default=-1) 155 | argparser.add_argument('--save_dir', type=str, default='logs/') 156 | argparser.add_argument('--no_save', default=False, action='store_true') 157 | argparser.add_argument('--verbose', default=False, action='store_true') 158 | 159 | argparser.add_argument('--use_orig_clf', default=False, 160 | action='store_true', help='Original Classifier only') 161 | argparser.add_argument('--env_label_ratio', default=1, 162 | type=float, help='ratio of env label') 163 | argparser.add_argument('--feature_file_prefix', default='', 164 | type=str, help='Prefix of the feature files to load') 165 | argparser.add_argument('--max_iter', default=1000, type=int, 166 | help='Max iterations for the logistic solver') 167 | argparser.add_argument('--file_suffix', default='', type=str, ) 168 | argparser.add_argument('--no_reweight', default=False, action='store_true', 169 | help='No reweighting for ISR classifier on reweight/groupDRO features') 170 | config = argparser.parse_args(args=args) 171 | config.__dict__.update(specs) 172 | return config 173 | 174 | 175 | if __name__ == '__main__': 176 | args = parse_args() 177 | eval_ISR(args) 178 | -------------------------------------------------------------------------------- /real_datasets/isr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import torch 4 | from sklearn import linear_model 5 | from sklearn.decomposition import PCA 6 | from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier 7 | 8 | 9 | def to_numpy(tensor): 10 | if isinstance(tensor, torch.Tensor): 11 | return tensor.detach().cpu().numpy() 12 | elif isinstance(tensor, np.ndarray): 13 | return tensor 14 | else: 15 | raise ValueError(f"Unknown type: {type(tensor)}") 16 | 17 | 18 | def feature_transform(Z: np.ndarray, u: np.ndarray, d_spu: int = 1, scale: float = 0) -> np.ndarray: 19 | scales = np.ones(Z.shape[1]) 20 | scales[:d_spu] = scale 21 | # print(Z.shape, u.shape, scales.shape) 22 | Z = Z @ u @ np.diag(scales) 23 | return Z 24 | 25 | 26 | def check_labels(labels) -> int: 27 | classes = np.unique(labels) 28 | n_classes = len(classes) 29 | assert np.all(classes == np.arange(n_classes)), f"Labels must be 0, 1, 2, ..., {n_classes - 1}" 30 | return n_classes 31 | 32 | 33 | def estimate_means(zs, ys, gs, n_envs, n_classes) -> dict: 34 | Zs, Ys = {}, {} 35 | for e in range(n_envs): 36 | Zs[e] = zs[gs == e] 37 | Ys[e] = ys[gs == e] 38 | Mus = {} 39 | for label in range(n_classes): 40 | means = {} 41 | for e in range(n_envs): 42 | means[e] = np.mean(Zs[e][Ys[e] == label], axis=0) 43 | Mus[label] = np.vstack(list(means.values())) 44 | return Mus 45 | 46 | 47 | def estimate_covs(zs, ys, gs, n_envs, n_classes) -> dict: 48 | Zs, Ys = {}, {} 49 | for e in range(n_envs): 50 | Zs[e] = zs[gs == e] 51 | Ys[e] = ys[gs == e] 52 | Covs = {label: {} for label in range(n_classes)} 53 | for label in range(n_classes): 54 | for e in range(n_envs): 55 | Covs[label][e] = np.cov(Zs[e][Ys[e] == label].T) 56 | return Covs 57 | 58 | 59 | def check_clf(clf, n_classes): 60 | if isinstance(clf, LogisticRegression) or isinstance(clf, RidgeClassifier) or isinstance(clf, SGDClassifier): 61 | if n_classes == 2: 62 | assert 1 <= clf.coef_.shape[0] <= 2, f"The output dim of a binary classifier must be 1 or 2" 63 | else: 64 | assert clf.coef_.shape[0] == n_classes, f"The output dimension of the classifier must be {n_classes}." 65 | return clf 66 | elif isinstance(clf, torch.nn.Linear): 67 | weight = clf.weight.detach().data.cpu().numpy() 68 | bias = clf.bias.detach().data.cpu().numpy() 69 | elif isinstance(clf, dict): 70 | weight, bias = to_numpy(clf['weight']), to_numpy(clf['bias']) 71 | else: 72 | raise ValueError(f"Unknown classifier type: {type(clf)}") 73 | 74 | assert weight.shape[0] == len( 75 | bias), f"The output dimension of weight should match bias: {weight.shape[0]} vs {len(bias)}" 76 | sklearn_clf = LogisticRegression() 77 | sklearn_clf.n_classes = n_classes 78 | sklearn_clf.classes_ = np.arange(n_classes) 79 | sklearn_clf.coef_ = weight 80 | sklearn_clf.intercept_ = bias 81 | assert sklearn_clf.coef_.shape[0] == n_classes, f"The output dimension of the classifier must be {n_classes}." 82 | 83 | return sklearn_clf 84 | 85 | 86 | class ISRClassifier: 87 | default_clf_kwargs = dict(C=1, max_iter=1000, random_state=0) 88 | 89 | def __init__(self, version: str = 'mean', pca_dim: int = -1, d_spu: int = -1, spu_scale: float = 0, 90 | chosen_class=None, clf_type: str = 'LogisticRegression', clf_kwargs: dict = None, 91 | ): 92 | self.version = version 93 | 94 | self.pca_dim = pca_dim 95 | self.d_spu = d_spu 96 | self.spu_scale = spu_scale 97 | 98 | self.clf_kwargs = ISRClassifier.default_clf_kwargs if clf_kwargs is None else clf_kwargs 99 | self.clf_type = clf_type 100 | self.chosen_class = chosen_class 101 | self.Us = {} # stores computed projection matrices 102 | assert self.clf_type in ['LogisticRegression', 'RidgeClassifier', 'SGDClassifier'], \ 103 | f"Unknown classifier type: {self.clf_type}" 104 | 105 | def set_params(self, **params): 106 | for name, val in params.items(): 107 | setattr(self, name, val) 108 | 109 | def fit(self, features, labels, envs, chosen_class: int = None, d_spu: int = None, given_clf=None, 110 | spu_scale: float = None): 111 | 112 | # estimate the stats (mean & cov) and fit a PCA if requested 113 | self.fit_data(features, labels, envs) 114 | 115 | if chosen_class is None: 116 | assert self.chosen_class is not None, "chosen_class must be specified if not given in the constructor" 117 | chosen_class = self.chosen_class 118 | 119 | if self.version == 'mean': 120 | self.fit_isr_mean(chosen_class=chosen_class, d_spu=d_spu) 121 | elif self.version == 'cov': 122 | self.fit_isr_cov(chosen_class=chosen_class, d_spu=d_spu) 123 | else: 124 | raise ValueError(f"Unknown ISR version: {self.version}") 125 | 126 | self.fit_clf(features, labels, given_clf=given_clf, spu_scale=spu_scale) 127 | return self 128 | 129 | def fit_data(self, features, labels, envs, n_classes=None, n_envs=None): 130 | # estimate the mean and covariance of each class per environment 131 | self.n_classes = check_labels(labels) 132 | self.n_envs = check_labels(envs) 133 | if n_classes is not None: assert self.n_classes == n_classes 134 | if n_envs is not None: assert self.n_envs == n_envs 135 | 136 | # fit a PCA if requested 137 | if self.pca_dim > 0: 138 | self.pca = PCA(n_components=self.pca_dim).fit(features) 139 | features = self.pca.transform(features) 140 | else: 141 | self.pca = None 142 | self.means = estimate_means(features, labels, envs, self.n_envs, self.n_classes) 143 | self.covs = estimate_covs(features, labels, envs, self.n_envs, self.n_classes) 144 | return features 145 | 146 | def fit_isr_mean(self, chosen_class: int, d_spu: int = None): 147 | d_spu = self.d_spu if d_spu is None else d_spu 148 | assert d_spu < self.n_envs 149 | assert 0 <= chosen_class < self.n_classes 150 | # We project features into a subspace, and d_spu is the dimension of the subspace 151 | # Wew derive theoretically in the paper that the projection dimension of ISR-Mean 152 | # is at most n_envs-1 153 | if d_spu <= 0: self.d_spu = self.n_envs - 1 154 | 155 | key = ('mean', chosen_class, self.d_spu) 156 | if key in self.Us: 157 | return self.Us[key] 158 | 159 | # Estimate the empirical mean of each class 160 | 161 | # This PCA is just a helper function to obtain the projection matrix 162 | helper_pca = PCA(n_components=self.d_spu).fit(self.means[chosen_class]) 163 | # The projection matrix has dimension (orig_dim, d_spu) 164 | # The SVD is just to pad the projection matrix with columns (the dimensions orthogonal 165 | # to the projection subspace) that makes the matrix a full-rank square matrix. 166 | U_proj = helper_pca.components_.T 167 | 168 | self.U = np.linalg.qr(U_proj, mode='complete')[0].real 169 | # The first d_spu dimensions of U correspond to spurious features, which we will 170 | # discard or reduce. The remaining dimensions are of the invariant feature subspace that 171 | # the algorithm identifies (not necessarily to be the real invariant features). 172 | 173 | # If we want to discard the spurious features, we can simply reduce the first d_spu 174 | # dimensions of U to zeros. However, this may hurt the performance of the algorithm sometimes, 175 | # so we can use the following strategy: rescale of the first d_spu dimensions with 176 | # factor between 0 and 1. This rescale factor is spu_scale that is chosen by the user. 177 | # print('\neig(U):', np.real(np.linalg.eigvals(self.U))) 178 | # print('singular vals:', s) 179 | self.Us[key] = self.U 180 | return self.U 181 | 182 | def fit_isr_cov(self, chosen_class: int, d_spu: int = None): 183 | self.d_spu = d_spu if d_spu is not None else self.d_spu 184 | assert self.d_spu > 0, "d_spu must be provided for ISR-Cov" 185 | # TODO: implement ISR-Cov for n_envs > 2 186 | assert self.n_envs == 2, "ISR-Cov is only implemented for binary env so far" 187 | 188 | key = ('cov', chosen_class, self.d_spu) 189 | if key in self.Us: 190 | return self.Us[key] 191 | 192 | env_pair = [0, 1] 193 | cov_0 = self.covs[chosen_class][env_pair[0]] 194 | cov_1 = self.covs[chosen_class][env_pair[1]] 195 | cov_diff = cov_1 - cov_0 196 | D = cov_diff.shape[0] 197 | 198 | # take square root of cov_diff such that the resulting matrix has non-negative eigenvalues 199 | # the largest d_spu eigenvalues correspond to the spurious feature subspace 200 | # we only need compute the eigenvectors of these d_spu dimensions (save computation cost) 201 | 202 | cov_sqr = cov_diff @ cov_diff 203 | w, U_proj = scipy.linalg.eigh(cov_sqr, subset_by_index=[D - self.d_spu, D - 1]) 204 | assert w.min() >= 0 205 | order = np.flip(np.argsort(w).flatten()) 206 | U_proj = U_proj[:, order] 207 | 208 | # trivially call SVD to fill the rest columns the (D-d_spu)-dim subspace orthogonal to the 209 | # spurious feature subspace 210 | 211 | self.U = np.linalg.svd(U_proj, full_matrices=True)[0] 212 | 213 | self.Us[key] = self.U 214 | 215 | return self.U 216 | 217 | def fit_clf(self, features=None, labels=None, given_clf=None, sample_weight=None): 218 | if given_clf is None: 219 | assert features is not None and labels is not None 220 | self.clf = getattr(linear_model, self.clf_type)(**self.clf_kwargs) 221 | features = self.transform(features, ) 222 | self.clf.fit(features, labels, sample_weight=sample_weight) 223 | else: 224 | self.clf = check_clf(given_clf, n_classes=self.n_classes) 225 | self.clf.coef_ = self.clf.coef_ @ self.U 226 | return self.clf 227 | 228 | def transform(self, features, ): 229 | if self.pca is not None: 230 | features = self.pca.transform(features) 231 | new_zs = feature_transform(features, u=self.U, 232 | d_spu=self.d_spu, scale=self.spu_scale) 233 | return new_zs 234 | 235 | def predict(self, features): 236 | zs = self.transform(features) 237 | return self.clf.predict(zs) 238 | 239 | def score(self, features, labels): 240 | zs = self.transform(features) 241 | return self.clf.score(zs, labels) 242 | 243 | def fit_transform(self, features, labels, envs, chosen_class, given_clf=None): 244 | self.fit(features, labels, envs, chosen_class, given_clf) 245 | return self.transform(features) 246 | -------------------------------------------------------------------------------- /real_datasets/launch_parse.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from configs import LOG_FOLDER, get_parse_command 6 | 7 | dataset = 'CUB' 8 | gpu_idx = 0 # could be None if you want to use cpu 9 | 10 | # Suppose we already trained the models for seeds 0, 1, 2, 3, 4, 11 | # then we can parse these traind models by choosing log_seeds = np.arange(5) 12 | train_log_seeds = np.arange(10) 13 | 14 | # The training algorithms we want to parse 15 | algos = ['ERM', 'reweight', 'groupDRO'] 16 | 17 | # load checkpoint with a model selection rule 18 | # best: take the model at the epoch of largest worst-group validation accuracy 19 | # best_avg_acc: take the model at the epoch of largest average-group validation accuracy 20 | # last: take the trained model at the last epoch 21 | model_selects = ['best', ] 22 | 23 | log_dir = LOG_FOLDER 24 | 25 | command = get_parse_command(dataset=dataset, algos=algos, model_selects=model_selects, 26 | train_log_seeds=train_log_seeds, log_dir=log_dir, gpu_idx=gpu_idx, 27 | parse_script='parse_features.py') 28 | print('Command:', command) 29 | os.system(command) 30 | -------------------------------------------------------------------------------- /real_datasets/launch_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import product 3 | from tqdm import tqdm 4 | from configs import get_train_command 5 | 6 | gpu_idx = 0 # could be None if you want to use cpu 7 | 8 | algos = ['ERM','reweight','groupDRO'] 9 | dataset = 'MultiNLI' # could be 'CUB' (i.e., Waterbirds), 'CelebA' or 'MultiNLI' 10 | 11 | # can add some suffix to the algo name to flag the version, 12 | # e.g., with algo_suffix = "-my_version", the algo name becomes "ERM-my_version" 13 | algo_suffix = "" 14 | 15 | seeds = range(10) 16 | for seed, algo in tqdm(list(product(seeds, algos)), desc='Experiments'): 17 | command = get_train_command(dataset=dataset, algo=algo, gpu_idx=gpu_idx, seed=seed, 18 | save_best=True, save_last=True) 19 | print('Command:', command) 20 | os.system(command) 21 | -------------------------------------------------------------------------------- /real_datasets/parse_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from itertools import product 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torchvision 10 | from tqdm.auto import tqdm 11 | 12 | from configs.model_config import model_attributes 13 | from data import dataset_attributes, shift_types, prepare_data 14 | from utils.train_utils import check_args 15 | from utils.train_utils import set_seed 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | # Settings 20 | parser.add_argument('-d', '--dataset', 21 | choices=dataset_attributes.keys(), required=True) 22 | parser.add_argument('-s', '--shift_type', 23 | choices=shift_types, default='confounder') 24 | # Confounders 25 | parser.add_argument('-t', '--target_name') 26 | parser.add_argument('-c', '--confounder_names', nargs='+') 27 | # Resume? 28 | parser.add_argument('--resume', default=False, action='store_true') 29 | # Label shifts 30 | parser.add_argument('--minority_fraction', type=float) 31 | parser.add_argument('--imbalance_ratio', type=float) 32 | # Data 33 | parser.add_argument('--fraction', type=float, default=1.0) 34 | parser.add_argument('--root_dir', default=None) 35 | parser.add_argument('--reweight_groups', action='store_true', default=False) 36 | parser.add_argument('--augment_data', action='store_true', default=False) 37 | parser.add_argument('--val_fraction', type=float, default=0.1) 38 | # Objective 39 | parser.add_argument('--robust', default=False, action='store_true') 40 | parser.add_argument('--alpha', type=float, default=0.2) 41 | parser.add_argument('--generalization_adjustment', default="0.0") 42 | parser.add_argument('--automatic_adjustment', 43 | default=False, action='store_true') 44 | parser.add_argument('--robust_step_size', default=0.01, type=float) 45 | parser.add_argument('--use_normalized_loss', 46 | default=False, action='store_true') 47 | parser.add_argument('--btl', default=False, action='store_true') 48 | parser.add_argument('--hinge', default=False, action='store_true') 49 | 50 | # Model 51 | parser.add_argument( 52 | '--model', 53 | choices=model_attributes.keys(), 54 | default='resnet50') 55 | parser.add_argument('--train_from_scratch', action='store_true', default=False) 56 | 57 | # Optimization 58 | parser.add_argument('--n_epochs', type=int, default=4) 59 | parser.add_argument('--batch_size', type=int, default=32) 60 | parser.add_argument('--lr', type=float, default=0.001) 61 | parser.add_argument('--scheduler', action='store_true', default=False) 62 | parser.add_argument('--weight_decay', type=float, default=5e-5) 63 | parser.add_argument('--gamma', type=float, default=0.1) 64 | parser.add_argument('--minimum_variational_weight', type=float, default=0) 65 | # Misc 66 | parser.add_argument('--seed', type=int, default=0) 67 | parser.add_argument('--show_progress', default=False, action='store_true') 68 | parser.add_argument('--log_dir', default='/data/common/inv-feature/logs/') 69 | parser.add_argument('--log_every', default=1e8, type=int) 70 | parser.add_argument('--save_step', type=int, default=1e8) 71 | parser.add_argument('--save_best', action='store_true', default=False) 72 | parser.add_argument('--save_last', action='store_true', default=False) 73 | 74 | parser.add_argument('--parse_algos', nargs='+', 75 | default=['ERM', 'groupDRO', 'reweight']) 76 | parser.add_argument('--parse_model_selects', nargs='+', 77 | default=['best', 'best_avg_acc', 'last'], 78 | help='best is based on worst-group validation accuracy.') 79 | parser.add_argument('--parse_seeds', nargs='+', 80 | default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 81 | parser.add_argument( 82 | '--parse_dir', default='/data/common/inv-feature/logs/', type=str) 83 | 84 | args = parser.parse_args() 85 | check_args(args) 86 | if args.model == 'bert': 87 | args.max_grad_norm = 1.0 88 | args.adam_epsilon = 1e-8 89 | args.warmup_steps = 0 90 | 91 | if args.robust: 92 | algo = 'groupDRO' 93 | elif args.reweight_groups: 94 | algo = 'reweight' 95 | else: 96 | algo = 'ERM' 97 | 98 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 99 | set_seed(args.seed) 100 | # Data 101 | # Test data for label_shift_step is not implemented yet 102 | test_data = None 103 | test_loader = None 104 | if args.shift_type == 'confounder': 105 | train_data, val_data, test_data = prepare_data(args, train=True) 106 | elif args.shift_type == 'label_shift_step': 107 | train_data, val_data = prepare_data(args, train=True) 108 | 109 | loader_kwargs = {'batch_size': args.batch_size, 110 | 'num_workers': 4, 'pin_memory': True} 111 | train_loader = train_data.get_loader( 112 | train=True, reweight_groups=args.reweight_groups, **loader_kwargs) 113 | val_loader = val_data.get_loader( 114 | train=False, reweight_groups=None, **loader_kwargs) 115 | if test_data is not None: 116 | test_loader = test_data.get_loader( 117 | train=False, reweight_groups=None, **loader_kwargs) 118 | 119 | data = {} 120 | data['train_loader'] = train_loader 121 | data['val_loader'] = val_loader 122 | data['test_loader'] = test_loader 123 | data['train_data'] = train_data 124 | data['val_data'] = val_data 125 | data['test_data'] = test_data 126 | n_classes = train_data.n_classes 127 | 128 | # Initialize model 129 | pretrained = not args.train_from_scratch 130 | 131 | if model_attributes[args.model]['feature_type'] in ('precomputed', 'raw_flattened'): 132 | assert pretrained 133 | # Load precomputed features 134 | d = train_data.input_size()[0] 135 | model = nn.Linear(d, n_classes) 136 | model.has_aux_logits = False 137 | elif args.model == 'resnet50': 138 | model = torchvision.models.resnet50(pretrained=pretrained) 139 | d = model.fc.in_features 140 | model.fc = nn.Linear(d, n_classes) 141 | elif args.model == 'resnet34': 142 | model = torchvision.models.resnet34(pretrained=pretrained) 143 | d = model.fc.in_features 144 | model.fc = nn.Linear(d, n_classes) 145 | elif args.model == 'wideresnet50': 146 | model = torchvision.models.wide_resnet50_2(pretrained=pretrained) 147 | d = model.fc.in_features 148 | model.fc = nn.Linear(d, n_classes) 149 | elif args.model == 'bert': 150 | assert args.dataset == 'MultiNLI' 151 | 152 | from transformers import BertConfig, BertForSequenceClassification 153 | 154 | config_class = BertConfig 155 | model_class = BertForSequenceClassification 156 | 157 | config = config_class.from_pretrained( 158 | 'bert-base-uncased', 159 | num_labels=3, 160 | finetuning_task='mnli') 161 | model = model_class.from_pretrained( 162 | 'bert-base-uncased', 163 | from_tf=False, 164 | config=config) 165 | else: 166 | raise ValueError('Model not recognized.') 167 | 168 | model = model.to(device) 169 | 170 | if not args.model.startswith('bert'): 171 | encoder = torch.nn.Sequential( 172 | *(list(model.children())[:-1] + [torch.nn.Flatten()])) 173 | output_layer = model.fc 174 | 175 | 176 | def process_batch(model, x, y=None, g=None, bert=True): 177 | if bert: 178 | input_ids = x[:, :, 0] 179 | input_masks = x[:, :, 1] 180 | segment_ids = x[:, :, 2] 181 | outputs = model.bert( 182 | input_ids=input_ids, 183 | attention_mask=input_masks, 184 | token_type_ids=segment_ids, 185 | ) 186 | pooled_output = outputs[1] 187 | logits = model.classifier(pooled_output) 188 | result = {'feature': pooled_output.detach().cpu().numpy(), 189 | 'pred': np.argmax(logits.detach().cpu().numpy(), axis=1), 190 | } 191 | else: 192 | features = encoder(x) 193 | logits = output_layer(features) 194 | result = {'feature': features.detach().cpu().numpy(), 195 | 'pred': np.argmax(logits.detach().cpu().numpy(), axis=1), 196 | } 197 | if y is not None: 198 | result['label'] = y.detach().cpu().numpy() 199 | if g is not None: 200 | result['group'] = g.detach().cpu().numpy() 201 | return result 202 | 203 | 204 | for algo, model_select, seed in tqdm(list(product(args.parse_algos, args.parse_model_selects, args.parse_seeds)), 205 | desc='Iter'): 206 | print('Current iter:', algo, model_select, seed) 207 | save_dir = f'{args.parse_dir}/{args.dataset}/{algo}/s{seed}/' 208 | if not os.path.exists(save_dir): 209 | continue 210 | model.load_state_dict(torch.load(save_dir + f'/{model_select}_model.pth', 211 | map_location='cpu').state_dict()) 212 | 213 | model.eval() 214 | 215 | # save the last linear layer (classifier head) 216 | if 'bert' in type(model).__name__.lower(): 217 | weight = model.classifier.weight.detach().cpu().numpy() 218 | bias = model.classifier.bias.detach().cpu().umpy() 219 | elif 'resnet' in type(model).__name__.lower(): 220 | weight = model.fc.weight.detach().cpu().numpy() 221 | bias = model.fc.bias.detach().cpu().numpy() 222 | else: 223 | raise ValueError(f'Unknown model type: {type(model)}') 224 | pickle.dump({'weight': weight, 'bias': bias}, open( 225 | save_dir + f'/{model_select}_clf.p', 'wb')) 226 | 227 | # save parsed features 228 | for split, loader in zip(['train', 'val', 'test'], [train_loader, val_loader, test_loader]): 229 | results = [] 230 | fname = model_select + '_' + f'{split}_data.p' 231 | if os.path.exists(save_dir + '/' + fname): 232 | continue 233 | with torch.set_grad_enabled(False): 234 | for batch_idx, batch in enumerate(tqdm(loader)): 235 | batch = tuple(t.to(device) for t in batch) 236 | x = batch[0] 237 | y = batch[1] 238 | g = batch[2] 239 | if args.model.startswith("bert"): 240 | result = process_batch(model, x, y, g, bert=True) 241 | else: 242 | result = process_batch(model, x, y, g, bert=False) 243 | results.append(result) 244 | parsed_data = {} 245 | for key in results[0].keys(): 246 | parsed_data[key] = np.concatenate( 247 | [result[key] for result in results]) 248 | 249 | pickle.dump(parsed_data, open(save_dir + '/' + fname, 'wb')) 250 | 251 | del results 252 | del parsed_data 253 | -------------------------------------------------------------------------------- /real_datasets/run_expt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | 9 | import configs 10 | from configs.model_config import model_attributes 11 | from data import dataset_attributes, shift_types, prepare_data, log_data 12 | from train import train 13 | from utils.train_utils import set_seed, Logger, CSVBatchLogger, log_args 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | 19 | # Settings 20 | parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True) 21 | parser.add_argument('-s', '--shift_type', choices=shift_types, required=True) 22 | # Confounders 23 | parser.add_argument('-t', '--target_name') 24 | parser.add_argument('-c', '--confounder_names', nargs='+') 25 | # Resume? 26 | parser.add_argument('--resume', default=False, action='store_true') 27 | # Label shifts 28 | parser.add_argument('--minority_fraction', type=float) 29 | parser.add_argument('--imbalance_ratio', type=float) 30 | # Data 31 | parser.add_argument('--fraction', type=float, default=1.0) 32 | parser.add_argument('--root_dir', default=None) 33 | parser.add_argument('--reweight_groups', action='store_true', default=False) 34 | parser.add_argument('--augment_data', action='store_true', default=False) 35 | parser.add_argument('--val_fraction', type=float, default=0.1) 36 | # Objective 37 | parser.add_argument('--robust', default=False, action='store_true') 38 | parser.add_argument('--alpha', type=float, default=0.2) 39 | parser.add_argument('--generalization_adjustment', default="0.0") 40 | parser.add_argument('--automatic_adjustment', default=False, action='store_true') 41 | parser.add_argument('--robust_step_size', default=0.01, type=float) 42 | parser.add_argument('--use_normalized_loss', default=False, action='store_true') 43 | parser.add_argument('--btl', default=False, action='store_true') 44 | parser.add_argument('--hinge', default=False, action='store_true') 45 | 46 | # Model 47 | parser.add_argument( 48 | '--model', 49 | choices=model_attributes.keys(), 50 | default='resnet50') 51 | parser.add_argument('--train_from_scratch', action='store_true', default=False) 52 | 53 | # Optimization 54 | parser.add_argument('--n_epochs', type=int, default=4) 55 | parser.add_argument('--batch_size', type=int, default=32) 56 | parser.add_argument('--lr', type=float, default=0.001) 57 | parser.add_argument('--scheduler', action='store_true', default=False) 58 | parser.add_argument('--weight_decay', type=float, default=5e-5) 59 | parser.add_argument('--gamma', type=float, default=0.1) 60 | parser.add_argument('--minimum_variational_weight', type=float, default=0) 61 | # Misc 62 | parser.add_argument('--seed', type=int, default=0) 63 | parser.add_argument('--show_progress', default=False, action='store_true') 64 | parser.add_argument('--log_dir', default=configs.LOG_FOLDER) 65 | parser.add_argument('--log_every', default=1e8, type=int) 66 | parser.add_argument('--save_step', type=int, default=1e8) 67 | parser.add_argument('--save_best', action='store_true', default=False) 68 | parser.add_argument('--save_last', action='store_true', default=False) 69 | parser.add_argument('--algo_suffix', type=str, default='', help='The suffix of log folder name') 70 | args = parser.parse_args() 71 | check_args(args) 72 | 73 | # BERT-specific configs copied over from run_glue.py 74 | if args.model == 'bert': 75 | args.max_grad_norm = 1.0 76 | args.adam_epsilon = 1e-8 77 | args.warmup_steps = 0 78 | 79 | if args.robust: 80 | algo = 'groupDRO' 81 | elif args.reweight_groups: 82 | algo = 'reweight' 83 | else: 84 | algo = 'ERM' 85 | 86 | args.log_dir = os.path.join(args.log_dir, args.dataset, algo + args.algo_suffix, f's{args.seed}') 87 | 88 | if os.path.exists(args.log_dir) and args.resume: 89 | resume = True 90 | mode = 'a' 91 | else: 92 | resume = False 93 | mode = 'w' 94 | 95 | ## Initialize logs 96 | if not os.path.exists(args.log_dir): 97 | os.makedirs(args.log_dir) 98 | 99 | logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode) 100 | # Record args 101 | log_args(args, logger) 102 | 103 | set_seed(args.seed) 104 | 105 | # Data 106 | # Test data for label_shift_step is not implemented yet 107 | test_data = None 108 | test_loader = None 109 | if args.shift_type == 'confounder': 110 | train_data, val_data, test_data = prepare_data(args, train=True) 111 | elif args.shift_type == 'label_shift_step': 112 | train_data, val_data = prepare_data(args, train=True) 113 | 114 | loader_kwargs = {'batch_size': args.batch_size, 'num_workers': 4, 'pin_memory': True} 115 | train_loader = train_data.get_loader(train=True, reweight_groups=args.reweight_groups, **loader_kwargs) 116 | val_loader = val_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) 117 | if test_data is not None: 118 | test_loader = test_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) 119 | 120 | data = {} 121 | data['train_loader'] = train_loader 122 | data['val_loader'] = val_loader 123 | data['test_loader'] = test_loader 124 | data['train_data'] = train_data 125 | data['val_data'] = val_data 126 | data['test_data'] = test_data 127 | n_classes = train_data.n_classes 128 | 129 | log_data(data, logger) 130 | 131 | ## Initialize model 132 | pretrained = not args.train_from_scratch 133 | if resume: 134 | model = torch.load(os.path.join(args.log_dir, 'last_model.pth')) 135 | d = train_data.input_size()[0] 136 | elif model_attributes[args.model]['feature_type'] in ('precomputed', 'raw_flattened'): 137 | assert pretrained 138 | # Load precomputed features 139 | d = train_data.input_size()[0] 140 | model = nn.Linear(d, n_classes) 141 | model.has_aux_logits = False 142 | elif args.model == 'resnet50': 143 | model = torchvision.models.resnet50(pretrained=pretrained) 144 | d = model.fc.in_features 145 | model.fc = nn.Linear(d, n_classes) 146 | elif args.model == 'resnet34': 147 | model = torchvision.models.resnet34(pretrained=pretrained) 148 | d = model.fc.in_features 149 | model.fc = nn.Linear(d, n_classes) 150 | elif args.model == 'wideresnet50': 151 | model = torchvision.models.wide_resnet50_2(pretrained=pretrained) 152 | d = model.fc.in_features 153 | model.fc = nn.Linear(d, n_classes) 154 | elif args.model == 'bert': 155 | assert args.dataset == 'MultiNLI' 156 | 157 | from transformers import BertConfig, BertForSequenceClassification 158 | config_class = BertConfig 159 | model_class = BertForSequenceClassification 160 | 161 | config = config_class.from_pretrained( 162 | 'bert-base-uncased', 163 | num_labels=3, 164 | finetuning_task='mnli') 165 | model = model_class.from_pretrained( 166 | 'bert-base-uncased', 167 | from_tf=False, 168 | config=config) 169 | else: 170 | raise ValueError('Model not recognized.') 171 | 172 | logger.flush() 173 | 174 | ## Define the objective 175 | if args.hinge: 176 | assert args.dataset in ['CelebA', 'CUB'] # Only supports binary 177 | 178 | def hinge_loss(yhat, y): 179 | # The torch loss takes in three arguments so we need to split yhat 180 | # It also expects classes in {+1.0, -1.0} whereas by default we give them in {0, 1} 181 | # Furthermore, if y = 1 it expects the first input to be higher instead of the second, 182 | # so we need to swap yhat[:, 0] and yhat[:, 1]... 183 | torch_loss = torch.nn.MarginRankingLoss(margin=1.0, reduction='none') 184 | y = (y.float() * 2.0) - 1.0 185 | return torch_loss(yhat[:, 1], yhat[:, 0], y) 186 | 187 | criterion = hinge_loss 188 | else: 189 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 190 | 191 | if resume: 192 | df = pd.read_csv(os.path.join(args.log_dir, 'test.csv')) 193 | epoch_offset = df.loc[len(df) - 1, 'epoch'] + 1 194 | logger.write(f'starting from epoch {epoch_offset}') 195 | else: 196 | epoch_offset = 0 197 | train_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'train.csv'), train_data.n_groups, mode=mode) 198 | val_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'val.csv'), train_data.n_groups, mode=mode) 199 | test_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'test.csv'), train_data.n_groups, mode=mode) 200 | 201 | train(model, criterion, data, logger, train_csv_logger, val_csv_logger, test_csv_logger, args, 202 | epoch_offset=epoch_offset) 203 | 204 | train_csv_logger.close() 205 | val_csv_logger.close() 206 | test_csv_logger.close() 207 | 208 | 209 | def check_args(args): 210 | if args.shift_type == 'confounder': 211 | assert args.confounder_names 212 | assert args.target_name 213 | elif args.shift_type.startswith('label_shift'): 214 | assert args.minority_fraction 215 | assert args.imbalance_ratio 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /real_datasets/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.optim import AdamW 6 | from tqdm import tqdm 7 | from transformers import get_linear_schedule_with_warmup 8 | 9 | from utils.loss_utils import LossComputer 10 | 11 | 12 | def run_epoch(epoch, model, optimizer, loader, loss_computer, logger, csv_logger, args, 13 | is_training, show_progress=False, log_every=50, scheduler=None): 14 | """ 15 | scheduler is only used inside this function if model is bert. 16 | """ 17 | 18 | if is_training: 19 | model.train() 20 | if args.model == 'bert': 21 | model.zero_grad() 22 | else: 23 | model.eval() 24 | 25 | if show_progress: 26 | prog_bar_loader = tqdm(loader) 27 | else: 28 | prog_bar_loader = loader 29 | 30 | with torch.set_grad_enabled(is_training): 31 | for batch_idx, batch in enumerate(prog_bar_loader): 32 | 33 | batch = tuple(t.cuda() for t in batch) 34 | x = batch[0] 35 | y = batch[1] 36 | g = batch[2] 37 | if args.model == 'bert': 38 | input_ids = x[:, :, 0] 39 | input_masks = x[:, :, 1] 40 | segment_ids = x[:, :, 2] 41 | outputs = model( 42 | input_ids=input_ids, 43 | attention_mask=input_masks, 44 | token_type_ids=segment_ids, 45 | labels=y 46 | )[1] # [1] returns logits 47 | else: 48 | outputs = model(x) 49 | 50 | loss_main = loss_computer.loss(outputs, y, g, is_training) 51 | 52 | if is_training: 53 | if args.model == 'bert': 54 | loss_main.backward() 55 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 56 | optimizer.step() 57 | scheduler.step() 58 | model.zero_grad() 59 | else: 60 | optimizer.zero_grad() 61 | loss_main.backward() 62 | optimizer.step() 63 | 64 | if is_training and (batch_idx + 1) % log_every == 0: 65 | csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args)) 66 | csv_logger.flush() 67 | loss_computer.log_stats(logger, is_training) 68 | loss_computer.reset_stats() 69 | 70 | if (not is_training) or loss_computer.batch_count > 0: 71 | csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args)) 72 | csv_logger.flush() 73 | loss_computer.log_stats(logger, is_training) 74 | if is_training: 75 | loss_computer.reset_stats() 76 | 77 | 78 | def train(model, criterion, dataset, 79 | logger, train_csv_logger, val_csv_logger, test_csv_logger, 80 | args, epoch_offset): 81 | model = model.cuda() 82 | 83 | # process generalization adjustment stuff 84 | adjustments = [float(c) for c in args.generalization_adjustment.split(',')] 85 | assert len(adjustments) in (1, dataset['train_data'].n_groups) 86 | if len(adjustments) == 1: 87 | adjustments = np.array(adjustments * dataset['train_data'].n_groups) 88 | else: 89 | adjustments = np.array(adjustments) 90 | 91 | train_loss_computer = LossComputer( 92 | criterion, 93 | is_robust=args.robust, 94 | dataset=dataset['train_data'], 95 | alpha=args.alpha, 96 | gamma=args.gamma, 97 | adj=adjustments, 98 | step_size=args.robust_step_size, 99 | normalize_loss=args.use_normalized_loss, 100 | btl=args.btl, 101 | min_var_weight=args.minimum_variational_weight) 102 | 103 | # BERT uses its own scheduler and optimizer 104 | if args.model == 'bert': 105 | no_decay = ['bias', 'LayerNorm.weight'] 106 | optimizer_grouped_parameters = [ 107 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 108 | 'weight_decay': args.weight_decay}, 109 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 110 | ] 111 | optimizer = AdamW( 112 | optimizer_grouped_parameters, 113 | lr=args.lr, 114 | eps=args.adam_epsilon) 115 | t_total = len(dataset['train_loader']) * args.n_epochs 116 | print(f'\nt_total is {t_total}\n') 117 | scheduler = get_linear_schedule_with_warmup( 118 | optimizer, 119 | num_warmup_steps=args.warmup_steps, 120 | num_training_steps=t_total) 121 | else: 122 | optimizer = torch.optim.SGD( 123 | filter(lambda p: p.requires_grad, model.parameters()), 124 | lr=args.lr, 125 | momentum=0.9, 126 | weight_decay=args.weight_decay) 127 | if args.scheduler: 128 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 129 | optimizer, 130 | 'min', 131 | factor=0.1, 132 | patience=5, 133 | threshold=0.0001, 134 | min_lr=0, 135 | eps=1e-08) 136 | else: 137 | scheduler = None 138 | 139 | best_val_acc = 0 140 | best_avg_val_acc = 0 141 | for epoch in range(epoch_offset, epoch_offset + args.n_epochs): 142 | logger.write('\nEpoch [%d]:\n' % epoch) 143 | logger.write(f'Training:\n') 144 | run_epoch( 145 | epoch, model, optimizer, 146 | dataset['train_loader'], 147 | train_loss_computer, 148 | logger, train_csv_logger, args, 149 | is_training=True, 150 | show_progress=args.show_progress, 151 | log_every=args.log_every, 152 | scheduler=scheduler) 153 | 154 | logger.write(f'\nValidation:\n') 155 | val_loss_computer = LossComputer( 156 | criterion, 157 | is_robust=args.robust, 158 | dataset=dataset['val_data'], 159 | step_size=args.robust_step_size, 160 | alpha=args.alpha) 161 | run_epoch( 162 | epoch, model, optimizer, 163 | dataset['val_loader'], 164 | val_loss_computer, 165 | logger, val_csv_logger, args, 166 | is_training=False) 167 | 168 | # Test set; don't print to avoid peeking 169 | if dataset['test_data'] is not None: 170 | test_loss_computer = LossComputer( 171 | criterion, 172 | is_robust=args.robust, 173 | dataset=dataset['test_data'], 174 | step_size=args.robust_step_size, 175 | alpha=args.alpha) 176 | run_epoch( 177 | epoch, model, optimizer, 178 | dataset['test_loader'], 179 | test_loss_computer, 180 | None, test_csv_logger, args, 181 | is_training=False) 182 | 183 | # Inspect learning rates 184 | if (epoch + 1) % 1 == 0: 185 | for param_group in optimizer.param_groups: 186 | curr_lr = param_group['lr'] 187 | logger.write('Current lr: %f\n' % curr_lr) 188 | 189 | if args.scheduler and args.model != 'bert': 190 | if args.robust: 191 | val_loss, _ = val_loss_computer.compute_robust_loss_greedy(val_loss_computer.avg_group_loss, 192 | val_loss_computer.avg_group_loss) 193 | else: 194 | val_loss = val_loss_computer.avg_actual_loss 195 | scheduler.step(val_loss) # scheduler step to update lr at the end of epoch 196 | 197 | # if epoch % args.save_step == 0: 198 | # torch.save(model, os.path.join(args.log_dir, '%d_model.pth' % epoch)) 199 | 200 | if args.save_last: 201 | torch.save(model, os.path.join(args.log_dir, 'last_model.pth')) 202 | 203 | if args.save_best: 204 | 205 | curr_val_acc = min(val_loss_computer.avg_group_acc) 206 | curr_avg_val_acc = val_loss_computer.avg_acc 207 | 208 | logger.write(f'Current average validation accuracy: {curr_avg_val_acc}\n') 209 | logger.write(f'Current worst-group validation accuracy: {curr_val_acc}\n') 210 | if curr_val_acc > best_val_acc: 211 | best_val_acc = curr_val_acc 212 | torch.save(model, os.path.join(args.log_dir, 'best_model.pth')) 213 | logger.write(f'Best worst-group model saved at epoch {epoch}\n') 214 | if curr_avg_val_acc > best_avg_val_acc: 215 | best_avg_val_acc = curr_avg_val_acc 216 | torch.save(model, os.path.join(args.log_dir, 'best_avg_acc_model.pth')) 217 | logger.write(f'Best average-accuracy model saved at epoch {epoch}\n') 218 | 219 | if args.automatic_adjustment: 220 | gen_gap = val_loss_computer.avg_group_loss - train_loss_computer.exp_avg_loss 221 | adjustments = gen_gap * torch.sqrt(train_loss_computer.group_counts) 222 | train_loss_computer.adj = adjustments 223 | logger.write('Adjustments updated\n') 224 | for group_idx in range(train_loss_computer.n_groups): 225 | logger.write( 226 | f' {train_loss_computer.get_group_name(group_idx)}:\t' 227 | f'adj = {train_loss_computer.adj[group_idx]:.3f}\n') 228 | logger.write('\n') 229 | -------------------------------------------------------------------------------- /real_datasets/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def extract_data(data, transform=None, ): 9 | zs, ys, preds, gs = data['feature'], data['label'], data.get( 10 | 'pred', None), data['group'] 11 | if transform is not None: 12 | zs = transform(zs) 13 | 14 | return zs, ys, gs, preds 15 | 16 | 17 | def check_row_exist_in_df(row, df=None, df_path=None): 18 | # check if a row exists in a dataframe 19 | if df is None: 20 | if not os.path.exists(df_path): return False 21 | df = pd.read_csv(df_path) 22 | arrays = [] 23 | special_cols = [] 24 | for col, val in row.items(): 25 | if isinstance(val, list) or isinstance(val, tuple) or isinstance(val, np.ndarray): 26 | special_cols.append(col) 27 | continue 28 | arrays.append(df[col].values == val) 29 | exist_rows = np.prod(arrays, axis=0) 30 | n_exist = np.sum(exist_rows) 31 | if n_exist == 0 or len(special_cols) == 0: 32 | return n_exist > 0 33 | else: 34 | for col in special_cols: 35 | elements = np.array(row[col]) 36 | if np.isin(elements, df.loc[exist_rows][col].values).mean() < 1: 37 | return False 38 | else: 39 | continue 40 | return True 41 | 42 | 43 | def save_df(df, save_path, subset=None, verbose=False, drop_duplicates=True): 44 | if os.path.exists(save_path): 45 | orig_df = pd.read_csv(save_path) 46 | df = pd.concat([orig_df, df]) 47 | if drop_duplicates: 48 | df = df.drop_duplicates(subset=subset, 49 | keep='last', 50 | ignore_index=True) 51 | df.to_csv(save_path, index=False) 52 | if verbose: 53 | print("Saved to", save_path) 54 | 55 | 56 | def measure_group_accs(clf, zs, ys, gs, include_avg_acc=True): 57 | accs = {} 58 | if include_avg_acc: 59 | accs['avg_acc'] = clf.score(zs, ys) 60 | worst_group = None 61 | worst_acc = np.inf 62 | for g in np.unique(gs): 63 | g_idx = gs == g 64 | acc = clf.score(zs[g_idx], ys[g_idx]) 65 | accs[f'acc-{int(g)}'] = acc 66 | if acc < worst_acc: 67 | worst_group = g 68 | worst_acc = acc 69 | return accs, worst_acc, worst_group 70 | 71 | 72 | def group2env(groups, n_envs): 73 | # if the group is defined by id_class*n_envs+id_env, 74 | # this function can convert it to id_env 75 | return groups % n_envs 76 | 77 | 78 | def load_data(args): 79 | log_dir = os.path.join( 80 | args.root_dir, f'{args.dataset}/{args.algo}/s{args.seed}/') 81 | prefix = args.feature_file_prefix + args.model_select + '_' 82 | if not os.path.exists(log_dir + f'/{prefix}train_data.p'): 83 | raise ValueError(f"No parsed {prefix}train_data.p at {log_dir}") 84 | 85 | train_data = pickle.load(open(log_dir + f'/{prefix}train_data.p', 'rb')) 86 | val_data = pickle.load(open(log_dir + f'/{prefix}val_data.p', 'rb')) 87 | test_data = pickle.load(open(log_dir + f'/{prefix}test_data.p', 'rb')) 88 | return train_data, val_data, test_data, log_dir 89 | 90 | def update_args(args,specs:dict): 91 | for k,v in specs.items(): 92 | if hasattr(args,k): 93 | setattr(args,k,v) 94 | else: 95 | raise ValueError(f"No attribute {k} in args") 96 | return args 97 | -------------------------------------------------------------------------------- /real_datasets/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LossComputer: 5 | def __init__(self, criterion, is_robust, dataset, alpha=None, gamma=0.1, adj=None, min_var_weight=0, step_size=0.01, 6 | normalize_loss=False, btl=False): 7 | self.criterion = criterion 8 | self.is_robust = is_robust 9 | self.gamma = gamma 10 | self.alpha = alpha 11 | self.min_var_weight = min_var_weight 12 | self.step_size = step_size 13 | self.normalize_loss = normalize_loss 14 | self.btl = btl 15 | 16 | self.n_groups = dataset.n_groups 17 | self.group_counts = dataset.group_counts().cuda() 18 | self.group_frac = self.group_counts / self.group_counts.sum() 19 | self.group_str = dataset.group_str 20 | 21 | if adj is not None: 22 | self.adj = torch.from_numpy(adj).float().cuda() 23 | else: 24 | self.adj = torch.zeros(self.n_groups).float().cuda() 25 | 26 | if is_robust: 27 | assert alpha, 'alpha must be specified' 28 | 29 | # quantities maintained throughout training 30 | self.adv_probs = torch.ones(self.n_groups).cuda() / self.n_groups 31 | self.exp_avg_loss = torch.zeros(self.n_groups).cuda() 32 | self.exp_avg_initialized = torch.zeros(self.n_groups).byte().cuda() 33 | 34 | self.reset_stats() 35 | 36 | def loss(self, yhat, y, group_idx=None, is_training=False): 37 | # compute per-sample and per-group losses 38 | per_sample_losses = self.criterion(yhat, y) 39 | group_loss, group_count = self.compute_group_avg(per_sample_losses, group_idx) 40 | group_acc, group_count = self.compute_group_avg((torch.argmax(yhat, 1) == y).float(), group_idx) 41 | 42 | # update historical losses 43 | self.update_exp_avg_loss(group_loss, group_count) 44 | 45 | # compute overall loss 46 | if self.is_robust and not self.btl: 47 | actual_loss, weights = self.compute_robust_loss(group_loss, group_count) 48 | elif self.is_robust and self.btl: 49 | actual_loss, weights = self.compute_robust_loss_btl(group_loss, group_count) 50 | else: 51 | actual_loss = per_sample_losses.mean() 52 | weights = None 53 | 54 | # update stats 55 | self.update_stats(actual_loss, group_loss, group_acc, group_count, weights) 56 | 57 | return actual_loss 58 | 59 | def compute_robust_loss(self, group_loss, group_count): 60 | adjusted_loss = group_loss 61 | if torch.all(self.adj > 0): 62 | adjusted_loss += self.adj / torch.sqrt(self.group_counts) 63 | if self.normalize_loss: 64 | adjusted_loss = adjusted_loss / (adjusted_loss.sum()) 65 | self.adv_probs = self.adv_probs * torch.exp(self.step_size * adjusted_loss.data) 66 | self.adv_probs = self.adv_probs / (self.adv_probs.sum()) 67 | 68 | robust_loss = group_loss @ self.adv_probs 69 | return robust_loss, self.adv_probs 70 | 71 | def compute_robust_loss_btl(self, group_loss, group_count): 72 | adjusted_loss = self.exp_avg_loss + self.adj / torch.sqrt(self.group_counts) 73 | return self.compute_robust_loss_greedy(group_loss, adjusted_loss) 74 | 75 | def compute_robust_loss_greedy(self, group_loss, ref_loss): 76 | sorted_idx = ref_loss.sort(descending=True)[1] 77 | sorted_loss = group_loss[sorted_idx] 78 | sorted_frac = self.group_frac[sorted_idx] 79 | 80 | mask = torch.cumsum(sorted_frac, dim=0) <= self.alpha 81 | weights = mask.float() * sorted_frac / self.alpha 82 | last_idx = mask.sum() 83 | weights[last_idx] = 1 - weights.sum() 84 | weights = sorted_frac * self.min_var_weight + weights * (1 - self.min_var_weight) 85 | 86 | robust_loss = sorted_loss @ weights 87 | 88 | # sort the weights back 89 | _, unsort_idx = sorted_idx.sort() 90 | unsorted_weights = weights[unsort_idx] 91 | return robust_loss, unsorted_weights 92 | 93 | def compute_group_avg(self, losses, group_idx): 94 | # compute observed counts and mean loss for each group 95 | group_map = (group_idx == torch.arange(self.n_groups).unsqueeze(1).long().cuda()).float() 96 | group_count = group_map.sum(1) 97 | group_denom = group_count + (group_count == 0).float() # avoid nans 98 | group_loss = (group_map @ losses.view(-1)) / group_denom 99 | return group_loss, group_count 100 | 101 | def update_exp_avg_loss(self, group_loss, group_count): 102 | prev_weights = (1 - self.gamma * (group_count > 0).float()) * (self.exp_avg_initialized > 0).float() 103 | curr_weights = 1 - prev_weights 104 | self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss * curr_weights 105 | self.exp_avg_initialized = (self.exp_avg_initialized > 0) + (group_count > 0) 106 | 107 | def reset_stats(self): 108 | self.processed_data_counts = torch.zeros(self.n_groups).cuda() 109 | self.update_data_counts = torch.zeros(self.n_groups).cuda() 110 | self.update_batch_counts = torch.zeros(self.n_groups).cuda() 111 | self.avg_group_loss = torch.zeros(self.n_groups).cuda() 112 | self.avg_group_acc = torch.zeros(self.n_groups).cuda() 113 | self.avg_per_sample_loss = 0. 114 | self.avg_actual_loss = 0. 115 | self.avg_acc = 0. 116 | self.batch_count = 0. 117 | 118 | def update_stats(self, actual_loss, group_loss, group_acc, group_count, weights=None): 119 | # avg group loss 120 | denom = self.processed_data_counts + group_count 121 | denom += (denom == 0).float() 122 | prev_weight = self.processed_data_counts / denom 123 | curr_weight = group_count / denom 124 | self.avg_group_loss = prev_weight * self.avg_group_loss + curr_weight * group_loss 125 | 126 | # avg group acc 127 | self.avg_group_acc = prev_weight * self.avg_group_acc + curr_weight * group_acc 128 | 129 | # batch-wise average actual loss 130 | denom = self.batch_count + 1 131 | self.avg_actual_loss = (self.batch_count / denom) * self.avg_actual_loss + (1 / denom) * actual_loss 132 | 133 | # counts 134 | self.processed_data_counts += group_count 135 | if self.is_robust: 136 | self.update_data_counts += group_count * ((weights > 0).float()) 137 | self.update_batch_counts += ((group_count * weights) > 0).float() 138 | else: 139 | self.update_data_counts += group_count 140 | self.update_batch_counts += (group_count > 0).float() 141 | self.batch_count += 1 142 | 143 | # avg per-sample quantities 144 | group_frac = self.processed_data_counts / (self.processed_data_counts.sum()) 145 | self.avg_per_sample_loss = group_frac @ self.avg_group_loss 146 | self.avg_acc = group_frac @ self.avg_group_acc 147 | 148 | def get_model_stats(self, model, args, stats_dict): 149 | model_norm_sq = 0. 150 | for param in model.parameters(): 151 | model_norm_sq += torch.norm(param) ** 2 152 | stats_dict['model_norm_sq'] = model_norm_sq.item() 153 | stats_dict['reg_loss'] = args.weight_decay / 2 * model_norm_sq.item() 154 | return stats_dict 155 | 156 | def get_stats(self, model=None, args=None): 157 | stats_dict = {} 158 | for idx in range(self.n_groups): 159 | stats_dict[f'avg_loss_group:{idx}'] = self.avg_group_loss[idx].item() 160 | stats_dict[f'exp_avg_loss_group:{idx}'] = self.exp_avg_loss[idx].item() 161 | stats_dict[f'avg_acc_group:{idx}'] = self.avg_group_acc[idx].item() 162 | stats_dict[f'processed_data_count_group:{idx}'] = self.processed_data_counts[idx].item() 163 | stats_dict[f'update_data_count_group:{idx}'] = self.update_data_counts[idx].item() 164 | stats_dict[f'update_batch_count_group:{idx}'] = self.update_batch_counts[idx].item() 165 | 166 | stats_dict['avg_actual_loss'] = self.avg_actual_loss.item() 167 | stats_dict['avg_per_sample_loss'] = self.avg_per_sample_loss.item() 168 | stats_dict['avg_acc'] = self.avg_acc.item() 169 | 170 | # Model stats 171 | if model is not None: 172 | assert args is not None 173 | stats_dict = self.get_model_stats(model, args, stats_dict) 174 | 175 | return stats_dict 176 | 177 | def log_stats(self, logger, is_training): 178 | if logger is None: 179 | return 180 | 181 | logger.write(f'Average incurred loss: {self.avg_per_sample_loss.item():.3f} \n') 182 | logger.write(f'Average sample loss: {self.avg_actual_loss.item():.3f} \n') 183 | logger.write(f'Average acc: {self.avg_acc.item():.3f} \n') 184 | for group_idx in range(self.n_groups): 185 | logger.write( 186 | f' {self.group_str(group_idx)} ' 187 | f'[n = {int(self.processed_data_counts[group_idx])}]:\t' 188 | f'loss = {self.avg_group_loss[group_idx]:.3f} ' 189 | f'exp loss = {self.exp_avg_loss[group_idx]:.3f} ' 190 | f'adjusted loss = {self.exp_avg_loss[group_idx] + self.adj[group_idx] / torch.sqrt(self.group_counts)[group_idx]:.3f} ' 191 | f'adv prob = {self.adv_probs[group_idx]:3f} ' 192 | f'acc = {self.avg_group_acc[group_idx]:.3f}\n') 193 | logger.flush() 194 | -------------------------------------------------------------------------------- /real_datasets/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def check_args(args): 10 | if args.shift_type == 'confounder': 11 | assert args.confounder_names 12 | assert args.target_name 13 | elif args.shift_type.startswith('label_shift'): 14 | assert args.minority_fraction 15 | assert args.imbalance_ratio 16 | 17 | 18 | class Logger(object): 19 | def __init__(self, fpath=None, mode='w'): 20 | self.console = sys.stdout 21 | self.file = None 22 | if fpath is not None: 23 | self.file = open(fpath, mode) 24 | 25 | def __del__(self): 26 | self.close() 27 | 28 | def __enter__(self): 29 | pass 30 | 31 | def __exit__(self, *args): 32 | self.close() 33 | 34 | def write(self, msg): 35 | self.console.write(msg) 36 | if self.file is not None: 37 | self.file.write(msg) 38 | 39 | def flush(self): 40 | self.console.flush() 41 | if self.file is not None: 42 | self.file.flush() 43 | os.fsync(self.file.fileno()) 44 | 45 | def close(self): 46 | self.console.close() 47 | if self.file is not None: 48 | self.file.close() 49 | 50 | 51 | class CSVBatchLogger: 52 | def __init__(self, csv_path, n_groups, mode='w'): 53 | columns = ['epoch', 'batch'] 54 | for idx in range(n_groups): 55 | columns.append(f'avg_loss_group:{idx}') 56 | columns.append(f'exp_avg_loss_group:{idx}') 57 | columns.append(f'avg_acc_group:{idx}') 58 | columns.append(f'processed_data_count_group:{idx}') 59 | columns.append(f'update_data_count_group:{idx}') 60 | columns.append(f'update_batch_count_group:{idx}') 61 | columns.append('avg_actual_loss') 62 | columns.append('avg_per_sample_loss') 63 | columns.append('avg_acc') 64 | columns.append('model_norm_sq') 65 | columns.append('reg_loss') 66 | 67 | self.path = csv_path 68 | self.file = open(csv_path, mode) 69 | self.columns = columns 70 | self.writer = csv.DictWriter(self.file, fieldnames=columns) 71 | if mode == 'w': 72 | self.writer.writeheader() 73 | 74 | def log(self, epoch, batch, stats_dict): 75 | stats_dict['epoch'] = epoch 76 | stats_dict['batch'] = batch 77 | self.writer.writerow(stats_dict) 78 | 79 | def flush(self): 80 | self.file.flush() 81 | 82 | def close(self): 83 | self.file.close() 84 | 85 | 86 | class AverageMeter(object): 87 | """Computes and stores the average and current value""" 88 | 89 | def __init__(self): 90 | self.reset() 91 | 92 | def reset(self): 93 | self.val = 0 94 | self.avg = 0 95 | self.sum = 0 96 | self.count = 0 97 | 98 | def update(self, val, n=1): 99 | self.val = val 100 | self.sum += val * n 101 | self.count += n 102 | self.avg = self.sum / self.count 103 | 104 | 105 | def accuracy(output, target, topk=(1,)): 106 | """Computes the precision@k for the specified values of k""" 107 | maxk = max(topk) 108 | batch_size = target.size(0) 109 | 110 | _, pred = output.topk(maxk, 1, True, True) 111 | pred = pred.t() 112 | temp = target.view(1, -1).expand_as(pred) 113 | temp = temp.cuda() 114 | correct = pred.eq(temp) 115 | 116 | res = [] 117 | for k in topk: 118 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 119 | res.append(correct_k.mul_(100.0 / batch_size)) 120 | return res 121 | 122 | 123 | def set_seed(seed): 124 | """Sets seed""" 125 | if torch.cuda.is_available(): 126 | torch.cuda.manual_seed(seed) 127 | torch.manual_seed(seed) 128 | np.random.seed(seed) 129 | torch.backends.cudnn.benchmark = False 130 | torch.backends.cudnn.deterministic = True 131 | 132 | 133 | def log_args(args, logger): 134 | for argname, argval in vars(args).items(): 135 | logger.write(f'{argname.replace("_", " ").capitalize()}: {argval}\n') 136 | logger.write('\n') 137 | -------------------------------------------------------------------------------- /real_datasets/utils_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import csv 21 | import logging 22 | import os 23 | import sys 24 | from io import open 25 | 26 | from scipy.stats import pearsonr, spearmanr 27 | from sklearn.metrics import matthews_corrcoef, f1_score 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class InputExample(object): 33 | """A single training/test example for simple sequence classification.""" 34 | 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | """Constructs a InputExample. 37 | 38 | Args: 39 | guid: Unique id for the example. 40 | text_a: string. The untokenized text of the first sequence. For single 41 | sequence tasks, only this sequence must be specified. 42 | text_b: (Optional) string. The untokenized text of the second sequence. 43 | Only must be specified for sequence pair tasks. 44 | label: (Optional) string. The label of the example. This should be 45 | specified for train and dev examples, but not for test examples. 46 | """ 47 | self.guid = guid 48 | self.text_a = text_a 49 | self.text_b = text_b 50 | self.label = label 51 | 52 | 53 | class InputFeatures(object): 54 | """A single set of features of data.""" 55 | 56 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 57 | self.input_ids = input_ids 58 | self.input_mask = input_mask 59 | self.segment_ids = segment_ids 60 | self.label_id = label_id 61 | 62 | 63 | class DataProcessor(object): 64 | """Base class for data converters for sequence classification data sets.""" 65 | 66 | def get_train_examples(self, data_dir): 67 | """Gets a collection of `InputExample`s for the train set.""" 68 | raise NotImplementedError() 69 | 70 | def get_dev_examples(self, data_dir): 71 | """Gets a collection of `InputExample`s for the dev set.""" 72 | raise NotImplementedError() 73 | 74 | def get_labels(self): 75 | """Gets the list of labels for this data set.""" 76 | raise NotImplementedError() 77 | 78 | @classmethod 79 | def _read_tsv(cls, input_file, quotechar=None): 80 | """Reads a tab separated value file.""" 81 | with open(input_file, "r", encoding="utf-8-sig") as f: 82 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 83 | lines = [] 84 | for line in reader: 85 | if sys.version_info[0] == 2: 86 | line = list(unicode(cell, 'utf-8') for cell in line) 87 | lines.append(line) 88 | return lines 89 | 90 | 91 | class MrpcProcessor(DataProcessor): 92 | """Processor for the MRPC data set (GLUE version).""" 93 | 94 | def get_train_examples(self, data_dir): 95 | """See base class.""" 96 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 97 | return self._create_examples( 98 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 99 | 100 | def get_dev_examples(self, data_dir): 101 | """See base class.""" 102 | return self._create_examples( 103 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 104 | 105 | def get_labels(self): 106 | """See base class.""" 107 | return ["0", "1"] 108 | 109 | def _create_examples(self, lines, set_type): 110 | """Creates examples for the training and dev sets.""" 111 | examples = [] 112 | for (i, line) in enumerate(lines): 113 | if i == 0: 114 | continue 115 | guid = "%s-%s" % (set_type, i) 116 | text_a = line[3] 117 | text_b = line[4] 118 | label = line[0] 119 | examples.append( 120 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 121 | return examples 122 | 123 | 124 | class MnliProcessor(DataProcessor): 125 | """Processor for the MultiNLI data set (GLUE version).""" 126 | 127 | def get_train_examples(self, data_dir): 128 | """See base class.""" 129 | return self._create_examples( 130 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 131 | 132 | def get_dev_examples(self, data_dir): 133 | """See base class.""" 134 | return self._create_examples( 135 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 136 | "dev_matched") 137 | 138 | def get_labels(self): 139 | """See base class.""" 140 | return ["contradiction", "entailment", "neutral"] 141 | 142 | def _create_examples(self, lines, set_type): 143 | """Creates examples for the training and dev sets.""" 144 | examples = [] 145 | for (i, line) in enumerate(lines): 146 | if i == 0: 147 | continue 148 | guid = "%s-%s" % (set_type, line[0]) 149 | text_a = line[8] 150 | text_b = line[9] 151 | label = line[-1] 152 | examples.append( 153 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 154 | return examples 155 | 156 | 157 | class MnliMismatchedProcessor(MnliProcessor): 158 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 159 | 160 | def get_dev_examples(self, data_dir): 161 | """See base class.""" 162 | return self._create_examples( 163 | self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), 164 | "dev_matched") 165 | 166 | 167 | class ColaProcessor(DataProcessor): 168 | """Processor for the CoLA data set (GLUE version).""" 169 | 170 | def get_train_examples(self, data_dir): 171 | """See base class.""" 172 | return self._create_examples( 173 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 174 | 175 | def get_dev_examples(self, data_dir): 176 | """See base class.""" 177 | return self._create_examples( 178 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 179 | 180 | def get_labels(self): 181 | """See base class.""" 182 | return ["0", "1"] 183 | 184 | def _create_examples(self, lines, set_type): 185 | """Creates examples for the training and dev sets.""" 186 | examples = [] 187 | for (i, line) in enumerate(lines): 188 | guid = "%s-%s" % (set_type, i) 189 | text_a = line[3] 190 | label = line[1] 191 | examples.append( 192 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 193 | return examples 194 | 195 | 196 | class Sst2Processor(DataProcessor): 197 | """Processor for the SST-2 data set (GLUE version).""" 198 | 199 | def get_train_examples(self, data_dir): 200 | """See base class.""" 201 | return self._create_examples( 202 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 203 | 204 | def get_dev_examples(self, data_dir): 205 | """See base class.""" 206 | return self._create_examples( 207 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 208 | 209 | def get_labels(self): 210 | """See base class.""" 211 | return ["0", "1"] 212 | 213 | def _create_examples(self, lines, set_type): 214 | """Creates examples for the training and dev sets.""" 215 | examples = [] 216 | for (i, line) in enumerate(lines): 217 | if i == 0: 218 | continue 219 | guid = "%s-%s" % (set_type, i) 220 | text_a = line[0] 221 | label = line[1] 222 | examples.append( 223 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 224 | return examples 225 | 226 | 227 | class StsbProcessor(DataProcessor): 228 | """Processor for the STS-B data set (GLUE version).""" 229 | 230 | def get_train_examples(self, data_dir): 231 | """See base class.""" 232 | return self._create_examples( 233 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 234 | 235 | def get_dev_examples(self, data_dir): 236 | """See base class.""" 237 | return self._create_examples( 238 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 239 | 240 | def get_labels(self): 241 | """See base class.""" 242 | return [None] 243 | 244 | def _create_examples(self, lines, set_type): 245 | """Creates examples for the training and dev sets.""" 246 | examples = [] 247 | for (i, line) in enumerate(lines): 248 | if i == 0: 249 | continue 250 | guid = "%s-%s" % (set_type, line[0]) 251 | text_a = line[7] 252 | text_b = line[8] 253 | label = line[-1] 254 | examples.append( 255 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 256 | return examples 257 | 258 | 259 | class QqpProcessor(DataProcessor): 260 | """Processor for the QQP data set (GLUE version).""" 261 | 262 | def get_train_examples(self, data_dir): 263 | """See base class.""" 264 | return self._create_examples( 265 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 266 | 267 | def get_dev_examples(self, data_dir): 268 | """See base class.""" 269 | return self._create_examples( 270 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 271 | 272 | def get_labels(self): 273 | """See base class.""" 274 | return ["0", "1"] 275 | 276 | def _create_examples(self, lines, set_type): 277 | """Creates examples for the training and dev sets.""" 278 | examples = [] 279 | for (i, line) in enumerate(lines): 280 | if i == 0: 281 | continue 282 | guid = "%s-%s" % (set_type, line[0]) 283 | try: 284 | text_a = line[3] 285 | text_b = line[4] 286 | label = line[5] 287 | except IndexError: 288 | continue 289 | examples.append( 290 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 291 | return examples 292 | 293 | 294 | class QnliProcessor(DataProcessor): 295 | """Processor for the QNLI data set (GLUE version).""" 296 | 297 | def get_train_examples(self, data_dir): 298 | """See base class.""" 299 | return self._create_examples( 300 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 301 | 302 | def get_dev_examples(self, data_dir): 303 | """See base class.""" 304 | return self._create_examples( 305 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 306 | "dev_matched") 307 | 308 | def get_labels(self): 309 | """See base class.""" 310 | return ["entailment", "not_entailment"] 311 | 312 | def _create_examples(self, lines, set_type): 313 | """Creates examples for the training and dev sets.""" 314 | examples = [] 315 | for (i, line) in enumerate(lines): 316 | if i == 0: 317 | continue 318 | guid = "%s-%s" % (set_type, line[0]) 319 | text_a = line[1] 320 | text_b = line[2] 321 | label = line[-1] 322 | examples.append( 323 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 324 | return examples 325 | 326 | 327 | class RteProcessor(DataProcessor): 328 | """Processor for the RTE data set (GLUE version).""" 329 | 330 | def get_train_examples(self, data_dir): 331 | """See base class.""" 332 | return self._create_examples( 333 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 334 | 335 | def get_dev_examples(self, data_dir): 336 | """See base class.""" 337 | return self._create_examples( 338 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 339 | 340 | def get_labels(self): 341 | """See base class.""" 342 | return ["entailment", "not_entailment"] 343 | 344 | def _create_examples(self, lines, set_type): 345 | """Creates examples for the training and dev sets.""" 346 | examples = [] 347 | for (i, line) in enumerate(lines): 348 | if i == 0: 349 | continue 350 | guid = "%s-%s" % (set_type, line[0]) 351 | text_a = line[1] 352 | text_b = line[2] 353 | label = line[-1] 354 | examples.append( 355 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 356 | return examples 357 | 358 | 359 | class WnliProcessor(DataProcessor): 360 | """Processor for the WNLI data set (GLUE version).""" 361 | 362 | def get_train_examples(self, data_dir): 363 | """See base class.""" 364 | return self._create_examples( 365 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 366 | 367 | def get_dev_examples(self, data_dir): 368 | """See base class.""" 369 | return self._create_examples( 370 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 371 | 372 | def get_labels(self): 373 | """See base class.""" 374 | return ["0", "1"] 375 | 376 | def _create_examples(self, lines, set_type): 377 | """Creates examples for the training and dev sets.""" 378 | examples = [] 379 | for (i, line) in enumerate(lines): 380 | if i == 0: 381 | continue 382 | guid = "%s-%s" % (set_type, line[0]) 383 | text_a = line[1] 384 | text_b = line[2] 385 | label = line[-1] 386 | examples.append( 387 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 388 | return examples 389 | 390 | 391 | def convert_examples_to_features(examples, label_list, max_seq_length, 392 | tokenizer, output_mode, 393 | cls_token_at_end=False, 394 | cls_token='[CLS]', 395 | cls_token_segment_id=1, 396 | sep_token='[SEP]', 397 | sep_token_extra=False, 398 | pad_on_left=False, 399 | pad_token=0, 400 | pad_token_segment_id=0, 401 | sequence_a_segment_id=0, 402 | sequence_b_segment_id=1, 403 | mask_padding_with_zero=True): 404 | """ Loads a data file into a list of `InputBatch`s 405 | `cls_token_at_end` define the location of the CLS token: 406 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 407 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 408 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 409 | """ 410 | 411 | label_map = {label : i for i, label in enumerate(label_list)} 412 | 413 | features = [] 414 | for (ex_index, example) in enumerate(examples): 415 | if ex_index % 10000 == 0: 416 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 417 | 418 | tokens_a = tokenizer.tokenize(example.text_a) 419 | 420 | tokens_b = None 421 | if example.text_b: 422 | tokens_b = tokenizer.tokenize(example.text_b) 423 | # Modifies `tokens_a` and `tokens_b` in place so that the total 424 | # length is less than the specified length. 425 | # Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa. 426 | special_tokens_count = 4 if sep_token_extra else 3 427 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) 428 | else: 429 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 430 | special_tokens_count = 3 if sep_token_extra else 2 431 | if len(tokens_a) > max_seq_length - special_tokens_count: 432 | tokens_a = tokens_a[:(max_seq_length - special_tokens_count)] 433 | 434 | # The convention in BERT is: 435 | # (a) For sequence pairs: 436 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 437 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 438 | # (b) For single sequences: 439 | # tokens: [CLS] the dog is hairy . [SEP] 440 | # type_ids: 0 0 0 0 0 0 0 441 | # 442 | # Where "type_ids" are used to indicate whether this is the first 443 | # sequence or the second sequence. The embedding vectors for `type=0` and 444 | # `type=1` were learned during pre-training and are added to the wordpiece 445 | # embedding vector (and position vector). This is not *strictly* necessary 446 | # since the [SEP] token unambiguously separates the sequences, but it makes 447 | # it easier for the model to learn the concept of sequences. 448 | # 449 | # For classification tasks, the first vector (corresponding to [CLS]) is 450 | # used as as the "sentence vector". Note that this only makes sense because 451 | # the entire model is fine-tuned. 452 | tokens = tokens_a + [sep_token] 453 | if sep_token_extra: 454 | # roberta uses an extra separator b/w pairs of sentences 455 | tokens += [sep_token] 456 | segment_ids = [sequence_a_segment_id] * len(tokens) 457 | 458 | if tokens_b: 459 | tokens += tokens_b + [sep_token] 460 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) 461 | 462 | if cls_token_at_end: 463 | tokens = tokens + [cls_token] 464 | segment_ids = segment_ids + [cls_token_segment_id] 465 | else: 466 | tokens = [cls_token] + tokens 467 | segment_ids = [cls_token_segment_id] + segment_ids 468 | 469 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 470 | 471 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 472 | # tokens are attended to. 473 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 474 | 475 | # Zero-pad up to the sequence length. 476 | padding_length = max_seq_length - len(input_ids) 477 | if pad_on_left: 478 | input_ids = ([pad_token] * padding_length) + input_ids 479 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 480 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 481 | else: 482 | input_ids = input_ids + ([pad_token] * padding_length) 483 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 484 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) 485 | 486 | assert len(input_ids) == max_seq_length 487 | assert len(input_mask) == max_seq_length 488 | assert len(segment_ids) == max_seq_length 489 | 490 | if output_mode == "classification": 491 | label_id = label_map[example.label] 492 | elif output_mode == "regression": 493 | label_id = float(example.label) 494 | else: 495 | raise KeyError(output_mode) 496 | 497 | if ex_index < 5: 498 | logger.info("*** Example ***") 499 | logger.info("guid: %s" % (example.guid)) 500 | logger.info("tokens: %s" % " ".join( 501 | [str(x) for x in tokens])) 502 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 503 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 504 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 505 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 506 | 507 | features.append( 508 | InputFeatures(input_ids=input_ids, 509 | input_mask=input_mask, 510 | segment_ids=segment_ids, 511 | label_id=label_id)) 512 | return features 513 | 514 | 515 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 516 | """Truncates a sequence pair in place to the maximum length.""" 517 | 518 | # This is a simple heuristic which will always truncate the longer sequence 519 | # one token at a time. This makes more sense than truncating an equal percent 520 | # of tokens from each, since if one sequence is very short then each token 521 | # that's truncated likely contains more information than a longer sequence. 522 | while True: 523 | total_length = len(tokens_a) + len(tokens_b) 524 | if total_length <= max_length: 525 | break 526 | if len(tokens_a) > len(tokens_b): 527 | tokens_a.pop() 528 | else: 529 | tokens_b.pop() 530 | 531 | 532 | def simple_accuracy(preds, labels): 533 | return (preds == labels).mean() 534 | 535 | 536 | def acc_and_f1(preds, labels): 537 | acc = simple_accuracy(preds, labels) 538 | f1 = f1_score(y_true=labels, y_pred=preds) 539 | return { 540 | "acc": acc, 541 | "f1": f1, 542 | "acc_and_f1": (acc + f1) / 2, 543 | } 544 | 545 | 546 | def pearson_and_spearman(preds, labels): 547 | pearson_corr = pearsonr(preds, labels)[0] 548 | spearman_corr = spearmanr(preds, labels)[0] 549 | return { 550 | "pearson": pearson_corr, 551 | "spearmanr": spearman_corr, 552 | "corr": (pearson_corr + spearman_corr) / 2, 553 | } 554 | 555 | 556 | def compute_metrics(task_name, preds, labels): 557 | assert len(preds) == len(labels) 558 | if task_name == "cola": 559 | return {"mcc": matthews_corrcoef(labels, preds)} 560 | elif task_name == "sst-2": 561 | return {"acc": simple_accuracy(preds, labels)} 562 | elif task_name == "mrpc": 563 | return acc_and_f1(preds, labels) 564 | elif task_name == "sts-b": 565 | return pearson_and_spearman(preds, labels) 566 | elif task_name == "qqp": 567 | return acc_and_f1(preds, labels) 568 | elif task_name == "mnli": 569 | return {"acc": simple_accuracy(preds, labels)} 570 | elif task_name == "mnli-mm": 571 | return {"acc": simple_accuracy(preds, labels)} 572 | elif task_name == "qnli": 573 | return {"acc": simple_accuracy(preds, labels)} 574 | elif task_name == "rte": 575 | return {"acc": simple_accuracy(preds, labels)} 576 | elif task_name == "wnli": 577 | return {"acc": simple_accuracy(preds, labels)} 578 | else: 579 | raise KeyError(task_name) 580 | 581 | processors = { 582 | "cola": ColaProcessor, 583 | "mnli": MnliProcessor, 584 | "mnli-mm": MnliMismatchedProcessor, 585 | "mrpc": MrpcProcessor, 586 | "sst-2": Sst2Processor, 587 | "sts-b": StsbProcessor, 588 | "qqp": QqpProcessor, 589 | "qnli": QnliProcessor, 590 | "rte": RteProcessor, 591 | "wnli": WnliProcessor, 592 | } 593 | 594 | output_modes = { 595 | "cola": "classification", 596 | "mnli": "classification", 597 | "mnli-mm": "classification", 598 | "mrpc": "classification", 599 | "sst-2": "classification", 600 | "sts-b": "regression", 601 | "qqp": "classification", 602 | "qnli": "classification", 603 | "rte": "classification", 604 | "wnli": "classification", 605 | } 606 | 607 | GLUE_TASKS_NUM_LABELS = { 608 | "cola": 2, 609 | "mnli": 3, 610 | "mrpc": 2, 611 | "sst-2": 2, 612 | "sts-b": 1, 613 | "qqp": 2, 614 | "qnli": 2, 615 | "rte": 2, 616 | "wnli": 2, 617 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | pandas 4 | torchvision 5 | pillow 6 | matplotlib 7 | seaborn 8 | scikit-learn 9 | tqdm 10 | scipy 11 | transformers 12 | joblib --------------------------------------------------------------------------------