├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md └── domainbed ├── __init__.py ├── algorithms.py ├── algorithms_inference.py ├── command_launchers.py ├── datasets.py ├── hparams_registry.py ├── lib ├── fast_data_loader.py ├── misc.py ├── query.py ├── reporting.py └── wide_resnet.py ├── model_selection.py ├── networks.py └── scripts ├── __init__.py ├── collect_results.py ├── download.py ├── inference.py ├── list_top_hparams.py ├── save_images.py ├── sweep.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | slurmconfig/* 3 | *pyc 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `ModelRatatouille` 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to `ModelRatatouille`, you agree that your contributions 31 | will be licensed under the LICENSE file in the root directory of this source 32 | tree. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Ratatouille: Recycling Diverse Models for Out-of-Distribution Generalization 2 | 3 | Official PyTorch implementation of model ratatouille | [paper](https://arxiv.org/abs/2212.10445) 4 | 5 | [Alexandre Ramé](https://alexrame.github.io/), [Kartik Ahuja](https://ahujak.github.io/), [Jianyu Zhang](https://scholar.google.com/citations?user=srn7ay8AAAAJ&hl=en), [Matthieu Cord](http://webia.lip6.fr/~cord/), [Léon Bottou](https://leon.bottou.org/), [David Lopez-Paz](http://lopezpaz.org/) 6 | 7 | ## TL;DR 8 | 9 | We propose a new fine-tuning strategy that improves OOD generalization in computer vision by recycling and averaging weights specialized on diverse auxiliary tasks. 10 | 11 | ## Abstract 12 | 13 | Foundation models are redefining how AI systems are built. Practitioners now follow a standard procedure to build their machine learning solutions: from a pre-trained foundation model, they fine-tune the weights on the target task of interest. Then, the Internet is swarmed by a handful of foundation models fine-tuned on many diverse tasks: these individual fine-tunings exist in isolation without benefiting from each other. In our opinion, this is a missed opportunity, as these specialized models contain rich and diverse features. In this paper, we thus propose model ratatouille, a new strategy to recycle the multiple fine-tunings of the same foundation model on diverse auxiliary tasks. Specifically, we repurpose these auxiliary weights as initializations for multiple parallel fine-tunings on the target task; then, we average all fine-tuned weights to obtain the final model. This recycling strategy aims at maximizing the diversity in weights by leveraging the diversity in auxiliary tasks. Empirically, it improves the state of the art on the reference DomainBed benchmark for out-of-distribution generalization. Looking forward, this work contributes to the emerging paradigm of updatable machine learning where, akin to open-source software development, the community collaborates to reliably update machine learning models. 14 | 15 | # Setup 16 | 17 | ## Codebase and DomainBed 18 | 19 | Our code is adapted from the open-source [DomainBed github](https://github.com/facebookresearch/DomainBed/), which is a PyTorch benchmark including datasets and algorithms evaluating OOD generalization. It was introduced in [In Search of Lost Domain Generalization, ICLR 2021](https://openreview.net/forum?id=lQdXeXDoWtI). More specifically, our code extends the [DiWA github](https://github.com/alexrame/diwa), which weight averages the models obtained from the hyperparameter search as a replacement to only selecting one single model: this was motivated and explained in [model soups, ICML 2022](https://arxiv.org/abs/2203.05482) and [DiWA, NeurIPS 2022](https://arxiv.org/abs/2205.09739) papers. 20 | 21 | ## Packages requirements 22 | 23 | * python == 3.7.10 24 | * torch == 1.12.1 25 | * torchvision == 0.13.1 26 | * numpy == 1.21.5 27 | 28 | ## Datasets 29 | 30 | We consider the following [datasets](domainbed/datasets.py): 31 | 32 | * VLCS ([Fang et al., 2013](https://openaccess.thecvf.com/content_iccv_2013/papers/Fang_Unbiased_Metric_Learning_2013_ICCV_paper.pdf)) 33 | * PACS ([Li et al., 2017](https://arxiv.org/abs/1710.03077)) 34 | * OfficeHome ([Venkateswara et al., 2017](https://arxiv.org/abs/1706.07522)) 35 | * A TerraIncognita ([Beery et al., 2018](https://arxiv.org/abs/1807.04975)) subset 36 | * DomainNet ([Peng et al., 2019](http://ai.bu.edu/M3SDA/)) 37 | 38 | You can download the datasets with following command: 39 | 40 | ```sh 41 | python3 -m domainbed.scripts.download --data_dir ${data_dir} 42 | ``` 43 | 44 | # Ratatouille: procedure details 45 | 46 | Our procedure is in three stages. 47 | 48 | 1. Auxiliary trainings: create a pool of specialized models on various auxiliary tasks. 49 | 2. Target trainings: apply the standard hyperparameter search starting from these auxiliary initializations. 50 | 3. Weight selection: average the fine-tuned weights. 51 | 52 | The different experiments are saved in `${expe_dir}`. 53 | 54 | ## Building a pool of specialized auxiliary weights 55 | 56 | For real-world applications, we envision that specialized weights may be downloaded from collaborative open-source repositories of neural networks. In practice in this github, to populate the folder `${expe_dir}/aux`, we will perform fine-tunings on DomainBed's datasets. Specifically, we use the `sweep` script with either VLCS, PACS, OfficeHome, TerraIncognita or DomainNet as the `${auxiliary_dataset}`. 57 | 58 | ```sh 59 | mkdir ${expe_dir}/lp # dir containing the linear probe runs 60 | mkdir ${expe_dir}/aux # dir containing the auxiliary runs 61 | 62 | for auxiliary_dataset in VLCS PACS OfficeHome TerraIncognita DomainNet 63 | do 64 | python -m domainbed.scripts.sweep launch\ 65 | --data_dir ${data_dir}\ 66 | --dataset ${auxiliary_dataset}\ 67 | --test_env -1\ ## this means that we train on all domains simultaneously: there is no OOD test env for auxiliary trainings. 68 | --output_dir_lp ${expe_dir}/lp/${auxiliary_dataset}_notest\ ## where the shared linear probe is saved 69 | --output_dir ${expe_dir}/aux/${auxiliary_dataset}_notest\ ## where the auxiliary hyperparameter sweep is saved 70 | --n_hparams 4\ ## we only need 4 runs in the hyperparameter search 71 | --n_trials 1 ## only one data split 72 | done 73 | ``` 74 | 75 | First, if `output_dir_lp` does not exist, we linear probe (lp) the classifier (to prevent [feature distortion](https://openreview.net/forum?id=UYneFzXSJWh)): this classifier initialization will be used in the subsequent runs. Second, we populate `output_dir` with `n_hparams` ERM runs following the hyperparameter distributions from [here](domainbed/hparams_registry.py). 76 | 77 | Critically, this procedure is agnostic to the target task, and thus is done only once. 78 | 79 | ## Fine-tunings on the target task 80 | 81 | Now we focus on a given `${target_dataset}`, and one `${test_env}` domain considered as the test domain: other domains are for training. As previously, we leverage the `sweep` script. 82 | 83 | ```sh 84 | mkdir ${expe_dir}/target # dir containing the target runs 85 | target_dataset=OfficeHome ## or any other DomainBed's dataset 86 | test_env=0 ## or any integer between 0 and 3 87 | 88 | python -m domainbed.scripts.sweep launch\ 89 | --data_dir ${data_dir}\ 90 | --dataset ${target_dataset}\ 91 | --test_env ${test_env}\ ## domain not seen during training and kept apart for OOD evaluation 92 | --output_dir_lp ${expe_dir}/lp/${target_dataset}_test${test_env}\ ## where the shared linear probe is saved 93 | --output_dir ${expe_dir}/target/${target_dataset}_withaux\ ## where the target hyperparameter sweep is saved 94 | --aux_dir ${expe_dir}/aux\ ## where the pool of auxiliary weights are saved 95 | --n_hparams 20\ ## default number of hyperparameters, but 5 already provides good results 96 | --n_trials 1 ## set to 3 to test different data splits 97 | ``` 98 | 99 | The arg `aux_dir` is the directory containing the different auxiliary runs to initialize the featurizer. Obviously, to prevent any kind of information leakage, in the code we will discard from `aux_dir` the models inter-trained on `${target_dataset}`: in brief, we ensure that `${target_dataset}` $\neq$ `${auxiliary_dataset}`. 100 | 101 | ## Average the fine-tuned weights 102 | 103 | Ratatouille's main theoretical contribution states the linear mode connectivity across models fine-tuned on the target task starting from different initializations. Thus we average the weights obtained from previous sweep. 104 | 105 | ```sh 106 | python -m domainbed.scripts.inference\ 107 | --data_dir ${data_dir}\ 108 | --dataset ${target_dataset}\ 109 | --test_env ${test_env}\ 110 | --input_dir ${expe_dir}/target/${target_dataset}_withaux\ 111 | --weight_selection uniform\ # or use greedy 112 | --trial_seed 0 113 | ``` 114 | 115 | If you want to obtain standard deviations on different data splits, set `--n_trials 3` in the sweep command. Then you can specify `trial_seed` to either `0`, `1` or `2`: you can also average all `60` weights from the `3` trials by setting`trial_seed`to`-1`, what we call`uniform`$^\dagger$. 116 | 117 | 118 | # Baselines 119 | 120 | ### Inter-training 121 | 122 | Inter-training selects the best model based on ID validation accuracy from previous runs. To reproduce the results, call: 123 | 124 | ````sh 125 | python -m domainbed.scripts.collect_results --input_dir ${expe_dir}/target/${target_dataset}_withaux 126 | ```` 127 | 128 | ## Vanilla fine-tuning and Soups/DiWA 129 | 130 | You first need to launch a new sweep without specifying `aux_dir`. 131 | ```sh 132 | python -m domainbed.scripts.sweep launch\ 133 | ... # same as before 134 | --output_dir ${expe_dir}/target/${target_dataset}_noaux\ ## change the output dir 135 | --aux_dir none 136 | ``` 137 | 138 | Then call `collect_results.py` (for vanilla fine-tuning) or `inference.py` (for Soups/DiWA) with `--input_dir ${expe_dir}/target/${target_dataset}_noaux`. In brief, model ratatouille is to inter-training as model soups is to vanilla fine-tuning. 139 | 140 | ## Fusing 141 | 142 | Add `--fusing_range 4` in the previous sweep command to operate linear interpolation at initialization as in [fusing](https://arxiv.org/abs/2204.03044), where rather than selecting one single checkpoint at initialization, they linearly interpolate multiple auxiliary featurizers. 143 | 144 | ```sh 145 | python -m domainbed.scripts.sweep launch\ 146 | ... # same as before 147 | --output_dir ${expe_dir}/target/${target_dataset}_withaux_fusing4\ ## change the output dir 148 | --fusing_range 4 # The value `4` specifies how the interpolating coefficients are sampled. 149 | ``` 150 | 151 | # Results 152 | 153 | Ratatouille sets a new state of the art on DomainBed. 154 | 155 | | Algorithm | Selection | PACS | VLCS | OfficeHome | TerraInc | DomainNet | Avg | 156 | |---|---|---|---|---|---|---|---| 157 | | Vanilla fine-tuning | ID val | 85.5 | 77.5 | 66.5 | 46.1 | 40.9 | 63.3 | 158 | | Coral | ID val | 86.2 | 78.8 | 68.7 | 47.6 | 41.5 | 64.6 | 159 | | SWAD | Loss-aware | 88.1 | **79.1** | 70.6 | 50.0 | 46.5 | 66.9 | 160 | |---|---|---|---|---|---|---|---| 161 | | ERM | ID val | 85.9 | 78.1 | 69.4 | 50.4 | 44.3 | 65.6 | 162 | | Soups/DiWA | Greedy | 88.0 | 78.5 | 71.5 | 51.6 | **47.7** | 67.5 | 163 | | Soups/DiWA | Uniform | 88.7 | 78.4 | 72.1 | 51.4 | 47.4 | 67.6 | 164 | | Soups/DiWA$^{\dagger}$ | Uniform$^{\dagger}$ | 89.0 | 78.6 | 72.8 | 51.9 | **47.7** | 68.0 | 165 | |---|---|---|---|---|---|---|---| 166 | | Inter-training | ID val | 89.0 | 77.7 | 69.9 | 46.7 | 44.5 | 65.6 | 167 | | Fusing | ID val | 88.0 | 78.5 | 71.5 | 46.7 | 44.4 | 65.8 | 168 | | Ratatouille | Uniform | 89.5 | 78.5 | 73.1 | 51.8 | 47.5 | 68.1 | 169 | | Ratatouille | Greedy | **90.5** | 78.7 | 73.4 | 49.2 | **47.7** | 67.9 | 170 | | Ratatouille$^{\dagger}$ | Uniform$^{\dagger}$ | 89.8 | 78.3 | **73.5** | **52.0** | **47.7** | **68.3** | 171 | 172 | 173 | # Other information 174 | 175 | ## License 176 | 177 | This source code is released under the MIT license, included [here](LICENSE). 178 | 179 | ## Why ratatouille ? 180 | 181 | We named our method after this traditional French dish for two main reasons. Firstly, the ratatouille is often used as a way to recycle leftover vegetables. Secondly, the ratatouille is better prepared by cooking each ingredient separately before mixing them: this technique ensures that each ingredient “will taste truly of itself”, as [noted](https://www.bbc.com/travel/article/20200812-the-right-way-to-make-ratatouille) by chef Joël Robuchon. 182 | 183 | ## Citation 184 | 185 | If you find this code useful for your research, please consider citing our work: 186 | 187 | ``` 188 | @article{rame2022recycling, 189 | title={Model Ratatouille: Recycling Diverse Models for Out-of-Distribution Generalization}, 190 | author={Ram{\'e}, Alexandre and Ahuja, Kartik and Zhang, Jianyu and Cord, Matthieu and Bottou, L{\'e}on and Lopez-Paz, David}, 191 | journal={arXiv preprint arXiv:2212.10445}, 192 | year={2022} 193 | } 194 | ``` 195 | 196 | Correspondence to alexandre.rame at isir.upmc.fr 197 | -------------------------------------------------------------------------------- /domainbed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | -------------------------------------------------------------------------------- /domainbed/algorithms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import os 7 | 8 | from domainbed.lib import misc 9 | from domainbed import networks 10 | 11 | ALGORITHMS = [ 12 | 'ERM', 13 | ] 14 | 15 | def get_algorithm_class(algorithm_name): 16 | """Return the algorithm class with the given name.""" 17 | if algorithm_name not in globals(): 18 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 19 | return globals()[algorithm_name] 20 | 21 | class Algorithm(torch.nn.Module): 22 | """ 23 | A subclass of Algorithm implements a domain generalization algorithm. 24 | Subclasses should implement the following: 25 | - update() 26 | - predict() 27 | """ 28 | def __init__(self, input_shape, num_classes, hparams): 29 | super(Algorithm, self).__init__() 30 | self.input_shape = input_shape 31 | self.num_classes = num_classes 32 | self.hparams = hparams 33 | 34 | def update(self, minibatches, unlabeled=None): 35 | """ 36 | Perform one update step, given a list of (x, y) tuples for all 37 | environments. 38 | 39 | Admits an optional list of unlabeled minibatches from the test domains, 40 | when task is domain_adaptation. 41 | """ 42 | raise NotImplementedError 43 | 44 | def predict(self, x): 45 | raise NotImplementedError 46 | 47 | class ERM(Algorithm): 48 | """ 49 | Empirical Risk Minimization (ERM) 50 | """ 51 | 52 | def __init__(self, input_shape, num_classes, hparams, what_is_trainable=False, path_init="", dict_featurizers_aux={}): 53 | 54 | super(ERM, self).__init__(input_shape, num_classes, hparams) 55 | self._what_is_trainable = what_is_trainable 56 | self._dict_featurizers_aux = dict_featurizers_aux 57 | self._create_network() 58 | self._load_network(path_init) 59 | self.register_buffer('update_count', torch.tensor([0])) 60 | self._init_optimizer() 61 | 62 | def _create_network(self): 63 | self.featurizer = networks.Featurizer(self.input_shape, self.hparams) 64 | self.classifier = networks.Classifier( 65 | self.featurizer.n_outputs, 66 | self.num_classes) 67 | 68 | self.network = nn.Sequential(self.featurizer, self.classifier) 69 | 70 | def _load_network(self, path_init): 71 | if path_init: 72 | assert os.path.exists(path_init) 73 | self.network.load_state_dict(torch.load(path_init), strict=True) 74 | 75 | if self._dict_featurizers_aux: 76 | list_featurizers_aux = [] 77 | list_lambdas_aux = [] 78 | for path_featurizer_aux, lambda_aux in self._dict_featurizers_aux.items(): 79 | featurizer_aux = networks.Featurizer(self.input_shape, self.hparams) 80 | if path_featurizer_aux != 'imagenet': 81 | featurizer_aux.load_state_dict(torch.load(path_featurizer_aux), strict=True) 82 | list_featurizers_aux.append(featurizer_aux) 83 | list_lambdas_aux.append(lambda_aux) 84 | 85 | if len(list_featurizers_aux) == 1: 86 | wa_weights = {k:v for k, v in list_featurizers_aux[0].named_parameters()} 87 | else: 88 | # for fusing at initialization 89 | wa_weights = misc.get_name_waparameters( 90 | list_featurizers_aux, 91 | list_lambdas_aux) 92 | for name, param in self.featurizer.named_parameters(): 93 | param.data = wa_weights[name] 94 | 95 | def _need_lp(self): 96 | return len([key for key in self._dict_featurizers_aux.keys() if key != "imagenet"]) 97 | 98 | def _get_training_parameters(self): 99 | if self._need_lp(): 100 | # apply another lp linear probe only when the featurizer is not transferred directly from ImageNet 101 | if self.update_count == self.hparams["lp_steps"]: 102 | print(f"Now back to update {self._what_is_trainable}") 103 | what_is_trainable = self._what_is_trainable 104 | else: 105 | assert self.update_count == 0 106 | what_is_trainable = "classifier" 107 | else: 108 | what_is_trainable = self._what_is_trainable 109 | 110 | if what_is_trainable in ["all"]: 111 | training_parameters = self.network.parameters() 112 | else: 113 | assert what_is_trainable in ["classifier"] 114 | training_parameters = self.classifier.parameters() 115 | return training_parameters 116 | 117 | def _init_optimizer(self): 118 | training_parameters = self._get_training_parameters() 119 | self.optimizer = torch.optim.Adam( 120 | training_parameters, 121 | lr=self.hparams["lr"], 122 | weight_decay=self.hparams['weight_decay'] 123 | ) 124 | 125 | def update(self, minibatches): 126 | all_x = torch.cat([x for x,y in minibatches]) 127 | all_y = torch.cat([y for x,y in minibatches]) 128 | loss = F.cross_entropy(self.predict(all_x), all_y) 129 | 130 | self.optimizer.zero_grad() 131 | loss.backward() 132 | self.optimizer.step() 133 | 134 | if self._need_lp() and self.update_count == self.hparams["lp_steps"]: 135 | self._init_optimizer() 136 | self.update_count += 1 137 | 138 | return {'loss': loss.item()} 139 | 140 | def predict(self, x): 141 | return self.network(x) 142 | 143 | def save_path_for_future_init(self, path_init): 144 | assert not os.path.exists(path_init), "The initialization has already been saved" 145 | torch.save(self.network.state_dict(), path_init) 146 | -------------------------------------------------------------------------------- /domainbed/algorithms_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import copy 4 | from domainbed import algorithms 5 | 6 | class ERM(algorithms.ERM): 7 | 8 | def __init__(self, input_shape, num_classes, hparams): 9 | algorithms.Algorithm.__init__(self, input_shape, num_classes, hparams) 10 | algorithms.ERM._create_network(self) 11 | 12 | class WA(algorithms.ERM): 13 | 14 | def __init__(self, input_shape, num_classes): 15 | """ 16 | """ 17 | algorithms.Algorithm.__init__(self, input_shape, num_classes, hparams={}) 18 | self.network_wa = None 19 | self.global_count = 0 20 | 21 | def add_weights(self, network): 22 | if self.network_wa is None: 23 | self.network_wa = copy.deepcopy(network) 24 | else: 25 | for param_q, param_k in zip(network.parameters(), self.network_wa.parameters()): 26 | param_k.data = (param_k.data * self.global_count + param_q.data) / (1. + self.global_count) 27 | self.global_count += 1 28 | 29 | def predict(self, x): 30 | return self.network_wa(x) 31 | -------------------------------------------------------------------------------- /domainbed/command_launchers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | A command launcher launches a list of commands on a cluster; implement your own 5 | launcher to add support for your cluster. We've provided an example launcher 6 | which runs all commands serially on the local machine. 7 | """ 8 | 9 | import subprocess 10 | import time 11 | import torch 12 | 13 | def multi_gpu_launcher(commands): 14 | """ 15 | Launch commands on the local machine, using all GPUs in parallel. 16 | """ 17 | n_gpus = torch.cuda.device_count() 18 | print(f'WARNING: using experimental multi_gpu_launcher with {n_gpus} gpus.') 19 | procs_by_gpu = [None] * n_gpus 20 | 21 | while len(commands) > 0: 22 | for gpu_idx in range(n_gpus): 23 | proc = procs_by_gpu[gpu_idx] 24 | if (proc is None) or (proc.poll() is not None): 25 | # Nothing is running on this GPU; launch a command. 26 | cmd = commands.pop(0) 27 | new_proc = subprocess.Popen( 28 | f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', shell=True) 29 | procs_by_gpu[gpu_idx] = new_proc 30 | break 31 | time.sleep(1) 32 | 33 | # Wait for the last few tasks to finish before returning 34 | for p in procs_by_gpu: 35 | if p is not None: 36 | p.wait() 37 | -------------------------------------------------------------------------------- /domainbed/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | from PIL import Image, ImageFile 6 | from torchvision import transforms 7 | import torchvision.datasets.folder 8 | from torch.utils.data import TensorDataset, Subset 9 | from torchvision.datasets import MNIST, ImageFolder 10 | from torchvision.transforms.functional import rotate 11 | 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | DATASETS = [ 15 | # Debug 16 | "Debug28", 17 | "Debug224", 18 | # Small images 19 | "ColoredMNIST", 20 | "RotatedMNIST", 21 | # Big images 22 | "VLCS", 23 | "PACS", 24 | "OfficeHome", 25 | "TerraIncognita", 26 | "DomainNet", 27 | "SVIRO", 28 | # WILDS datasets 29 | "WILDSCamelyon", 30 | "WILDSFMoW" 31 | ] 32 | 33 | def get_dataset_class(dataset_name): 34 | """Return the dataset class with the given name.""" 35 | if dataset_name not in globals(): 36 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 37 | return globals()[dataset_name] 38 | 39 | 40 | def num_environments(dataset_name): 41 | return len(get_dataset_class(dataset_name).ENVIRONMENTS) 42 | 43 | 44 | class MultipleDomainDataset: 45 | N_STEPS = 5001 # Default, subclasses may override 46 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 47 | N_WORKERS = 8 # Default, subclasses may override 48 | ENVIRONMENTS = None # Subclasses should override 49 | INPUT_SHAPE = None # Subclasses should override 50 | 51 | def __getitem__(self, index): 52 | return self.datasets[index] 53 | 54 | def __len__(self): 55 | return len(self.datasets) 56 | 57 | 58 | class Debug(MultipleDomainDataset): 59 | def __init__(self, root, test_envs, hparams): 60 | super().__init__() 61 | self.input_shape = self.INPUT_SHAPE 62 | self.num_classes = 2 63 | self.datasets = [] 64 | for _ in [0, 1, 2]: 65 | self.datasets.append( 66 | TensorDataset( 67 | torch.randn(16, *self.INPUT_SHAPE), 68 | torch.randint(0, self.num_classes, (16,)) 69 | ) 70 | ) 71 | 72 | class Debug28(Debug): 73 | INPUT_SHAPE = (3, 28, 28) 74 | ENVIRONMENTS = ['0', '1', '2'] 75 | 76 | class Debug224(Debug): 77 | INPUT_SHAPE = (3, 224, 224) 78 | ENVIRONMENTS = ['0', '1', '2'] 79 | 80 | 81 | class MultipleEnvironmentMNIST(MultipleDomainDataset): 82 | def __init__(self, root, environments, dataset_transform, input_shape, 83 | num_classes): 84 | super().__init__() 85 | if root is None: 86 | raise ValueError('Data directory not specified!') 87 | 88 | original_dataset_tr = MNIST(root, train=True, download=True) 89 | original_dataset_te = MNIST(root, train=False, download=True) 90 | 91 | original_images = torch.cat((original_dataset_tr.data, 92 | original_dataset_te.data)) 93 | 94 | original_labels = torch.cat((original_dataset_tr.targets, 95 | original_dataset_te.targets)) 96 | 97 | shuffle = torch.randperm(len(original_images)) 98 | 99 | original_images = original_images[shuffle] 100 | original_labels = original_labels[shuffle] 101 | 102 | self.datasets = [] 103 | 104 | for i in range(len(environments)): 105 | images = original_images[i::len(environments)] 106 | labels = original_labels[i::len(environments)] 107 | self.datasets.append(dataset_transform(images, labels, environments[i])) 108 | 109 | self.input_shape = input_shape 110 | self.num_classes = num_classes 111 | 112 | 113 | class ColoredMNIST(MultipleEnvironmentMNIST): 114 | ENVIRONMENTS = ['+90%', '+80%', '-90%'] 115 | CHECKPOINT_FREQ = 100 116 | def __init__(self, root, test_envs, hparams): 117 | super(ColoredMNIST, self).__init__(root, [0.1, 0.2, 0.9], 118 | self.color_dataset, (2, 28, 28,), 2) 119 | 120 | self.input_shape = (2, 28, 28,) 121 | self.num_classes = 2 122 | 123 | def color_dataset(self, images, labels, environment): 124 | # # Subsample 2x for computational convenience 125 | # images = images.reshape((-1, 28, 28))[:, ::2, ::2] 126 | # Assign a binary label based on the digit 127 | labels = (labels < 5).float() 128 | # Flip label with probability 0.25 129 | labels = self.torch_xor_(labels, 130 | self.torch_bernoulli_(0.25, len(labels))) 131 | 132 | # Assign a color based on the label; flip the color with probability e 133 | colors = self.torch_xor_(labels, 134 | self.torch_bernoulli_(environment, 135 | len(labels))) 136 | images = torch.stack([images, images], dim=1) 137 | # Apply the color to the image by zeroing out the other color channel 138 | images[torch.tensor(range(len(images))), ( 139 | 1 - colors).long(), :, :] *= 0 140 | 141 | x = images.float().div_(255.0) 142 | y = labels.view(-1).long() 143 | 144 | return TensorDataset(x, y) 145 | 146 | def torch_bernoulli_(self, p, size): 147 | return (torch.rand(size) < p).float() 148 | 149 | def torch_xor_(self, a, b): 150 | return (a - b).abs() 151 | 152 | 153 | class RotatedMNIST(MultipleEnvironmentMNIST): 154 | ENVIRONMENTS = ['0', '15', '30', '45', '60', '75'] 155 | 156 | def __init__(self, root, test_envs, hparams): 157 | super(RotatedMNIST, self).__init__(root, [0, 15, 30, 45, 60, 75], 158 | self.rotate_dataset, (1, 28, 28,), 10) 159 | 160 | def rotate_dataset(self, images, labels, angle): 161 | rotation = transforms.Compose([ 162 | transforms.ToPILImage(), 163 | transforms.Lambda(lambda x: rotate(x, angle, fill=(0,), 164 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR)), 165 | transforms.ToTensor()]) 166 | 167 | x = torch.zeros(len(images), 1, 28, 28) 168 | for i in range(len(images)): 169 | x[i] = rotation(images[i]) 170 | 171 | y = labels.view(-1) 172 | 173 | return TensorDataset(x, y) 174 | 175 | 176 | class MultipleEnvironmentImageFolder(MultipleDomainDataset): 177 | def __init__(self, root, test_envs, augment, hparams): 178 | super().__init__() 179 | environments = [f.name for f in os.scandir(root) if f.is_dir()] 180 | environments = sorted(environments) 181 | 182 | transform = transforms.Compose([ 183 | transforms.Resize((224,224)), 184 | transforms.ToTensor(), 185 | transforms.Normalize( 186 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 187 | ]) 188 | 189 | augment_transform = transforms.Compose([ 190 | # transforms.Resize((224,224)), 191 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 192 | transforms.RandomHorizontalFlip(), 193 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 194 | transforms.RandomGrayscale(), 195 | transforms.ToTensor(), 196 | transforms.Normalize( 197 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 198 | ]) 199 | 200 | self.datasets = [] 201 | for i, environment in enumerate(environments): 202 | 203 | if augment and (i not in test_envs): 204 | env_transform = augment_transform 205 | else: 206 | env_transform = transform 207 | 208 | path = os.path.join(root, environment) 209 | env_dataset = ImageFolder(path, 210 | transform=env_transform) 211 | 212 | self.datasets.append(env_dataset) 213 | 214 | self.input_shape = (3, 224, 224,) 215 | self.num_classes = len(self.datasets[-1].classes) 216 | 217 | class VLCS(MultipleEnvironmentImageFolder): 218 | CHECKPOINT_FREQ = 50 219 | ENVIRONMENTS = ["C", "L", "S", "V"] 220 | def __init__(self, root, test_envs, hparams): 221 | self.dir = os.path.join(root, "VLCS/") 222 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 223 | 224 | class PACS(MultipleEnvironmentImageFolder): 225 | CHECKPOINT_FREQ = 100 226 | ENVIRONMENTS = ["A", "C", "P", "S"] 227 | def __init__(self, root, test_envs, hparams): 228 | self.dir = os.path.join(root, "PACS/") 229 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 230 | 231 | class DomainNet(MultipleEnvironmentImageFolder): 232 | CHECKPOINT_FREQ = 500 233 | N_STEPS = 15001 # DomainNet requires more training steps, as previously observed in SWAD, MA or DiWA 234 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"] 235 | def __init__(self, root, test_envs, hparams): 236 | self.dir = os.path.join(root, "domain_net/") 237 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 238 | 239 | class OfficeHome(MultipleEnvironmentImageFolder): 240 | CHECKPOINT_FREQ = 100 241 | ENVIRONMENTS = ["A", "C", "P", "R"] 242 | def __init__(self, root, test_envs, hparams): 243 | self.dir = os.path.join(root, "office_home/") 244 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 245 | 246 | class TerraIncognita(MultipleEnvironmentImageFolder): 247 | CHECKPOINT_FREQ = 100 248 | ENVIRONMENTS = ["L100", "L38", "L43", "L46"] 249 | def __init__(self, root, test_envs, hparams): 250 | self.dir = os.path.join(root, "terra_incognita/") 251 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 252 | 253 | class SVIRO(MultipleEnvironmentImageFolder): 254 | CHECKPOINT_FREQ = 300 255 | ENVIRONMENTS = ["aclass", "escape", "hilux", "i3", "lexus", "tesla", "tiguan", "tucson", "x5", "zoe"] 256 | def __init__(self, root, test_envs, hparams): 257 | self.dir = os.path.join(root, "sviro/") 258 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 259 | 260 | 261 | class WILDSEnvironment: 262 | def __init__( 263 | self, 264 | wilds_dataset, 265 | metadata_name, 266 | metadata_value, 267 | transform=None): 268 | self.name = metadata_name + "_" + str(metadata_value) 269 | 270 | metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 271 | metadata_array = wilds_dataset.metadata_array 272 | subset_indices = torch.where( 273 | metadata_array[:, metadata_index] == metadata_value)[0] 274 | 275 | self.dataset = wilds_dataset 276 | self.indices = subset_indices 277 | self.transform = transform 278 | 279 | def __getitem__(self, i): 280 | x = self.dataset.get_input(self.indices[i]) 281 | if type(x).__name__ != "Image": 282 | x = Image.fromarray(x) 283 | 284 | y = self.dataset.y_array[self.indices[i]] 285 | if self.transform is not None: 286 | x = self.transform(x) 287 | return x, y 288 | 289 | def __len__(self): 290 | return len(self.indices) 291 | 292 | 293 | class WILDSDataset(MultipleDomainDataset): 294 | INPUT_SHAPE = (3, 224, 224) 295 | def __init__(self, dataset, metadata_name, test_envs, augment, hparams): 296 | super().__init__() 297 | 298 | transform = transforms.Compose([ 299 | transforms.Resize((224, 224)), 300 | transforms.ToTensor(), 301 | transforms.Normalize( 302 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 303 | ]) 304 | 305 | augment_transform = transforms.Compose([ 306 | transforms.Resize((224, 224)), 307 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 308 | transforms.RandomHorizontalFlip(), 309 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 310 | transforms.RandomGrayscale(), 311 | transforms.ToTensor(), 312 | transforms.Normalize( 313 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 314 | ]) 315 | 316 | self.datasets = [] 317 | 318 | for i, metadata_value in enumerate( 319 | self.metadata_values(dataset, metadata_name)): 320 | if augment and (i not in test_envs): 321 | env_transform = augment_transform 322 | else: 323 | env_transform = transform 324 | 325 | env_dataset = WILDSEnvironment( 326 | dataset, metadata_name, metadata_value, env_transform) 327 | 328 | self.datasets.append(env_dataset) 329 | 330 | self.input_shape = (3, 224, 224,) 331 | self.num_classes = dataset.n_classes 332 | 333 | def metadata_values(self, wilds_dataset, metadata_name): 334 | metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 335 | metadata_vals = wilds_dataset.metadata_array[:, metadata_index] 336 | return sorted(list(set(metadata_vals.view(-1).tolist()))) 337 | 338 | 339 | class WILDSCamelyon(WILDSDataset): 340 | ENVIRONMENTS = [ "hospital_0", "hospital_1", "hospital_2", "hospital_3", 341 | "hospital_4"] 342 | def __init__(self, root, test_envs, hparams): 343 | from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 344 | dataset = Camelyon17Dataset(root_dir=root) 345 | super().__init__( 346 | dataset, "hospital", test_envs, hparams['data_augmentation'], hparams) 347 | 348 | 349 | class WILDSFMoW(WILDSDataset): 350 | ENVIRONMENTS = [ "region_0", "region_1", "region_2", "region_3", 351 | "region_4", "region_5"] 352 | def __init__(self, root, test_envs, hparams): 353 | from wilds.datasets.fmow_dataset import FMoWDataset 354 | dataset = FMoWDataset(root_dir=root) 355 | super().__init__( 356 | dataset, "region", test_envs, hparams['data_augmentation'], hparams) 357 | -------------------------------------------------------------------------------- /domainbed/hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | from domainbed.lib import misc 4 | 5 | 6 | def _hparams(algorithm, dataset, random_seed): 7 | """ 8 | Global registry of hyperparams. Each entry is a (default, random) tuple. 9 | New algorithms / networks / etc. should add entries here. 10 | """ 11 | 12 | hparams = {} 13 | 14 | def _hparam(name, default_val, random_val_fn): 15 | """Define a hyperparameter. random_val_fn takes a RandomState and 16 | returns a random hyperparameter value.""" 17 | assert(name not in hparams) 18 | random_state = np.random.RandomState( 19 | misc.seed_hash(random_seed, name) 20 | ) 21 | hparams[name] = (default_val, random_val_fn(random_state)) 22 | 23 | # Unconditional hparam definitions. 24 | _hparam('data_augmentation', True, lambda r: True) 25 | _hparam('resnet_dropout', 0., lambda r: r.choice([0., 0.1, 0.5])) 26 | 27 | ## Mild hyperparameter ranges as first defined in SWAD (https://arxiv.org/abs/2102.08604) and DiWA 28 | _hparam('lr', 5e-5, lambda r: r.choice([1e-5, 3e-5, 5e-5])) 29 | _hparam('weight_decay', 0, lambda r: r.choice([1e-4, 1e-6])) 30 | _hparam('batch_size', 32, lambda r: 32) 31 | _hparam('lp_steps', 200., lambda r: r.choice([200])) 32 | 33 | return hparams 34 | 35 | 36 | def default_hparams(algorithm, dataset): 37 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()} 38 | 39 | 40 | def random_hparams(algorithm, dataset, seed): 41 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()} 42 | -------------------------------------------------------------------------------- /domainbed/lib/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | class _InfiniteSampler(torch.utils.data.Sampler): 6 | """Wraps another Sampler to yield an infinite stream.""" 7 | def __init__(self, sampler): 8 | self.sampler = sampler 9 | 10 | def __iter__(self): 11 | while True: 12 | for batch in self.sampler: 13 | yield batch 14 | 15 | class InfiniteDataLoader: 16 | def __init__(self, dataset, weights, batch_size, num_workers): 17 | super().__init__() 18 | 19 | if weights is not None: 20 | sampler = torch.utils.data.WeightedRandomSampler(weights, 21 | replacement=True, 22 | num_samples=batch_size) 23 | else: 24 | sampler = torch.utils.data.RandomSampler(dataset, 25 | replacement=True) 26 | 27 | if weights == None: 28 | weights = torch.ones(len(dataset)) 29 | 30 | batch_sampler = torch.utils.data.BatchSampler( 31 | sampler, 32 | batch_size=batch_size, 33 | drop_last=True) 34 | 35 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 36 | dataset, 37 | num_workers=num_workers, 38 | batch_sampler=_InfiniteSampler(batch_sampler) 39 | )) 40 | 41 | def __iter__(self): 42 | while True: 43 | yield next(self._infinite_iterator) 44 | 45 | def __len__(self): 46 | raise ValueError 47 | 48 | class FastDataLoader: 49 | """DataLoader wrapper with slightly improved speed by not respawning worker 50 | processes at every epoch.""" 51 | def __init__(self, dataset, batch_size, num_workers): 52 | super().__init__() 53 | 54 | batch_sampler = torch.utils.data.BatchSampler( 55 | torch.utils.data.RandomSampler(dataset, replacement=False), 56 | batch_size=batch_size, 57 | drop_last=False 58 | ) 59 | 60 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 61 | dataset, 62 | num_workers=num_workers, 63 | batch_sampler=_InfiniteSampler(batch_sampler) 64 | )) 65 | 66 | self._length = len(batch_sampler) 67 | 68 | def __iter__(self): 69 | for _ in range(len(self)): 70 | yield next(self._infinite_iterator) 71 | 72 | def __len__(self): 73 | return self._length 74 | -------------------------------------------------------------------------------- /domainbed/lib/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Things that don't belong anywhere else 5 | """ 6 | 7 | import hashlib 8 | import json 9 | import os 10 | import re 11 | import copy 12 | 13 | import itertools 14 | import sys 15 | from shutil import copyfile 16 | from collections import OrderedDict, defaultdict 17 | from numbers import Number 18 | import operator 19 | 20 | import numpy as np 21 | import torch 22 | import tqdm 23 | 24 | 25 | def _detect_folders(args): 26 | folders = [os.path.join(args.aux_dir, path) for path in os.listdir(args.aux_dir) 27 | if not path.startswith(args.dataset)] 28 | folders = sorted([folder for folder in folders if os.path.isdir(folder)]) 29 | print(f"Discover {len(folders)} auxiliary folders") 30 | # assert 20 % (len(folders) + 1) == 0 31 | return folders 32 | 33 | def _detect_path_in_subfolder(args, folder): 34 | subfolders = [os.path.join(folder, path) for path in os.listdir(folder)] 35 | subfolders = sorted([subfolder for subfolder in subfolders if os.path.isdir(subfolder) and "done" in os.listdir(subfolder) and "featurizer_last.pkl" in os.listdir(subfolder)]) 36 | assert len(subfolders), f"No subfolders found with finished auxiliary trainings in {folder}" 37 | index_subfolder = args.hparams_seed % len(subfolders) 38 | return os.path.join(subfolders[index_subfolder], "featurizer_last.pkl") 39 | 40 | def get_featurizer_aux(args): 41 | folders = _detect_folders(args) 42 | index_folder = args.hparams_seed % (len(folders) + 1) 43 | if index_folder == 0: 44 | # case of directly transferred from ImageNet, so no need of any auxiliary trainings 45 | return "imagenet" 46 | return _detect_path_in_subfolder(args, folders[index_folder-1]) 47 | 48 | def get_list_featurizers_aux(args): 49 | folders = _detect_folders(args) 50 | list_aux_featurizers = ["imagenet"] 51 | for folder in folders: 52 | list_aux_featurizers.append(_detect_path_in_subfolder(args, folder)) 53 | return list_aux_featurizers 54 | 55 | def get_name_waparameters(featurizers_aux, lambdas_aux): 56 | weights = {} 57 | list_gen_named_params = [featurizer.named_parameters() for featurizer in featurizers_aux] 58 | for name_0, param_0 in featurizers_aux[0].named_parameters(): 59 | named_params = [next(gen_named_params) for gen_named_params in list_gen_named_params] 60 | new_data = torch.zeros_like(param_0.data) 61 | for i in range(len(featurizers_aux)): 62 | name_i, param_i = named_params[i] 63 | assert name_0 == name_i 64 | new_data = new_data + lambdas_aux[i] * param_i 65 | weights[name_0] = new_data 66 | return weights 67 | 68 | def np_encoder(object): 69 | if isinstance(object, np.generic): 70 | return object.item() 71 | 72 | def get_score(results, test_envs, metric_key="acc"): 73 | val_env_keys = [] 74 | for i in itertools.count(): 75 | acc_key = f'env{i}_out_' + metric_key 76 | if acc_key in results: 77 | if i not in test_envs: 78 | val_env_keys.append(acc_key) 79 | else: 80 | break 81 | assert i > 0 82 | return np.mean([results[key] for key in val_env_keys]) 83 | 84 | class MergeDataset(torch.utils.data.Dataset): 85 | def __init__(self, datasets): 86 | super(MergeDataset, self).__init__() 87 | self.datasets = datasets 88 | 89 | def __getitem__(self, key): 90 | count = 0 91 | for d in self.datasets: 92 | if key - count >= len(d): 93 | count += len(d) 94 | else: 95 | return d[key - count] 96 | raise ValueError(key) 97 | 98 | def __len__(self): 99 | return sum([len(d) for d in self.datasets]) 100 | 101 | def pdb(): 102 | sys.stdout = sys.__stdout__ 103 | import pdb 104 | print("Launching PDB, enter 'n' to step to parent function.") 105 | pdb.set_trace() 106 | 107 | def seed_hash(*args): 108 | """ 109 | Derive an integer hash from all args, for use as a random seed. 110 | """ 111 | args_str = str(args) 112 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) 113 | 114 | def print_separator(): 115 | print("="*80) 116 | 117 | def print_row(row, colwidth=10, latex=False): 118 | if latex: 119 | sep = " & " 120 | end_ = "\\\\" 121 | else: 122 | sep = " " 123 | end_ = "" 124 | 125 | def format_val(x): 126 | if np.issubdtype(type(x), np.floating): 127 | x = "{:.10f}".format(x) 128 | return str(x).ljust(colwidth)[:colwidth] 129 | print(sep.join([format_val(x) for x in row]), end_) 130 | 131 | 132 | class _SplitDataset(torch.utils.data.Dataset): 133 | """Used by split_dataset""" 134 | def __init__(self, underlying_dataset, keys): 135 | super(_SplitDataset, self).__init__() 136 | self.underlying_dataset = underlying_dataset 137 | self.keys = keys 138 | def __getitem__(self, key): 139 | return self.underlying_dataset[self.keys[key]] 140 | def __len__(self): 141 | return len(self.keys) 142 | 143 | def split_dataset(dataset, n, seed=0): 144 | """ 145 | Return a pair of datasets corresponding to a random split of the given 146 | dataset, with n datapoints in the first dataset and the rest in the last, 147 | using the given random seed 148 | """ 149 | assert(n <= len(dataset)) 150 | keys = list(range(len(dataset))) 151 | np.random.RandomState(seed).shuffle(keys) 152 | keys_1 = keys[:n] 153 | keys_2 = keys[n:] 154 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 155 | 156 | def random_pairs_of_minibatches(minibatches): 157 | perm = torch.randperm(len(minibatches)).tolist() 158 | pairs = [] 159 | 160 | for i in range(len(minibatches)): 161 | j = i + 1 if i < (len(minibatches) - 1) else 0 162 | 163 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 164 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 165 | 166 | min_n = min(len(xi), len(xj)) 167 | 168 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 169 | 170 | return pairs 171 | 172 | def split_meta_train_test(minibatches, num_meta_test=1): 173 | n_domains = len(minibatches) 174 | perm = torch.randperm(n_domains).tolist() 175 | pairs = [] 176 | meta_train = perm[:(n_domains-num_meta_test)] 177 | meta_test = perm[-num_meta_test:] 178 | 179 | for i,j in zip(meta_train, cycle(meta_test)): 180 | xi, yi = minibatches[i][0], minibatches[i][1] 181 | xj, yj = minibatches[j][0], minibatches[j][1] 182 | 183 | min_n = min(len(xi), len(xj)) 184 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 185 | 186 | return pairs 187 | 188 | def accuracy(network, loader, device): 189 | correct = 0 190 | total = 0 191 | 192 | network.eval() 193 | with torch.no_grad(): 194 | for x, y in loader: 195 | x = x.to(device) 196 | y = y.to(device) 197 | p = network.predict(x) 198 | batch_weights = torch.ones(len(x)).to(device) 199 | if p.size(1) == 1: 200 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() 201 | else: 202 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() 203 | total += batch_weights.sum().item() 204 | network.train() 205 | 206 | return correct / total 207 | 208 | class Tee: 209 | def __init__(self, fname, mode="a"): 210 | self.stdout = sys.stdout 211 | self.file = open(fname, mode) 212 | 213 | def write(self, message): 214 | self.stdout.write(message) 215 | self.file.write(message) 216 | self.flush() 217 | 218 | def flush(self): 219 | self.stdout.flush() 220 | self.file.flush() 221 | 222 | class ParamDict(OrderedDict): 223 | """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. 224 | A dictionary where the values are Tensors, meant to represent weights of 225 | a model. This subclass lets you perform arithmetic on weights directly.""" 226 | 227 | def __init__(self, *args, **kwargs): 228 | super().__init__(*args, *kwargs) 229 | 230 | def _prototype(self, other, op): 231 | if isinstance(other, Number): 232 | return ParamDict({k: op(v, other) for k, v in self.items()}) 233 | elif isinstance(other, dict): 234 | return ParamDict({k: op(self[k], other[k]) for k in self}) 235 | else: 236 | raise NotImplementedError 237 | 238 | def __add__(self, other): 239 | return self._prototype(other, operator.add) 240 | 241 | def __rmul__(self, other): 242 | return self._prototype(other, operator.mul) 243 | 244 | __mul__ = __rmul__ 245 | 246 | def __neg__(self): 247 | return ParamDict({k: -v for k, v in self.items()}) 248 | 249 | def __rsub__(self, other): 250 | # a- b := a + (-b) 251 | return self.__add__(other.__neg__()) 252 | 253 | __sub__ = __rsub__ 254 | 255 | def __truediv__(self, other): 256 | return self._prototype(other, operator.truediv) 257 | -------------------------------------------------------------------------------- /domainbed/lib/query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Small query library.""" 4 | 5 | import collections 6 | import inspect 7 | import json 8 | import types 9 | import unittest 10 | import warnings 11 | import math 12 | 13 | import numpy as np 14 | 15 | 16 | def make_selector_fn(selector): 17 | """ 18 | If selector is a function, return selector. 19 | Otherwise, return a function corresponding to the selector string. Examples 20 | of valid selector strings and the corresponding functions: 21 | x lambda obj: obj['x'] 22 | x.y lambda obj: obj['x']['y'] 23 | x,y lambda obj: (obj['x'], obj['y']) 24 | """ 25 | if isinstance(selector, str): 26 | if ',' in selector: 27 | parts = selector.split(',') 28 | part_selectors = [make_selector_fn(part) for part in parts] 29 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 30 | elif '.' in selector: 31 | parts = selector.split('.') 32 | part_selectors = [make_selector_fn(part) for part in parts] 33 | def f(obj): 34 | for sel in part_selectors: 35 | obj = sel(obj) 36 | return obj 37 | return f 38 | else: 39 | key = selector.strip() 40 | return lambda obj: obj[key] 41 | elif isinstance(selector, types.FunctionType): 42 | return selector 43 | else: 44 | raise TypeError 45 | 46 | def hashable(obj): 47 | try: 48 | hash(obj) 49 | return obj 50 | except TypeError: 51 | return json.dumps({'_':obj}, sort_keys=True) 52 | 53 | class Q(object): 54 | def __init__(self, list_): 55 | super(Q, self).__init__() 56 | self._list = list_ 57 | 58 | def __len__(self): 59 | return len(self._list) 60 | 61 | def __getitem__(self, key): 62 | return self._list[key] 63 | 64 | def __eq__(self, other): 65 | if isinstance(other, self.__class__): 66 | return self._list == other._list 67 | else: 68 | return self._list == other 69 | 70 | def __str__(self): 71 | return str(self._list) 72 | 73 | def __repr__(self): 74 | return repr(self._list) 75 | 76 | def _append(self, item): 77 | """Unsafe, be careful you know what you're doing.""" 78 | self._list.append(item) 79 | 80 | def group(self, selector): 81 | """ 82 | Group elements by selector and return a list of (group, group_records) 83 | tuples. 84 | """ 85 | selector = make_selector_fn(selector) 86 | groups = {} 87 | for x in self._list: 88 | group = selector(x) 89 | group_key = hashable(group) 90 | if group_key not in groups: 91 | groups[group_key] = (group, Q([])) 92 | groups[group_key][1]._append(x) 93 | results = [groups[key] for key in sorted(groups.keys())] 94 | return Q(results) 95 | 96 | def group_map(self, selector, fn): 97 | """ 98 | Group elements by selector, apply fn to each group, and return a list 99 | of the results. 100 | """ 101 | return self.group(selector).map(fn) 102 | 103 | def map(self, fn): 104 | """ 105 | map self onto fn. If fn takes multiple args, tuple-unpacking 106 | is applied. 107 | """ 108 | if len(inspect.signature(fn).parameters) > 1: 109 | return Q([fn(*x) for x in self._list]) 110 | else: 111 | return Q([fn(x) for x in self._list]) 112 | 113 | def select(self, selector): 114 | selector = make_selector_fn(selector) 115 | return Q([selector(x) for x in self._list]) 116 | 117 | def min(self): 118 | return min(self._list) 119 | 120 | def max(self): 121 | return max(self._list) 122 | 123 | def sum(self): 124 | return sum(self._list) 125 | 126 | def len(self): 127 | return len(self._list) 128 | 129 | def mean(self): 130 | with warnings.catch_warnings(): 131 | warnings.simplefilter("ignore") 132 | return float(np.mean(self._list)) 133 | 134 | def std(self): 135 | with warnings.catch_warnings(): 136 | warnings.simplefilter("ignore") 137 | return float(np.std(self._list)) 138 | 139 | def mean_std(self): 140 | return (self.mean(), self.std()) 141 | 142 | def argmax(self, selector): 143 | selector = make_selector_fn(selector) 144 | return max(self._list, key=selector) 145 | 146 | def filter(self, fn): 147 | return Q([x for x in self._list if fn(x)]) 148 | 149 | def filter_equals(self, selector, value): 150 | """like [x for x in y if x.selector == value]""" 151 | selector = make_selector_fn(selector) 152 | return self.filter(lambda r: selector(r) == value) 153 | 154 | def filter_not_none(self): 155 | return self.filter(lambda r: r is not None) 156 | 157 | def filter_not_nan(self): 158 | return self.filter(lambda r: not np.isnan(r)) 159 | 160 | def flatten(self): 161 | return Q([y for x in self._list for y in x]) 162 | 163 | def unique(self): 164 | result = [] 165 | result_set = set() 166 | for x in self._list: 167 | hashable_x = hashable(x) 168 | if hashable_x not in result_set: 169 | result_set.add(hashable_x) 170 | result.append(x) 171 | return Q(result) 172 | 173 | def sorted(self, key=None): 174 | if key is None: 175 | key = lambda x: x 176 | def key2(x): 177 | x = key(x) 178 | if isinstance(x, (np.floating, float)) and np.isnan(x): 179 | return float('-inf') 180 | else: 181 | return x 182 | return Q(sorted(self._list, key=key2)) 183 | -------------------------------------------------------------------------------- /domainbed/lib/reporting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | import json 6 | import os 7 | 8 | import tqdm 9 | 10 | from domainbed.lib.query import Q 11 | 12 | def load_records(path): 13 | records = [] 14 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))), 15 | ncols=80, 16 | leave=False): 17 | results_path = os.path.join(path, subdir, "results.jsonl") 18 | try: 19 | with open(results_path, "r") as f: 20 | for line in f: 21 | records.append(json.loads(line[:-1])) 22 | except IOError: 23 | pass 24 | 25 | return Q(records) 26 | 27 | def get_grouped_records(records): 28 | """Group records by (trial_seed, dataset, algorithm, test_env). Because 29 | records can have multiple test envs, a given record may appear in more than 30 | one group.""" 31 | result = collections.defaultdict(lambda: []) 32 | for r in records: 33 | for test_env in r["args"]["test_envs"]: 34 | group = (r["args"]["trial_seed"], 35 | r["args"]["dataset"], 36 | r["args"]["algorithm"], 37 | test_env) 38 | result[group].append(r) 39 | return Q([{"trial_seed": t, "dataset": d, "algorithm": a, "test_env": e, 40 | "records": Q(r)} for (t,d,a,e),r in result.items()]) 41 | -------------------------------------------------------------------------------- /domainbed/lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | From https://github.com/meliketoy/wide-resnet.pytorch 5 | """ 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d( 19 | in_planes, 20 | out_planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=1, 24 | bias=True) 25 | 26 | 27 | def conv_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 31 | init.constant_(m.bias, 0) 32 | elif classname.find('BatchNorm') != -1: 33 | init.constant_(m.weight, 1) 34 | init.constant_(m.bias, 0) 35 | 36 | 37 | class wide_basic(nn.Module): 38 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 39 | super(wide_basic, self).__init__() 40 | self.bn1 = nn.BatchNorm2d(in_planes) 41 | self.conv1 = nn.Conv2d( 42 | in_planes, planes, kernel_size=3, padding=1, bias=True) 43 | self.dropout = nn.Dropout(p=dropout_rate) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d( 46 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d( 52 | in_planes, planes, kernel_size=1, stride=stride, 53 | bias=True), ) 54 | 55 | def forward(self, x): 56 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 57 | out = self.conv2(F.relu(self.bn2(out))) 58 | out += self.shortcut(x) 59 | 60 | return out 61 | 62 | 63 | class Wide_ResNet(nn.Module): 64 | """Wide Resnet with the softmax layer chopped off""" 65 | def __init__(self, input_shape, depth, widen_factor, dropout_rate): 66 | super(Wide_ResNet, self).__init__() 67 | self.in_planes = 16 68 | 69 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 70 | n = (depth - 4) / 6 71 | k = widen_factor 72 | 73 | # print('| Wide-Resnet %dx%d' % (depth, k)) 74 | nStages = [16, 16 * k, 32 * k, 64 * k] 75 | 76 | self.conv1 = conv3x3(input_shape[0], nStages[0]) 77 | self.layer1 = self._wide_layer( 78 | wide_basic, nStages[1], n, dropout_rate, stride=1) 79 | self.layer2 = self._wide_layer( 80 | wide_basic, nStages[2], n, dropout_rate, stride=2) 81 | self.layer3 = self._wide_layer( 82 | wide_basic, nStages[3], n, dropout_rate, stride=2) 83 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 84 | 85 | self.n_outputs = nStages[3] 86 | 87 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 88 | strides = [stride] + [1] * (int(num_blocks) - 1) 89 | layers = [] 90 | 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 93 | self.in_planes = planes 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | out = self.conv1(x) 99 | out = self.layer1(out) 100 | out = self.layer2(out) 101 | out = self.layer3(out) 102 | out = F.relu(self.bn1(out)) 103 | out = F.avg_pool2d(out, 8) 104 | return out[:, :, 0, 0] 105 | -------------------------------------------------------------------------------- /domainbed/model_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import itertools 4 | import numpy as np 5 | 6 | def get_test_records(records): 7 | """Given records with a common test env, get the test records (i.e. the 8 | records with *only* that single test env and no other test envs)""" 9 | return records.filter(lambda r: len(r['args']['test_envs']) == 1) 10 | 11 | class SelectionMethod: 12 | """Abstract class whose subclasses implement strategies for model 13 | selection across hparams and timesteps.""" 14 | 15 | def __init__(self): 16 | raise TypeError 17 | 18 | @classmethod 19 | def run_acc(self, run_records): 20 | """ 21 | Given records from a run, return a {val_acc, test_acc} dict representing 22 | the best val-acc and corresponding test-acc for that run. 23 | """ 24 | raise NotImplementedError 25 | 26 | @classmethod 27 | def hparams_accs(self, records): 28 | """ 29 | Given all records from a single (dataset, algorithm, test env) pair, 30 | return a sorted list of (run_acc, records) tuples. 31 | """ 32 | return (records.group('args.hparams_seed') 33 | .map(lambda _, run_records: 34 | ( 35 | self.run_acc(run_records), 36 | run_records 37 | ) 38 | ).filter(lambda x: x[0] is not None) 39 | .sorted(key=lambda x: x[0]['val_acc'])[::-1] 40 | ) 41 | 42 | @classmethod 43 | def sweep_acc(self, records): 44 | """ 45 | Given all records from a single (dataset, algorithm, test env) pair, 46 | return the mean test acc of the k runs with the top val accs. 47 | """ 48 | _hparams_accs = self.hparams_accs(records) 49 | if len(_hparams_accs): 50 | return _hparams_accs[0][0]['test_acc'] 51 | else: 52 | return None 53 | 54 | class OracleSelectionMethod(SelectionMethod): 55 | """Like Selection method which picks argmax(test_out_acc) across all hparams 56 | and checkpoints, but instead of taking the argmax over all 57 | checkpoints, we pick the last checkpoint, i.e. no early stopping.""" 58 | name = "test-domain validation set (oracle)" 59 | 60 | @classmethod 61 | def run_acc(self, run_records): 62 | run_records = run_records.filter(lambda r: 63 | len(r['args']['test_envs']) == 1) 64 | if not len(run_records): 65 | return None 66 | test_env = run_records[0]['args']['test_envs'][0] 67 | test_out_acc_key = 'env{}_out_acc'.format(test_env) 68 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 69 | chosen_record = run_records.sorted(lambda r: r['step'])[-1] 70 | return { 71 | 'val_acc': chosen_record[test_out_acc_key], 72 | 'test_acc': chosen_record[test_in_acc_key] 73 | } 74 | 75 | class IIDAccuracySelectionMethod(SelectionMethod): 76 | """Picks argmax(mean(env_out_acc for env in train_envs))""" 77 | name = "training-domain validation set" 78 | 79 | @classmethod 80 | def _step_acc(self, record): 81 | """Given a single record, return a {val_acc, test_acc} dict.""" 82 | test_env = record['args']['test_envs'][0] 83 | val_env_keys = [] 84 | for i in itertools.count(): 85 | if f'env{i}_out_acc' not in record: 86 | break 87 | if i != test_env: 88 | val_env_keys.append(f'env{i}_out_acc') 89 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 90 | return { 91 | 'val_acc': np.mean([record[key] for key in val_env_keys]), 92 | 'test_acc': record[test_in_acc_key] 93 | } 94 | 95 | @classmethod 96 | def run_acc(self, run_records): 97 | test_records = get_test_records(run_records) 98 | if not len(test_records): 99 | return None 100 | return test_records.map(self._step_acc).argmax('val_acc') 101 | 102 | class LeaveOneOutSelectionMethod(SelectionMethod): 103 | """Picks (hparams, step) by leave-one-out cross validation.""" 104 | name = "leave-one-domain-out cross-validation" 105 | 106 | @classmethod 107 | def _step_acc(self, records): 108 | """Return the {val_acc, test_acc} for a group of records corresponding 109 | to a single step.""" 110 | test_records = get_test_records(records) 111 | if len(test_records) != 1: 112 | return None 113 | 114 | test_env = test_records[0]['args']['test_envs'][0] 115 | n_envs = 0 116 | for i in itertools.count(): 117 | if f'env{i}_out_acc' not in records[0]: 118 | break 119 | n_envs += 1 120 | val_accs = np.zeros(n_envs) - 1 121 | for r in records.filter(lambda r: len(r['args']['test_envs']) == 2): 122 | val_env = (set(r['args']['test_envs']) - set([test_env])).pop() 123 | val_accs[val_env] = r['env{}_in_acc'.format(val_env)] 124 | val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:]) 125 | if any([v==-1 for v in val_accs]): 126 | return None 127 | val_acc = np.sum(val_accs) / (n_envs-1) 128 | return { 129 | 'val_acc': val_acc, 130 | 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)] 131 | } 132 | 133 | @classmethod 134 | def run_acc(self, records): 135 | step_accs = records.group('step').map(lambda step, step_records: 136 | self._step_acc(step_records) 137 | ).filter_not_none() 138 | if len(step_accs): 139 | return step_accs.argmax('val_acc') 140 | else: 141 | return None 142 | -------------------------------------------------------------------------------- /domainbed/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models 7 | 8 | from domainbed.lib import wide_resnet 9 | import copy 10 | 11 | 12 | def remove_batch_norm_from_resnet(model): 13 | fuse = torch.nn.utils.fusion.fuse_conv_bn_eval 14 | model.eval() 15 | 16 | model.conv1 = fuse(model.conv1, model.bn1) 17 | model.bn1 = Identity() 18 | 19 | for name, module in model.named_modules(): 20 | if name.startswith("layer") and len(name) == 6: 21 | for b, bottleneck in enumerate(module): 22 | for name2, module2 in bottleneck.named_modules(): 23 | if name2.startswith("conv"): 24 | bn_name = "bn" + name2[-1] 25 | setattr(bottleneck, name2, 26 | fuse(module2, getattr(bottleneck, bn_name))) 27 | setattr(bottleneck, bn_name, Identity()) 28 | if isinstance(bottleneck.downsample, torch.nn.Sequential): 29 | bottleneck.downsample[0] = fuse(bottleneck.downsample[0], 30 | bottleneck.downsample[1]) 31 | bottleneck.downsample[1] = Identity() 32 | model.train() 33 | return model 34 | 35 | 36 | class Identity(nn.Module): 37 | """An identity layer""" 38 | def __init__(self): 39 | super(Identity, self).__init__() 40 | 41 | def forward(self, x): 42 | return x 43 | 44 | 45 | class MLP(nn.Module): 46 | """Just an MLP""" 47 | def __init__(self, n_inputs, n_outputs, hparams): 48 | super(MLP, self).__init__() 49 | self.input = nn.Linear(n_inputs, hparams['mlp_width']) 50 | self.dropout = nn.Dropout(hparams['mlp_dropout']) 51 | self.hiddens = nn.ModuleList([ 52 | nn.Linear(hparams['mlp_width'], hparams['mlp_width']) 53 | for _ in range(hparams['mlp_depth']-2)]) 54 | self.output = nn.Linear(hparams['mlp_width'], n_outputs) 55 | self.n_outputs = n_outputs 56 | 57 | def forward(self, x): 58 | x = self.input(x) 59 | x = self.dropout(x) 60 | x = F.relu(x) 61 | for hidden in self.hiddens: 62 | x = hidden(x) 63 | x = self.dropout(x) 64 | x = F.relu(x) 65 | x = self.output(x) 66 | return x 67 | 68 | 69 | class ResNet(torch.nn.Module): 70 | """ResNet with the softmax chopped off and the batchnorm frozen""" 71 | def __init__(self, input_shape, hparams): 72 | super(ResNet, self).__init__() 73 | if hparams.get('resnet18', False): 74 | self.network = torchvision.models.resnet18(pretrained=True) 75 | self.n_outputs = 512 76 | else: 77 | self.network = torchvision.models.resnet50(pretrained=True) 78 | self.n_outputs = 2048 79 | 80 | # self.network = remove_batch_norm_from_resnet(self.network) 81 | 82 | # adapt number of channels 83 | nc = input_shape[0] 84 | if nc != 3: 85 | tmp = self.network.conv1.weight.data.clone() 86 | 87 | self.network.conv1 = nn.Conv2d( 88 | nc, 64, kernel_size=(7, 7), 89 | stride=(2, 2), padding=(3, 3), bias=False) 90 | 91 | for i in range(nc): 92 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :] 93 | 94 | # save memory 95 | del self.network.fc 96 | self.network.fc = Identity() 97 | 98 | self.freeze_bn() 99 | self.hparams = hparams 100 | self.dropout = nn.Dropout(hparams['resnet_dropout']) 101 | 102 | def forward(self, x): 103 | """Encode x into a feature vector of size n_outputs.""" 104 | return self.dropout(self.network(x)) 105 | 106 | def train(self, mode=True): 107 | """ 108 | Override the default train() to freeze the BN parameters 109 | """ 110 | super().train(mode) 111 | self.freeze_bn() 112 | 113 | def freeze_bn(self): 114 | for m in self.network.modules(): 115 | if isinstance(m, nn.BatchNorm2d): 116 | m.eval() 117 | 118 | 119 | class MNIST_CNN(nn.Module): 120 | """ 121 | Hand-tuned architecture for MNIST. 122 | Weirdness I've noticed so far with this architecture: 123 | - adding a linear layer after the mean-pool in features hurts 124 | RotatedMNIST-100 generalization severely. 125 | """ 126 | n_outputs = 128 127 | 128 | def __init__(self, input_shape): 129 | super(MNIST_CNN, self).__init__() 130 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1) 131 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 132 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1) 133 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1) 134 | 135 | self.bn0 = nn.GroupNorm(8, 64) 136 | self.bn1 = nn.GroupNorm(8, 128) 137 | self.bn2 = nn.GroupNorm(8, 128) 138 | self.bn3 = nn.GroupNorm(8, 128) 139 | 140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 141 | 142 | def forward(self, x): 143 | x = self.conv1(x) 144 | x = F.relu(x) 145 | x = self.bn0(x) 146 | 147 | x = self.conv2(x) 148 | x = F.relu(x) 149 | x = self.bn1(x) 150 | 151 | x = self.conv3(x) 152 | x = F.relu(x) 153 | x = self.bn2(x) 154 | 155 | x = self.conv4(x) 156 | x = F.relu(x) 157 | x = self.bn3(x) 158 | 159 | x = self.avgpool(x) 160 | x = x.view(len(x), -1) 161 | return x 162 | 163 | 164 | class ContextNet(nn.Module): 165 | def __init__(self, input_shape): 166 | super(ContextNet, self).__init__() 167 | 168 | # Keep same dimensions 169 | padding = (5 - 1) // 2 170 | self.context_net = nn.Sequential( 171 | nn.Conv2d(input_shape[0], 64, 5, padding=padding), 172 | nn.BatchNorm2d(64), 173 | nn.ReLU(), 174 | nn.Conv2d(64, 64, 5, padding=padding), 175 | nn.BatchNorm2d(64), 176 | nn.ReLU(), 177 | nn.Conv2d(64, 1, 5, padding=padding), 178 | ) 179 | 180 | def forward(self, x): 181 | return self.context_net(x) 182 | 183 | 184 | def Featurizer(input_shape, hparams): 185 | """Auto-select an appropriate featurizer for the given input shape.""" 186 | if len(input_shape) == 1: 187 | return MLP(input_shape[0], hparams["mlp_width"], hparams) 188 | elif input_shape[1:3] == (28, 28): 189 | return MNIST_CNN(input_shape) 190 | elif input_shape[1:3] == (32, 32): 191 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.) 192 | elif input_shape[1:3] == (224, 224): 193 | return ResNet(input_shape, hparams) 194 | else: 195 | raise NotImplementedError 196 | 197 | 198 | def Classifier(in_features, out_features): 199 | return torch.nn.Linear(in_features, out_features) 200 | 201 | 202 | class WholeFish(nn.Module): 203 | def __init__(self, input_shape, num_classes, hparams, weights=None): 204 | super(WholeFish, self).__init__() 205 | featurizer = Featurizer(input_shape, hparams) 206 | classifier = Classifier( 207 | featurizer.n_outputs, 208 | num_classes) 209 | self.net = nn.Sequential( 210 | featurizer, classifier 211 | ) 212 | if weights is not None: 213 | self.load_state_dict(copy.deepcopy(weights)) 214 | 215 | def reset_weights(self, weights): 216 | self.load_state_dict(copy.deepcopy(weights)) 217 | 218 | def forward(self, x): 219 | return self.net(x) 220 | -------------------------------------------------------------------------------- /domainbed/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | -------------------------------------------------------------------------------- /domainbed/scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | 6 | import argparse 7 | import functools 8 | import glob 9 | import pickle 10 | import itertools 11 | import json 12 | import os 13 | import random 14 | import sys 15 | 16 | import numpy as np 17 | import tqdm 18 | 19 | from domainbed import datasets 20 | from domainbed import algorithms 21 | from domainbed.lib import misc, reporting 22 | from domainbed import model_selection 23 | from domainbed.lib.query import Q 24 | import warnings 25 | 26 | def format_mean(data, latex): 27 | """Given a list of datapoints, return a string describing their mean and 28 | standard error""" 29 | if len(data) == 0: 30 | return None, None, "X" 31 | mean = 100 * np.mean(list(data)) 32 | err = 100 * np.std(list(data) / np.sqrt(len(data))) 33 | if latex: 34 | return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err) 35 | else: 36 | return mean, err, "{:.1f} +/- {:.1f}".format(mean, err) 37 | 38 | def print_table(table, header_text, row_labels, col_labels, colwidth=10, 39 | latex=True): 40 | """Pretty-print a 2D array of data, optionally with row/col labels""" 41 | print("") 42 | 43 | if latex: 44 | num_cols = len(table[0]) 45 | print("\\begin{center}") 46 | print("\\adjustbox{max width=\\textwidth}{%") 47 | print("\\begin{tabular}{l" + "c" * num_cols + "}") 48 | print("\\toprule") 49 | else: 50 | print("--------", header_text) 51 | 52 | for row, label in zip(table, row_labels): 53 | row.insert(0, label) 54 | 55 | if latex: 56 | col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}" 57 | for col_label in col_labels] 58 | table.insert(0, col_labels) 59 | 60 | for r, row in enumerate(table): 61 | misc.print_row(row, colwidth=colwidth, latex=latex) 62 | if latex and r == 0: 63 | print("\\midrule") 64 | if latex: 65 | print("\\bottomrule") 66 | print("\\end{tabular}}") 67 | print("\\end{center}") 68 | 69 | def print_results_tables(records, selection_method, latex): 70 | """Given all records, print a results table for each dataset.""" 71 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 72 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 73 | ).filter(lambda g: g["sweep_acc"] is not None) 74 | 75 | # read algorithm names and sort (predefined order) 76 | alg_names = Q(records).select("args.algorithm").unique() 77 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 78 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 79 | 80 | # read dataset names and sort (lexicographic order) 81 | dataset_names = Q(records).select("args.dataset").unique().sorted() 82 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 83 | 84 | for dataset in dataset_names: 85 | if latex: 86 | print() 87 | print("\\subsubsection{{{}}}".format(dataset)) 88 | test_envs = range(datasets.num_environments(dataset)) 89 | 90 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 91 | for i, algorithm in enumerate(alg_names): 92 | means = [] 93 | for j, test_env in enumerate(test_envs): 94 | trial_accs = (grouped_records 95 | .filter_equals( 96 | "dataset, algorithm, test_env", 97 | (dataset, algorithm, test_env) 98 | ).select("sweep_acc")) 99 | mean, err, table[i][j] = format_mean(trial_accs, latex) 100 | means.append(mean) 101 | if None in means: 102 | table[i][-1] = "X" 103 | else: 104 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 105 | 106 | col_labels = [ 107 | "Algorithm", 108 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 109 | "Avg" 110 | ] 111 | header_text = (f"Dataset: {dataset}, " 112 | f"model selection method: {selection_method.name}") 113 | print_table(table, header_text, alg_names, list(col_labels), 114 | colwidth=20, latex=latex) 115 | 116 | # Print an "averages" table 117 | if latex: 118 | print() 119 | print("\\subsubsection{Averages}") 120 | 121 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 122 | for i, algorithm in enumerate(alg_names): 123 | means = [] 124 | for j, dataset in enumerate(dataset_names): 125 | trial_averages = (grouped_records 126 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 127 | .group("trial_seed") 128 | .map(lambda trial_seed, group: 129 | group.select("sweep_acc").mean() 130 | ) 131 | ) 132 | mean, err, table[i][j] = format_mean(trial_averages, latex) 133 | means.append(mean) 134 | if None in means: 135 | table[i][-1] = "X" 136 | else: 137 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 138 | 139 | col_labels = ["Algorithm", *dataset_names, "Avg"] 140 | header_text = f"Averages, model selection method: {selection_method.name}" 141 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 142 | latex=latex) 143 | 144 | if __name__ == "__main__": 145 | np.set_printoptions(suppress=True) 146 | 147 | parser = argparse.ArgumentParser( 148 | description="Domain generalization testbed") 149 | parser.add_argument("--input_dir", type=str, required=True) 150 | parser.add_argument("--latex", action="store_true") 151 | args = parser.parse_args() 152 | 153 | results_file = "results.tex" if args.latex else "results.txt" 154 | 155 | sys.stdout = misc.Tee(os.path.join(args.input_dir, results_file), "w") 156 | 157 | records = reporting.load_records(args.input_dir) 158 | 159 | if args.latex: 160 | print("\\documentclass{article}") 161 | print("\\usepackage{booktabs}") 162 | print("\\usepackage{adjustbox}") 163 | print("\\begin{document}") 164 | print("\\section{Full DomainBed results}") 165 | print("% Total records:", len(records)) 166 | else: 167 | print("Total records:", len(records)) 168 | 169 | SELECTION_METHODS = [ 170 | model_selection.IIDAccuracySelectionMethod, 171 | model_selection.LeaveOneOutSelectionMethod, 172 | model_selection.OracleSelectionMethod, 173 | ] 174 | 175 | for selection_method in SELECTION_METHODS: 176 | if args.latex: 177 | print() 178 | print("\\subsection{{Model selection: {}}}".format( 179 | selection_method.name)) 180 | print_results_tables(records, selection_method, args.latex) 181 | 182 | if args.latex: 183 | print("\\end{document}") 184 | -------------------------------------------------------------------------------- /domainbed/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torchvision.datasets import MNIST 4 | import xml.etree.ElementTree as ET 5 | from zipfile import ZipFile 6 | import argparse 7 | import tarfile 8 | import shutil 9 | import gdown 10 | import uuid 11 | import json 12 | import os 13 | 14 | from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 15 | from wilds.datasets.fmow_dataset import FMoWDataset 16 | 17 | 18 | # utils ####################################################################### 19 | 20 | def stage_path(data_dir, name): 21 | full_path = os.path.join(data_dir, name) 22 | 23 | if not os.path.exists(full_path): 24 | os.makedirs(full_path) 25 | 26 | return full_path 27 | 28 | 29 | def download_and_extract(url, dst, remove=True): 30 | gdown.download(url, dst, quiet=False) 31 | 32 | if dst.endswith(".tar.gz"): 33 | tar = tarfile.open(dst, "r:gz") 34 | tar.extractall(os.path.dirname(dst)) 35 | tar.close() 36 | 37 | if dst.endswith(".tar"): 38 | tar = tarfile.open(dst, "r:") 39 | tar.extractall(os.path.dirname(dst)) 40 | tar.close() 41 | 42 | if dst.endswith(".zip"): 43 | zf = ZipFile(dst, "r") 44 | zf.extractall(os.path.dirname(dst)) 45 | zf.close() 46 | 47 | if remove: 48 | os.remove(dst) 49 | 50 | 51 | # VLCS ######################################################################## 52 | 53 | # Slower, but builds dataset from the original sources 54 | # 55 | # def download_vlcs(data_dir): 56 | # full_path = stage_path(data_dir, "VLCS") 57 | # 58 | # tmp_path = os.path.join(full_path, "tmp/") 59 | # if not os.path.exists(tmp_path): 60 | # os.makedirs(tmp_path) 61 | # 62 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 63 | # lines = f.readlines() 64 | # files = [line.strip().split() for line in lines] 65 | # 66 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 67 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 68 | # 69 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 70 | # os.path.join(tmp_path, "caltech101.tar.gz")) 71 | # 72 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 73 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 74 | # 75 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 76 | # tar.extractall(tmp_path) 77 | # tar.close() 78 | # 79 | # for src, dst in files: 80 | # class_folder = os.path.join(data_dir, dst) 81 | # 82 | # if not os.path.exists(class_folder): 83 | # os.makedirs(class_folder) 84 | # 85 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 86 | # 87 | # if "labelme" in src: 88 | # # download labelme from the web 89 | # gdown.download(src, dst, quiet=False) 90 | # else: 91 | # src = os.path.join(tmp_path, src) 92 | # shutil.copyfile(src, dst) 93 | # 94 | # shutil.rmtree(tmp_path) 95 | 96 | 97 | def download_vlcs(data_dir): 98 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 99 | full_path = stage_path(data_dir, "VLCS") 100 | 101 | download_and_extract("https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 102 | os.path.join(data_dir, "VLCS.tar.gz")) 103 | 104 | 105 | # MNIST ####################################################################### 106 | 107 | def download_mnist(data_dir): 108 | # Original URL: http://yann.lecun.com/exdb/mnist/ 109 | full_path = stage_path(data_dir, "MNIST") 110 | MNIST(full_path, download=True) 111 | 112 | 113 | # PACS ######################################################################## 114 | 115 | def download_pacs(data_dir): 116 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 117 | full_path = stage_path(data_dir, "PACS") 118 | 119 | download_and_extract("https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", 120 | os.path.join(data_dir, "PACS.zip")) 121 | 122 | os.rename(os.path.join(data_dir, "kfold"), 123 | full_path) 124 | 125 | 126 | # Office-Home ################################################################# 127 | 128 | def download_office_home(data_dir): 129 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 130 | full_path = stage_path(data_dir, "office_home") 131 | 132 | download_and_extract("https://drive.google.com/uc?id=1uY0pj7oFsjMxRwaD3Sxy0jgel0fsYXLC", 133 | os.path.join(data_dir, "office_home.zip")) 134 | 135 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), 136 | full_path) 137 | 138 | 139 | # DomainNET ################################################################### 140 | 141 | def download_domain_net(data_dir): 142 | # Original URL: http://ai.bu.edu/M3SDA/ 143 | full_path = stage_path(data_dir, "domain_net") 144 | 145 | urls = [ 146 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 147 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 148 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 149 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 150 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 151 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip" 152 | ] 153 | 154 | for url in urls: 155 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 156 | 157 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 158 | for line in f.readlines(): 159 | try: 160 | os.remove(os.path.join(full_path, line.strip())) 161 | except OSError: 162 | pass 163 | 164 | 165 | # TerraIncognita ############################################################## 166 | 167 | def download_terra_incognita(data_dir): 168 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 169 | # New URL: http://lila.science/datasets/caltech-camera-traps 170 | 171 | full_path = stage_path(data_dir, "terra_incognita") 172 | 173 | download_and_extract( 174 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz", 175 | os.path.join(full_path, "terra_incognita_images.tar.gz")) 176 | 177 | download_and_extract( 178 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip", 179 | os.path.join(full_path, "caltech_camera_traps.json.zip")) 180 | 181 | include_locations = ["38", "46", "100", "43"] 182 | 183 | include_categories = [ 184 | "bird", "bobcat", "cat", "coyote", "dog", "empty", "opossum", "rabbit", 185 | "raccoon", "squirrel" 186 | ] 187 | 188 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 189 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json") 190 | destination_folder = full_path 191 | 192 | stats = {} 193 | 194 | if not os.path.exists(destination_folder): 195 | os.mkdir(destination_folder) 196 | 197 | with open(annotations_file, "r") as f: 198 | data = json.load(f) 199 | 200 | category_dict = {} 201 | for item in data['categories']: 202 | category_dict[item['id']] = item['name'] 203 | 204 | for image in data['images']: 205 | image_location = image['location'] 206 | 207 | if image_location not in include_locations: 208 | continue 209 | 210 | loc_folder = os.path.join(destination_folder, 211 | 'location_' + str(image_location) + '/') 212 | 213 | if not os.path.exists(loc_folder): 214 | os.mkdir(loc_folder) 215 | 216 | image_id = image['id'] 217 | image_fname = image['file_name'] 218 | 219 | for annotation in data['annotations']: 220 | if annotation['image_id'] == image_id: 221 | if image_location not in stats: 222 | stats[image_location] = {} 223 | 224 | category = category_dict[annotation['category_id']] 225 | 226 | if category not in include_categories: 227 | continue 228 | 229 | if category not in stats[image_location]: 230 | stats[image_location][category] = 0 231 | else: 232 | stats[image_location][category] += 1 233 | 234 | loc_cat_folder = os.path.join(loc_folder, category + '/') 235 | 236 | if not os.path.exists(loc_cat_folder): 237 | os.mkdir(loc_cat_folder) 238 | 239 | dst_path = os.path.join(loc_cat_folder, image_fname) 240 | src_path = os.path.join(images_folder, image_fname) 241 | 242 | shutil.copyfile(src_path, dst_path) 243 | 244 | shutil.rmtree(images_folder) 245 | os.remove(annotations_file) 246 | 247 | 248 | # SVIRO ################################################################# 249 | 250 | def download_sviro(data_dir): 251 | # Original URL: https://sviro.kl.dfki.de 252 | full_path = stage_path(data_dir, "sviro") 253 | 254 | download_and_extract("https://sviro.kl.dfki.de/?wpdmdl=1731", 255 | os.path.join(data_dir, "sviro_grayscale_rectangle_classification.zip")) 256 | 257 | os.rename(os.path.join(data_dir, "SVIRO_DOMAINBED"), 258 | full_path) 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = argparse.ArgumentParser(description='Download datasets') 263 | parser.add_argument('--data_dir', type=str, required=True) 264 | args = parser.parse_args() 265 | 266 | # download_mnist(args.data_dir) 267 | # download_pacs(args.data_dir) 268 | # download_office_home(args.data_dir) 269 | # download_domain_net(args.data_dir) 270 | # download_vlcs(args.data_dir) 271 | download_terra_incognita(args.data_dir) 272 | # download_sviro(args.data_dir) 273 | # Camelyon17Dataset(root_dir=args.data_dir, download=True) 274 | # FMoWDataset(root_dir=args.data_dir, download=True) 275 | -------------------------------------------------------------------------------- /domainbed/scripts/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | import json 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from domainbed import datasets, algorithms_inference 11 | from domainbed.lib import misc 12 | from domainbed.lib.fast_data_loader import FastDataLoader 13 | from domainbed.lib import misc 14 | 15 | 16 | def _get_args(): 17 | parser = argparse.ArgumentParser(description='Inference with weight averaging') 18 | 19 | parser.add_argument('--data_dir', type=str) 20 | parser.add_argument('--input_dir', type=str) 21 | parser.add_argument('--dataset', type=str) 22 | parser.add_argument('--test_env', type=int) 23 | parser.add_argument('--weight_selection', type=str, default="uniform") # or "greedy" 24 | parser.add_argument( 25 | '--trial_seed', 26 | type=int, 27 | default="0", 28 | ) 29 | 30 | inf_args = parser.parse_args() 31 | return inf_args 32 | 33 | 34 | def create_splits(domain, inf_args, dataset, _filter): 35 | splits = [] 36 | 37 | for env_i, env in enumerate(dataset): 38 | if domain == "test" and env_i != inf_args.test_env: 39 | continue 40 | elif domain == "train" and env_i == inf_args.test_env: 41 | continue 42 | 43 | if _filter == "full": 44 | splits.append(env) 45 | else: 46 | out_, in_ = misc.split_dataset( 47 | env, int(len(env) * 0.2), misc.seed_hash(inf_args.trial_seed, env_i) 48 | ) 49 | if _filter == "in": 50 | splits.append(in_) 51 | elif _filter == "out": 52 | splits.append(out_) 53 | else: 54 | raise ValueError(_filter) 55 | 56 | return splits 57 | 58 | 59 | def get_dict_folder_to_score(inf_args): 60 | output_folders = [ 61 | os.path.join(input_dir, path) 62 | for input_dir in inf_args.input_dir.split(",") 63 | for path in os.listdir(input_dir) 64 | ] 65 | output_folders = [ 66 | output_folder for output_folder in output_folders 67 | if os.path.isdir(output_folder) and "done" in os.listdir(output_folder) and "model_best.pkl" in os.listdir(output_folder) 68 | ] 69 | 70 | dict_folder_to_score = {} 71 | for folder in output_folders: 72 | model_path = os.path.join(folder, "model_best.pkl") 73 | save_dict = torch.load(model_path) 74 | train_args = save_dict["args"] 75 | 76 | if train_args["dataset"] != inf_args.dataset: 77 | continue 78 | if train_args["test_envs"] != [inf_args.test_env]: 79 | continue 80 | if train_args["trial_seed"] != inf_args.trial_seed and inf_args.trial_seed != -1: 81 | continue 82 | score = misc.get_score( 83 | json.loads(save_dict["results"]), 84 | [inf_args.test_env]) 85 | dict_folder_to_score[folder] = score 86 | 87 | if len(dict_folder_to_score) == 0: 88 | raise ValueError(f"No folders found for: {inf_args}") 89 | return dict_folder_to_score 90 | 91 | def get_wa_results( 92 | good_checkpoints, dataset, data_names, data_splits, device 93 | ): 94 | wa_algorithm = algorithms_inference.WA( 95 | dataset.input_shape, 96 | dataset.num_classes, 97 | ) 98 | for folder in good_checkpoints: 99 | save_dict = torch.load(os.path.join(folder, "model_best.pkl")) 100 | train_args = save_dict["args"] 101 | 102 | # load individual weights 103 | algorithm = algorithms_inference.ERM( 104 | dataset.input_shape, dataset.num_classes, 105 | save_dict["model_hparams"] 106 | ) 107 | algorithm.load_state_dict(save_dict["model_dict"], strict=False) 108 | wa_algorithm.add_weights(algorithm.network) 109 | del algorithm 110 | 111 | wa_algorithm.to(device) 112 | wa_algorithm.eval() 113 | random.seed(train_args["seed"]) 114 | np.random.seed(train_args["seed"]) 115 | torch.manual_seed(train_args["seed"]) 116 | torch.backends.cudnn.deterministic = True 117 | torch.backends.cudnn.benchmark = False 118 | 119 | data_loaders = [ 120 | FastDataLoader( 121 | dataset=split, 122 | batch_size=64, 123 | num_workers=dataset.N_WORKERS 124 | ) for split in data_splits 125 | ] 126 | 127 | data_evals = zip(data_names, data_loaders) 128 | dict_results = {} 129 | 130 | for name, loader in data_evals: 131 | print(f"Inference at {name}") 132 | dict_results[name + "_acc"] = misc.accuracy(wa_algorithm, loader, device) 133 | 134 | dict_results["length"] = len(good_checkpoints) 135 | return dict_results 136 | 137 | 138 | 139 | def print_results(dict_results): 140 | results_keys = sorted(list(dict_results.keys())) 141 | misc.print_row(results_keys, colwidth=12) 142 | misc.print_row([dict_results[key] for key in results_keys], colwidth=12) 143 | 144 | 145 | def main(): 146 | inf_args = _get_args() 147 | device = "cuda" if torch.cuda.is_available() else "cpu" 148 | 149 | print(f"Begin DiWA for: {inf_args} with device: {device}") 150 | 151 | if inf_args.dataset in vars(datasets): 152 | dataset_class = vars(datasets)[inf_args.dataset] 153 | dataset = dataset_class( 154 | inf_args.data_dir, [inf_args.test_env], hparams={"data_augmentation": False} 155 | ) 156 | else: 157 | raise NotImplementedError 158 | 159 | # load individual folders and their corresponding scores on train_out 160 | dict_folder_to_score = get_dict_folder_to_score(inf_args) 161 | 162 | # load data: test and optionally train_out for greedy weight selection 163 | data_splits, data_names = [], [] 164 | dict_domain_to_filter = {"test": "full"} 165 | if inf_args.weight_selection == "greedy": 166 | assert inf_args.trial_seed != -1 167 | dict_domain_to_filter["train"] = "out" 168 | for domain in dict_domain_to_filter: 169 | _data_splits = create_splits(domain, inf_args, dataset, dict_domain_to_filter[domain]) 170 | if domain == "train": 171 | data_splits.append(misc.MergeDataset(_data_splits)) 172 | else: 173 | data_splits.append(_data_splits[0]) 174 | data_names.append(domain) 175 | 176 | ## sort individual members by decreasing accuracy on train_out 177 | sorted_checkpoints = sorted(dict_folder_to_score.keys(), key=lambda x: dict_folder_to_score[x], reverse=True) 178 | for ckpt in sorted_checkpoints: 179 | print("Found", ckpt, dict_folder_to_score[ckpt]) 180 | 181 | # compute score after weight averaging 182 | if inf_args.weight_selection == "greedy": 183 | # greedy weight selection 184 | selected_indexes = [] 185 | best_result = -float("inf") 186 | dict_best_results = {} 187 | ## incrementally add them to the WA 188 | for i in range(0, len(sorted_checkpoints)): 189 | selected_indexes.append(i) 190 | selected_checkpoints = [sorted_checkpoints[index] for index in selected_indexes] 191 | 192 | ood_results = get_wa_results( 193 | selected_checkpoints, dataset, data_names, data_splits, device 194 | ) 195 | ood_results["i"] = i 196 | ## accept only if WA's accuracy is improved 197 | if ood_results["train_acc"] >= best_result: 198 | dict_best_results = ood_results 199 | ood_results["accept"] = 1 200 | best_result = ood_results["train_acc"] 201 | print(f"Accepting index {i}") 202 | else: 203 | ood_results["accept"] = 0 204 | selected_indexes.pop(-1) 205 | print(f"Skipping index {i}") 206 | print_results(ood_results) 207 | 208 | ## print final scores 209 | dict_best_results["final"] = 1 210 | print_results(dict_best_results) 211 | 212 | elif inf_args.weight_selection == "uniform": 213 | dict_results = get_wa_results( 214 | sorted_checkpoints, dataset, data_names, data_splits, device 215 | ) 216 | print_results(dict_results) 217 | 218 | else: 219 | raise ValueError(inf_args.weight_selection) 220 | 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /domainbed/scripts/list_top_hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Example usage: 5 | python -u -m domainbed.scripts.list_top_hparams \ 6 | --input_dir domainbed/misc/test_sweep_data --algorithm ERM \ 7 | --dataset VLCS --test_env 0 8 | """ 9 | 10 | import collections 11 | 12 | 13 | import argparse 14 | import functools 15 | import glob 16 | import pickle 17 | import itertools 18 | import json 19 | import os 20 | import random 21 | import sys 22 | 23 | import numpy as np 24 | import tqdm 25 | 26 | from domainbed import datasets 27 | from domainbed import algorithms 28 | from domainbed.lib import misc, reporting 29 | from domainbed import model_selection 30 | from domainbed.lib.query import Q 31 | import warnings 32 | 33 | def todo_rename(records, selection_method, latex): 34 | 35 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 36 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 37 | ).filter(lambda g: g["sweep_acc"] is not None) 38 | 39 | # read algorithm names and sort (predefined order) 40 | alg_names = Q(records).select("args.algorithm").unique() 41 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 42 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 43 | 44 | # read dataset names and sort (lexicographic order) 45 | dataset_names = Q(records).select("args.dataset").unique().sorted() 46 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 47 | 48 | for dataset in dataset_names: 49 | if latex: 50 | print() 51 | print("\\subsubsection{{{}}}".format(dataset)) 52 | test_envs = range(datasets.num_environments(dataset)) 53 | 54 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 55 | for i, algorithm in enumerate(alg_names): 56 | means = [] 57 | for j, test_env in enumerate(test_envs): 58 | trial_accs = (grouped_records 59 | .filter_equals( 60 | "dataset, algorithm, test_env", 61 | (dataset, algorithm, test_env) 62 | ).select("sweep_acc")) 63 | mean, err, table[i][j] = format_mean(trial_accs, latex) 64 | means.append(mean) 65 | if None in means: 66 | table[i][-1] = "X" 67 | else: 68 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 69 | 70 | col_labels = [ 71 | "Algorithm", 72 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 73 | "Avg" 74 | ] 75 | header_text = (f"Dataset: {dataset}, " 76 | f"model selection method: {selection_method.name}") 77 | print_table(table, header_text, alg_names, list(col_labels), 78 | colwidth=20, latex=latex) 79 | 80 | # Print an "averages" table 81 | if latex: 82 | print() 83 | print("\\subsubsection{Averages}") 84 | 85 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 86 | for i, algorithm in enumerate(alg_names): 87 | means = [] 88 | for j, dataset in enumerate(dataset_names): 89 | trial_averages = (grouped_records 90 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 91 | .group("trial_seed") 92 | .map(lambda trial_seed, group: 93 | group.select("sweep_acc").mean() 94 | ) 95 | ) 96 | mean, err, table[i][j] = format_mean(trial_averages, latex) 97 | means.append(mean) 98 | if None in means: 99 | table[i][-1] = "X" 100 | else: 101 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 102 | 103 | col_labels = ["Algorithm", *dataset_names, "Avg"] 104 | header_text = f"Averages, model selection method: {selection_method.name}" 105 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 106 | latex=latex) 107 | 108 | if __name__ == "__main__": 109 | np.set_printoptions(suppress=True) 110 | 111 | parser = argparse.ArgumentParser( 112 | description="Domain generalization testbed") 113 | parser.add_argument("--input_dir", required=True) 114 | parser.add_argument('--dataset', required=True) 115 | parser.add_argument('--algorithm', required=True) 116 | parser.add_argument('--test_env', type=int, required=True) 117 | args = parser.parse_args() 118 | 119 | records = reporting.load_records(args.input_dir) 120 | print("Total records:", len(records)) 121 | 122 | records = reporting.get_grouped_records(records) 123 | records = records.filter( 124 | lambda r: 125 | r['dataset'] == args.dataset and 126 | r['algorithm'] == args.algorithm and 127 | r['test_env'] == args.test_env 128 | ) 129 | 130 | SELECTION_METHODS = [ 131 | model_selection.IIDAccuracySelectionMethod, 132 | model_selection.LeaveOneOutSelectionMethod, 133 | model_selection.OracleSelectionMethod, 134 | ] 135 | 136 | for selection_method in SELECTION_METHODS: 137 | print(f'Model selection: {selection_method.name}') 138 | 139 | for group in records: 140 | print(f"trial_seed: {group['trial_seed']}") 141 | best_hparams = selection_method.hparams_accs(group['records']) 142 | for run_acc, hparam_records in best_hparams: 143 | print(f"\t{run_acc}") 144 | for r in hparam_records: 145 | assert(r['hparams'] == hparam_records[0]['hparams']) 146 | print("\t\thparams:") 147 | for k, v in sorted(hparam_records[0]['hparams'].items()): 148 | print('\t\t\t{}: {}'.format(k, v)) 149 | print("\t\toutput_dirs:") 150 | output_dirs = hparam_records.select('args.output_dir').unique() 151 | for output_dir in output_dirs: 152 | print(f"\t\t\t{output_dir}") -------------------------------------------------------------------------------- /domainbed/scripts/save_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Save some representative images from each dataset to disk. 5 | """ 6 | import random 7 | import torch 8 | import argparse 9 | from domainbed import hparams_registry 10 | from domainbed import datasets 11 | import imageio 12 | import os 13 | from tqdm import tqdm 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='Domain generalization') 17 | parser.add_argument('--data_dir', type=str) 18 | parser.add_argument('--output_dir', type=str) 19 | args = parser.parse_args() 20 | 21 | os.makedirs(args.output_dir, exist_ok=True) 22 | datasets_to_save = ['OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST', 'SVIRO'] 23 | 24 | for dataset_name in tqdm(datasets_to_save): 25 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 26 | dataset = datasets.get_dataset_class(dataset_name)( 27 | args.data_dir, 28 | list(range(datasets.num_environments(dataset_name))), 29 | hparams) 30 | for env_idx, env in enumerate(tqdm(dataset)): 31 | for i in tqdm(range(50)): 32 | idx = random.choice(list(range(len(env)))) 33 | x, y = env[idx] 34 | while y > 10: 35 | idx = random.choice(list(range(len(env)))) 36 | x, y = env[idx] 37 | if x.shape[0] == 2: 38 | x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3,:,:] 39 | if x.min() < 0: 40 | mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None] 41 | std = torch.tensor([0.229, 0.224, 0.225])[:,None,None] 42 | x = (x * std) + mean 43 | assert(x.min() >= 0) 44 | assert(x.max() <= 1) 45 | x = (x * 255.99) 46 | x = x.numpy().astype('uint8').transpose(1,2,0) 47 | imageio.imwrite( 48 | os.path.join(args.output_dir, 49 | f'{dataset_name}_env{env_idx}{dataset.ENVIRONMENTS[env_idx]}_{i}_idx{idx}_class{y}.png'), 50 | x) 51 | -------------------------------------------------------------------------------- /domainbed/scripts/sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Run sweeps 5 | """ 6 | 7 | import argparse 8 | import copy 9 | import hashlib 10 | import json 11 | import os 12 | import shutil 13 | import gc 14 | import torch 15 | import numpy as np 16 | 17 | from domainbed.lib import misc 18 | from domainbed import command_launchers 19 | from domainbed.scripts import train as train_script 20 | 21 | import tqdm 22 | import shlex 23 | 24 | class Job: 25 | NOT_LAUNCHED = 'Not launched' 26 | INCOMPLETE = 'Incomplete' 27 | DONE = 'Done' 28 | 29 | def __init__(self, train_args, sweep_output_dir): 30 | args_str = json.dumps(train_args, sort_keys=True) 31 | 32 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 33 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 34 | 35 | self.train_args = copy.deepcopy(train_args) 36 | self.train_args['output_dir'] = self.output_dir 37 | command = ['python', '-m', 'domainbed.scripts.train'] 38 | for k, v in sorted(self.train_args.items()): 39 | if isinstance(v, list): 40 | v = ' '.join([str(v_) for v_ in v]) 41 | elif isinstance(v, str): 42 | v = shlex.quote(v) 43 | command.append(f'--{k} {v}') 44 | self.command_str = ' '.join(command) 45 | 46 | if os.path.exists(os.path.join(self.output_dir, 'done')): 47 | self.state = Job.DONE 48 | elif os.path.exists(self.output_dir): 49 | self.state = Job.INCOMPLETE 50 | else: 51 | self.state = Job.NOT_LAUNCHED 52 | 53 | def __str__(self): 54 | job_info = (self.train_args['dataset'], 55 | self.train_args['algorithm'], 56 | self.train_args['test_envs'], 57 | self.train_args['hparams_seed']) 58 | return '{}: {} {}'.format( 59 | self.state, 60 | self.output_dir, 61 | job_info) 62 | 63 | @staticmethod 64 | def launch(jobs, launcher_fn): 65 | print('Launching...') 66 | jobs = jobs.copy() 67 | np.random.shuffle(jobs) 68 | print('Making job directories:') 69 | for job in tqdm.tqdm(jobs, leave=False): 70 | os.makedirs(job.output_dir, exist_ok=True) 71 | commands = [job.command_str for job in jobs] 72 | launcher_fn(commands) 73 | print(f'Launched {len(jobs)} jobs!') 74 | 75 | @staticmethod 76 | def delete(jobs): 77 | print('Deleting...') 78 | for job in jobs: 79 | shutil.rmtree(job.output_dir) 80 | print(f'Deleted {len(jobs)} jobs!') 81 | 82 | def all_test_env_combinations(n): 83 | """ 84 | For a dataset with n >= 3 envs, return all combinations of 1 and 2 test 85 | envs. 86 | """ 87 | assert(n >= 3) 88 | for i in range(n): 89 | yield [i] 90 | for j in range(i+1, n): 91 | yield [i, j] 92 | 93 | def make_args_lp(args): 94 | dict_args_lp = {} 95 | dict_args_lp['output_dir'] = args.output_dir_lp 96 | dict_args_lp['dataset'] = args.dataset 97 | dict_args_lp['algorithm'] = args.algorithm 98 | dict_args_lp['test_envs'] = args.test_env 99 | dict_args_lp['data_dir'] = args.data_dir 100 | dict_args_lp["what_is_trainable"] = "classifier" 101 | raw_args_lp = [ 102 | item 103 | for key, value in dict_args_lp.items() 104 | for item in ["--" + key, str(value)] 105 | ] 106 | return raw_args_lp 107 | 108 | 109 | def make_args_list(args): 110 | args_list = [] 111 | for trial_seed in range(args.n_trials): 112 | for hparams_seed in range(args.n_hparams_from, args.n_hparams): 113 | train_args = {} 114 | train_args['trial_seed'] = trial_seed 115 | train_args['path_init'] = args.path_init 116 | train_args['dataset'] = args.dataset 117 | train_args['algorithm'] = args.algorithm 118 | train_args['test_envs'] = [args.test_env] 119 | train_args['hparams_seed'] = hparams_seed 120 | train_args['data_dir'] = args.data_dir 121 | train_args['aux_dir'] = args.aux_dir 122 | train_args['fusing_range'] = args.fusing_range 123 | 124 | train_args['seed'] = misc.seed_hash( 125 | args.dataset, 126 | args.algorithm, 127 | [args.test_env], 128 | hparams_seed, 129 | trial_seed) 130 | args_list.append(train_args) 131 | return args_list 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description='Run a sweep') 136 | parser.add_argument('command', choices=['launch', 'delete']) 137 | parser.add_argument('--dataset', type=str, required=True) 138 | parser.add_argument('--test_env', type=int, required=True) 139 | parser.add_argument('--algorithm', default="ERM") 140 | parser.add_argument('--n_hparams_from', type=int, default=0) 141 | parser.add_argument('--n_hparams', type=int, default=20) 142 | parser.add_argument('--output_dir', type=str, required=True) 143 | parser.add_argument('--data_dir', type=str, required=True) 144 | parser.add_argument('--seed', type=int, default=0) 145 | parser.add_argument('--n_trials', type=int, default=3) 146 | 147 | # New args for ratatouille 148 | parser.add_argument('--aux_dir', type=str, default="") 149 | parser.add_argument('--fusing_range', type=float, default=-1) 150 | parser.add_argument('--output_dir_lp', type=str, default=None) 151 | 152 | args = parser.parse_args() 153 | 154 | # 1. LP procedure to initialize a shared classifier 155 | if args.output_dir_lp is None: 156 | args.path_init = "" 157 | else: 158 | if os.path.exists(os.path.join(args.output_dir_lp, 'done')): 159 | print('LP already done.') 160 | # elif os.path.isdir(args.output_dir_lp): 161 | # print("incomplete") 162 | else: 163 | print('Do LP.') 164 | raw_args_lp = make_args_lp(args) 165 | train_args_lp = train_script.parse_args(raw_args_lp) 166 | train_script.main(train_args_lp) 167 | print('Done LP.') 168 | gc.collect() 169 | torch.cuda.empty_cache() 170 | # be sure to free gpus memory 171 | args.path_init = os.path.join(args.output_dir_lp, "network_best.pkl") 172 | assert os.path.exists(args.path_init) 173 | print('Init path is ready.') 174 | 175 | # 2. Hyperparamerer sweep 176 | args_list = make_args_list(args) 177 | jobs = [Job(train_args, args.output_dir) for train_args in args_list] 178 | for job in jobs: 179 | print(job) 180 | print("{} jobs: {} done, {} incomplete, {} not launched.".format( 181 | len(jobs), 182 | len([j for j in jobs if j.state == Job.DONE]), 183 | len([j for j in jobs if j.state == Job.INCOMPLETE]), 184 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) 185 | ) 186 | 187 | if args.command == 'launch': 188 | to_launch = [j for j in jobs if j.state in [Job.NOT_LAUNCHED, Job.INCOMPLETE]] 189 | print(f'About to launch {len(to_launch)} jobs.') 190 | Job.launch(to_launch, launcher_fn = command_launchers.multi_gpu_launcher) 191 | 192 | elif args.command == 'delete': 193 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 194 | print(f'About to delete {len(to_delete)} jobs.') 195 | Job.delete(to_delete) 196 | -------------------------------------------------------------------------------- /domainbed/scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import math 11 | 12 | import numpy as np 13 | import PIL 14 | import torch 15 | import torchvision 16 | import torch.utils.data 17 | 18 | from domainbed import datasets 19 | from domainbed import hparams_registry 20 | from domainbed import algorithms 21 | from domainbed.lib import misc 22 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 23 | 24 | 25 | def main(args): 26 | os.makedirs(args.output_dir, exist_ok=True) 27 | sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) 28 | sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) 29 | 30 | print("Environment:") 31 | print("\tPython: {}".format(sys.version.split(" ")[0])) 32 | print("\tPyTorch: {}".format(torch.__version__)) 33 | print("\tTorchvision: {}".format(torchvision.__version__)) 34 | print("\tCUDA: {}".format(torch.version.cuda)) 35 | print("\tCUDNN: {}".format(torch.backends.cudnn.version())) 36 | print("\tNumPy: {}".format(np.__version__)) 37 | print("\tPIL: {}".format(PIL.__version__)) 38 | 39 | print('Args:') 40 | for k, v in sorted(vars(args).items()): 41 | print('\t{}: {}'.format(k, v)) 42 | 43 | if args.hparams_seed == 0: 44 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 45 | else: 46 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 47 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 48 | 49 | print('HParams:') 50 | for k, v in sorted(hparams.items()): 51 | print('\t{}: {}'.format(k, v)) 52 | 53 | random.seed(args.seed) 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | torch.backends.cudnn.deterministic = True 57 | torch.backends.cudnn.benchmark = False 58 | 59 | if torch.cuda.is_available(): 60 | device = "cuda" 61 | else: 62 | device = "cpu" 63 | 64 | if args.dataset in vars(datasets): 65 | dataset = vars(datasets)[args.dataset](args.data_dir, args.test_envs, hparams) 66 | else: 67 | raise NotImplementedError 68 | 69 | # Split each env into an 'in-split' and an 'out-split'. We'll train on 70 | # each in-split except the test envs, and evaluate on all splits. 71 | in_splits = [] 72 | out_splits = [] 73 | for env_i, env in enumerate(dataset): 74 | out, in_ = misc.split_dataset(env, 75 | int(len(env)*args.holdout_fraction), 76 | misc.seed_hash(args.trial_seed, env_i)) 77 | in_splits.append(in_) 78 | out_splits.append(out) 79 | train_loaders = [InfiniteDataLoader( 80 | dataset=env, 81 | weights=None, 82 | batch_size=hparams['batch_size'], 83 | num_workers=dataset.N_WORKERS) 84 | for i, env in enumerate(in_splits) 85 | if i not in args.test_envs] 86 | eval_loaders = [FastDataLoader( 87 | dataset=env, 88 | batch_size=64, 89 | num_workers=dataset.N_WORKERS) 90 | for env in in_splits + out_splits] 91 | eval_loader_names = ['env{}_in'.format(i) 92 | for i in range(len(in_splits))] 93 | eval_loader_names += ['env{}_out'.format(i) 94 | for i in range(len(out_splits))] 95 | 96 | dict_featurizers_aux = {} 97 | if args.aux_dir not in ["", "none"]: 98 | if args.fusing_range >= 0: 99 | # interpolating multiple featurizers 100 | list_featurizers_aux = misc.get_list_featurizers_aux(args) 101 | list_kappas_aux = [math.exp(args.fusing_range * random.random()) for _ in list_featurizers_aux] 102 | dict_featurizers_aux = { 103 | featurizer_aux: kappa_aux / sum(list_kappas_aux) 104 | for featurizer_aux, kappa_aux in zip(list_featurizers_aux, list_kappas_aux) 105 | } 106 | else: 107 | # selecting only one single auxiliary featurizer 108 | dict_featurizers_aux = {misc.get_featurizer_aux(args): 1.} 109 | print(f"Dictionnary mapping featurizers to interpolating lambda: {dict_featurizers_aux}") 110 | 111 | algorithm = algorithms.get_algorithm_class(args.algorithm)( 112 | input_shape=dataset.input_shape, 113 | num_classes=dataset.num_classes, 114 | hparams=hparams, 115 | what_is_trainable=args.what_is_trainable, 116 | path_init=args.path_init, 117 | dict_featurizers_aux=dict_featurizers_aux) 118 | algorithm.to(device) 119 | 120 | train_minibatches_iterator = zip(*train_loaders) 121 | checkpoint_vals = collections.defaultdict(lambda: []) 122 | 123 | steps_per_epoch = min([len(env)/hparams['batch_size'] for env in in_splits]) 124 | n_steps = dataset.N_STEPS 125 | 126 | def save_checkpoint(results=None, suffix="best"): 127 | save_dict = { 128 | "args": vars(args), 129 | "model_input_shape": dataset.input_shape, 130 | "model_num_classes": dataset.num_classes, 131 | "model_hparams": hparams, 132 | "model_dict": algorithm.state_dict() 133 | } 134 | if results is not None: 135 | save_dict["results"] = results 136 | torch.save(save_dict, os.path.join(args.output_dir, "model_" + suffix + ".pkl")) 137 | torch.save(algorithm.network.state_dict(), os.path.join(args.output_dir, "network_" + suffix + ".pkl")) 138 | 139 | best_score = -float("inf") 140 | last_results_keys = None 141 | for step in range(0, n_steps): 142 | step_start_time = time.time() 143 | minibatches_device = [(x.to(device), y.to(device)) for x,y in next(train_minibatches_iterator)] 144 | step_vals = algorithm.update(minibatches_device) 145 | checkpoint_vals['step_time'].append(time.time() - step_start_time) 146 | 147 | for key, val in step_vals.items(): 148 | checkpoint_vals[key].append(val) 149 | 150 | if (step % dataset.CHECKPOINT_FREQ == 0) or (step == n_steps - 1): 151 | results = { 152 | 'step': step, 153 | 'epoch': step / steps_per_epoch, 154 | } 155 | 156 | for key, val in checkpoint_vals.items(): 157 | results[key] = np.mean(val) 158 | 159 | for name, loader, in zip(eval_loader_names, eval_loaders): 160 | acc = misc.accuracy(algorithm, loader, device) 161 | results[name+'_acc'] = acc 162 | 163 | results['mem_gb'] = torch.cuda.max_memory_allocated() / (1024.*1024.*1024.) 164 | 165 | results_keys = sorted(results.keys()) 166 | if results_keys != last_results_keys: 167 | misc.print_row(results_keys, colwidth=12) 168 | last_results_keys = results_keys 169 | misc.print_row([results[key] for key in results_keys], 170 | colwidth=12) 171 | 172 | results.update({ 173 | 'hparams': hparams, 174 | 'args': vars(args) 175 | }) 176 | 177 | with open(os.path.join(args.output_dir, 'results.jsonl'), 'a') as f: 178 | f.write(json.dumps(results, sort_keys=True, default=misc.np_encoder) + "\n") 179 | 180 | ## DiWA ## 181 | current_score = misc.get_score(results, args.test_envs) 182 | if current_score > best_score: 183 | best_score = current_score 184 | print(f"Saving new best score at step: {step} at path: model_best.pkl") 185 | save_checkpoint( 186 | results=json.dumps(results, sort_keys=True, default=misc.np_encoder), 187 | ) 188 | algorithm.to(device) 189 | 190 | checkpoint_vals = collections.defaultdict(lambda: []) 191 | 192 | # saving the last featurizer's weights 193 | torch.save(algorithm.featurizer.state_dict(), os.path.join(args.output_dir, "featurizer_last.pkl")) 194 | with open(os.path.join(args.output_dir, 'done'), 'w') as f: 195 | f.write('done') 196 | algorithm.cpu() 197 | 198 | def parse_args(raw_args=None): 199 | parser = argparse.ArgumentParser(description='Domain generalization') 200 | parser.add_argument('--data_dir', type=str) 201 | parser.add_argument('--output_dir', type=str, required=True) 202 | parser.add_argument('--dataset', type=str) 203 | parser.add_argument('--test_envs', type=int, nargs='+', default=[]) 204 | parser.add_argument('--algorithm', type=str, default="ERM") 205 | parser.add_argument('--hparams_seed', type=int, default=0, 206 | help='Seed for random hparams (0 means "default hparams")') 207 | parser.add_argument('--trial_seed', type=int, default=0, 208 | help='Trial number (used for seeding split_dataset and ' 209 | 'random_hparams).') 210 | parser.add_argument('--seed', type=int, default=0, 211 | help='Seed for everything else') 212 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 213 | 214 | # New args for ratatouille 215 | parser.add_argument('--what_is_trainable', type=str, default="all") 216 | parser.add_argument('--path_init', type=str, default="") 217 | parser.add_argument('--aux_dir', type=str, default="") 218 | parser.add_argument('--fusing_range', type=float, default=-1) 219 | args = parser.parse_args(raw_args) 220 | return args 221 | 222 | 223 | if __name__ == "__main__": 224 | args = parse_args() 225 | main(args) 226 | --------------------------------------------------------------------------------