├── __init__.py ├── exps ├── __init__.py └── vis_attention.py ├── domainbed ├── lib │ ├── __init__.py │ ├── writers.py │ ├── reporting.py │ ├── fast_data_loader.py │ ├── wide_resnet.py │ ├── logger.py │ ├── query.py │ ├── misc.py │ └── swa_utils.py ├── losses │ ├── __init__.py │ ├── variance_loss.py │ ├── importance_loss.py │ └── kullback_leibler_divergence.py ├── __init__.py ├── scripts │ ├── __init__.py │ ├── save_images.py │ ├── list_top_hparams.py │ ├── collect_results.py │ ├── sweep.py │ ├── download.py │ └── train.py ├── test │ ├── __init__.py │ ├── lib │ │ ├── __init__.py │ │ ├── test_misc.py │ │ └── test_query.py │ ├── scripts │ │ ├── __init__.py │ │ ├── test_train.py │ │ ├── test_collect_results.py │ │ └── test_sweep.py │ ├── helpers.py │ ├── test_hparams_registry.py │ ├── test_networks.py │ ├── test_models.py │ ├── test_datasets.py │ └── test_model_selection.py ├── command_launchers.py ├── backbones.py ├── evaluator.py ├── model_selection.py ├── registry.py ├── networks.py ├── hparams_registry.py ├── datasets.py └── algorithms.py ├── requirements.txt ├── LICENSE └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exps/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /domainbed/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /domainbed/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /domainbed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | -------------------------------------------------------------------------------- /domainbed/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | -------------------------------------------------------------------------------- /domainbed/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | -------------------------------------------------------------------------------- /domainbed/test/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | -------------------------------------------------------------------------------- /domainbed/test/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown==4.2.0 2 | numpy>=1.22 3 | Pillow==9.0.1 4 | prettytable==2.1.0 5 | sconf==0.2.3 6 | tensorboardX==2.5 -------------------------------------------------------------------------------- /domainbed/losses/variance_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def variance_loss(x): 4 | mean_weightings = x.mean(dim=1).mean(dim=0) 5 | std = torch.std(x, dim=1) 6 | coefficients_of_variation = std / mean_weightings 7 | var_loss = -(coefficients_of_variation.sum(dim=0) / x.shape[0]) 8 | return var_loss 9 | -------------------------------------------------------------------------------- /domainbed/losses/importance_loss.py: -------------------------------------------------------------------------------- 1 | def importance(x): 2 | return x.sum(dim=0) 3 | 4 | 5 | def squared_coefficient_of_variation(x): 6 | x = x.float() 7 | cv_squared = x.var() / (x.mean()**2 + 1e-10) 8 | return cv_squared 9 | 10 | # Maximum loss = num_experts 11 | def importance_loss(x): 12 | imp = importance(x) 13 | cv_squared = squared_coefficient_of_variation(imp) 14 | return cv_squared -------------------------------------------------------------------------------- /domainbed/losses/kullback_leibler_divergence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from losses.importance_loss import importance 5 | 6 | def kl_divergence(x, num_experts): 7 | p = (importance(x) + 1e-10) / x.shape[0] 8 | q = torch.full(size=(num_experts, ), fill_value= 1.0 / num_experts, device=p.device) 9 | divergence = torch.sum(p * torch.log(p / q)) 10 | return divergence 11 | -------------------------------------------------------------------------------- /domainbed/test/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | DEBUG_DATASETS = ['Debug224'] 6 | DEBUG_ALGORITHMS = ['SFMOE'] 7 | 8 | def make_minibatches(dataset, batch_size): 9 | """Test helper to make a minibatches array like train.py""" 10 | minibatches = [] 11 | for env in dataset: 12 | X = torch.stack([env[i][0] for i in range(batch_size)]).cuda() 13 | y = torch.stack([torch.as_tensor(env[i][1]) 14 | for i in range(batch_size)]).cuda() 15 | minibatches.append((X, y)) 16 | return minibatches 17 | -------------------------------------------------------------------------------- /domainbed/test/lib/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import unittest 4 | from domainbed.lib import misc 5 | 6 | class TestMisc(unittest.TestCase): 7 | 8 | def test_make_weights_for_balanced_classes(self): 9 | dataset = [('A', 0), ('B', 1), ('C', 0), ('D', 2), ('E', 3), ('F', 0)] 10 | result = misc.make_weights_for_balanced_classes(dataset) 11 | self.assertEqual(result.sum(), 1) 12 | self.assertEqual(result[0], result[2]) 13 | self.assertEqual(result[1], result[3]) 14 | self.assertEqual(3 * result[0], result[1]) 15 | -------------------------------------------------------------------------------- /domainbed/lib/writers.py: -------------------------------------------------------------------------------- 1 | class Writer: 2 | def add_scalars(self, tag_scalar_dic, global_step): 3 | raise NotImplementedError() 4 | 5 | def add_scalars_with_prefix(self, tag_scalar_dic, global_step, prefix): 6 | tag_scalar_dic = {prefix + k: v for k, v in tag_scalar_dic.items()} 7 | self.add_scalars(tag_scalar_dic, global_step) 8 | 9 | 10 | class TBWriter(Writer): 11 | def __init__(self, dir_path): 12 | from tensorboardX import SummaryWriter 13 | 14 | self.writer = SummaryWriter(dir_path, flush_secs=30) 15 | 16 | def add_scalars(self, tag_scalar_dic, global_step): 17 | for tag, scalar in tag_scalar_dic.items(): 18 | self.writer.add_scalar(tag, scalar, global_step) 19 | 20 | 21 | def get_writer(dir_path): 22 | """ 23 | Args: 24 | dir_path: tb dir 25 | """ 26 | writer = TBWriter(dir_path) 27 | 28 | return writer 29 | -------------------------------------------------------------------------------- /domainbed/test/test_hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import unittest 4 | import itertools 5 | 6 | from domainbed import hparams_registry 7 | from domainbed import datasets 8 | from domainbed import algorithms 9 | 10 | from parameterized import parameterized 11 | 12 | class TestHparamsRegistry(unittest.TestCase): 13 | 14 | @parameterized.expand(itertools.product(algorithms.ALGORITHMS, datasets.DATASETS)) 15 | def test_random_hparams_deterministic(self, algorithm_name, dataset_name): 16 | """Test that hparams_registry.random_hparams is deterministic""" 17 | a = hparams_registry.random_hparams(algorithm_name, dataset_name, 0) 18 | b = hparams_registry.random_hparams(algorithm_name, dataset_name, 0) 19 | self.assertEqual(a.keys(), b.keys()) 20 | for key in a.keys(): 21 | self.assertEqual(a[key], b[key], key) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /domainbed/test/test_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import os 7 | import subprocess 8 | import sys 9 | import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | from domainbed import datasets 16 | from domainbed import hparams_registry 17 | from domainbed import algorithms 18 | from domainbed import networks 19 | from domainbed.test import helpers 20 | 21 | from parameterized import parameterized 22 | 23 | 24 | class TestNetworks(unittest.TestCase): 25 | 26 | @parameterized.expand(itertools.product(helpers.DEBUG_DATASETS)) 27 | def test_featurizer(self, dataset_name): 28 | """Test that Featurizer() returns a module which can take a 29 | correctly-sized input and return a correctly-sized output.""" 30 | batch_size = 8 31 | hparams = hparams_registry.default_hparams('EIL', dataset_name) 32 | dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) 33 | input_ = helpers.make_minibatches(dataset, batch_size)[0][0] 34 | input_shape = dataset.input_shape 35 | algorithm = networks.Featurizer(input_shape, hparams).cuda() 36 | output = algorithm(input_) 37 | self.assertEqual(list(output.shape), [batch_size, algorithm.n_outputs]) 38 | -------------------------------------------------------------------------------- /domainbed/lib/reporting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | import json 6 | import os 7 | 8 | import tqdm 9 | 10 | from domainbed.lib.query import Q 11 | 12 | def load_records(path): 13 | records = [] 14 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))), 15 | ncols=80, 16 | leave=False): 17 | results_path = os.path.join(path, subdir, "results.jsonl") 18 | try: 19 | with open(results_path, "r") as f: 20 | for line in f: 21 | records.append(json.loads(line[:-1])) 22 | except IOError: 23 | pass 24 | 25 | return Q(records) 26 | 27 | def get_grouped_records(records): 28 | """Group records by (trial_seed, dataset, algorithm, test_env). Because 29 | records can have multiple test envs, a given record may appear in more than 30 | one group.""" 31 | result = collections.defaultdict(lambda: []) 32 | for r in records: 33 | for test_env in r["args"]["test_envs"]: 34 | group = (r["args"]["trial_seed"], 35 | r["args"]["dataset"], 36 | r["args"]["algorithm"], 37 | test_env) 38 | result[group].append(r) 39 | return Q([{"trial_seed": t, "dataset": d, "algorithm": a, "test_env": e, 40 | "records": Q(r)} for (t,d,a,e),r in result.items()]) 41 | -------------------------------------------------------------------------------- /domainbed/test/test_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Unit tests.""" 4 | 5 | import argparse 6 | import itertools 7 | import json 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import unittest 13 | import uuid 14 | 15 | import torch 16 | 17 | from domainbed import datasets 18 | from domainbed import hparams_registry 19 | from domainbed import algorithms 20 | from domainbed import networks 21 | from domainbed.test import helpers 22 | 23 | from parameterized import parameterized 24 | 25 | 26 | class TestAlgorithms(unittest.TestCase): 27 | 28 | @parameterized.expand(itertools.product(helpers.DEBUG_DATASETS, helpers.DEBUG_ALGORITHMS)) 29 | def test_init_update_predict(self, dataset_name, algorithm_name): 30 | """Test that a given algorithm inits, updates and predicts without raising 31 | errors.""" 32 | batch_size = 8 33 | hparams = hparams_registry.default_hparams(algorithm_name, dataset_name) 34 | dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) 35 | minibatches = helpers.make_minibatches(dataset, batch_size) 36 | algorithm_class = algorithms.get_algorithm_class(algorithm_name) 37 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset), 38 | hparams).cuda() 39 | for _ in range(3): 40 | self.assertIsNotNone(algorithm.update(minibatches)) 41 | algorithm.eval() 42 | self.assertEqual(list(algorithm.predict(minibatches[0][0]).shape), 43 | [batch_size, dataset.num_classes]) 44 | -------------------------------------------------------------------------------- /domainbed/test/test_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Unit tests.""" 4 | 5 | import argparse 6 | import itertools 7 | import json 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import unittest 13 | import uuid 14 | 15 | import torch 16 | 17 | from domainbed import datasets 18 | from domainbed import hparams_registry 19 | from domainbed import algorithms 20 | from domainbed import networks 21 | 22 | from parameterized import parameterized 23 | 24 | from domainbed.test import helpers 25 | 26 | class TestDatasets(unittest.TestCase): 27 | 28 | @parameterized.expand(itertools.product(datasets.DATASETS)) 29 | @unittest.skipIf('DATA_DIR' not in os.environ, 'needs DATA_DIR environment ' 30 | 'variable') 31 | def test_dataset_erm(self, dataset_name): 32 | """ 33 | Test that ERM can complete one step on a given dataset without raising 34 | an error. 35 | Also test that num_environments() works correctly. 36 | """ 37 | batch_size = 8 38 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 39 | dataset = datasets.get_dataset_class(dataset_name)( 40 | os.environ['DATA_DIR'], [], hparams) 41 | self.assertEqual(datasets.num_environments(dataset_name), 42 | len(dataset)) 43 | algorithm = algorithms.get_algorithm_class('ERM')( 44 | dataset.input_shape, 45 | dataset.num_classes, 46 | len(dataset), 47 | hparams).cuda() 48 | minibatches = helpers.make_minibatches(dataset, batch_size) 49 | algorithm.update(minibatches) 50 | -------------------------------------------------------------------------------- /domainbed/test/lib/test_query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import unittest 4 | from domainbed.lib.query import Q, make_selector_fn 5 | 6 | class TestQuery(unittest.TestCase): 7 | def test_everything(self): 8 | numbers = Q([1, 4, 2]) 9 | people = Q([ 10 | {'name': 'Bob', 'age': 40}, 11 | {'name': 'Alice', 'age': 20}, 12 | {'name': 'Bob', 'age': 10} 13 | ]) 14 | 15 | self.assertEqual(numbers.select(lambda x: 2*x), [2, 8, 4]) 16 | 17 | self.assertEqual(numbers.min(), 1) 18 | self.assertEqual(numbers.max(), 4) 19 | self.assertEqual(numbers.mean(), 7/3) 20 | 21 | self.assertEqual(people.select('name'), ['Bob', 'Alice', 'Bob']) 22 | 23 | self.assertEqual( 24 | set(people.group('name').map(lambda _,g: g.select('age').mean())), 25 | set([25, 20]) 26 | ) 27 | 28 | self.assertEqual(people.argmax('age'), people[0]) 29 | 30 | def test_group_by_unhashable(self): 31 | jobs = Q([ 32 | {'hparams': {1:2}, 'score': 3}, 33 | {'hparams': {1:2}, 'score': 4}, 34 | {'hparams': {2:4}, 'score': 5} 35 | ]) 36 | grouped = jobs.group('hparams') 37 | self.assertEqual(grouped, [ 38 | ({1:2}, [jobs[0], jobs[1]]), 39 | ({2:4}, [jobs[2]]) 40 | ]) 41 | 42 | def test_comma_selector(self): 43 | struct = {'a': {'b': 1}, 'c': 2} 44 | fn = make_selector_fn('a.b,c') 45 | self.assertEqual(fn(struct), (1, 2)) 46 | 47 | def test_unique(self): 48 | numbers = Q([1,2,1,3,2,1,3,1,2,3]) 49 | self.assertEqual(numbers.unique(), [1,2,3]) 50 | -------------------------------------------------------------------------------- /domainbed/test/scripts/test_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # import argparse 4 | # import itertools 5 | import json 6 | import os 7 | import subprocess 8 | # import sys 9 | # import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | # import datasets 16 | # import hparams_registry 17 | # import algorithms 18 | # import networks 19 | # from parameterized import parameterized 20 | 21 | # import test.helpers 22 | 23 | class TestTrain(unittest.TestCase): 24 | 25 | @unittest.skipIf('DATA_DIR' not in os.environ, 'needs DATA_DIR environment ' 26 | 'variable') 27 | def test_end_to_end(self): 28 | """Test that train.py successfully completes one step""" 29 | output_dir = os.path.join('/tmp', str(uuid.uuid4())) 30 | os.makedirs(output_dir, exist_ok=True) 31 | 32 | subprocess.run(f'python -m domainbed.scripts.train --dataset RotatedMNIST ' 33 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 34 | f'--steps=501', shell=True) 35 | 36 | with open(os.path.join(output_dir, 'results.jsonl')) as f: 37 | lines = [l[:-1] for l in f] 38 | last_epoch = json.loads(lines[-1]) 39 | self.assertEqual(last_epoch['step'], 500) 40 | # Conservative values; anything lower and something's likely wrong. 41 | self.assertGreater(last_epoch['env0_in_acc'], 0.80) 42 | self.assertGreater(last_epoch['env1_in_acc'], 0.95) 43 | self.assertGreater(last_epoch['env2_in_acc'], 0.95) 44 | self.assertGreater(last_epoch['env3_in_acc'], 0.95) 45 | self.assertGreater(last_epoch['env3_in_acc'], 0.95) 46 | 47 | with open(os.path.join(output_dir, 'out.txt')) as f: 48 | text = f.read() 49 | self.assertTrue('500' in text) 50 | -------------------------------------------------------------------------------- /domainbed/scripts/save_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Save some representative images from each dataset to disk. 5 | """ 6 | import random 7 | import torch 8 | import argparse 9 | from domainbed import hparams_registry 10 | from domainbed import datasets 11 | import imageio 12 | import os 13 | from tqdm import tqdm 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='Domain generalization') 17 | parser.add_argument('--data_dir', type=str) 18 | parser.add_argument('--output_dir', type=str) 19 | args = parser.parse_args() 20 | 21 | os.makedirs(args.output_dir, exist_ok=True) 22 | datasets_to_save = ['OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST', 'SVIRO'] 23 | 24 | for dataset_name in tqdm(datasets_to_save): 25 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 26 | dataset = datasets.get_dataset_class(dataset_name)( 27 | args.data_dir, 28 | list(range(datasets.num_environments(dataset_name))), 29 | hparams) 30 | for env_idx, env in enumerate(tqdm(dataset)): 31 | for i in tqdm(range(50)): 32 | idx = random.choice(list(range(len(env)))) 33 | x, y = env[idx] 34 | while y > 10: 35 | idx = random.choice(list(range(len(env)))) 36 | x, y = env[idx] 37 | if x.shape[0] == 2: 38 | x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3,:,:] 39 | if x.min() < 0: 40 | mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None] 41 | std = torch.tensor([0.229, 0.224, 0.225])[:,None,None] 42 | x = (x * std) + mean 43 | assert(x.min() >= 0) 44 | assert(x.max() <= 1) 45 | x = (x * 255.99) 46 | x = x.numpy().astype('uint8').transpose(1,2,0) 47 | imageio.imwrite( 48 | os.path.join(args.output_dir, 49 | f'{dataset_name}_env{env_idx}{dataset.ENVIRONMENTS[env_idx]}_{i}_idx{idx}_class{y}.png'), 50 | x) 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to Generalizable Mixture-of-Experts for Domain Generalization 2 | 3 | 🔥 Our paper [Sparse Mixture-of-Experts are Domain Generalizable Learners](https://openreview.net/forum?id=RecZ9nB9Q4) has officially been accepted as ICLR 2023 for Oral presentation. 4 | 5 | 🔥 GMoE-S/16 model currently [ranks top place](https://paperswithcode.com/sota/domain-generalization-on-domainnet) among multiple DG datasets without extra pre-training data. (Our GMoE-S/16 is initilized from [DeiT-S/16](https://github.com/facebookresearch/deit/blob/main/README_deit.md), which was only pretrained on ImageNet-1K 2012) 6 | 7 | Wondering why GMoEs have astonishing performance? 🤯 Let's investigate the generalization ability of model architecture itself and see the great potentials of Sparse Mixture-of-Experts (MoE) architecture. 8 | 9 | ### Preparation 10 | 11 | ```sh 12 | pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 13 | 14 | python3 -m pip uninstall tutel -y 15 | python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main 16 | 17 | pip3 install -r requirements.txt 18 | ``` 19 | 20 | ### Datasets 21 | 22 | ```sh 23 | python3 -m domainbed.scripts.download \ 24 | --data_dir=./domainbed/data 25 | ``` 26 | 27 | ### Environments 28 | 29 | Environment details used in paper for the main experiments on Nvidia V100 GPU. 30 | 31 | ```shell 32 | Environment: 33 | Python: 3.9.12 34 | PyTorch: 1.12.0+cu116 35 | Torchvision: 0.13.0+cu116 36 | CUDA: 11.6 37 | CUDNN: 8302 38 | NumPy: 1.19.5 39 | PIL: 9.2.0 40 | ``` 41 | 42 | ## Start Training 43 | 44 | Train a model: 45 | 46 | ```sh 47 | python3 -m domainbed.scripts.train\ 48 | --data_dir=./domainbed/data/OfficeHome/\ 49 | --algorithm GMOE\ 50 | --dataset OfficeHome\ 51 | --test_env 2 52 | ``` 53 | 54 | ## Hyper-params 55 | 56 | We put hparams for each dataset into 57 | ```sh 58 | ./domainbed/hparams_registry.py 59 | ``` 60 | 61 | Basically, you just need to choose `--algorithm` and `--dataset`. The optimal hparams will be loaded accordingly. 62 | 63 | ## License 64 | 65 | This source code is released under the MIT license, included [here](LICENSE). 66 | 67 | ## Acknowledgement 68 | 69 | The MoE module is built on [Tutel MoE](https://github.com/microsoft/tutel). 70 | -------------------------------------------------------------------------------- /domainbed/command_launchers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | A command launcher launches a list of commands on a cluster; implement your own 5 | launcher to add support for your cluster. We've provided an example launcher 6 | which runs all commands serially on the local machine. 7 | """ 8 | import os 9 | import subprocess 10 | import time 11 | 12 | import torch.cuda 13 | 14 | 15 | def local_launcher(commands): 16 | """Launch commands serially on the local machine.""" 17 | for cmd in commands: 18 | subprocess.call(cmd, shell=True) 19 | 20 | 21 | def dummy_launcher(commands): 22 | """ 23 | Doesn't run anything; instead, prints each command. 24 | Useful for testing. 25 | """ 26 | for cmd in commands: 27 | print(f'Dummy launcher: {cmd}') 28 | 29 | 30 | def multi_gpu_launcher(commands): 31 | """ 32 | Launch commands on the local machine, using all GPUs in parallel. 33 | """ 34 | print('WARNING: using experimental multi_gpu_launcher.') 35 | gpu_count = torch.cuda.device_count() 36 | # gpu_count = 5 37 | n_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(',') 38 | # n_gpus = ['3', '4', '5', '6', '7'] 39 | print('*' * 80) 40 | print(n_gpus) 41 | print('*' * 80) 42 | procs_by_gpu = [None] * gpu_count 43 | 44 | while len(commands) > 0: 45 | for idx in range(gpu_count): 46 | cur_gpu = n_gpus[idx] 47 | proc = procs_by_gpu[idx] 48 | if (proc is None) or (proc.poll() is not None): 49 | # Nothing is running on this GPU; launch a command. 50 | cmd = commands.pop(0) 51 | print(f'CUDA_VISIBLE_DEVICES={cur_gpu} {cmd}') 52 | new_proc = subprocess.Popen( 53 | f'CUDA_VISIBLE_DEVICES={cur_gpu} {cmd}', shell=True) 54 | procs_by_gpu[idx] = new_proc 55 | break 56 | time.sleep(1) 57 | 58 | # Wait for the last few tasks to finish before returning 59 | for p in procs_by_gpu: 60 | if p is not None: 61 | p.wait() 62 | 63 | 64 | REGISTRY = { 65 | 'local': local_launcher, 66 | 'dummy': dummy_launcher, 67 | 'multi_gpu': multi_gpu_launcher 68 | } 69 | 70 | try: 71 | from domainbed import facebook 72 | 73 | facebook.register_command_launchers(REGISTRY) 74 | except ImportError: 75 | pass 76 | -------------------------------------------------------------------------------- /domainbed/lib/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | class _InfiniteSampler(torch.utils.data.Sampler): 6 | """Wraps another Sampler to yield an infinite stream.""" 7 | def __init__(self, sampler): 8 | self.sampler = sampler 9 | 10 | def __iter__(self): 11 | while True: 12 | for batch in self.sampler: 13 | yield batch 14 | 15 | class InfiniteDataLoader: 16 | def __init__(self, dataset, weights, batch_size, num_workers): 17 | super().__init__() 18 | 19 | if weights is not None: 20 | sampler = torch.utils.data.WeightedRandomSampler(weights, 21 | replacement=True, 22 | num_samples=batch_size) 23 | else: 24 | sampler = torch.utils.data.RandomSampler(dataset, 25 | replacement=True) 26 | 27 | if weights == None: 28 | weights = torch.ones(len(dataset)) 29 | 30 | batch_sampler = torch.utils.data.BatchSampler( 31 | sampler, 32 | batch_size=batch_size, 33 | drop_last=True) 34 | 35 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 36 | dataset, 37 | num_workers=num_workers, 38 | batch_sampler=_InfiniteSampler(batch_sampler) 39 | )) 40 | 41 | def __iter__(self): 42 | while True: 43 | yield next(self._infinite_iterator) 44 | 45 | def __len__(self): 46 | raise ValueError 47 | 48 | class FastDataLoader: 49 | """DataLoader wrapper with slightly improved speed by not respawning worker 50 | processes at every epoch.""" 51 | def __init__(self, dataset, batch_size, num_workers): 52 | super().__init__() 53 | 54 | batch_sampler = torch.utils.data.BatchSampler( 55 | torch.utils.data.RandomSampler(dataset, replacement=False), 56 | batch_size=batch_size, 57 | drop_last=False 58 | ) 59 | 60 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 61 | dataset, 62 | num_workers=num_workers, 63 | batch_sampler=_InfiniteSampler(batch_sampler) 64 | )) 65 | 66 | self._length = len(batch_sampler) 67 | 68 | def __iter__(self): 69 | for _ in range(len(self)): 70 | yield next(self._infinite_iterator) 71 | 72 | def __len__(self): 73 | return self._length 74 | -------------------------------------------------------------------------------- /domainbed/lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | From https://github.com/meliketoy/wide-resnet.pytorch 5 | """ 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d( 19 | in_planes, 20 | out_planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=1, 24 | bias=True) 25 | 26 | 27 | def conv_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 31 | init.constant_(m.bias, 0) 32 | elif classname.find('BatchNorm') != -1: 33 | init.constant_(m.weight, 1) 34 | init.constant_(m.bias, 0) 35 | 36 | 37 | class wide_basic(nn.Module): 38 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 39 | super(wide_basic, self).__init__() 40 | self.bn1 = nn.BatchNorm2d(in_planes) 41 | self.conv1 = nn.Conv2d( 42 | in_planes, planes, kernel_size=3, padding=1, bias=True) 43 | self.dropout = nn.Dropout(p=dropout_rate) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d( 46 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d( 52 | in_planes, planes, kernel_size=1, stride=stride, 53 | bias=True), ) 54 | 55 | def forward(self, x): 56 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 57 | out = self.conv2(F.relu(self.bn2(out))) 58 | out += self.shortcut(x) 59 | 60 | return out 61 | 62 | 63 | class Wide_ResNet(nn.Module): 64 | """Wide Resnet with the softmax layer chopped off""" 65 | def __init__(self, input_shape, depth, widen_factor, dropout_rate): 66 | super(Wide_ResNet, self).__init__() 67 | self.in_planes = 16 68 | 69 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 70 | n = (depth - 4) / 6 71 | k = widen_factor 72 | 73 | # print('| Wide-Resnet %dx%d' % (depth, k)) 74 | nStages = [16, 16 * k, 32 * k, 64 * k] 75 | 76 | self.conv1 = conv3x3(input_shape[0], nStages[0]) 77 | self.layer1 = self._wide_layer( 78 | wide_basic, nStages[1], n, dropout_rate, stride=1) 79 | self.layer2 = self._wide_layer( 80 | wide_basic, nStages[2], n, dropout_rate, stride=2) 81 | self.layer3 = self._wide_layer( 82 | wide_basic, nStages[3], n, dropout_rate, stride=2) 83 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 84 | 85 | self.n_outputs = nStages[3] 86 | 87 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 88 | strides = [stride] + [1] * (int(num_blocks) - 1) 89 | layers = [] 90 | 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 93 | self.in_planes = planes 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | out = self.conv1(x) 99 | out = self.layer1(out) 100 | out = self.layer2(out) 101 | out = self.layer3(out) 102 | out = F.relu(self.bn1(out)) 103 | out = F.avg_pool2d(out, 8) 104 | return out[:, :, 0, 0] 105 | -------------------------------------------------------------------------------- /domainbed/lib/logger.py: -------------------------------------------------------------------------------- 1 | """ Singleton Logger """ 2 | import sys 3 | import logging 4 | 5 | 6 | def levelize(levelname): 7 | """Convert levelname to level only if it is levelname""" 8 | if isinstance(levelname, str): 9 | return logging.getLevelName(levelname) 10 | else: 11 | return levelname # already level 12 | 13 | 14 | class ColorFormatter(logging.Formatter): 15 | color_dic = { 16 | "DEBUG": 37, # white 17 | "INFO": 36, # cyan 18 | "WARNING": 33, # yellow 19 | "ERROR": 31, # red 20 | "CRITICAL": 41, # white on red bg 21 | } 22 | 23 | def format(self, record): 24 | color = self.color_dic.get(record.levelname, 37) # default white 25 | record.levelname = "\033[{}m{}\033[0m".format(color, record.levelname) 26 | return logging.Formatter.format(self, record) 27 | 28 | 29 | class Logger(logging.Logger): 30 | NAME = "SingletonLogger" 31 | 32 | @classmethod 33 | def get(cls, file_path=None, level="INFO", colorize=True, track_code=False): 34 | logging.setLoggerClass(cls) 35 | logger = logging.getLogger(cls.NAME) 36 | logging.setLoggerClass(logging.Logger) # restore 37 | logger.setLevel(level) 38 | 39 | if logger.hasHandlers(): 40 | # If logger already got all handlers (# handlers == 2), use the logger. 41 | # else, re-set handlers. 42 | if len(logger.handlers) == 2: 43 | return logger 44 | 45 | logger.handlers.clear() 46 | 47 | log_format = "%(levelname)s %(asctime)s | %(message)s" 48 | # log_format = '%(asctime)s | %(message)s' 49 | if track_code: 50 | log_format = ( 51 | "%(levelname)s::%(asctime)s | [%(filename)s] [%(funcName)s:%(lineno)d] " 52 | "%(message)s" 53 | ) 54 | date_format = "%m/%d %H:%M:%S" 55 | if colorize: 56 | formatter = ColorFormatter(log_format, date_format) 57 | else: 58 | formatter = logging.Formatter(log_format, date_format) 59 | 60 | # standard output handler 61 | # NOTE as default, StreamHandler use stderr stream instead of stdout stream. 62 | # Use StreamHandler(sys.stdout) for stdout stream. 63 | stream_handler = logging.StreamHandler(sys.stdout) 64 | stream_handler.setFormatter(formatter) 65 | logger.addHandler(stream_handler) 66 | 67 | if file_path: 68 | # file output handler 69 | file_handler = logging.FileHandler(file_path) 70 | file_handler.setFormatter(formatter) 71 | logger.addHandler(file_handler) 72 | 73 | logger.propagate = False 74 | 75 | return logger 76 | 77 | def nofmt(self, msg, *args, level="INFO", **kwargs): 78 | level = levelize(level) 79 | formatters = self.remove_formats() 80 | super().log(level, msg, *args, **kwargs) 81 | self.set_formats(formatters) 82 | 83 | def remove_formats(self): 84 | """Remove all formats from logger""" 85 | formatters = [] 86 | for handler in self.handlers: 87 | formatters.append(handler.formatter) 88 | handler.setFormatter(logging.Formatter("%(message)s")) 89 | 90 | return formatters 91 | 92 | def set_formats(self, formatters): 93 | """Set formats to every handler of logger""" 94 | for handler, formatter in zip(self.handlers, formatters): 95 | handler.setFormatter(formatter) 96 | 97 | def set_file_handler(self, file_path): 98 | file_handler = logging.FileHandler(file_path) 99 | formatter = self.handlers[0].formatter 100 | file_handler.setFormatter(formatter) 101 | self.addHandler(file_handler) 102 | -------------------------------------------------------------------------------- /domainbed/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kakao Brain. All Rights Reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models 6 | import clip 7 | 8 | 9 | def clip_imageencoder(name): 10 | model, _preprocess = clip.load(name, device="cpu") 11 | imageencoder = model.visual 12 | 13 | return imageencoder 14 | 15 | 16 | class Identity(nn.Module): 17 | """An identity layer""" 18 | 19 | def __init__(self): 20 | super(Identity, self).__init__() 21 | 22 | def forward(self, x): 23 | return x 24 | 25 | 26 | def torchhub_load(repo, model, **kwargs): 27 | try: 28 | # torch >= 1.10 29 | network = torch.hub.load(repo, model=model, skip_validation=True, **kwargs) 30 | except TypeError: 31 | # torch 1.7.1 32 | network = torch.hub.load(repo, model=model, **kwargs) 33 | 34 | return network 35 | 36 | 37 | def get_backbone(name, preserve_readout, pretrained): 38 | if not pretrained: 39 | assert name in ["resnet50", "swag_regnety_16gf"], "Only RN50/RegNet supports non-pretrained network" 40 | 41 | if name == "resnet18": 42 | network = torchvision.models.resnet18(pretrained=True) 43 | n_outputs = 512 44 | elif name == "resnet50": 45 | network = torchvision.models.resnet50(pretrained=pretrained) 46 | n_outputs = 2048 47 | elif name == "resnet50_barlowtwins": 48 | network = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50') 49 | n_outputs = 2048 50 | elif name == "resnet50_moco": 51 | network = torchvision.models.resnet50() 52 | 53 | # download pretrained model of MoCo v3: https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/r-50-1000ep.pth.tar 54 | ckpt_path = "./r-50-1000ep.pth.tar" 55 | 56 | # https://github.com/facebookresearch/moco-v3/blob/main/main_lincls.py#L172 57 | print("=> loading checkpoint '{}'".format(ckpt_path)) 58 | checkpoint = torch.load(ckpt_path, map_location="cpu") 59 | 60 | # rename moco pre-trained keys 61 | state_dict = checkpoint['state_dict'] 62 | linear_keyword = "fc" # resnet linear keyword 63 | for k in list(state_dict.keys()): 64 | # retain only base_encoder up to before the embedding layer 65 | if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword): 66 | # remove prefix 67 | state_dict[k[len("module.base_encoder."):]] = state_dict[k] 68 | # delete renamed or unused k 69 | del state_dict[k] 70 | 71 | msg = network.load_state_dict(state_dict, strict=False) 72 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 73 | 74 | print("=> loaded pre-trained model '{}'".format(ckpt_path)) 75 | 76 | n_outputs = 2048 77 | elif name.startswith("clip_resnet"): 78 | name = "RN" + name[11:] 79 | network = clip_imageencoder(name) 80 | n_outputs = network.output_dim 81 | elif name == "clip_vit-b16": 82 | network = clip_imageencoder("ViT-B/16") 83 | n_outputs = network.output_dim 84 | elif name == "swag_regnety_16gf": 85 | # No readout layer as default 86 | network = torchhub_load("facebookresearch/swag", model="regnety_16gf", pretrained=pretrained) 87 | 88 | network.head = nn.Sequential( 89 | nn.AdaptiveAvgPool2d(1), 90 | nn.Flatten(1), 91 | ) 92 | n_outputs = 3024 93 | else: 94 | raise ValueError(name) 95 | 96 | if not preserve_readout: 97 | # remove readout layer (but left GAP and flatten) 98 | # final output shape: [B, n_outputs] 99 | if name.startswith("resnet"): 100 | del network.fc 101 | network.fc = Identity() 102 | 103 | return network, n_outputs 104 | -------------------------------------------------------------------------------- /domainbed/test/scripts/test_collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import os 7 | import subprocess 8 | import sys 9 | import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | from domainbed import datasets 16 | from domainbed import hparams_registry 17 | from domainbed import algorithms 18 | from domainbed import networks 19 | from domainbed.test import helpers 20 | from domainbed.scripts import collect_results 21 | 22 | from parameterized import parameterized 23 | import io 24 | import textwrap 25 | 26 | class TestCollectResults(unittest.TestCase): 27 | 28 | def test_format_mean(self): 29 | self.assertEqual( 30 | collect_results.format_mean([0.1, 0.2, 0.3], False)[2], 31 | '20.0 +/- 4.7') 32 | self.assertEqual( 33 | collect_results.format_mean([0.1, 0.2, 0.3], True)[2], 34 | '20.0 $\pm$ 4.7') 35 | 36 | def test_print_table_non_latex(self): 37 | temp_out = io.StringIO() 38 | sys.stdout = temp_out 39 | table = [['1', '2'], ['3', '4']] 40 | collect_results.print_table(table, 'Header text', ['R1', 'R2'], 41 | ['C1', 'C2'], colwidth=10, latex=False) 42 | sys.stdout = sys.__stdout__ 43 | self.assertEqual( 44 | temp_out.getvalue(), 45 | textwrap.dedent(""" 46 | -------- Header text 47 | C1 C2 48 | R1 1 2 49 | R2 3 4 50 | """) 51 | ) 52 | 53 | def test_print_table_latex(self): 54 | temp_out = io.StringIO() 55 | sys.stdout = temp_out 56 | table = [['1', '2'], ['3', '4']] 57 | collect_results.print_table(table, 'Header text', ['R1', 'R2'], 58 | ['C1', 'C2'], colwidth=10, latex=True) 59 | sys.stdout = sys.__stdout__ 60 | self.assertEqual( 61 | temp_out.getvalue(), 62 | textwrap.dedent(r""" 63 | \begin{center} 64 | \adjustbox{max width=\textwidth}{% 65 | \begin{tabular}{lcc} 66 | \toprule 67 | \textbf{C1 & \textbf{C2 \\ 68 | \midrule 69 | R1 & 1 & 2 \\ 70 | R2 & 3 & 4 \\ 71 | \bottomrule 72 | \end{tabular}} 73 | \end{center} 74 | """) 75 | ) 76 | 77 | def test_get_grouped_records(self): 78 | pass # TODO 79 | 80 | def test_print_results_tables(self): 81 | pass # TODO 82 | 83 | def test_load_records(self): 84 | pass # TODO 85 | 86 | def test_end_to_end(self): 87 | """ 88 | Test that collect_results.py's output matches a manually-verified 89 | ground-truth when run on a given directory of test sweep data. 90 | 91 | If you make any changes to the output of collect_results.py, you'll need 92 | to update the ground-truth and manually verify that it's still 93 | correct. The command used to update the ground-truth is: 94 | 95 | python -m domainbed.scripts.collect_results --input_dir=domainbed/misc/test_sweep_data \ 96 | | tee domainbed/misc/test_sweep_results.txt 97 | 98 | Furthermore, if you make any changes to the data format, you'll also 99 | need to rerun the test sweep. The command used to run the test sweep is: 100 | 101 | python -m domainbed.scripts.sweep launch --data_dir=$DATA_DIR \ 102 | --output_dir=domainbed/misc/test_sweep_data --algorithms ERM \ 103 | --datasets VLCS --steps 1001 --n_hparams 2 --n_trials 2 \ 104 | --command_launcher local 105 | """ 106 | result = subprocess.run('python -m domainbed.scripts.collect_results' 107 | ' --input_dir=domainbed/misc/test_sweep_data', shell=True, 108 | stdout=subprocess.PIPE) 109 | 110 | with open('domainbed/misc/test_sweep_results.txt', 'r') as f: 111 | ground_truth = f.read() 112 | 113 | self.assertEqual(result.stdout.decode('utf8'), ground_truth) 114 | -------------------------------------------------------------------------------- /domainbed/evaluator.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from domainbed.lib.fast_data_loader import FastDataLoader 6 | 7 | if torch.cuda.is_available(): 8 | device = "cuda" 9 | else: 10 | device = "cpu" 11 | 12 | 13 | def accuracy_from_loader(algorithm, loader, weights, debug=False): 14 | correct = 0 15 | total = 0 16 | losssum = 0.0 17 | weights_offset = 0 18 | 19 | algorithm.eval() 20 | 21 | for i, batch in enumerate(loader): 22 | x = batch["x"].to(device) 23 | y = batch["y"].to(device) 24 | 25 | with torch.no_grad(): 26 | logits = algorithm.predict(x) 27 | loss = F.cross_entropy(logits, y).item() 28 | 29 | B = len(x) 30 | losssum += loss * B 31 | 32 | if weights is None: 33 | batch_weights = torch.ones(len(x)) 34 | else: 35 | batch_weights = weights[weights_offset : weights_offset + len(x)] 36 | weights_offset += len(x) 37 | batch_weights = batch_weights.to(device) 38 | if logits.size(1) == 1: 39 | correct += (logits.gt(0).eq(y).float() * batch_weights).sum().item() 40 | else: 41 | correct += (logits.argmax(1).eq(y).float() * batch_weights).sum().item() 42 | total += batch_weights.sum().item() 43 | 44 | if debug: 45 | break 46 | 47 | algorithm.train() 48 | 49 | acc = correct / total 50 | loss = losssum / total 51 | return acc, loss 52 | 53 | 54 | def accuracy(algorithm, loader_kwargs, weights, **kwargs): 55 | if isinstance(loader_kwargs, dict): 56 | loader = FastDataLoader(**loader_kwargs) 57 | elif isinstance(loader_kwargs, FastDataLoader): 58 | loader = loader_kwargs 59 | else: 60 | raise ValueError(loader_kwargs) 61 | return accuracy_from_loader(algorithm, loader, weights, **kwargs) 62 | 63 | 64 | class Evaluator: 65 | def __init__( 66 | self, test_envs, eval_meta, n_envs, logger, evalmode="fast", debug=False, target_env=None 67 | ): 68 | all_envs = list(range(n_envs)) 69 | train_envs = sorted(set(all_envs) - set(test_envs)) 70 | self.test_envs = test_envs 71 | self.train_envs = train_envs 72 | self.eval_meta = eval_meta 73 | self.n_envs = n_envs 74 | self.logger = logger 75 | self.evalmode = evalmode 76 | self.debug = debug 77 | 78 | if target_env is not None: 79 | self.set_target_env(target_env) 80 | 81 | def set_target_env(self, target_env): 82 | """When len(test_envs) == 2, you can specify target env for computing exact test acc.""" 83 | self.test_envs = [target_env] 84 | 85 | def evaluate(self, algorithm, ret_losses=False): 86 | n_train_envs = len(self.train_envs) 87 | n_test_envs = len(self.test_envs) 88 | assert n_test_envs == 1 89 | summaries = collections.defaultdict(float) 90 | # for key order 91 | summaries["test_in"] = 0.0 92 | summaries["test_out"] = 0.0 93 | summaries["train_in"] = 0.0 94 | summaries["train_out"] = 0.0 95 | accuracies = {} 96 | losses = {} 97 | 98 | # order: in_splits + out_splits. 99 | for name, loader_kwargs, weights in self.eval_meta: 100 | # env\d_[in|out] 101 | env_name, inout = name.split("_") 102 | env_num = int(env_name[3:]) 103 | 104 | skip_eval = self.evalmode == "fast" and inout == "in" and env_num not in self.test_envs 105 | if skip_eval: 106 | continue 107 | 108 | is_test = env_num in self.test_envs 109 | acc, loss = accuracy(algorithm, loader_kwargs, weights, debug=self.debug) 110 | accuracies[name] = acc 111 | losses[name] = loss 112 | 113 | if env_num in self.train_envs: 114 | summaries["train_" + inout] += acc / n_train_envs 115 | if inout == "out": 116 | summaries["tr_" + inout + "loss"] += loss / n_train_envs 117 | elif is_test: 118 | summaries["test_" + inout] += acc / n_test_envs 119 | 120 | if ret_losses: 121 | return accuracies, summaries, losses 122 | else: 123 | return accuracies, summaries 124 | -------------------------------------------------------------------------------- /domainbed/test/test_model_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Unit tests.""" 4 | 5 | import argparse 6 | import itertools 7 | import json 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import unittest 13 | import uuid 14 | 15 | import torch 16 | 17 | from domainbed import model_selection 18 | from domainbed.lib.query import Q 19 | 20 | from parameterized import parameterized 21 | 22 | def make_record(step, hparams_seed, envs): 23 | """envs is a list of (in_acc, out_acc, is_test_env) tuples""" 24 | result = { 25 | 'args': {'test_envs': [], 'hparams_seed': hparams_seed}, 26 | 'step': step 27 | } 28 | for i, (in_acc, out_acc, is_test_env) in enumerate(envs): 29 | if is_test_env: 30 | result['args']['test_envs'].append(i) 31 | result[f'env{i}_in_acc'] = in_acc 32 | result[f'env{i}_out_acc'] = out_acc 33 | return result 34 | 35 | class TestSelectionMethod(unittest.TestCase): 36 | 37 | class MySelectionMethod(model_selection.SelectionMethod): 38 | @classmethod 39 | def run_acc(self, run_records): 40 | return { 41 | 'val_acc': run_records[0]['env0_out_acc'], 42 | 'test_acc': run_records[0]['env0_in_acc'] 43 | } 44 | 45 | def test_sweep_acc(self): 46 | sweep_records = Q([ 47 | make_record(0, 0, [(0.7, 0.8, True)]), 48 | make_record(0, 1, [(0.9, 0.5, True)]) 49 | ]) 50 | 51 | self.assertEqual( 52 | self.MySelectionMethod.sweep_acc(sweep_records), 53 | 0.7 54 | ) 55 | 56 | def test_sweep_acc_empty(self): 57 | self.assertEqual( 58 | self.MySelectionMethod.sweep_acc(Q([])), 59 | None 60 | ) 61 | 62 | class TestOracleSelectionMethod(unittest.TestCase): 63 | 64 | def test_run_acc_best_first(self): 65 | """Test run_acc() when the run has two records and the best one comes 66 | first""" 67 | run_records = Q([ 68 | make_record(0, 0, [(0.75, 0.70, True)]), 69 | make_record(1, 0, [(0.65, 0.60, True)]) 70 | ]) 71 | self.assertEqual( 72 | model_selection.OracleSelectionMethod.run_acc(run_records), 73 | {'val_acc': 0.60, 'test_acc': 0.65} 74 | ) 75 | 76 | def test_run_acc_best_last(self): 77 | """Test run_acc() when the run has two records and the best one comes 78 | last""" 79 | run_records = Q([ 80 | make_record(0, 0, [(0.75, 0.70, True)]), 81 | make_record(1, 0, [(0.85, 0.80, True)]) 82 | ]) 83 | self.assertEqual( 84 | model_selection.OracleSelectionMethod.run_acc(run_records), 85 | {'val_acc': 0.80, 'test_acc': 0.85} 86 | ) 87 | 88 | def test_run_acc_empty(self): 89 | """Test run_acc() when there are no valid records to choose from.""" 90 | self.assertEqual( 91 | model_selection.OracleSelectionMethod.run_acc(Q([])), 92 | None 93 | ) 94 | 95 | class TestIIDAccuracySelectionMethod(unittest.TestCase): 96 | 97 | def test_run_acc(self): 98 | run_records = Q([ 99 | make_record(0, 0, 100 | [(0.1, 0.2, True), (0.5, 0.6, False), (0.6, 0.7, False)]), 101 | make_record(1, 0, 102 | [(0.3, 0.4, True), (0.6, 0.7, False), (0.7, 0.8, False)]), 103 | ]) 104 | self.assertEqual( 105 | model_selection.IIDAccuracySelectionMethod.run_acc(run_records), 106 | {'val_acc': 0.75, 'test_acc': 0.3} 107 | ) 108 | 109 | def test_run_acc_empty(self): 110 | self.assertEqual( 111 | model_selection.IIDAccuracySelectionMethod.run_acc(Q([])), 112 | None) 113 | 114 | class TestLeaveOneOutSelectionMethod(unittest.TestCase): 115 | 116 | def test_run_acc(self): 117 | run_records = Q([ 118 | make_record(0, 0, 119 | [(0.1, 0., True), (0.0, 0., False), (0.0, 0., False)]), 120 | make_record(0, 0, 121 | [(0.0, 0., True), (0.5, 0., True), (0., 0., False)]), 122 | make_record(0, 0, 123 | [(0.0, 0., True), (0.0, 0., False), (0.6, 0., True)]), 124 | ]) 125 | self.assertEqual( 126 | model_selection.LeaveOneOutSelectionMethod.run_acc(run_records), 127 | {'val_acc': 0.55, 'test_acc': 0.1} 128 | ) 129 | 130 | def test_run_acc_empty(self): 131 | run_records = Q([ 132 | make_record(0, 0, 133 | [(0.1, 0., True), (0.0, 0., False), (0.0, 0., False)]), 134 | make_record(0, 0, 135 | [(0.0, 0., True), (0.5, 0., True), (0., 0., False)]), 136 | ]) 137 | self.assertEqual( 138 | model_selection.LeaveOneOutSelectionMethod.run_acc(run_records), 139 | None 140 | ) 141 | -------------------------------------------------------------------------------- /domainbed/test/scripts/test_sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import os 7 | import subprocess 8 | import sys 9 | import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | from domainbed import datasets 16 | from domainbed import hparams_registry 17 | from domainbed import algorithms 18 | from domainbed import networks 19 | from domainbed.test import helpers 20 | from domainbed.scripts import sweep 21 | 22 | from parameterized import parameterized 23 | 24 | class TestSweep(unittest.TestCase): 25 | 26 | def test_job(self): 27 | """Test that a newly-created job has valid 28 | output_dir, state, and command_str properties.""" 29 | train_args = {'foo': 'bar'} 30 | sweep_output_dir = f'/tmp/{str(uuid.uuid4())}' 31 | job = sweep.Job(train_args, sweep_output_dir) 32 | self.assertTrue(job.output_dir.startswith(sweep_output_dir)) 33 | self.assertEqual(job.state, sweep.Job.NOT_LAUNCHED) 34 | self.assertEqual(job.command_str, 35 | f'python -m domainbed.scripts.train --foo bar --output_dir {job.output_dir}') 36 | 37 | def test_job_launch(self): 38 | """Test that launching a job calls the launcher_fn with appropariate 39 | arguments, and sets the job to INCOMPLETE state.""" 40 | train_args = {'foo': 'bar'} 41 | sweep_output_dir = f'/tmp/{str(uuid.uuid4())}' 42 | job = sweep.Job(train_args, sweep_output_dir) 43 | 44 | launcher_fn_called = False 45 | def launcher_fn(commands): 46 | nonlocal launcher_fn_called 47 | launcher_fn_called = True 48 | self.assertEqual(len(commands), 1) 49 | self.assertEqual(commands[0], job.command_str) 50 | 51 | sweep.Job.launch([job], launcher_fn) 52 | self.assertTrue(launcher_fn_called) 53 | 54 | job = sweep.Job(train_args, sweep_output_dir) 55 | self.assertEqual(job.state, sweep.Job.INCOMPLETE) 56 | 57 | def test_job_delete(self): 58 | """Test that deleting a launched job returns it to the NOT_LAUNCHED 59 | state""" 60 | train_args = {'foo': 'bar'} 61 | sweep_output_dir = f'/tmp/{str(uuid.uuid4())}' 62 | job = sweep.Job(train_args, sweep_output_dir) 63 | sweep.Job.launch([job], (lambda commands: None)) 64 | sweep.Job.delete([job]) 65 | 66 | job = sweep.Job(train_args, sweep_output_dir) 67 | self.assertEqual(job.state, sweep.Job.NOT_LAUNCHED) 68 | 69 | 70 | def test_make_args_list(self): 71 | """Test that, for a typical input, make_job_list returns a list 72 | of the correct length""" 73 | args_list = sweep.make_args_list( 74 | n_trials=2, 75 | dataset_names=['Debug28'], 76 | algorithms=['ERM'], 77 | n_hparams_from=0, 78 | n_hparams=3, 79 | steps=123, 80 | data_dir='/tmp/data', 81 | task='domain_generalization', 82 | holdout_fraction=0.2, 83 | single_test_envs=False, 84 | hparams=None 85 | ) 86 | assert(len(args_list) == 2*3*(3+3)) 87 | 88 | @unittest.skipIf('DATA_DIR' not in os.environ, 'needs DATA_DIR environment ' 89 | 'variable') 90 | def test_end_to_end(self): 91 | output_dir = os.path.join('/tmp', str(uuid.uuid4())) 92 | result = subprocess.run(f'python -m domainbed.scripts.sweep launch ' 93 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 94 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 95 | f'--command_launcher dummy --skip_confirmation', 96 | shell=True, capture_output=True) 97 | stdout_lines = result.stdout.decode('utf8').split("\n") 98 | dummy_launcher_lines = [l for l in stdout_lines 99 | if l.startswith('Dummy launcher:')] 100 | self.assertEqual(len(dummy_launcher_lines), 6) 101 | 102 | # Now run it again and make sure it doesn't try to relaunch those jobs 103 | result = subprocess.run(f'python -m domainbed.scripts.sweep launch ' 104 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 105 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 106 | f'--command_launcher dummy --skip_confirmation', 107 | shell=True, capture_output=True) 108 | stdout_lines = result.stdout.decode('utf8').split("\n") 109 | dummy_launcher_lines = [l for l in stdout_lines 110 | if l.startswith('Dummy launcher:')] 111 | self.assertEqual(len(dummy_launcher_lines), 0) 112 | 113 | # Delete the incomplete jobs, try launching again, and make sure they 114 | # get relaunched. 115 | subprocess.run(f'python -m domainbed.scripts.sweep delete_incomplete ' 116 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 117 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 118 | f'--command_launcher dummy --skip_confirmation', 119 | shell=True, capture_output=True) 120 | 121 | result = subprocess.run(f'python -m domainbed.scripts.sweep launch ' 122 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 123 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 124 | f'--command_launcher dummy --skip_confirmation', 125 | shell=True, capture_output=True) 126 | stdout_lines = result.stdout.decode('utf8').split("\n") 127 | dummy_launcher_lines = [l for l in stdout_lines 128 | if l.startswith('Dummy launcher:')] 129 | self.assertEqual(len(dummy_launcher_lines), 6) 130 | -------------------------------------------------------------------------------- /domainbed/model_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import itertools 4 | 5 | import numpy as np 6 | 7 | 8 | def get_test_records(records): 9 | """Given records with a common test env, get the test records (i.e. the 10 | records with *only* that single test env and no other test envs)""" 11 | return records.filter(lambda r: len(r['args']['test_envs']) == 1) 12 | 13 | 14 | class SelectionMethod: 15 | """Abstract class whose subclasses implement strategies for model 16 | selection across hparams and timesteps.""" 17 | 18 | def __init__(self): 19 | raise TypeError 20 | 21 | @classmethod 22 | def run_acc(self, run_records): 23 | """ 24 | Given records from a run, return a {val_acc, test_acc} dict representing 25 | the best val-acc and corresponding test-acc for that run. 26 | """ 27 | raise NotImplementedError 28 | 29 | @classmethod 30 | def hparams_accs(self, records): 31 | """ 32 | Given all records from a single (dataset, algorithm, test env) pair, 33 | return a sorted list of (run_acc, records) tuples. 34 | """ 35 | return (records.group('args.hparams_seed') 36 | .map(lambda _, run_records: 37 | ( 38 | self.run_acc(run_records), 39 | run_records 40 | ) 41 | ).filter(lambda x: x[0] is not None) 42 | .sorted(key=lambda x: x[0]['val_acc'])[::-1] 43 | ) 44 | 45 | @classmethod 46 | def sweep_acc(self, records): 47 | """ 48 | Given all records from a single (dataset, algorithm, test env) pair, 49 | return the mean test acc of the k runs with the top val accs. 50 | """ 51 | _hparams_accs = self.hparams_accs(records) 52 | if len(_hparams_accs): 53 | return _hparams_accs[0][0]['test_acc'] 54 | else: 55 | return None 56 | 57 | 58 | class OracleSelectionMethod(SelectionMethod): 59 | """Like Selection method which picks argmax(test_out_acc) across all hparams 60 | and checkpoints, but instead of taking the argmax over all 61 | checkpoints, we pick the last checkpoint, i.e. no early stopping.""" 62 | name = "test-domain validation set (oracle)" 63 | 64 | @classmethod 65 | def run_acc(self, run_records): 66 | run_records = run_records.filter(lambda r: len(r['args']['test_envs']) == 1) 67 | if not len(run_records): 68 | return None 69 | test_env = run_records[0]['args']['test_envs'][0] 70 | test_out_acc_key = 'env{}_out_acc'.format(test_env) 71 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 72 | chosen_record = run_records.sorted(lambda r: r['step'])[-1] 73 | return { 74 | 'val_acc': chosen_record[test_out_acc_key], 75 | 'test_acc': chosen_record[test_in_acc_key] 76 | } 77 | 78 | 79 | class IIDAccuracySelectionMethod(SelectionMethod): 80 | """Picks argmax(mean(env_out_acc for env in train_envs))""" 81 | name = "training-domain validation set" 82 | 83 | @classmethod 84 | def _step_acc(self, record): 85 | """Given a single record, return a {val_acc, test_acc} dict.""" 86 | test_env = record['args']['test_envs'][0] 87 | val_env_keys = [] 88 | for i in itertools.count(): 89 | if f'env{i}_out_acc' not in record: 90 | break 91 | if i != test_env: 92 | val_env_keys.append(f'env{i}_out_acc') 93 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 94 | return { 95 | 'val_acc': np.mean([record[key] for key in val_env_keys]), 96 | 'test_acc': record[test_in_acc_key] 97 | } 98 | 99 | @classmethod 100 | def run_acc(self, run_records): 101 | test_records = get_test_records(run_records) 102 | if not len(test_records): 103 | return None 104 | return test_records.map(self._step_acc).argmax('val_acc') 105 | 106 | 107 | class LeaveOneOutSelectionMethod(SelectionMethod): 108 | """Picks (hparams, step) by leave-one-out cross validation.""" 109 | name = "leave-one-domain-out cross-validation" 110 | 111 | @classmethod 112 | def _step_acc(self, records): 113 | """Return the {val_acc, test_acc} for a group of records corresponding 114 | to a single step.""" 115 | test_records = get_test_records(records) 116 | if len(test_records) != 1: 117 | return None 118 | 119 | test_env = test_records[0]['args']['test_envs'][0] 120 | n_envs = 0 121 | for i in itertools.count(): 122 | if f'env{i}_out_acc' not in records[0]: 123 | break 124 | n_envs += 1 125 | val_accs = np.zeros(n_envs) - 1 126 | for r in records.filter(lambda r: len(r['args']['test_envs']) == 2): 127 | val_env = (set(r['args']['test_envs']) - set([test_env])).pop() 128 | val_accs[val_env] = r['env{}_in_acc'.format(val_env)] 129 | val_accs = list(val_accs[:test_env]) + list(val_accs[test_env + 1:]) 130 | if any([v == -1 for v in val_accs]): 131 | return None 132 | val_acc = np.sum(val_accs) / (n_envs - 1) 133 | return { 134 | 'val_acc': val_acc, 135 | 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)] 136 | } 137 | 138 | @classmethod 139 | def run_acc(self, records): 140 | step_accs = records.group('step').map(lambda step, step_records: 141 | self._step_acc(step_records) 142 | ).filter_not_none() 143 | if len(step_accs): 144 | return step_accs.argmax('val_acc') 145 | else: 146 | return None 147 | -------------------------------------------------------------------------------- /domainbed/lib/query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Small query library.""" 4 | 5 | import collections 6 | import inspect 7 | import json 8 | import types 9 | import unittest 10 | import warnings 11 | import math 12 | 13 | import numpy as np 14 | 15 | 16 | def make_selector_fn(selector): 17 | """ 18 | If selector is a function, return selector. 19 | Otherwise, return a function corresponding to the selector string. Examples 20 | of valid selector strings and the corresponding functions: 21 | x lambda obj: obj['x'] 22 | x.y lambda obj: obj['x']['y'] 23 | x,y lambda obj: (obj['x'], obj['y']) 24 | """ 25 | if isinstance(selector, str): 26 | if ',' in selector: 27 | parts = selector.split(',') 28 | part_selectors = [make_selector_fn(part) for part in parts] 29 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 30 | elif '.' in selector: 31 | parts = selector.split('.') 32 | part_selectors = [make_selector_fn(part) for part in parts] 33 | def f(obj): 34 | for sel in part_selectors: 35 | obj = sel(obj) 36 | return obj 37 | return f 38 | else: 39 | key = selector.strip() 40 | return lambda obj: obj[key] 41 | elif isinstance(selector, types.FunctionType): 42 | return selector 43 | else: 44 | raise TypeError 45 | 46 | def hashable(obj): 47 | try: 48 | hash(obj) 49 | return obj 50 | except TypeError: 51 | return json.dumps({'_':obj}, sort_keys=True) 52 | 53 | class Q(object): 54 | def __init__(self, list_): 55 | super(Q, self).__init__() 56 | self._list = list_ 57 | 58 | def __len__(self): 59 | return len(self._list) 60 | 61 | def __getitem__(self, key): 62 | return self._list[key] 63 | 64 | def __eq__(self, other): 65 | if isinstance(other, self.__class__): 66 | return self._list == other._list 67 | else: 68 | return self._list == other 69 | 70 | def __str__(self): 71 | return str(self._list) 72 | 73 | def __repr__(self): 74 | return repr(self._list) 75 | 76 | def _append(self, item): 77 | """Unsafe, be careful you know what you're doing.""" 78 | self._list.append(item) 79 | 80 | def group(self, selector): 81 | """ 82 | Group elements by selector and return a list of (group, group_records) 83 | tuples. 84 | """ 85 | selector = make_selector_fn(selector) 86 | groups = {} 87 | for x in self._list: 88 | group = selector(x) 89 | group_key = hashable(group) 90 | if group_key not in groups: 91 | groups[group_key] = (group, Q([])) 92 | groups[group_key][1]._append(x) 93 | results = [groups[key] for key in sorted(groups.keys())] 94 | return Q(results) 95 | 96 | def group_map(self, selector, fn): 97 | """ 98 | Group elements by selector, apply fn to each group, and return a list 99 | of the results. 100 | """ 101 | return self.group(selector).map(fn) 102 | 103 | def map(self, fn): 104 | """ 105 | map self onto fn. If fn takes multiple args, tuple-unpacking 106 | is applied. 107 | """ 108 | if len(inspect.signature(fn).parameters) > 1: 109 | return Q([fn(*x) for x in self._list]) 110 | else: 111 | return Q([fn(x) for x in self._list]) 112 | 113 | def select(self, selector): 114 | selector = make_selector_fn(selector) 115 | return Q([selector(x) for x in self._list]) 116 | 117 | def min(self): 118 | return min(self._list) 119 | 120 | def max(self): 121 | return max(self._list) 122 | 123 | def sum(self): 124 | return sum(self._list) 125 | 126 | def len(self): 127 | return len(self._list) 128 | 129 | def mean(self): 130 | with warnings.catch_warnings(): 131 | warnings.simplefilter("ignore") 132 | return float(np.mean(self._list)) 133 | 134 | def std(self): 135 | with warnings.catch_warnings(): 136 | warnings.simplefilter("ignore") 137 | return float(np.std(self._list)) 138 | 139 | def mean_std(self): 140 | return (self.mean(), self.std()) 141 | 142 | def argmax(self, selector): 143 | selector = make_selector_fn(selector) 144 | return max(self._list, key=selector) 145 | 146 | def filter(self, fn): 147 | return Q([x for x in self._list if fn(x)]) 148 | 149 | def filter_equals(self, selector, value): 150 | """like [x for x in y if x.selector == value]""" 151 | selector = make_selector_fn(selector) 152 | return self.filter(lambda r: selector(r) == value) 153 | 154 | def filter_not_none(self): 155 | return self.filter(lambda r: r is not None) 156 | 157 | def filter_not_nan(self): 158 | return self.filter(lambda r: not np.isnan(r)) 159 | 160 | def flatten(self): 161 | return Q([y for x in self._list for y in x]) 162 | 163 | def unique(self): 164 | result = [] 165 | result_set = set() 166 | for x in self._list: 167 | hashable_x = hashable(x) 168 | if hashable_x not in result_set: 169 | result_set.add(hashable_x) 170 | result.append(x) 171 | return Q(result) 172 | 173 | def sorted(self, key=None): 174 | if key is None: 175 | key = lambda x: x 176 | def key2(x): 177 | x = key(x) 178 | if isinstance(x, (np.floating, float)) and np.isnan(x): 179 | return float('-inf') 180 | else: 181 | return x 182 | return Q(sorted(self._list, key=key2)) 183 | -------------------------------------------------------------------------------- /domainbed/scripts/list_top_hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Example usage: 5 | python -u -m domainbed.scripts.list_top_hparams \ 6 | --input_dir domainbed/misc/test_sweep_data --algorithm ERM \ 7 | --dataset VLCS --test_env 0 8 | """ 9 | 10 | import collections 11 | 12 | 13 | import argparse 14 | import functools 15 | import glob 16 | import pickle 17 | import itertools 18 | import json 19 | import os 20 | import random 21 | import sys 22 | 23 | import numpy as np 24 | import tqdm 25 | 26 | from domainbed import datasets 27 | from domainbed import algorithms 28 | from domainbed.lib import misc, reporting 29 | from domainbed import model_selection 30 | from domainbed.lib.query import Q 31 | import warnings 32 | 33 | def todo_rename(records, selection_method, latex): 34 | 35 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 36 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 37 | ).filter(lambda g: g["sweep_acc"] is not None) 38 | 39 | # read algorithm names and sort (predefined order) 40 | alg_names = Q(records).select("args.algorithm").unique() 41 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 42 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 43 | 44 | # read dataset names and sort (lexicographic order) 45 | dataset_names = Q(records).select("args.dataset").unique().sorted() 46 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 47 | 48 | for dataset in dataset_names: 49 | if latex: 50 | print() 51 | print("\\subsubsection{{{}}}".format(dataset)) 52 | test_envs = range(datasets.num_environments(dataset)) 53 | 54 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 55 | for i, algorithm in enumerate(alg_names): 56 | means = [] 57 | for j, test_env in enumerate(test_envs): 58 | trial_accs = (grouped_records 59 | .filter_equals( 60 | "dataset, algorithm, test_env", 61 | (dataset, algorithm, test_env) 62 | ).select("sweep_acc")) 63 | mean, err, table[i][j] = format_mean(trial_accs, latex) 64 | means.append(mean) 65 | if None in means: 66 | table[i][-1] = "X" 67 | else: 68 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 69 | 70 | col_labels = [ 71 | "Algorithm", 72 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 73 | "Avg" 74 | ] 75 | header_text = (f"Dataset: {dataset}, " 76 | f"model selection method: {selection_method.name}") 77 | print_table(table, header_text, alg_names, list(col_labels), 78 | colwidth=20, latex=latex) 79 | 80 | # Print an "averages" table 81 | if latex: 82 | print() 83 | print("\\subsubsection{Averages}") 84 | 85 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 86 | for i, algorithm in enumerate(alg_names): 87 | means = [] 88 | for j, dataset in enumerate(dataset_names): 89 | trial_averages = (grouped_records 90 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 91 | .group("trial_seed") 92 | .map(lambda trial_seed, group: 93 | group.select("sweep_acc").mean() 94 | ) 95 | ) 96 | mean, err, table[i][j] = format_mean(trial_averages, latex) 97 | means.append(mean) 98 | if None in means: 99 | table[i][-1] = "X" 100 | else: 101 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 102 | 103 | col_labels = ["Algorithm", *dataset_names, "Avg"] 104 | header_text = f"Averages, model selection method: {selection_method.name}" 105 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 106 | latex=latex) 107 | 108 | if __name__ == "__main__": 109 | np.set_printoptions(suppress=True) 110 | 111 | parser = argparse.ArgumentParser( 112 | description="Domain generalization testbed") 113 | parser.add_argument("--input_dir", required=True) 114 | parser.add_argument('--dataset', required=True) 115 | parser.add_argument('--algorithm', required=True) 116 | parser.add_argument('--test_env', type=int, required=True) 117 | args = parser.parse_args() 118 | 119 | records = reporting.load_records(args.input_dir) 120 | print("Total records:", len(records)) 121 | 122 | records = reporting.get_grouped_records(records) 123 | records = records.filter( 124 | lambda r: 125 | r['dataset'] == args.dataset and 126 | r['algorithm'] == args.algorithm and 127 | r['test_env'] == args.test_env 128 | ) 129 | 130 | SELECTION_METHODS = [ 131 | model_selection.IIDAccuracySelectionMethod, 132 | model_selection.LeaveOneOutSelectionMethod, 133 | model_selection.OracleSelectionMethod, 134 | ] 135 | 136 | for selection_method in SELECTION_METHODS: 137 | print(f'Model selection: {selection_method.name}') 138 | 139 | for group in records: 140 | print(f"trial_seed: {group['trial_seed']}") 141 | best_hparams = selection_method.hparams_accs(group['records']) 142 | for run_acc, hparam_records in best_hparams: 143 | print(f"\t{run_acc}") 144 | for r in hparam_records: 145 | assert(r['hparams'] == hparam_records[0]['hparams']) 146 | print("\t\thparams:") 147 | for k, v in sorted(hparam_records[0]['hparams'].items()): 148 | print('\t\t\t{}: {}'.format(k, v)) 149 | print("\t\toutput_dirs:") 150 | output_dirs = hparam_records.select('args.output_dir').unique() 151 | for output_dir in output_dirs: 152 | print(f"\t\t\t{output_dir}") -------------------------------------------------------------------------------- /domainbed/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | 11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 12 | 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] 13 | 14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 15 | _model_to_module = {} # mapping of model names to module names 16 | _model_entrypoints = {} # mapping of model names to entrypoint fns 17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 18 | _model_pretrained_cfgs = dict() # central repo for model default_cfgs 19 | 20 | 21 | def register_model(fn): 22 | # lookup containing module 23 | mod = sys.modules[fn.__module__] 24 | module_name_split = fn.__module__.split('.') 25 | module_name = module_name_split[-1] if len(module_name_split) else '' 26 | 27 | # add model to __all__ in module 28 | model_name = fn.__name__ 29 | if hasattr(mod, '__all__'): 30 | mod.__all__.append(model_name) 31 | else: 32 | mod.__all__ = [model_name] 33 | 34 | # add entries to registry dict/sets 35 | _model_entrypoints[model_name] = fn 36 | _model_to_module[model_name] = module_name 37 | _module_to_models[module_name].add(model_name) 38 | has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this 39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 41 | # entrypoints or non-matching combos 42 | cfg = mod.default_cfgs[model_name] 43 | has_valid_pretrained = ( 44 | ('url' in cfg and 'http' in cfg['url']) or 45 | ('file' in cfg and cfg['file']) or 46 | ('hf_hub_id' in cfg and cfg['hf_hub_id']) 47 | ) 48 | _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] 49 | if has_valid_pretrained: 50 | _model_has_pretrained.add(model_name) 51 | return fn 52 | 53 | 54 | def _natural_key(string_): 55 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 56 | 57 | 58 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 59 | """ Return list of available model names, sorted alphabetically 60 | 61 | Args: 62 | filter (str) - Wildcard filter string that works with fnmatch 63 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 64 | pretrained (bool) - Include only models with pretrained weights if True 65 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 66 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 67 | 68 | Example: 69 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 70 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 71 | """ 72 | if module: 73 | all_models = list(_module_to_models[module]) 74 | else: 75 | all_models = _model_entrypoints.keys() 76 | if filter: 77 | models = [] 78 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 79 | for f in include_filters: 80 | include_models = fnmatch.filter(all_models, f) # include these models 81 | if len(include_models): 82 | models = set(models).union(include_models) 83 | else: 84 | models = all_models 85 | if exclude_filters: 86 | if not isinstance(exclude_filters, (tuple, list)): 87 | exclude_filters = [exclude_filters] 88 | for xf in exclude_filters: 89 | exclude_models = fnmatch.filter(models, xf) # exclude these models 90 | if len(exclude_models): 91 | models = set(models).difference(exclude_models) 92 | if pretrained: 93 | models = _model_has_pretrained.intersection(models) 94 | if name_matches_cfg: 95 | models = set(_model_pretrained_cfgs).intersection(models) 96 | return list(sorted(models, key=_natural_key)) 97 | 98 | 99 | def is_model(model_name): 100 | """ Check if a model name exists 101 | """ 102 | return model_name in _model_entrypoints 103 | 104 | 105 | def model_entrypoint(model_name): 106 | """Fetch a model entrypoint for specified model name 107 | """ 108 | return _model_entrypoints[model_name] 109 | 110 | 111 | def list_modules(): 112 | """ Return list of module names that contain models / model entrypoints 113 | """ 114 | modules = _module_to_models.keys() 115 | return list(sorted(modules)) 116 | 117 | 118 | def is_model_in_modules(model_name, module_names): 119 | """Check if a model exists within a subset of modules 120 | Args: 121 | model_name (str) - name of model to check 122 | module_names (tuple, list, set) - names of modules to search in 123 | """ 124 | assert isinstance(module_names, (tuple, list, set)) 125 | return any(model_name in _module_to_models[n] for n in module_names) 126 | 127 | 128 | def is_model_pretrained(model_name): 129 | return model_name in _model_has_pretrained 130 | 131 | 132 | def get_pretrained_cfg(model_name): 133 | if model_name in _model_pretrained_cfgs: 134 | return deepcopy(_model_pretrained_cfgs[model_name]) 135 | return {} 136 | 137 | 138 | def has_pretrained_cfg_key(model_name, cfg_key): 139 | """ Query model default_cfgs for existence of a specific key. 140 | """ 141 | if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]: 142 | return True 143 | return False 144 | 145 | 146 | def is_pretrained_cfg_key(model_name, cfg_key): 147 | """ Return truthy value for specified model default_cfg key, False if does not exist. 148 | """ 149 | if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False): 150 | return True 151 | return False 152 | 153 | 154 | def get_pretrained_cfg_value(model_name, cfg_key): 155 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 156 | """ 157 | if model_name in _model_pretrained_cfgs: 158 | return _model_pretrained_cfgs[model_name].get(cfg_key, None) 159 | return None -------------------------------------------------------------------------------- /domainbed/scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | import sys 6 | 7 | import numpy as np 8 | 9 | sys.path.append("/mnt/lustre/bli/projects/EIL/domainbed") 10 | from domainbed import algorithms 11 | from domainbed import datasets 12 | from domainbed import model_selection 13 | from domainbed.lib import misc, reporting 14 | from domainbed.lib.query import Q 15 | 16 | 17 | def format_mean(data, latex): 18 | """Given a list of datapoints, return a string describing their mean and 19 | standard error""" 20 | if len(data) == 0: 21 | return None, None, "X" 22 | mean = 100 * np.mean(list(data)) 23 | err = 100 * np.std(list(data) / np.sqrt(len(data))) 24 | if latex: 25 | return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err) 26 | else: 27 | return mean, err, "{:.1f} +/- {:.1f}".format(mean, err) 28 | 29 | 30 | def print_table(table, header_text, row_labels, col_labels, colwidth=10, 31 | latex=True): 32 | """Pretty-print a 2D array of data, optionally with row/col labels""" 33 | print("") 34 | 35 | if latex: 36 | num_cols = len(table[0]) 37 | print("\\begin{center}") 38 | print("\\adjustbox{max width=\\textwidth}{%") 39 | print("\\begin{tabular}{l" + "c" * num_cols + "}") 40 | print("\\toprule") 41 | else: 42 | print("--------", header_text) 43 | 44 | for row, label in zip(table, row_labels): 45 | row.insert(0, label) 46 | 47 | if latex: 48 | col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}" 49 | for col_label in col_labels] 50 | table.insert(0, col_labels) 51 | 52 | for r, row in enumerate(table): 53 | misc.print_row(row, colwidth=colwidth, latex=latex) 54 | if latex and r == 0: 55 | print("\\midrule") 56 | if latex: 57 | print("\\bottomrule") 58 | print("\\end{tabular}}") 59 | print("\\end{center}") 60 | 61 | 62 | def print_results_tables(records, selection_method, latex): 63 | """Given all records, print a results table for each dataset.""" 64 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 65 | {**group, "sweep_acc": selection_method.sweep_acc(group["records"])} 66 | ).filter(lambda g: g["sweep_acc"] is not None) 67 | 68 | # read algorithm names and sort (predefined order) 69 | alg_names = Q(records).select("args.algorithm").unique() 70 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 71 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 72 | 73 | # read dataset names and sort (lexicographic order) 74 | dataset_names = Q(records).select("args.dataset").unique().sorted() 75 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 76 | 77 | for dataset in dataset_names: 78 | if latex: 79 | print() 80 | print("\\subsubsection{{{}}}".format(dataset)) 81 | test_envs = range(datasets.num_environments(dataset)) 82 | 83 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 84 | for i, algorithm in enumerate(alg_names): 85 | means = [] 86 | for j, test_env in enumerate(test_envs): 87 | trial_accs = (grouped_records.filter_equals("dataset, algorithm, test_env", (dataset, algorithm, test_env)).select("sweep_acc")) 88 | mean, err, table[i][j] = format_mean(trial_accs, latex) 89 | means.append(mean) 90 | if None in means: 91 | table[i][-1] = "X" 92 | else: 93 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 94 | 95 | col_labels = [ 96 | "Algorithm", 97 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 98 | "Avg" 99 | ] 100 | header_text = (f"Dataset: {dataset}, " 101 | f"model selection method: {selection_method.name}") 102 | print_table(table, header_text, alg_names, list(col_labels), 103 | colwidth=20, latex=latex) 104 | 105 | # Print an "averages" table 106 | if latex: 107 | print() 108 | print("\\subsubsection{Averages}") 109 | 110 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 111 | for i, algorithm in enumerate(alg_names): 112 | means = [] 113 | for j, dataset in enumerate(dataset_names): 114 | trial_averages = (grouped_records 115 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 116 | .group("trial_seed") 117 | .map(lambda trial_seed, group: 118 | group.select("sweep_acc").mean() 119 | ) 120 | ) 121 | mean, err, table[i][j] = format_mean(trial_averages, latex) 122 | means.append(mean) 123 | if None in means: 124 | table[i][-1] = "X" 125 | else: 126 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 127 | 128 | col_labels = ["Algorithm", *dataset_names, "Avg"] 129 | header_text = f"Averages, model selection method: {selection_method.name}" 130 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 131 | latex=latex) 132 | 133 | 134 | if __name__ == "__main__": 135 | np.set_printoptions(suppress=True) 136 | 137 | parser = argparse.ArgumentParser( 138 | description="Domain generalization testbed") 139 | parser.add_argument("--input_dir", type=str, required=True) 140 | parser.add_argument("--latex", action="store_true") 141 | args = parser.parse_args() 142 | 143 | results_file = "results.tex" if args.latex else "results.txt" 144 | 145 | sys.stdout = misc.Tee(os.path.join(args.input_dir, results_file), "w") 146 | 147 | records = reporting.load_records(args.input_dir) 148 | 149 | if args.latex: 150 | print("\\documentclass{article}") 151 | print("\\usepackage{booktabs}") 152 | print("\\usepackage{adjustbox}") 153 | print("\\begin{document}") 154 | print("\\section{Full DomainBed results}") 155 | print("% Total records:", len(records)) 156 | else: 157 | print("Total records:", len(records)) 158 | 159 | SELECTION_METHODS = [ 160 | model_selection.IIDAccuracySelectionMethod, 161 | model_selection.LeaveOneOutSelectionMethod, 162 | model_selection.OracleSelectionMethod, 163 | ] 164 | 165 | for selection_method in SELECTION_METHODS: 166 | if args.latex: 167 | print() 168 | print("\\subsection{{Model selection: {}}}".format( 169 | selection_method.name)) 170 | print_results_tables(records, selection_method, args.latex) 171 | 172 | if args.latex: 173 | print("\\end{document}") 174 | -------------------------------------------------------------------------------- /domainbed/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import copy 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.models 9 | 10 | from domainbed.lib import wide_resnet 11 | 12 | 13 | def remove_batch_norm_from_resnet(model): 14 | fuse = torch.nn.utils.fusion.fuse_conv_bn_eval 15 | model.eval() 16 | 17 | model.conv1 = fuse(model.conv1, model.bn1) 18 | model.bn1 = Identity() 19 | 20 | for name, module in model.named_modules(): 21 | if name.startswith("layer") and len(name) == 6: 22 | for b, bottleneck in enumerate(module): 23 | for name2, module2 in bottleneck.named_modules(): 24 | if name2.startswith("conv"): 25 | bn_name = "bn" + name2[-1] 26 | setattr(bottleneck, name2, 27 | fuse(module2, getattr(bottleneck, bn_name))) 28 | setattr(bottleneck, bn_name, Identity()) 29 | if isinstance(bottleneck.downsample, torch.nn.Sequential): 30 | bottleneck.downsample[0] = fuse(bottleneck.downsample[0], 31 | bottleneck.downsample[1]) 32 | bottleneck.downsample[1] = Identity() 33 | model.train() 34 | return model 35 | 36 | 37 | class Identity(nn.Module): 38 | """An identity layer""" 39 | def __init__(self): 40 | super(Identity, self).__init__() 41 | 42 | def forward(self, x): 43 | return x 44 | 45 | 46 | class MLP(nn.Module): 47 | """Just an MLP""" 48 | def __init__(self, n_inputs, n_outputs, hparams): 49 | super(MLP, self).__init__() 50 | self.input = nn.Linear(n_inputs, hparams['mlp_width']) 51 | self.dropout = nn.Dropout(hparams['mlp_dropout']) 52 | self.hiddens = nn.ModuleList([ 53 | nn.Linear(hparams['mlp_width'], hparams['mlp_width']) 54 | for _ in range(hparams['mlp_depth']-2)]) 55 | self.output = nn.Linear(hparams['mlp_width'], n_outputs) 56 | self.n_outputs = n_outputs 57 | 58 | def forward(self, x): 59 | x = self.input(x) 60 | x = self.dropout(x) 61 | x = F.relu(x) 62 | for hidden in self.hiddens: 63 | x = hidden(x) 64 | x = self.dropout(x) 65 | x = F.relu(x) 66 | x = self.output(x) 67 | return x 68 | 69 | 70 | class ResNet(torch.nn.Module): 71 | """ResNet with the softmax chopped off and the batchnorm frozen""" 72 | def __init__(self, input_shape, hparams): 73 | super(ResNet, self).__init__() 74 | if hparams['resnet18']: 75 | self.network = torchvision.models.resnet18(pretrained=True) 76 | self.n_outputs = 512 77 | else: 78 | self.network = torchvision.models.resnet50(pretrained=True) 79 | self.n_outputs = 2048 80 | 81 | # self.network = remove_batch_norm_from_resnet(self.network) 82 | 83 | # adapt number of channels 84 | nc = input_shape[0] 85 | if nc != 3: 86 | tmp = self.network.conv1.weight.data.clone() 87 | 88 | self.network.conv1 = nn.Conv2d( 89 | nc, 64, kernel_size=(7, 7), 90 | stride=(2, 2), padding=(3, 3), bias=False) 91 | 92 | for i in range(nc): 93 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :] 94 | 95 | # save memory 96 | del self.network.fc 97 | self.network.fc = Identity() 98 | 99 | self.freeze_bn() 100 | self.hparams = hparams 101 | self.dropout = nn.Dropout(hparams['resnet_dropout']) 102 | 103 | def forward(self, x): 104 | """Encode x into a feature vector of size n_outputs.""" 105 | return self.dropout(self.network(x)) 106 | 107 | def train(self, mode=True): 108 | """ 109 | Override the default train() to freeze the BN parameters 110 | """ 111 | super().train(mode) 112 | self.freeze_bn() 113 | 114 | def freeze_bn(self): 115 | for m in self.network.modules(): 116 | if isinstance(m, nn.BatchNorm2d): 117 | m.eval() 118 | 119 | 120 | class MNIST_CNN(nn.Module): 121 | """ 122 | Hand-tuned architecture for MNIST. 123 | Weirdness I've noticed so far with this architecture: 124 | - adding a linear layer after the mean-pool in features hurts 125 | RotatedMNIST-100 generalization severely. 126 | """ 127 | n_outputs = 128 128 | 129 | def __init__(self, input_shape): 130 | super(MNIST_CNN, self).__init__() 131 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1) 132 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 133 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1) 134 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1) 135 | 136 | self.bn0 = nn.GroupNorm(8, 64) 137 | self.bn1 = nn.GroupNorm(8, 128) 138 | self.bn2 = nn.GroupNorm(8, 128) 139 | self.bn3 = nn.GroupNorm(8, 128) 140 | 141 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = F.relu(x) 146 | x = self.bn0(x) 147 | 148 | x = self.conv2(x) 149 | x = F.relu(x) 150 | x = self.bn1(x) 151 | 152 | x = self.conv3(x) 153 | x = F.relu(x) 154 | x = self.bn2(x) 155 | 156 | x = self.conv4(x) 157 | x = F.relu(x) 158 | x = self.bn3(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(len(x), -1) 162 | return x 163 | 164 | 165 | class ContextNet(nn.Module): 166 | def __init__(self, input_shape): 167 | super(ContextNet, self).__init__() 168 | 169 | # Keep same dimensions 170 | padding = (5 - 1) // 2 171 | self.context_net = nn.Sequential( 172 | nn.Conv2d(input_shape[0], 64, 5, padding=padding), 173 | nn.BatchNorm2d(64), 174 | nn.ReLU(), 175 | nn.Conv2d(64, 64, 5, padding=padding), 176 | nn.BatchNorm2d(64), 177 | nn.ReLU(), 178 | nn.Conv2d(64, 1, 5, padding=padding), 179 | ) 180 | 181 | def forward(self, x): 182 | return self.context_net(x) 183 | 184 | 185 | class PositionalEmbedding1D(nn.Module): 186 | """Adds (optionally learned) positional embeddings to the inputs.""" 187 | 188 | def __init__(self, seq_len, dim): 189 | super().__init__() 190 | self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim)) 191 | 192 | def forward(self, x): 193 | """Input has shape `(batch_size, seq_len, emb_dim)`""" 194 | return x + self.pos_embedding 195 | 196 | 197 | def Featurizer(input_shape, hparams=None): 198 | """Auto-select an appropriate featurizer for the given input shape.""" 199 | if len(input_shape) == 1: 200 | return MLP(input_shape[0], hparams["mlp_width"], hparams) 201 | elif input_shape[1:3] == (28, 28): 202 | return MNIST_CNN(input_shape) 203 | elif input_shape[1:3] == (32, 32): 204 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.) 205 | elif input_shape[1:3] == (224, 224): 206 | return ResNet(input_shape, hparams) 207 | else: 208 | raise NotImplementedError 209 | 210 | 211 | def Classifier(in_features, out_features, is_nonlinear=False): 212 | if is_nonlinear: 213 | return torch.nn.Sequential( 214 | torch.nn.Linear(in_features, in_features // 2), 215 | torch.nn.ReLU(), 216 | torch.nn.Linear(in_features // 2, in_features // 4), 217 | torch.nn.ReLU(), 218 | torch.nn.Linear(in_features // 4, out_features)) 219 | else: 220 | return torch.nn.Linear(in_features, out_features) 221 | 222 | 223 | class WholeFish(nn.Module): 224 | def __init__(self, input_shape, num_classes, hparams, weights=None): 225 | super(WholeFish, self).__init__() 226 | featurizer = Featurizer(input_shape, hparams) 227 | classifier = Classifier( 228 | featurizer.n_outputs, 229 | num_classes, 230 | hparams['nonlinear_classifier']) 231 | self.net = nn.Sequential( 232 | featurizer, classifier 233 | ) 234 | if weights is not None: 235 | self.load_state_dict(copy.deepcopy(weights)) 236 | 237 | def reset_weights(self, weights): 238 | self.load_state_dict(copy.deepcopy(weights)) 239 | 240 | def forward(self, x): 241 | return self.net(x) -------------------------------------------------------------------------------- /domainbed/scripts/sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Run sweeps 5 | """ 6 | 7 | import argparse 8 | import copy 9 | import hashlib 10 | import json 11 | import os 12 | import shlex 13 | import shutil 14 | import sys 15 | 16 | import numpy as np 17 | import tqdm 18 | 19 | sys.path.append("/mnt/lustre/bli/projects/EIL/domainbed") 20 | 21 | from domainbed import algorithms 22 | from domainbed import command_launchers 23 | from domainbed import datasets 24 | from domainbed.lib import misc 25 | 26 | 27 | class Job: 28 | NOT_LAUNCHED = 'Not launched' 29 | INCOMPLETE = 'Incomplete' 30 | DONE = 'Done' 31 | 32 | def __init__(self, train_args, sweep_output_dir): 33 | args_str = json.dumps(train_args, sort_keys=True) 34 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 35 | # args_hash = train_args['dataset'] + '_' + train_args['algorithm'] + '_' + train_args['test_envs'] + args_hash 36 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 37 | 38 | self.train_args = copy.deepcopy(train_args) 39 | self.train_args['output_dir'] = self.output_dir 40 | command = ['python', '-m', 'domainbed.scripts.train'] 41 | for k, v in sorted(self.train_args.items()): 42 | if isinstance(v, list): 43 | v = ' '.join([str(v_) for v_ in v]) 44 | elif isinstance(v, str): 45 | v = shlex.quote(v) 46 | command.append(f'--{k} {v}') 47 | self.command_str = ' '.join(command) 48 | 49 | if os.path.exists(os.path.join(self.output_dir, 'done')): 50 | self.state = Job.DONE 51 | elif os.path.exists(self.output_dir): 52 | self.state = Job.INCOMPLETE 53 | else: 54 | self.state = Job.NOT_LAUNCHED 55 | 56 | def __str__(self): 57 | job_info = (self.train_args['dataset'], 58 | self.train_args['algorithm'], 59 | self.train_args['test_envs'], 60 | self.train_args['hparams_seed']) 61 | return '{}: {} {}'.format( 62 | self.state, 63 | self.output_dir, 64 | job_info) 65 | 66 | @staticmethod 67 | def launch(jobs, launcher_fn): 68 | print('Launching...') 69 | jobs = jobs.copy() 70 | np.random.shuffle(jobs) 71 | print('Making job directories:') 72 | for job in tqdm.tqdm(jobs, leave=False): 73 | os.makedirs(job.output_dir, exist_ok=True) 74 | commands = [job.command_str for job in jobs] 75 | launcher_fn(commands) 76 | print(f'Launched {len(jobs)} jobs!') 77 | 78 | @staticmethod 79 | def delete(jobs): 80 | print('Deleting...') 81 | for job in jobs: 82 | shutil.rmtree(job.output_dir) 83 | print(f'Deleted {len(jobs)} jobs!') 84 | 85 | 86 | def all_test_env_combinations(n): 87 | """ 88 | For a dataset with n >= 3 envs, return all combinations of 1 and 2 test 89 | envs. 90 | """ 91 | assert (n >= 3) 92 | for i in range(n): 93 | yield [i] 94 | for j in range(i + 1, n): 95 | yield [i, j] 96 | 97 | 98 | def make_args_list(n_trials, dataset_names, algorithms, n_hparams_from, n_hparams, steps, 99 | data_dir, task, holdout_fraction, single_test_envs, hparams): 100 | args_list = [] 101 | for trial_seed in range(n_trials): 102 | for dataset in dataset_names: 103 | for algorithm in algorithms: 104 | if single_test_envs: 105 | all_test_envs = [ 106 | [i] for i in range(datasets.num_environments(dataset))] 107 | else: 108 | all_test_envs = all_test_env_combinations( 109 | datasets.num_environments(dataset)) 110 | for test_envs in all_test_envs: 111 | for hparams_seed in range(n_hparams_from, n_hparams): 112 | train_args = {} 113 | train_args['dataset'] = dataset 114 | train_args['algorithm'] = algorithm 115 | train_args['test_envs'] = test_envs 116 | train_args['holdout_fraction'] = holdout_fraction 117 | train_args['hparams_seed'] = hparams_seed 118 | train_args['data_dir'] = data_dir 119 | train_args['task'] = task 120 | train_args['trial_seed'] = trial_seed 121 | train_args['seed'] = misc.seed_hash(dataset, algorithm, test_envs, hparams_seed, trial_seed) 122 | if steps is not None: 123 | train_args['steps'] = steps 124 | if hparams is not None: 125 | train_args['hparams'] = hparams 126 | 127 | args_list.append(train_args) 128 | return args_list 129 | 130 | 131 | def ask_for_confirmation(): 132 | response = input('Are you sure? (y/n) ') 133 | if not response.lower().strip()[:1] == "y": 134 | print('Nevermind!') 135 | exit(0) 136 | 137 | 138 | DATASETS = [d for d in datasets.DATASETS if "Debug" not in d] 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description='Run a sweep') 142 | parser.add_argument('command', choices=['launch', 'delete_incomplete', 'delete_and_launch']) 143 | parser.add_argument('--datasets', nargs='+', type=str, default=DATASETS) 144 | parser.add_argument('--algorithms', nargs='+', type=str, default=algorithms.ALGORITHMS) 145 | parser.add_argument('--task', type=str, default="domain_generalization") 146 | parser.add_argument('--n_hparams_from', type=int, default=0) 147 | parser.add_argument('--n_hparams', type=int, default=20) 148 | parser.add_argument('--output_dir', type=str, required=True) 149 | parser.add_argument('--data_dir', type=str, default='/mnt/lustre/share/boli/domainbed_data') 150 | parser.add_argument('--seed', type=int, default=0) 151 | parser.add_argument('--n_trials', type=int, default=3) 152 | parser.add_argument('--command_launcher', type=str, required=True) 153 | parser.add_argument('--steps', type=int, default=None) 154 | parser.add_argument('--hparams', type=str, default=None) 155 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 156 | parser.add_argument('--single_test_envs', action='store_true') 157 | parser.add_argument('--skip_confirmation', action='store_true') 158 | args = parser.parse_args() 159 | 160 | args_list = make_args_list( 161 | n_trials=args.n_trials, 162 | dataset_names=args.datasets, 163 | algorithms=args.algorithms, 164 | n_hparams_from=args.n_hparams_from, 165 | n_hparams=args.n_hparams, 166 | steps=args.steps, 167 | data_dir=args.data_dir, 168 | task=args.task, 169 | holdout_fraction=args.holdout_fraction, 170 | single_test_envs=args.single_test_envs, 171 | hparams=args.hparams 172 | ) 173 | 174 | jobs = [Job(train_args, args.output_dir) for train_args in args_list] 175 | 176 | for job in jobs: 177 | print(job) 178 | print("{} jobs: {} done, {} incomplete, {} not launched.".format( 179 | len(jobs), 180 | len([j for j in jobs if j.state == Job.DONE]), 181 | len([j for j in jobs if j.state == Job.INCOMPLETE]), 182 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) 183 | ) 184 | 185 | if args.command == 'launch': 186 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] 187 | print(f'About to launch {len(to_launch)} jobs.') 188 | if not args.skip_confirmation: 189 | ask_for_confirmation() 190 | launcher_fn = command_launchers.REGISTRY[args.command_launcher] 191 | Job.launch(to_launch, launcher_fn) 192 | 193 | elif args.command == 'delete_incomplete': 194 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 195 | print(f'About to delete {len(to_delete)} jobs.') 196 | if not args.skip_confirmation: 197 | ask_for_confirmation() 198 | Job.delete(to_delete) 199 | 200 | elif args.command == 'delete_and_launch': 201 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 202 | print(f'About to delete {len(to_delete)} jobs.') 203 | Job.delete(to_delete) 204 | 205 | for j in jobs: 206 | if j.state == Job.INCOMPLETE: 207 | j.state = Job.NOT_LAUNCHED 208 | 209 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] 210 | print(f'About to launch {len(to_launch)} jobs.') 211 | launcher_fn = command_launchers.REGISTRY[args.command_launcher] 212 | Job.launch(to_launch, launcher_fn) 213 | -------------------------------------------------------------------------------- /domainbed/lib/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Things that don't belong anywhere else 5 | """ 6 | 7 | import hashlib 8 | import operator 9 | import sys 10 | from collections import Counter 11 | from collections import OrderedDict 12 | from datetime import datetime 13 | from numbers import Number 14 | 15 | import numpy as np 16 | import torch 17 | 18 | 19 | def merge_dictlist(dictlist): 20 | """Merge list of dicts into dict of lists, by grouping same key. 21 | """ 22 | ret = { 23 | k: [] 24 | for k in dictlist[0].keys() 25 | } 26 | for dic in dictlist: 27 | for data_key, v in dic.items(): 28 | ret[data_key].append(v) 29 | return ret 30 | 31 | 32 | def index_conditional_iterate(skip_condition, iterable, index): 33 | for i, x in enumerate(iterable): 34 | if skip_condition(i): 35 | continue 36 | 37 | if index: 38 | yield i, x 39 | else: 40 | yield x 41 | 42 | 43 | class SplitIterator: 44 | def __init__(self, test_envs): 45 | self.test_envs = test_envs 46 | 47 | def train(self, iterable, index=False): 48 | return index_conditional_iterate(lambda idx: idx in self.test_envs, iterable, index) 49 | 50 | def test(self, iterable, index=False): 51 | return index_conditional_iterate(lambda idx: idx not in self.test_envs, iterable, index) 52 | 53 | 54 | def timestamp(fmt="%y%m%d_%H-%M-%S"): 55 | return datetime.now().strftime(fmt) 56 | 57 | 58 | def l2_between_dicts(dict_1, dict_2): 59 | assert len(dict_1) == len(dict_2) 60 | dict_1_values = [dict_1[key] for key in sorted(dict_1.keys())] 61 | dict_2_values = [dict_2[key] for key in sorted(dict_1.keys())] 62 | return ( 63 | torch.cat(tuple([t.view(-1) for t in dict_1_values])) - 64 | torch.cat(tuple([t.view(-1) for t in dict_2_values])) 65 | ).pow(2).mean() 66 | 67 | 68 | class MovingAverage: 69 | 70 | def __init__(self, ema, oneminusema_correction=True): 71 | self.ema = ema 72 | self.ema_data = {} 73 | self._updates = 0 74 | self._oneminusema_correction = oneminusema_correction 75 | 76 | def update(self, dict_data): 77 | ema_dict_data = {} 78 | for name, data in dict_data.items(): 79 | data = data.view(1, -1) 80 | if self._updates == 0: 81 | previous_data = torch.zeros_like(data) 82 | else: 83 | previous_data = self.ema_data[name] 84 | 85 | ema_data = self.ema * previous_data + (1 - self.ema) * data 86 | if self._oneminusema_correction: 87 | # correction by 1/(1 - self.ema) 88 | # so that the gradients amplitude backpropagated in data is independent of self.ema 89 | ema_dict_data[name] = ema_data / (1 - self.ema) 90 | else: 91 | ema_dict_data[name] = ema_data 92 | self.ema_data[name] = ema_data.clone().detach() 93 | 94 | self._updates += 1 95 | return ema_dict_data 96 | 97 | 98 | def make_weights_for_balanced_classes(dataset): 99 | counts = Counter() 100 | classes = [] 101 | for _, y in dataset: 102 | y = int(y) 103 | counts[y] += 1 104 | classes.append(y) 105 | 106 | n_classes = len(counts) 107 | 108 | weight_per_class = {} 109 | for y in counts: 110 | weight_per_class[y] = 1 / (counts[y] * n_classes) 111 | 112 | weights = torch.zeros(len(dataset)) 113 | for i, y in enumerate(classes): 114 | weights[i] = weight_per_class[int(y)] 115 | 116 | return weights 117 | 118 | 119 | def pdb(): 120 | sys.stdout = sys.__stdout__ 121 | import pdb 122 | print("Launching PDB, enter 'n' to step to parent function.") 123 | pdb.set_trace() 124 | 125 | 126 | def seed_hash(*args): 127 | """ 128 | Derive an integer hash from all args, for use as a random seed. 129 | """ 130 | args_str = str(args) 131 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2 ** 31) 132 | 133 | 134 | def print_separator(): 135 | print("=" * 80) 136 | 137 | 138 | def print_row(row, colwidth=10, latex=False): 139 | if latex: 140 | sep = " & " 141 | end_ = "\\\\" 142 | else: 143 | sep = " " 144 | end_ = "" 145 | 146 | def format_val(x): 147 | if np.issubdtype(type(x), np.floating): 148 | x = "{:.4f}".format(x) 149 | return str(x).ljust(colwidth)[:colwidth] 150 | 151 | print(sep.join([format_val(x) for x in row]), end_) 152 | 153 | 154 | class _SplitDataset(torch.utils.data.Dataset): 155 | """Used by split_dataset""" 156 | 157 | def __init__(self, underlying_dataset, keys): 158 | super(_SplitDataset, self).__init__() 159 | self.underlying_dataset = underlying_dataset 160 | self.keys = keys 161 | 162 | def __getitem__(self, key): 163 | return self.underlying_dataset[self.keys[key]] 164 | 165 | def __len__(self): 166 | return len(self.keys) 167 | 168 | 169 | def split_dataset(dataset, n, seed=0): 170 | """ 171 | Return a pair of datasets corresponding to a random split of the given 172 | dataset, with n datapoints in the first dataset and the rest in the last, 173 | using the given random seed 174 | """ 175 | assert (n <= len(dataset)) 176 | keys = list(range(len(dataset))) 177 | np.random.RandomState(seed).shuffle(keys) 178 | keys_1 = keys[:n] 179 | keys_2 = keys[n:] 180 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 181 | 182 | 183 | def random_pairs_of_minibatches(minibatches): 184 | perm = torch.randperm(len(minibatches)).tolist() 185 | pairs = [] 186 | 187 | for i in range(len(minibatches)): 188 | j = i + 1 if i < (len(minibatches) - 1) else 0 189 | 190 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 191 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 192 | 193 | min_n = min(len(xi), len(xj)) 194 | 195 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 196 | 197 | return pairs 198 | 199 | 200 | def accuracy(network, loader, weights, device): 201 | correct = 0 202 | total = 0 203 | weights_offset = 0 204 | 205 | network.eval() 206 | with torch.no_grad(): 207 | for x, y in loader: 208 | x = x.to(device) 209 | y = y.to(device) 210 | p = network.predict(x) 211 | if weights is None: 212 | batch_weights = torch.ones(len(x)) 213 | else: 214 | batch_weights = weights[weights_offset: weights_offset + len(x)] 215 | weights_offset += len(x) 216 | batch_weights = batch_weights.to(device) 217 | if p.size(1) == 1: 218 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() 219 | else: 220 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() 221 | total += batch_weights.sum().item() 222 | network.train() 223 | 224 | return correct / total 225 | 226 | 227 | class Tee: 228 | def __init__(self, fname, mode="a"): 229 | self.stdout = sys.stdout 230 | self.file = open(fname, mode) 231 | 232 | def write(self, message): 233 | self.stdout.write(message) 234 | self.file.write(message) 235 | self.flush() 236 | 237 | def flush(self): 238 | self.stdout.flush() 239 | self.file.flush() 240 | 241 | 242 | class ParamDict(OrderedDict): 243 | """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. 244 | A dictionary where the values are Tensors, meant to represent weights of 245 | a model. This subclass lets you perform arithmetic on weights directly.""" 246 | 247 | def __init__(self, *args, **kwargs): 248 | super().__init__(*args, *kwargs) 249 | 250 | def _prototype(self, other, op): 251 | if isinstance(other, Number): 252 | return ParamDict({k: op(v, other) for k, v in self.items()}) 253 | elif isinstance(other, dict): 254 | return ParamDict({k: op(self[k], other[k]) for k in self}) 255 | else: 256 | raise NotImplementedError 257 | 258 | def __add__(self, other): 259 | return self._prototype(other, operator.add) 260 | 261 | def __rmul__(self, other): 262 | return self._prototype(other, operator.mul) 263 | 264 | __mul__ = __rmul__ 265 | 266 | def __neg__(self): 267 | return ParamDict({k: -v for k, v in self.items()}) 268 | 269 | def __rsub__(self, other): 270 | # a- b := a + (-b) 271 | return self.__add__(other.__neg__()) 272 | 273 | __sub__ = __rsub__ 274 | 275 | def __truediv__(self, other): 276 | return self._prototype(other, operator.truediv) 277 | -------------------------------------------------------------------------------- /domainbed/lib/swa_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py 2 | import copy 3 | import warnings 4 | import math 5 | from copy import deepcopy 6 | 7 | import torch 8 | from torch.nn import Module 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | 11 | from domainbed.ur_networks import URResNet 12 | 13 | 14 | class AveragedModel(Module): 15 | def filter(self, model): 16 | if isinstance(model, AveragedModel): 17 | # prevent nested averagedmodel 18 | model = model.module 19 | 20 | if hasattr(model, "get_forward_model"): 21 | model = model.get_forward_model() 22 | # URERM models use URNetwork, which manages features internally. 23 | for m in model.modules(): 24 | if isinstance(m, URResNet): 25 | m.clear_features() 26 | 27 | return model 28 | 29 | def __init__(self, model, device=None, avg_fn=None, rm_optimizer=False): 30 | super(AveragedModel, self).__init__() 31 | self.start_step = -1 32 | self.end_step = -1 33 | model = self.filter(model) 34 | self.module = deepcopy(model) 35 | self.module.zero_grad(set_to_none=True) 36 | if rm_optimizer: 37 | for k, v in vars(self.module).items(): 38 | if isinstance(v, torch.optim.Optimizer): 39 | setattr(self.module, k, None) 40 | # print(f"{k} -> {getattr(self.module, k)}") 41 | if device is not None: 42 | self.module = self.module.to(device) 43 | self.register_buffer('n_averaged', torch.tensor(0, dtype=torch.long, device=device)) 44 | if avg_fn is None: 45 | def avg_fn(averaged_model_parameter, model_parameter, num_averaged): 46 | return averaged_model_parameter + \ 47 | (model_parameter - averaged_model_parameter) / (num_averaged + 1) 48 | self.avg_fn = avg_fn 49 | 50 | def forward(self, *args, **kwargs): 51 | # return self.predict(*args, **kwargs) 52 | return self.module(*args, **kwargs) 53 | 54 | def predict(self, *args, **kwargs): 55 | return self.module.predict(*args, **kwargs) 56 | 57 | @property 58 | def network(self): 59 | return self.module.network 60 | 61 | def update_parameters(self, model, step=None, start_step=None, end_step=None): 62 | model = self.filter(model) 63 | for p_swa, p_model in zip(self.parameters(), model.parameters()): 64 | device = p_swa.device 65 | p_model_ = p_model.detach().to(device) 66 | if self.n_averaged == 0: 67 | p_swa.detach().copy_(p_model_) 68 | else: 69 | p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, 70 | self.n_averaged.to(device))) 71 | self.n_averaged += 1 72 | 73 | if step is not None: 74 | if start_step is None: 75 | start_step = step 76 | if end_step is None: 77 | end_step = step 78 | 79 | if start_step is not None: 80 | if self.n_averaged == 1: 81 | self.start_step = start_step 82 | 83 | if end_step is not None: 84 | self.end_step = end_step 85 | 86 | def clone(self): 87 | clone = copy.deepcopy(self.module) 88 | clone.optimizer = clone.new_optimizer(clone.network.parameters()) 89 | return clone 90 | 91 | 92 | @torch.no_grad() 93 | def update_bn(iterator, model, n_steps, device='cuda'): 94 | momenta = {} 95 | for module in model.modules(): 96 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 97 | module.running_mean = torch.zeros_like(module.running_mean) 98 | module.running_var = torch.ones_like(module.running_var) 99 | momenta[module] = module.momentum 100 | 101 | if not momenta: 102 | return 103 | 104 | was_training = model.training 105 | model.train() 106 | for module in momenta.keys(): 107 | module.momentum = None 108 | module.num_batches_tracked *= 0 109 | 110 | # for input in loader: 111 | for i in range(n_steps): 112 | # batches_dictlist: [{env0_data_key: tensor, env0_...}, env1_..., ...] 113 | batches_dictlist = next(iterator) 114 | x = torch.cat([ 115 | dic["x"] for dic in batches_dictlist 116 | ]) 117 | x = x.to(device) 118 | 119 | model(x) 120 | 121 | for bn_module in momenta.keys(): 122 | bn_module.momentum = momenta[bn_module] 123 | model.train(was_training) 124 | 125 | 126 | class SWALR(_LRScheduler): 127 | r"""Anneals the learning rate in each parameter group to a fixed value. 128 | This learning rate scheduler is meant to be used with Stochastic Weight 129 | Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). 130 | Arguments: 131 | optimizer (torch.optim.Optimizer): wrapped optimizer 132 | swa_lrs (float or list): the learning rate value for all param groups 133 | together or separately for each group. 134 | annealing_epochs (int): number of epochs in the annealing phase 135 | (default: 10) 136 | annealing_strategy (str): "cos" or "linear"; specifies the annealing 137 | strategy: "cos" for cosine annealing, "linear" for linear annealing 138 | (default: "cos") 139 | last_epoch (int): the index of the last epoch (default: 'cos') 140 | The :class:`SWALR` scheduler is can be used together with other 141 | schedulers to switch to a constant learning rate late in the training 142 | as in the example below. 143 | Example: 144 | >>> loader, optimizer, model = ... 145 | >>> lr_lambda = lambda epoch: 0.9 146 | >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 147 | >>> lr_lambda=lr_lambda) 148 | >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, 149 | >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) 150 | >>> swa_start = 160 151 | >>> for i in range(300): 152 | >>> for input, target in loader: 153 | >>> optimizer.zero_grad() 154 | >>> loss_fn(model(input), target).backward() 155 | >>> optimizer.step() 156 | >>> if i > swa_start: 157 | >>> swa_scheduler.step() 158 | >>> else: 159 | >>> scheduler.step() 160 | .. _Averaging Weights Leads to Wider Optima and Better Generalization: 161 | https://arxiv.org/abs/1803.05407 162 | """ 163 | def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1): 164 | swa_lrs = self._format_param(optimizer, swa_lr) 165 | for swa_lr, group in zip(swa_lrs, optimizer.param_groups): 166 | group['swa_lr'] = swa_lr 167 | if anneal_strategy not in ['cos', 'linear']: 168 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', " 169 | "instead got {}".format(anneal_strategy)) 170 | elif anneal_strategy == 'cos': 171 | self.anneal_func = self._cosine_anneal 172 | elif anneal_strategy == 'linear': 173 | self.anneal_func = self._linear_anneal 174 | if not isinstance(anneal_epochs, int) or anneal_epochs < 1: 175 | raise ValueError("anneal_epochs must be a positive integer, got {}".format( 176 | anneal_epochs)) 177 | self.anneal_epochs = anneal_epochs 178 | 179 | super(SWALR, self).__init__(optimizer, last_epoch) 180 | 181 | @staticmethod 182 | def _format_param(optimizer, swa_lrs): 183 | if isinstance(swa_lrs, (list, tuple)): 184 | if len(swa_lrs) != len(optimizer.param_groups): 185 | raise ValueError("swa_lr must have the same length as " 186 | "optimizer.param_groups: swa_lr has {}, " 187 | "optimizer.param_groups has {}".format( 188 | len(swa_lrs), len(optimizer.param_groups))) 189 | return swa_lrs 190 | else: 191 | return [swa_lrs] * len(optimizer.param_groups) 192 | 193 | @staticmethod 194 | def _linear_anneal(t): 195 | return t 196 | 197 | @staticmethod 198 | def _cosine_anneal(t): 199 | return (1 - math.cos(math.pi * t)) / 2 200 | 201 | @staticmethod 202 | def _get_initial_lr(lr, swa_lr, alpha): 203 | if alpha == 1: 204 | return swa_lr 205 | return (lr - alpha * swa_lr) / (1 - alpha) 206 | 207 | def get_lr(self): 208 | if not self._get_lr_called_within_step: 209 | warnings.warn("To get the last learning rate computed by the scheduler, " 210 | "please use `get_last_lr()`.", UserWarning) 211 | step = self._step_count - 1 212 | prev_t = max(0, min(1, (step - 1) / self.anneal_epochs)) 213 | prev_alpha = self.anneal_func(prev_t) 214 | prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha) 215 | for group in self.optimizer.param_groups] 216 | t = max(0, min(1, step / self.anneal_epochs)) 217 | alpha = self.anneal_func(t) 218 | return [group['swa_lr'] * alpha + lr * (1 - alpha) 219 | for group, lr in zip(self.optimizer.param_groups, prev_lrs)] 220 | -------------------------------------------------------------------------------- /domainbed/hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | 4 | from domainbed.lib import misc 5 | 6 | 7 | def _define_hparam(hparams, hparam_name, default_val, random_val_fn): 8 | hparams[hparam_name] = (hparams, hparam_name, default_val, random_val_fn) 9 | 10 | 11 | def _hparams(algorithm, dataset, random_seed): 12 | """ 13 | Global registry of hyperparams. Each entry is a (default, random) tuple. 14 | New algorithms / networks / etc. should add entries here. 15 | """ 16 | SMALL_IMAGES = ['Debug28', 'RotatedMNIST', 'ColoredMNIST'] 17 | 18 | hparams = {} 19 | 20 | def _hparam(name, default_val, random_val_fn): 21 | """Define a hyperparameter. random_val_fn takes a RandomState and 22 | returns a random hyperparameter value.""" 23 | # assert (name not in hparams) 24 | random_state = np.random.RandomState( 25 | misc.seed_hash(random_seed, name) 26 | ) 27 | hparams[name] = (default_val, random_val_fn(random_state)) 28 | 29 | # Unconditional hparam definitions. 30 | 31 | _hparam('data_augmentation', True, lambda r: True) 32 | _hparam('resnet18', False, lambda r: False) 33 | _hparam('resnet_dropout', 0., lambda r: r.choice([0., 0.1, 0.5])) 34 | _hparam('class_balanced', False, lambda r: False) 35 | # TODO: nonlinear classifiers disabled 36 | _hparam('nonlinear_classifier', False, 37 | lambda r: bool(r.choice([False, False]))) 38 | hparams["optimizer"] = ("adam", "adam") 39 | 40 | hparams["val_augment"] = (False, False) # augmentation for in-domain validation set 41 | hparams["freeze_bn"] = (True, True) 42 | hparams["pretrained"] = (True, True) # only for ResNet 43 | # Algorithm-specific hparam definitions. Each block of code below 44 | # corresponds to exactly one algorithm. 45 | 46 | if algorithm in ['DANN', 'CDANN']: 47 | _hparam('lambda', 1.0, lambda r: 10 ** r.uniform(-2, 2)) 48 | _hparam('weight_decay_d', 0., lambda r: 10 ** r.uniform(-6, -2)) 49 | _hparam('d_steps_per_g_step', 1, lambda r: int(2 ** r.uniform(0, 3))) 50 | _hparam('grad_penalty', 0., lambda r: 10 ** r.uniform(-2, 1)) 51 | _hparam('beta1', 0.5, lambda r: r.choice([0., 0.5])) 52 | _hparam('mlp_width', 256, lambda r: int(2 ** r.uniform(6, 10))) 53 | _hparam('mlp_depth', 3, lambda r: int(r.choice([3, 4, 5]))) 54 | _hparam('mlp_dropout', 0., lambda r: r.choice([0., 0.1, 0.5])) 55 | 56 | elif algorithm == 'Fish': 57 | _hparam('meta_lr', 0.5, lambda r: r.choice([0.05, 0.1, 0.5])) 58 | 59 | elif algorithm == "RSC": 60 | _hparam('rsc_f_drop_factor', 1 / 3, lambda r: r.uniform(0, 0.5)) 61 | _hparam('rsc_b_drop_factor', 1 / 3, lambda r: r.uniform(0, 0.5)) 62 | 63 | elif algorithm == "SagNet": 64 | _hparam('sag_w_adv', 0.1, lambda r: 10 ** r.uniform(-2, 1)) 65 | 66 | elif algorithm == "IRM" or algorithm == 'IRM_IN21k': 67 | _hparam('irm_lambda', 1e2, lambda r: 10 ** r.uniform(-1, 5)) 68 | _hparam('irm_penalty_anneal_iters', 500, 69 | lambda r: int(10 ** r.uniform(0, 4))) 70 | 71 | elif algorithm == "Mixup": 72 | _hparam('mixup_alpha', 0.2, lambda r: 10 ** r.uniform(-1, -1)) 73 | 74 | elif algorithm == "GroupDRO": 75 | _hparam('groupdro_eta', 1e-2, lambda r: 10 ** r.uniform(-3, -1)) 76 | 77 | elif algorithm == "MMD" or algorithm == "CORAL": 78 | _hparam('mmd_gamma', 1., lambda r: 10 ** r.uniform(-1, 1)) 79 | 80 | elif algorithm == "MLDG": 81 | _hparam('mldg_beta', 1., lambda r: 10 ** r.uniform(-1, 1)) 82 | 83 | elif algorithm == "MTL": 84 | _hparam('mtl_ema', .99, lambda r: r.choice([0.5, 0.9, 0.99, 1.])) 85 | 86 | elif algorithm == "VREx": 87 | _hparam('vrex_lambda', 1e1, lambda r: 10 ** r.uniform(-1, 5)) 88 | _hparam('vrex_penalty_anneal_iters', 500, 89 | lambda r: int(10 ** r.uniform(0, 4))) 90 | 91 | elif algorithm == "SD": 92 | _hparam('sd_reg', 0.1, lambda r: 10 ** r.uniform(-5, -1)) 93 | 94 | elif algorithm == "ANDMask": 95 | _hparam('tau', 1, lambda r: r.uniform(0.5, 1.)) 96 | 97 | elif algorithm == "IGA": 98 | _hparam('penalty', 1000, lambda r: 10 ** r.uniform(1, 5)) 99 | 100 | elif algorithm == "SANDMask": 101 | _hparam('tau', 1.0, lambda r: r.uniform(0.0, 1.)) 102 | _hparam('k', 1e+1, lambda r: 10 ** r.uniform(-3, 5)) 103 | 104 | elif algorithm == "Fishr": 105 | _hparam('lambda', 1000., lambda r: 10 ** r.uniform(1., 4.)) 106 | _hparam('penalty_anneal_iters', 1500, lambda r: int(r.uniform(0., 5000.))) 107 | _hparam('ema', 0.95, lambda r: r.uniform(0.90, 0.99)) 108 | 109 | elif algorithm == "TRM": 110 | _hparam('cos_lambda', 1e-4, lambda r: 10 ** r.uniform(-5, 0)) 111 | _hparam('iters', 200, lambda r: int(10 ** r.uniform(0, 4))) 112 | _hparam('groupdro_eta', 1e-2, lambda r: 10 ** r.uniform(-3, -1)) 113 | 114 | elif algorithm == "IB_ERM": 115 | _hparam('ib_lambda', 1e2, lambda r: 10 ** r.uniform(-1, 5)) 116 | _hparam('ib_penalty_anneal_iters', 500, 117 | lambda r: int(10 ** r.uniform(0, 4))) 118 | 119 | elif algorithm == "IB_IRM": 120 | _hparam('irm_lambda', 1e2, lambda r: 10 ** r.uniform(-1, 5)) 121 | _hparam('irm_penalty_anneal_iters', 500, 122 | lambda r: int(10 ** r.uniform(0, 4))) 123 | _hparam('ib_lambda', 1e2, lambda r: 10 ** r.uniform(-1, 5)) 124 | _hparam('ib_penalty_anneal_iters', 500, 125 | lambda r: int(10 ** r.uniform(0, 4))) 126 | 127 | elif algorithm == "CAD" or algorithm == "CondCAD": 128 | _hparam('lmbda', 1e-1, lambda r: r.choice([1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2])) 129 | _hparam('temperature', 0.1, lambda r: r.choice([0.05, 0.1])) 130 | _hparam('is_normalized', False, lambda r: False) 131 | _hparam('is_project', False, lambda r: False) 132 | _hparam('is_flipped', True, lambda r: True) 133 | 134 | # Dataset-and-algorithm-specific hparam definitions. Each block of code 135 | # below corresponds to exactly one hparam. Avoid nested conditionals. 136 | 137 | if dataset in SMALL_IMAGES: 138 | _hparam('lr', 1e-3, lambda r: 10 ** r.uniform(-4.5, -2.5)) 139 | else: 140 | _hparam('lr', 3e-5, lambda r: 10 ** r.uniform(-5, -3.5)) 141 | 142 | if dataset in SMALL_IMAGES: 143 | _hparam('weight_decay', 0., lambda r: 0.) 144 | else: 145 | _hparam('weight_decay', 0., lambda r: 10 ** r.uniform(-6, -2)) 146 | 147 | if dataset in SMALL_IMAGES: 148 | _hparam('batch_size', 64, lambda r: int(2 ** r.uniform(3, 9))) 149 | elif algorithm == 'ARM': 150 | _hparam('batch_size', 8, lambda r: 8) 151 | elif dataset == 'DomainNet': 152 | _hparam('batch_size', 32, lambda r: int(2 ** r.uniform(3, 5))) 153 | else: 154 | _hparam('batch_size', 32, lambda r: int(2 ** r.uniform(3, 5.5))) 155 | 156 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 157 | _hparam('lr_g', 1e-3, lambda r: 10 ** r.uniform(-4.5, -2.5)) 158 | elif algorithm in ['DANN', 'CDANN']: 159 | _hparam('lr_g', 5e-5, lambda r: 10 ** r.uniform(-5, -3.5)) 160 | 161 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 162 | _hparam('lr_d', 1e-3, lambda r: 10 ** r.uniform(-4.5, -2.5)) 163 | elif algorithm in ['DANN', 'CDANN']: 164 | _hparam('lr_d', 5e-5, lambda r: 10 ** r.uniform(-5, -3.5)) 165 | 166 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 167 | _hparam('weight_decay_g', 0., lambda r: 0.) 168 | elif algorithm in ['DANN', 'CDANN']: 169 | _hparam('weight_decay_g', 0., lambda r: 10 ** r.uniform(-6, -2)) 170 | 171 | if 'GMOE' in algorithm: 172 | if dataset == 'VLCS': 173 | _hparam('lr', 3e-5, lambda r: 10 ** r.uniform(-4.5, -2.5)) 174 | _hparam('resnet_dropout', 0.5, lambda r: r.choice([0., 0.1, 0.5])) 175 | _hparam('weight_decay', 1e-6, lambda r: 0.) 176 | 177 | if dataset == 'PACS': 178 | _hparam('lr', 3e-5, lambda r: 10 ** r.uniform(-4.5, -2.5)) 179 | _hparam('resnet_dropout', 0.0, lambda r: r.choice([0., 0.1, 0.5])) 180 | _hparam('weight_decay', 1e-6, lambda r: 0.) 181 | 182 | if dataset == 'OfficeHome': 183 | _hparam('lr', 1e-5, lambda r: 10 ** r.uniform(-4.5, -2.5)) 184 | _hparam('resnet_dropout', 0.1, lambda r: r.choice([0., 0.1, 0.5])) 185 | _hparam('weight_decay', 1e-6, lambda r: 0.) 186 | 187 | if dataset == 'TerraIncognita': 188 | _hparam('lr', 5e-5, lambda r: 10 ** r.uniform(-4.5, -2.5)) 189 | _hparam('resnet_dropout', 0.0, lambda r: r.choice([0., 0.1, 0.5])) 190 | _hparam('weight_decay', 1e-4, lambda r: 0.) 191 | 192 | if dataset == 'DomainNet': 193 | _hparam('lr', 5e-5, lambda r: 10 ** r.uniform(-4.5, -2.5)) 194 | _hparam('resnet_dropout', 0.1, lambda r: r.choice([0., 0.1, 0.5])) 195 | _hparam('weight_decay', 0, lambda r: 0.) 196 | 197 | if dataset == 'CUB': 198 | _hparam('lr', 5e-5, lambda r: 10 ** r.uniform(-4.5, -2.5)) 199 | _hparam('resnet_dropout', 0.1, lambda r: r.choice([0., 0.1, 0.5])) 200 | _hparam('weight_decay', 0, lambda r: 0.) 201 | 202 | return hparams 203 | 204 | 205 | def default_hparams(algorithm, dataset): 206 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()} 207 | 208 | 209 | def random_hparams(algorithm, dataset, seed): 210 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()} 211 | -------------------------------------------------------------------------------- /exps/vis_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | import cv2 18 | import random 19 | import colorsys 20 | import requests 21 | from io import BytesIO 22 | import glob 23 | 24 | import skimage.io 25 | from skimage.measure import find_contours 26 | import matplotlib.pyplot as plt 27 | from matplotlib.patches import Polygon 28 | import torch 29 | import torch.nn as nn 30 | import torchvision 31 | from torchvision import transforms as pth_transforms 32 | import numpy as np 33 | from PIL import Image 34 | from collections import OrderedDict 35 | 36 | sys.path.append("/mnt/lustre/bli/projects/EIL") 37 | from domainbed import vision_transformer, vision_transformer_hybrid 38 | 39 | 40 | def apply_mask(image, mask, color, alpha=0.5): 41 | for c in range(3): 42 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 43 | return image 44 | 45 | 46 | def random_colors(N, bright=True): 47 | """ 48 | Generate random colors. 49 | """ 50 | brightness = 1.0 if bright else 0.7 51 | hsv = [(i / N, 1, brightness) for i in range(N)] 52 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 53 | random.shuffle(colors) 54 | return colors 55 | 56 | 57 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): 58 | fig = plt.figure(figsize=figsize, frameon=False) 59 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 60 | ax.set_axis_off() 61 | fig.add_axes(ax) 62 | ax = plt.gca() 63 | 64 | N = 1 65 | mask = mask[None, :, :] 66 | # Generate random colors 67 | colors = random_colors(N) 68 | 69 | # Show area outside image boundaries. 70 | height, width = image.shape[:2] 71 | margin = 0 72 | ax.set_ylim(height + margin, -margin) 73 | ax.set_xlim(-margin, width + margin) 74 | ax.axis('off') 75 | masked_image = image.astype(np.uint32).copy() 76 | for i in range(N): 77 | color = colors[i] 78 | _mask = mask[i] 79 | if blur: 80 | _mask = cv2.blur(_mask, (10, 10)) 81 | # Mask 82 | masked_image = apply_mask(masked_image, _mask, color, alpha) 83 | # Mask Polygon 84 | # Pad to ensure proper polygons for masks that touch image edges. 85 | if contour: 86 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) 87 | padded_mask[1:-1, 1:-1] = _mask 88 | contours = find_contours(padded_mask, 0.5) 89 | for verts in contours: 90 | # Subtract the padding and flip (y, x) to (x, y) 91 | verts = np.fliplr(verts) - 1 92 | p = Polygon(verts, facecolor="none", edgecolor=color) 93 | ax.add_patch(p) 94 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 95 | fig.savefig(fname) 96 | print(f"{fname} saved.") 97 | return 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser('Visualize Self-Attention maps') 102 | parser.add_argument('--arch', default='vit_small', type=str, 103 | choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).') 104 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 105 | parser.add_argument('--pretrained_weights', default='', type=str, 106 | help="Path to pretrained weights to load.") 107 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 108 | help='Key to use in the checkpoint (example: "teacher")') 109 | parser.add_argument("--image_path", default="/mnt/lustre/bli/data/domain_net/real", type=str, help="Path of the image to load.") 110 | parser.add_argument("--image_size", default=(224, 224), type=int, nargs="+", help="Resize image.") 111 | parser.add_argument('--output_dir', default='./attn_output_vit', help='Path where to save visualizations.') 112 | parser.add_argument("--threshold", type=float, default=None, help="""We visualize masks obtained by thresholding the self-attention maps to keep xx% of the mass.""") 113 | args = parser.parse_args() 114 | 115 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 116 | # build & load model 117 | model_path = '{project_path}/sweep/output/{exp_name}/d2c8a444c1472737722e9354afe0f994/model.pkl' 118 | model = vision_transformer.deit_small_patch16_224(pretrained=True, num_classes=0, moe_interval=24, num_experts=4, Hierachical=False).cuda() 119 | state_dict = torch.load(model_path)['model_dict'] 120 | only_weights = OrderedDict() 121 | for item in state_dict.keys(): 122 | if 'head' not in item: 123 | only_weights[item.replace('model.', '')] = state_dict[item] 124 | 125 | for p in model.parameters(): 126 | p.requires_grad = False 127 | 128 | # model.load_state_dict(only_weights, strict=False) 129 | model.eval() 130 | import pickle 131 | 132 | image_list = [] 133 | 134 | # for filename in glob.glob("/mnt/lustre/bli/data/domain_net/real/**/*.jpg"): 135 | # image_list.append(filename) 136 | # 137 | # random.shuffle(image_list) 138 | # image_list = image_list[:1000] 139 | # 140 | # with open('test_image_list.pkl', 'wb') as fp: 141 | # pickle.dump(image_list, fp) 142 | 143 | with open('test_image_list.pkl', 'rb') as fp: 144 | image_list = pickle.load(fp) 145 | 146 | for img_full_path in image_list: 147 | img_name = img_full_path.split('/')[-1] 148 | if img_full_path is None: 149 | # user has not specified any image - we use our own image 150 | print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.") 151 | print("Since no image path have been provided, we take the first image in our paper.") 152 | response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") 153 | img = Image.open(BytesIO(response.content)) 154 | img = img.convert('RGB') 155 | elif os.path.isfile(img_full_path): 156 | with open(img_full_path, 'rb') as f: 157 | img = Image.open(f) 158 | img = img.convert('RGB') 159 | else: 160 | print(f"Provided image path {img_full_path} is non valid.") 161 | sys.exit(1) 162 | transform = pth_transforms.Compose([ 163 | pth_transforms.Resize(args.image_size), 164 | pth_transforms.ToTensor(), 165 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 166 | ]) 167 | img = transform(img) 168 | 169 | # make the image divisible by the patch size 170 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size 171 | img = img[:, :w, :h].unsqueeze(0) 172 | 173 | w_featmap = img.shape[-2] // args.patch_size 174 | h_featmap = img.shape[-1] // args.patch_size 175 | 176 | attentions = model.get_last_selfattention(img.to(device)) 177 | 178 | nh = attentions.shape[1] # number of head 179 | 180 | # we keep only the output patch attention 181 | attentions = attentions[0, :, 0, 2:].reshape(nh, -1) 182 | 183 | if args.threshold is not None: 184 | # we keep only a certain percentage of the mass 185 | val, idx = torch.sort(attentions) 186 | val /= torch.sum(val, dim=1, keepdim=True) 187 | cumval = torch.cumsum(val, dim=1) 188 | th_attn = cumval > (1 - args.threshold) 189 | idx2 = torch.argsort(idx) 190 | for head in range(nh): 191 | th_attn[head] = th_attn[head][idx2[head]] 192 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 193 | # interpolate 194 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 195 | 196 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 197 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 198 | 199 | # save attentions heatmaps 200 | os.makedirs(args.output_dir, exist_ok=True) 201 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, img_name)) 202 | for j in range(nh): 203 | fname = os.path.join(args.output_dir, "{}_attn_head".format(img_name.replace('.jpg', '')) + str(j) + ".png") 204 | plt.imsave(fname=fname, arr=attentions[j], format='png') 205 | print(f"{fname} saved.") 206 | 207 | if args.threshold is not None: 208 | image = skimage.io.imread(img_full_path) 209 | from skimage.transform import rescale, resize, downscale_local_mean 210 | 211 | image = resize(image, (224, 224), anti_aliasing=True) 212 | for j in range(nh): 213 | display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "{}_mask_th".format(img_name.replace('.jpg', '')) + str(args.threshold) + "_head" + str(j) + ".png"), blur=False) 214 | -------------------------------------------------------------------------------- /domainbed/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torchvision.datasets import MNIST 4 | import xml.etree.ElementTree as ET 5 | from zipfile import ZipFile 6 | import argparse 7 | import tarfile 8 | import shutil 9 | import gdown 10 | import uuid 11 | import json 12 | import os 13 | 14 | from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 15 | from wilds.datasets.fmow_dataset import FMoWDataset 16 | 17 | 18 | # utils ####################################################################### 19 | 20 | def stage_path(data_dir, name): 21 | full_path = os.path.join(data_dir, name) 22 | 23 | if not os.path.exists(full_path): 24 | os.makedirs(full_path) 25 | 26 | return full_path 27 | 28 | 29 | def download_and_extract(url, dst, remove=True): 30 | gdown.download(url, dst, quiet=False) 31 | 32 | if dst.endswith(".tar.gz"): 33 | tar = tarfile.open(dst, "r:gz") 34 | tar.extractall(os.path.dirname(dst)) 35 | tar.close() 36 | 37 | if dst.endswith(".tar"): 38 | tar = tarfile.open(dst, "r:") 39 | tar.extractall(os.path.dirname(dst)) 40 | tar.close() 41 | 42 | if dst.endswith(".zip"): 43 | zf = ZipFile(dst, "r") 44 | zf.extractall(os.path.dirname(dst)) 45 | zf.close() 46 | 47 | if remove: 48 | os.remove(dst) 49 | 50 | 51 | # VLCS ######################################################################## 52 | 53 | # Slower, but builds dataset from the original sources 54 | # 55 | # def download_vlcs(data_dir): 56 | # full_path = stage_path(data_dir, "VLCS") 57 | # 58 | # tmp_path = os.path.join(full_path, "tmp/") 59 | # if not os.path.exists(tmp_path): 60 | # os.makedirs(tmp_path) 61 | # 62 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 63 | # lines = f.readlines() 64 | # files = [line.strip().split() for line in lines] 65 | # 66 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 67 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 68 | # 69 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 70 | # os.path.join(tmp_path, "caltech101.tar.gz")) 71 | # 72 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 73 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 74 | # 75 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 76 | # tar.extractall(tmp_path) 77 | # tar.close() 78 | # 79 | # for src, dst in files: 80 | # class_folder = os.path.join(data_dir, dst) 81 | # 82 | # if not os.path.exists(class_folder): 83 | # os.makedirs(class_folder) 84 | # 85 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 86 | # 87 | # if "labelme" in src: 88 | # # download labelme from the web 89 | # gdown.download(src, dst, quiet=False) 90 | # else: 91 | # src = os.path.join(tmp_path, src) 92 | # shutil.copyfile(src, dst) 93 | # 94 | # shutil.rmtree(tmp_path) 95 | 96 | 97 | def download_vlcs(data_dir): 98 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 99 | full_path = stage_path(data_dir, "VLCS") 100 | 101 | download_and_extract("https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 102 | os.path.join(data_dir, "VLCS.tar.gz")) 103 | 104 | 105 | # MNIST ####################################################################### 106 | 107 | def download_mnist(data_dir): 108 | # Original URL: http://yann.lecun.com/exdb/mnist/ 109 | full_path = stage_path(data_dir, "MNIST") 110 | MNIST(full_path, download=True) 111 | 112 | 113 | # PACS ######################################################################## 114 | 115 | def download_pacs(data_dir): 116 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 117 | full_path = stage_path(data_dir, "PACS") 118 | 119 | download_and_extract("https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", 120 | os.path.join(data_dir, "PACS.zip")) 121 | 122 | os.rename(os.path.join(data_dir, "kfold"), 123 | full_path) 124 | 125 | 126 | # Office-Home ################################################################# 127 | 128 | def download_office_home(data_dir): 129 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 130 | full_path = stage_path(data_dir, "office_home") 131 | 132 | download_and_extract("https://drive.google.com/uc?id=1uY0pj7oFsjMxRwaD3Sxy0jgel0fsYXLC", 133 | os.path.join(data_dir, "office_home.zip")) 134 | 135 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), 136 | full_path) 137 | 138 | 139 | # DomainNET ################################################################### 140 | 141 | def download_domain_net(data_dir): 142 | # Original URL: http://ai.bu.edu/M3SDA/ 143 | full_path = stage_path(data_dir, "domain_net") 144 | 145 | urls = [ 146 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 147 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 148 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 149 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 150 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 151 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip" 152 | ] 153 | 154 | for url in urls: 155 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 156 | 157 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 158 | for line in f.readlines(): 159 | try: 160 | os.remove(os.path.join(full_path, line.strip())) 161 | except OSError: 162 | pass 163 | 164 | 165 | # TerraIncognita ############################################################## 166 | 167 | def download_terra_incognita(data_dir): 168 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 169 | # New URL: http://lila.science/datasets/caltech-camera-traps 170 | 171 | full_path = stage_path(data_dir, "terra_incognita") 172 | 173 | download_and_extract( 174 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz", 175 | os.path.join(full_path, "terra_incognita_images.tar.gz")) 176 | 177 | download_and_extract( 178 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip", 179 | os.path.join(full_path, "caltech_camera_traps.json.zip")) 180 | 181 | include_locations = ["38", "46", "100", "43"] 182 | 183 | include_categories = [ 184 | "bird", "bobcat", "cat", "coyote", "dog", "empty", "opossum", "rabbit", 185 | "raccoon", "squirrel" 186 | ] 187 | 188 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 189 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json") 190 | destination_folder = full_path 191 | 192 | stats = {} 193 | 194 | if not os.path.exists(destination_folder): 195 | os.mkdir(destination_folder) 196 | 197 | with open(annotations_file, "r") as f: 198 | data = json.load(f) 199 | 200 | category_dict = {} 201 | for item in data['categories']: 202 | category_dict[item['id']] = item['name'] 203 | 204 | for image in data['images']: 205 | image_location = image['location'] 206 | 207 | if image_location not in include_locations: 208 | continue 209 | 210 | loc_folder = os.path.join(destination_folder, 211 | 'location_' + str(image_location) + '/') 212 | 213 | if not os.path.exists(loc_folder): 214 | os.mkdir(loc_folder) 215 | 216 | image_id = image['id'] 217 | image_fname = image['file_name'] 218 | 219 | for annotation in data['annotations']: 220 | if annotation['image_id'] == image_id: 221 | if image_location not in stats: 222 | stats[image_location] = {} 223 | 224 | category = category_dict[annotation['category_id']] 225 | 226 | if category not in include_categories: 227 | continue 228 | 229 | if category not in stats[image_location]: 230 | stats[image_location][category] = 0 231 | else: 232 | stats[image_location][category] += 1 233 | 234 | loc_cat_folder = os.path.join(loc_folder, category + '/') 235 | 236 | if not os.path.exists(loc_cat_folder): 237 | os.mkdir(loc_cat_folder) 238 | 239 | dst_path = os.path.join(loc_cat_folder, image_fname) 240 | src_path = os.path.join(images_folder, image_fname) 241 | 242 | shutil.copyfile(src_path, dst_path) 243 | 244 | shutil.rmtree(images_folder) 245 | os.remove(annotations_file) 246 | 247 | 248 | # SVIRO ################################################################# 249 | 250 | def download_sviro(data_dir): 251 | # Original URL: https://sviro.kl.dfki.de 252 | full_path = stage_path(data_dir, "sviro") 253 | 254 | download_and_extract("https://sviro.kl.dfki.de/?wpdmdl=1731", 255 | os.path.join(data_dir, "sviro_grayscale_rectangle_classification.zip")) 256 | 257 | os.rename(os.path.join(data_dir, "SVIRO_DOMAINBED"), 258 | full_path) 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = argparse.ArgumentParser(description='Download datasets') 263 | parser.add_argument('--data_dir', type=str, required=True) 264 | args = parser.parse_args() 265 | 266 | # download_mnist(args.data_dir) 267 | # download_pacs(args.data_dir) 268 | # download_office_home(args.data_dir) 269 | # download_domain_net(args.data_dir) 270 | # download_vlcs(args.data_dir) 271 | download_terra_incognita(args.data_dir) 272 | # download_sviro(args.data_dir) 273 | # Camelyon17Dataset(root_dir=args.data_dir, download=True) 274 | # FMoWDataset(root_dir=args.data_dir, download=True) 275 | -------------------------------------------------------------------------------- /domainbed/scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | 11 | sys.path.append("/mnt/lustre/bli/projects/EIL/domainbed") 12 | os.environ['WANDB_API_KEY'] = 'abc1859572354a66fc85b2ad1d1009add929cbfa' 13 | 14 | import wandb 15 | import PIL 16 | import numpy as np 17 | import torch 18 | import torch.utils.data 19 | import torchvision 20 | 21 | from domainbed import algorithms 22 | from domainbed import datasets 23 | from domainbed import hparams_registry 24 | from domainbed.lib import misc 25 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser(description='Domain generalization') 29 | parser.add_argument('--data_dir', type=str, default='/mnt/lustre/share/boli/domainbed_data') 30 | parser.add_argument('--dataset', type=str, default="RotatedMNIST") 31 | parser.add_argument('--algorithm', type=str, default="ERM") 32 | parser.add_argument('--task', type=str, default="domain_generalization", 33 | choices=["domain_generalization", "domain_adaptation"]) 34 | parser.add_argument('--hparams', type=str, 35 | help='JSON-serialized hparams dict') 36 | parser.add_argument('--hparams_seed', type=int, default=0, 37 | help='Seed for random hparams (0 means "default hparams")') 38 | parser.add_argument('--trial_seed', type=int, default=0, 39 | help='Trial number (used for seeding split_dataset and ' 40 | 'random_hparams).') 41 | parser.add_argument('--seed', type=int, default=0, 42 | help='Seed for everything else') 43 | parser.add_argument('--batch_size', type=int, default=None) 44 | parser.add_argument('--drop_out', type=float, default=None) 45 | parser.add_argument('--lr', type=float, default=None) 46 | parser.add_argument('--weight_decay', type=float, default=None) 47 | parser.add_argument('--steps', type=int, default=None, 48 | help='Number of steps. Default is dataset-dependent.') 49 | parser.add_argument('--checkpoint_freq', type=int, default=None, 50 | help='Checkpoint every N steps. Default is dataset-dependent.') 51 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 52 | parser.add_argument('--output_dir', type=str, default="train_output") 53 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 54 | parser.add_argument('--uda_holdout_fraction', type=float, default=0, 55 | help="For domain adaptation, % of test to use unlabeled for training.") 56 | parser.add_argument('--skip_model_save', action='store_true') 57 | parser.add_argument('--save_model_every_checkpoint', action='store_true') 58 | args = parser.parse_args() 59 | 60 | # If we ever want to implement checkpointing, just persist these values 61 | # every once in a while, and then load them from disk here. 62 | start_step = 0 63 | algorithm_dict = None 64 | 65 | os.makedirs(args.output_dir, exist_ok=True) 66 | sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) 67 | sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) 68 | # if "Debug" not in args.dataset: 69 | # wandb.init(project="sparse-moe", 70 | # entity='drluodian', 71 | # config={'dataset': args.dataset, 72 | # 'algorithm': args.algorithm, 73 | # 'test_envs': args.test_envs}, 74 | # settings=wandb.Settings(start_method="fork")) 75 | # wandb.init(settings=wandb.Settings(start_method='thread')) 76 | # print("Environment:") 77 | # print("\tPython: {}".format(sys.version.split(" ")[0])) 78 | # print("\tPyTorch: {}".format(torch.__version__)) 79 | # print("\tTorchvision: {}".format(torchvision.__version__)) 80 | # print("\tCUDA: {}".format(torch.version.cuda)) 81 | # print("\tCUDNN: {}".format(torch.backends.cudnn.version())) 82 | # print("\tNumPy: {}".format(np.__version__)) 83 | # print("\tPIL: {}".format(PIL.__version__)) 84 | # 85 | # print('Args:') 86 | # for k, v in sorted(vars(args).items()): 87 | # print('\t{}: {}'.format(k, v)) 88 | 89 | if args.hparams_seed == 0: 90 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 91 | else: 92 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 93 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 94 | if args.hparams: 95 | hparams.update(json.loads(args.hparams)) 96 | 97 | if args.batch_size is not None: 98 | hparams['batch_size'] = args.batch_size 99 | if args.drop_out is not None: 100 | hparams['drop_out'] = args.drop_out 101 | if args.lr is not None: 102 | hparams['lr'] = args.lr 103 | if args.weight_decay is not None: 104 | hparams['weight_decay'] = args.weight_decay 105 | 106 | # print('HParams:') 107 | # for k, v in sorted(hparams.items()): 108 | # print('\t{}: {}'.format(k, v)) 109 | 110 | random.seed(args.seed) 111 | np.random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | torch.backends.cudnn.deterministic = True 114 | torch.backends.cudnn.benchmark = False 115 | 116 | if torch.cuda.is_available(): 117 | device = "cuda" 118 | else: 119 | device = "cpu" 120 | 121 | if args.dataset in vars(datasets): 122 | dataset = vars(datasets)[args.dataset](args.data_dir, 123 | args.test_envs, hparams) 124 | else: 125 | raise NotImplementedError 126 | 127 | # Split each env into an 'in-split' and an 'out-split'. We'll train on 128 | # each in-split except the test envs, and evaluate on all splits. 129 | 130 | # To allow unsupervised domain adaptation experiments, we split each test 131 | # env into 'in-split', 'uda-split' and 'out-split'. The 'in-split' is used 132 | # by collect_results.py to compute classification accuracies. The 133 | # 'out-split' is used by the Oracle model selectino method. The unlabeled 134 | # samples in 'uda-split' are passed to the algorithm at training time if 135 | # args.task == "domain_adaptation". If we are interested in comparing 136 | # domain generalization and domain adaptation results, then domain 137 | # generalization algorithms should create the same 'uda-splits', which will 138 | # be discared at training. 139 | in_splits = [] 140 | out_splits = [] 141 | uda_splits = [] 142 | for env_i, env in enumerate(dataset): 143 | uda = [] 144 | 145 | out, in_ = misc.split_dataset(env, int(len(env) * args.holdout_fraction), misc.seed_hash(args.trial_seed, env_i)) 146 | if env_i in args.test_envs: 147 | uda, in_ = misc.split_dataset(in_, int(len(in_) * args.uda_holdout_fraction), misc.seed_hash(args.trial_seed, env_i)) 148 | 149 | if hparams['class_balanced']: 150 | in_weights = misc.make_weights_for_balanced_classes(in_) 151 | out_weights = misc.make_weights_for_balanced_classes(out) 152 | if uda is not None: 153 | uda_weights = misc.make_weights_for_balanced_classes(uda) 154 | else: 155 | in_weights, out_weights, uda_weights = None, None, None 156 | in_splits.append((in_, in_weights)) 157 | out_splits.append((out, out_weights)) 158 | if len(uda): 159 | uda_splits.append((uda, uda_weights)) 160 | 161 | if args.task == "domain_adaptation" and len(uda_splits) == 0: 162 | raise ValueError("Not enough unlabeled samples for domain adaptation.") 163 | 164 | train_loaders = [InfiniteDataLoader( 165 | dataset=env, 166 | weights=env_weights, 167 | batch_size=hparams['batch_size'], 168 | num_workers=dataset.N_WORKERS) 169 | for i, (env, env_weights) in enumerate(in_splits) 170 | if i not in args.test_envs] 171 | 172 | uda_loaders = [InfiniteDataLoader( 173 | dataset=env, 174 | weights=env_weights, 175 | batch_size=hparams['batch_size'], 176 | num_workers=dataset.N_WORKERS) 177 | for i, (env, env_weights) in enumerate(uda_splits) 178 | if i in args.test_envs] 179 | 180 | eval_loaders = [FastDataLoader( 181 | dataset=env, 182 | batch_size=64, 183 | num_workers=dataset.N_WORKERS) 184 | for env, _ in (in_splits + out_splits + uda_splits)] 185 | eval_weights = [None for _, weights in (in_splits + out_splits + uda_splits)] 186 | eval_loader_names = ['env{}_in'.format(i) 187 | for i in range(len(in_splits))] 188 | eval_loader_names += ['env{}_out'.format(i) 189 | for i in range(len(out_splits))] 190 | eval_loader_names += ['env{}_uda'.format(i) 191 | for i in range(len(uda_splits))] 192 | 193 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 194 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 195 | len(dataset) - len(args.test_envs), hparams) 196 | 197 | if algorithm_dict is not None: 198 | algorithm.load_state_dict(algorithm_dict) 199 | 200 | algorithm.to(device) 201 | 202 | train_minibatches_iterator = zip(*train_loaders) 203 | uda_minibatches_iterator = zip(*uda_loaders) 204 | checkpoint_vals = collections.defaultdict(lambda: []) 205 | 206 | steps_per_epoch = min([len(env) / hparams['batch_size'] for env, _ in in_splits]) 207 | 208 | n_steps = args.steps or dataset.N_STEPS 209 | checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ 210 | 211 | 212 | def save_checkpoint(filename): 213 | if args.skip_model_save: 214 | return 215 | save_dict = { 216 | "args": vars(args), 217 | "model_input_shape": dataset.input_shape, 218 | "model_num_classes": dataset.num_classes, 219 | "model_num_domains": len(dataset) - len(args.test_envs), 220 | "model_hparams": hparams, 221 | "model_dict": algorithm.state_dict() 222 | } 223 | torch.save(save_dict, os.path.join(args.output_dir, filename)) 224 | 225 | 226 | last_results_keys = None 227 | for step in range(start_step, n_steps): 228 | step_start_time = time.time() 229 | minibatches_device = [(x.to(device), y.to(device)) for x, y in next(train_minibatches_iterator)] 230 | if args.task == "domain_adaptation": 231 | uda_device = [x.to(device) for x, _ in next(uda_minibatches_iterator)] 232 | else: 233 | uda_device = None 234 | step_vals = algorithm.update(minibatches_device) 235 | checkpoint_vals['step_time'].append(time.time() - step_start_time) 236 | 237 | for key, val in step_vals.items(): 238 | checkpoint_vals[key].append(val) 239 | 240 | if (step % checkpoint_freq == 0) or (step == n_steps - 1): 241 | results = { 242 | 'step': step, 243 | 'epoch': step / steps_per_epoch, 244 | } 245 | 246 | for key, val in checkpoint_vals.items(): 247 | results[key] = np.mean(val) 248 | 249 | evals = zip(eval_loader_names, eval_loaders, eval_weights) 250 | for name, loader, weights in evals: 251 | acc = misc.accuracy(algorithm, loader, weights, device) 252 | results[name + '_acc'] = acc 253 | 254 | results['algorithm'] = args.algorithm 255 | results['dataset'] = args.dataset 256 | results['test_envs'] = args.test_envs 257 | results['mem_gb'] = torch.cuda.max_memory_allocated() / (1024. * 1024. * 1024.) 258 | 259 | results_keys = sorted(results.keys()) 260 | if results_keys != last_results_keys: 261 | misc.print_row(results_keys, colwidth=12) 262 | last_results_keys = results_keys 263 | misc.print_row([results[key] for key in results_keys], 264 | colwidth=12) 265 | 266 | # if wandb.run: 267 | # wandb.log(results) 268 | results.update({ 269 | 'hparams': hparams, 270 | 'args': vars(args) 271 | }) 272 | 273 | epochs_path = os.path.join(args.output_dir, 'results.jsonl') 274 | with open(epochs_path, 'a') as f: 275 | f.write(json.dumps(results, sort_keys=True) + "\n") 276 | 277 | algorithm_dict = algorithm.state_dict() 278 | start_step = step + 1 279 | checkpoint_vals = collections.defaultdict(lambda: []) 280 | 281 | if args.save_model_every_checkpoint: 282 | save_checkpoint(f'model_step{step}.pkl') 283 | 284 | save_checkpoint('model.pkl') 285 | 286 | with open(os.path.join(args.output_dir, 'done'), 'w') as f: 287 | f.write('done') 288 | -------------------------------------------------------------------------------- /domainbed/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | 5 | import torch 6 | import torchvision.datasets.folder 7 | from PIL import Image, ImageFile 8 | from torch.utils.data import TensorDataset 9 | from torchvision import transforms 10 | from torchvision.datasets import MNIST, ImageFolder 11 | from torchvision.transforms.functional import rotate 12 | from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 13 | from wilds.datasets.fmow_dataset import FMoWDataset 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | DATASETS = [ 18 | # Debug 19 | "Debug28", 20 | "Debug224", 21 | # Small images 22 | "ColoredMNIST", 23 | "RotatedMNIST", 24 | # Big images 25 | "CUB", 26 | "VLCS", 27 | "PACS", 28 | "OfficeHome", 29 | "TerraIncognita", 30 | "DomainNet", 31 | "SVIRO", 32 | # WILDS datasets 33 | "WILDSCamelyon", 34 | "WILDSFMoW" 35 | ] 36 | 37 | 38 | def get_dataset_class(dataset_name): 39 | """Return the dataset class with the given name.""" 40 | if dataset_name not in globals(): 41 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 42 | return globals()[dataset_name] 43 | 44 | 45 | def num_environments(dataset_name): 46 | return len(get_dataset_class(dataset_name).ENVIRONMENTS) 47 | 48 | 49 | class MultipleDomainDataset: 50 | N_STEPS = 5001 # Default, subclasses may override 51 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 52 | N_WORKERS = 8 # Default, subclasses may override 53 | ENVIRONMENTS = None # Subclasses should override 54 | INPUT_SHAPE = None # Subclasses should override 55 | 56 | def __getitem__(self, index): 57 | return self.datasets[index] 58 | 59 | def __len__(self): 60 | return len(self.datasets) 61 | 62 | 63 | class Debug(MultipleDomainDataset): 64 | def __init__(self, root, test_envs, hparams): 65 | super().__init__() 66 | self.input_shape = self.INPUT_SHAPE 67 | self.num_classes = 2 68 | self.datasets = [] 69 | for _ in [0, 1, 2]: 70 | self.datasets.append( 71 | TensorDataset( 72 | torch.randn(16, *self.INPUT_SHAPE), 73 | torch.randint(0, self.num_classes, (16,)) 74 | ) 75 | ) 76 | 77 | 78 | class Debug28(Debug): 79 | N_WORKERS = 0 80 | INPUT_SHAPE = (3, 28, 28) 81 | ENVIRONMENTS = ['0', '1', '2'] 82 | 83 | 84 | class Debug224(Debug): 85 | N_WORKERS = 0 86 | INPUT_SHAPE = (3, 224, 224) 87 | ENVIRONMENTS = ['0', '1', '2'] 88 | 89 | 90 | class MultipleEnvironmentMNIST(MultipleDomainDataset): 91 | def __init__(self, root, environments, dataset_transform, input_shape, 92 | num_classes): 93 | super().__init__() 94 | if root is None: 95 | raise ValueError('Data directory not specified!') 96 | 97 | original_dataset_tr = MNIST(root, train=True, download=True) 98 | original_dataset_te = MNIST(root, train=False, download=True) 99 | 100 | original_images = torch.cat((original_dataset_tr.data, 101 | original_dataset_te.data)) 102 | 103 | original_labels = torch.cat((original_dataset_tr.targets, 104 | original_dataset_te.targets)) 105 | 106 | shuffle = torch.randperm(len(original_images)) 107 | 108 | original_images = original_images[shuffle] 109 | original_labels = original_labels[shuffle] 110 | 111 | self.datasets = [] 112 | 113 | for i in range(len(environments)): 114 | images = original_images[i::len(environments)] 115 | labels = original_labels[i::len(environments)] 116 | self.datasets.append(dataset_transform(images, labels, environments[i])) 117 | 118 | self.input_shape = input_shape 119 | self.num_classes = num_classes 120 | 121 | 122 | class ColoredMNIST(MultipleEnvironmentMNIST): 123 | ENVIRONMENTS = ['+90%', '+80%', '-90%'] 124 | 125 | def __init__(self, root, test_envs, hparams): 126 | super(ColoredMNIST, self).__init__(root, [0.1, 0.2, 0.9], 127 | self.color_dataset, (2, 28, 28,), 2) 128 | 129 | self.input_shape = (2, 28, 28,) 130 | self.num_classes = 2 131 | 132 | def color_dataset(self, images, labels, environment): 133 | # # Subsample 2x for computational convenience 134 | # images = images.reshape((-1, 28, 28))[:, ::2, ::2] 135 | # Assign a binary label based on the digit 136 | labels = (labels < 5).float() 137 | # Flip label with probability 0.25 138 | labels = self.torch_xor_(labels, 139 | self.torch_bernoulli_(0.25, len(labels))) 140 | 141 | # Assign a color based on the label; flip the color with probability e 142 | colors = self.torch_xor_(labels, 143 | self.torch_bernoulli_(environment, 144 | len(labels))) 145 | images = torch.stack([images, images], dim=1) 146 | # Apply the color to the image by zeroing out the other color channel 147 | images[torch.tensor(range(len(images))), ( 148 | 1 - colors).long(), :, :] *= 0 149 | 150 | x = images.float().div_(255.0) 151 | y = labels.view(-1).long() 152 | 153 | return TensorDataset(x, y) 154 | 155 | def torch_bernoulli_(self, p, size): 156 | return (torch.rand(size) < p).float() 157 | 158 | def torch_xor_(self, a, b): 159 | return (a - b).abs() 160 | 161 | 162 | class RotatedMNIST(MultipleEnvironmentMNIST): 163 | ENVIRONMENTS = ['0', '15', '30', '45', '60', '75'] 164 | 165 | def __init__(self, root, test_envs, hparams): 166 | super(RotatedMNIST, self).__init__(root, [0, 15, 30, 45, 60, 75], 167 | self.rotate_dataset, (1, 28, 28,), 10) 168 | 169 | def rotate_dataset(self, images, labels, angle): 170 | rotation = transforms.Compose([ 171 | transforms.ToPILImage(), 172 | transforms.Lambda(lambda x: rotate(x, angle, fill=(0,), 173 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR)), 174 | transforms.ToTensor()]) 175 | 176 | x = torch.zeros(len(images), 1, 28, 28) 177 | for i in range(len(images)): 178 | x[i] = rotation(images[i]) 179 | 180 | y = labels.view(-1) 181 | 182 | return TensorDataset(x, y) 183 | 184 | 185 | class MultipleEnvironmentImageFolder(MultipleDomainDataset): 186 | def __init__(self, root, test_envs, augment, hparams): 187 | super().__init__() 188 | environments = [f.name for f in os.scandir(root) if f.is_dir()] 189 | environments = sorted(environments) 190 | 191 | transform = transforms.Compose([ 192 | transforms.Resize((224, 224)), 193 | transforms.ToTensor(), 194 | transforms.Normalize( 195 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 196 | ]) 197 | 198 | augment_transform = transforms.Compose([ 199 | # transforms.Resize((224,224)), 200 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 201 | transforms.RandomHorizontalFlip(), 202 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 203 | transforms.RandomGrayscale(p=0.1), 204 | transforms.ToTensor(), 205 | transforms.Normalize( 206 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 207 | ]) 208 | 209 | self.datasets = [] 210 | for i, environment in enumerate(environments): 211 | 212 | if augment and (i not in test_envs): 213 | env_transform = augment_transform 214 | else: 215 | env_transform = transform 216 | 217 | path = os.path.join(root, environment) 218 | env_dataset = ImageFolder(path, transform=env_transform) 219 | 220 | self.datasets.append(env_dataset) 221 | 222 | self.input_shape = (3, 224, 224,) 223 | self.num_classes = len(self.datasets[-1].classes) 224 | 225 | 226 | class CUB(MultipleEnvironmentImageFolder): 227 | CHECKPOINT_FREQ = 300 228 | ENVIRONMENTS = ["Candy", "Mosaic", "Natural", "Udnie"] 229 | 230 | def __init__(self, root, test_envs, hparams): 231 | self.dir = os.path.join(root, "CUB_DG/") 232 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 233 | 234 | 235 | class VLCS(MultipleEnvironmentImageFolder): 236 | CHECKPOINT_FREQ = 300 237 | ENVIRONMENTS = ["C", "L", "S", "V"] 238 | 239 | def __init__(self, root, test_envs, hparams): 240 | self.dir = os.path.join(root, "VLCS/") 241 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 242 | 243 | 244 | class PACS(MultipleEnvironmentImageFolder): 245 | CHECKPOINT_FREQ = 300 246 | ENVIRONMENTS = ["A", "C", "P", "S"] 247 | 248 | def __init__(self, root, test_envs, hparams): 249 | self.dir = os.path.join(root, "PACS/") 250 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 251 | 252 | 253 | class DomainNet(MultipleEnvironmentImageFolder): 254 | CHECKPOINT_FREQ = 500 255 | N_STEPS = 15001 256 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"] 257 | 258 | def __init__(self, root, test_envs, hparams): 259 | self.dir = os.path.join(root, "domain_net/") 260 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 261 | 262 | 263 | class OfficeHome(MultipleEnvironmentImageFolder): 264 | CHECKPOINT_FREQ = 300 265 | ENVIRONMENTS = ["A", "C", "P", "R"] 266 | 267 | def __init__(self, root, test_envs, hparams): 268 | self.dir = os.path.join(root, "office_home/") 269 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 270 | 271 | 272 | class TerraIncognita(MultipleEnvironmentImageFolder): 273 | # may need larger weight decay 274 | CHECKPOINT_FREQ = 300 275 | ENVIRONMENTS = ["L100", "L38", "L43", "L46"] 276 | 277 | def __init__(self, root, test_envs, hparams): 278 | self.dir = os.path.join(root, "terra_incognita/") 279 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 280 | 281 | 282 | class SVIRO(MultipleEnvironmentImageFolder): 283 | CHECKPOINT_FREQ = 300 284 | ENVIRONMENTS = ["aclass", "escape", "hilux", "i3", "lexus", "tesla", "tiguan", "tucson", "x5", "zoe"] 285 | 286 | def __init__(self, root, test_envs, hparams): 287 | self.dir = os.path.join(root, "sviro/") 288 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 289 | 290 | 291 | class WILDSEnvironment: 292 | def __init__( 293 | self, 294 | wilds_dataset, 295 | metadata_name, 296 | metadata_value, 297 | transform=None): 298 | self.name = metadata_name + "_" + str(metadata_value) 299 | 300 | metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 301 | metadata_array = wilds_dataset.metadata_array 302 | subset_indices = torch.where( 303 | metadata_array[:, metadata_index] == metadata_value)[0] 304 | 305 | self.dataset = wilds_dataset 306 | self.indices = subset_indices 307 | self.transform = transform 308 | 309 | def __getitem__(self, i): 310 | x = self.dataset.get_input(self.indices[i]) 311 | if type(x).__name__ != "Image": 312 | x = Image.fromarray(x) 313 | 314 | y = self.dataset.y_array[self.indices[i]] 315 | if self.transform is not None: 316 | x = self.transform(x) 317 | return x, y 318 | 319 | def __len__(self): 320 | return len(self.indices) 321 | 322 | 323 | class WILDSDataset(MultipleDomainDataset): 324 | INPUT_SHAPE = (3, 224, 224) 325 | 326 | def __init__(self, dataset, metadata_name, test_envs, augment, hparams): 327 | super().__init__() 328 | 329 | transform = transforms.Compose([ 330 | transforms.Resize((224, 224)), 331 | transforms.ToTensor(), 332 | transforms.Normalize( 333 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 334 | ]) 335 | 336 | augment_transform = transforms.Compose([ 337 | transforms.Resize((224, 224)), 338 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 339 | transforms.RandomHorizontalFlip(), 340 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 341 | transforms.RandomGrayscale(), 342 | transforms.ToTensor(), 343 | transforms.Normalize( 344 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 345 | ]) 346 | 347 | self.datasets = [] 348 | 349 | for i, metadata_value in enumerate( 350 | self.metadata_values(dataset, metadata_name)): 351 | if augment and (i not in test_envs): 352 | env_transform = augment_transform 353 | else: 354 | env_transform = transform 355 | 356 | env_dataset = WILDSEnvironment( 357 | dataset, metadata_name, metadata_value, env_transform) 358 | 359 | self.datasets.append(env_dataset) 360 | 361 | self.input_shape = (3, 224, 224,) 362 | self.num_classes = dataset.n_classes 363 | 364 | def metadata_values(self, wilds_dataset, metadata_name): 365 | metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 366 | metadata_vals = wilds_dataset.metadata_array[:, metadata_index] 367 | return sorted(list(set(metadata_vals.view(-1).tolist()))) 368 | 369 | 370 | class WILDSCamelyon(WILDSDataset): 371 | ENVIRONMENTS = ["hospital_0", "hospital_1", "hospital_2", "hospital_3", 372 | "hospital_4"] 373 | 374 | def __init__(self, root, test_envs, hparams): 375 | dataset = Camelyon17Dataset(root_dir=root) 376 | super().__init__( 377 | dataset, "hospital", test_envs, hparams['data_augmentation'], hparams) 378 | 379 | 380 | class WILDSFMoW(WILDSDataset): 381 | ENVIRONMENTS = ["region_0", "region_1", "region_2", "region_3", 382 | "region_4", "region_5"] 383 | 384 | def __init__(self, root, test_envs, hparams): 385 | dataset = FMoWDataset(root_dir=root) 386 | super().__init__( 387 | dataset, "region", test_envs, hparams['data_augmentation'], hparams) 388 | -------------------------------------------------------------------------------- /domainbed/algorithms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import sys 4 | from itertools import chain 5 | 6 | import timm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import wandb 11 | import torch.autograd as autograd 12 | from domainbed.lib.misc import ( 13 | random_pairs_of_minibatches, ParamDict, MovingAverage, l2_between_dicts 14 | ) 15 | from copy import deepcopy 16 | import copy 17 | 18 | sys.path.append('/mnt/lustre/bli/projects/EIL/domainbed') 19 | import vision_transformer, vision_transformer_hybrid 20 | from collections import defaultdict, OrderedDict 21 | 22 | try: 23 | from backpack import backpack, extend 24 | from backpack.extensions import BatchGrad 25 | except: 26 | backpack = None 27 | 28 | from domainbed import networks 29 | from domainbed import resnet_variants 30 | import torchvision.models as models 31 | 32 | ALGORITHMS = [ 33 | 'ERM', 34 | 'Fish', 35 | 'IRM', 36 | 'GroupDRO', 37 | 'Mixup', 38 | 'MLDG', 39 | 'CORAL', 40 | 'MMD', 41 | 'DANN', 42 | 'CDANN', 43 | 'MTL', 44 | 'SagNet', 45 | 'ARM', 46 | 'VREx', 47 | 'RSC', 48 | 'SD', 49 | 'ANDMask', 50 | 'SANDMask', 51 | 'IGA', 52 | 'SelfReg', 53 | "Fishr", 54 | 'TRM', 55 | 'IB_ERM', 56 | 'IB_IRM', 57 | 'CAD', 58 | 'CondCAD', 59 | 'GMOE' 60 | ] 61 | 62 | 63 | def get_algorithm_class(algorithm_name): 64 | """Return the algorithm class with the given name.""" 65 | if algorithm_name not in globals(): 66 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 67 | return globals()[algorithm_name] 68 | 69 | 70 | class Algorithm(torch.nn.Module): 71 | """ 72 | A subclass of Algorithm implements a domain generalization algorithm. 73 | Subclasses should implement the following: 74 | - update() 75 | - predict() 76 | """ 77 | transforms = {} 78 | 79 | def __init__(self, input_shape, num_classes, num_domains, hparams): 80 | super(Algorithm, self).__init__() 81 | self.hparams = hparams 82 | 83 | def update(self, minibatches, unlabeled=None): 84 | """ 85 | Perform one update step, given a list of (x, y) tuples for all 86 | environments. 87 | 88 | Admits an optional list of unlabeled minibatches from the test domains, 89 | when task is domain_adaptation. 90 | """ 91 | raise NotImplementedError 92 | 93 | def predict(self, x): 94 | raise NotImplementedError 95 | 96 | 97 | class MovingAvg: 98 | def __init__(self, network): 99 | self.network = network 100 | self.network_sma = copy.deepcopy(network) 101 | self.network_sma.eval() 102 | self.sma_start_iter = 100 103 | self.global_iter = 0 104 | self.sma_count = 0 105 | 106 | def update_sma(self): 107 | self.global_iter += 1 108 | if self.global_iter >= self.sma_start_iter: 109 | self.sma_count += 1 110 | for param_q, param_k in zip(self.network.parameters(), self.network_sma.parameters()): 111 | param_k.data = (param_k.data * self.sma_count + param_q.data) / (1. + self.sma_count) 112 | else: 113 | for param_q, param_k in zip(self.network.parameters(), self.network_sma.parameters()): 114 | param_k.data = param_q.data 115 | 116 | 117 | class ERM_SMA(Algorithm, MovingAvg): 118 | """ 119 | Empirical Risk Minimization (ERM) with Simple Moving Average (SMA) prediction model 120 | """ 121 | 122 | def __init__(self, input_shape, num_classes, num_domains, hparams): 123 | Algorithm.__init__(self, input_shape, num_classes, num_domains, hparams) 124 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 125 | self.classifier = networks.Classifier( 126 | self.featurizer.n_outputs, 127 | num_classes, 128 | self.hparams['nonlinear_classifier']) 129 | self.network = nn.Sequential(self.featurizer, self.classifier) 130 | self.optimizer = torch.optim.Adam( 131 | self.network.parameters(), 132 | lr=self.hparams["lr"], 133 | weight_decay=self.hparams['weight_decay'] 134 | ) 135 | MovingAvg.__init__(self, self.network) 136 | 137 | def update(self, minibatches, unlabeled=None): 138 | all_x = torch.cat([x for x, y in minibatches]) 139 | all_y = torch.cat([y for x, y in minibatches]) 140 | loss = F.cross_entropy(self.network(all_x), all_y) 141 | self.optimizer.zero_grad() 142 | loss.backward() 143 | self.optimizer.step() 144 | self.update_sma() 145 | return {'loss': loss.item()} 146 | 147 | def predict(self, x): 148 | self.network_sma.eval() 149 | return self.network_sma(x) 150 | 151 | 152 | class ERM(Algorithm): 153 | """ 154 | Empirical Risk Minimization (ERM) 155 | """ 156 | 157 | def __init__(self, input_shape, num_classes, num_domains, hparams): 158 | super(ERM, self).__init__(input_shape, num_classes, num_domains, 159 | hparams) 160 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 161 | self.classifier = networks.Classifier( 162 | self.featurizer.n_outputs, 163 | num_classes, 164 | self.hparams['nonlinear_classifier']) 165 | 166 | self.network = nn.Sequential(self.featurizer, self.classifier).cuda() 167 | self.optimizer = torch.optim.Adam( 168 | self.network.parameters(), 169 | lr=self.hparams["lr"], 170 | weight_decay=self.hparams['weight_decay'] 171 | ) 172 | 173 | def update(self, minibatches, unlabeled=None): 174 | all_x = torch.cat([x for x, y in minibatches]) 175 | all_y = torch.cat([y for x, y in minibatches]) 176 | loss = F.cross_entropy(self.predict(all_x), all_y) 177 | 178 | self.optimizer.zero_grad() 179 | loss.backward() 180 | self.optimizer.step() 181 | 182 | return {'loss': loss.item()} 183 | 184 | def predict(self, x): 185 | return self.network(x) 186 | 187 | 188 | class GMOE(Algorithm): 189 | """ 190 | SFMOE 191 | """ 192 | 193 | def __init__(self, input_shape, num_classes, num_domains, hparams): 194 | super(GMOE, self).__init__(input_shape, num_classes, num_domains, hparams) 195 | self.model = vision_transformer.deit_small_patch16_224(pretrained=True, num_classes=num_classes, moe_layers=['F'] * 8 + ['S', 'F'] * 2, mlp_ratio=4., num_experts=6, is_tutel=True, drop_path_rate=0.1, router='cosine_top').cuda() 196 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams["lr"], weight_decay=self.hparams['weight_decay']) 197 | 198 | def update(self, minibatches, unlabeled=None): 199 | all_x = torch.cat([x for x, y in minibatches]) 200 | all_y = torch.cat([y for x, y in minibatches]) 201 | loss = F.cross_entropy(self.predict(all_x), all_y) 202 | loss_aux_list = [] 203 | for block in self.model.blocks: 204 | if getattr(block, 'aux_loss') is not None: 205 | loss_aux_list.append(block.aux_loss) 206 | 207 | loss_aux = 0 208 | for layer_loss in loss_aux_list: 209 | loss_aux += layer_loss 210 | 211 | loss += loss_aux 212 | self.optimizer.zero_grad() 213 | loss.backward() 214 | self.optimizer.step() 215 | 216 | return {'loss': loss.item(), 'loss_aux': loss_aux.item()} 217 | 218 | def predict(self, x, forward_feature=False): 219 | if forward_feature: 220 | return self.model.forward_features(x) 221 | else: 222 | prediction = self.model(x) 223 | if type(prediction) is tuple: 224 | return (prediction[0] + prediction[1]) / 2 225 | else: 226 | return prediction 227 | 228 | 229 | class Fish(Algorithm): 230 | """ 231 | Implementation of Fish, as seen in Gradient Matching for Domain 232 | Generalization, Shi et al. 2021. 233 | """ 234 | 235 | def __init__(self, input_shape, num_classes, num_domains, hparams): 236 | super(Fish, self).__init__(input_shape, num_classes, num_domains, 237 | hparams) 238 | self.input_shape = input_shape 239 | self.num_classes = num_classes 240 | 241 | self.network = networks.WholeFish(input_shape, num_classes, hparams) 242 | self.optimizer = torch.optim.Adam( 243 | self.network.parameters(), 244 | lr=self.hparams["lr"], 245 | weight_decay=self.hparams['weight_decay'] 246 | ) 247 | self.optimizer_inner_state = None 248 | 249 | def create_clone(self, device): 250 | self.network_inner = networks.WholeFish(self.input_shape, self.num_classes, self.hparams, 251 | weights=self.network.state_dict()).to(device) 252 | self.optimizer_inner = torch.optim.Adam( 253 | self.network_inner.parameters(), 254 | lr=self.hparams["lr"], 255 | weight_decay=self.hparams['weight_decay'] 256 | ) 257 | if self.optimizer_inner_state is not None: 258 | self.optimizer_inner.load_state_dict(self.optimizer_inner_state) 259 | 260 | def fish(self, meta_weights, inner_weights, lr_meta): 261 | meta_weights = ParamDict(meta_weights) 262 | inner_weights = ParamDict(inner_weights) 263 | meta_weights += lr_meta * (inner_weights - meta_weights) 264 | return meta_weights 265 | 266 | def update(self, minibatches, unlabeled=None): 267 | self.create_clone(minibatches[0][0].device) 268 | 269 | for x, y in minibatches: 270 | loss = F.cross_entropy(self.network_inner(x), y) 271 | self.optimizer_inner.zero_grad() 272 | loss.backward() 273 | self.optimizer_inner.step() 274 | 275 | self.optimizer_inner_state = self.optimizer_inner.state_dict() 276 | meta_weights = self.fish( 277 | meta_weights=self.network.state_dict(), 278 | inner_weights=self.network_inner.state_dict(), 279 | lr_meta=self.hparams["meta_lr"] 280 | ) 281 | self.network.reset_weights(meta_weights) 282 | 283 | return {'loss': loss.item()} 284 | 285 | def predict(self, x): 286 | return self.network(x) 287 | 288 | 289 | class AbstractDANN(Algorithm): 290 | """Domain-Adversarial Neural Networks (abstract class)""" 291 | 292 | def __init__(self, input_shape, num_classes, num_domains, 293 | hparams, conditional, class_balance): 294 | 295 | super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains, 296 | hparams) 297 | 298 | self.register_buffer('update_count', torch.tensor([0])) 299 | self.conditional = conditional 300 | self.class_balance = class_balance 301 | 302 | # Algorithms 303 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 304 | self.classifier = networks.Classifier( 305 | self.featurizer.n_outputs, 306 | num_classes, 307 | self.hparams['nonlinear_classifier']) 308 | self.discriminator = networks.MLP(self.featurizer.n_outputs, 309 | num_domains, self.hparams) 310 | self.class_embeddings = nn.Embedding(num_classes, 311 | self.featurizer.n_outputs) 312 | 313 | # Optimizers 314 | self.disc_opt = torch.optim.Adam( 315 | (list(self.discriminator.parameters()) + 316 | list(self.class_embeddings.parameters())), 317 | lr=self.hparams["lr_d"], 318 | weight_decay=self.hparams['weight_decay_d'], 319 | betas=(self.hparams['beta1'], 0.9)) 320 | 321 | self.gen_opt = torch.optim.Adam( 322 | (list(self.featurizer.parameters()) + 323 | list(self.classifier.parameters())), 324 | lr=self.hparams["lr_g"], 325 | weight_decay=self.hparams['weight_decay_g'], 326 | betas=(self.hparams['beta1'], 0.9)) 327 | 328 | def update(self, minibatches, unlabeled=None): 329 | device = "cuda" if minibatches[0][0].is_cuda else "cpu" 330 | self.update_count += 1 331 | all_x = torch.cat([x for x, y in minibatches]) 332 | all_y = torch.cat([y for x, y in minibatches]) 333 | all_z = self.featurizer(all_x) 334 | if self.conditional: 335 | disc_input = all_z + self.class_embeddings(all_y) 336 | else: 337 | disc_input = all_z 338 | disc_out = self.discriminator(disc_input) 339 | disc_labels = torch.cat([ 340 | torch.full((x.shape[0],), i, dtype=torch.int64, device=device) 341 | for i, (x, y) in enumerate(minibatches) 342 | ]) 343 | 344 | if self.class_balance: 345 | y_counts = F.one_hot(all_y).sum(dim=0) 346 | weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float() 347 | disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none') 348 | disc_loss = (weights * disc_loss).sum() 349 | else: 350 | disc_loss = F.cross_entropy(disc_out, disc_labels) 351 | 352 | disc_softmax = F.softmax(disc_out, dim=1) 353 | input_grad = autograd.grad(disc_softmax[:, disc_labels].sum(), 354 | [disc_input], create_graph=True)[0] 355 | grad_penalty = (input_grad ** 2).sum(dim=1).mean(dim=0) 356 | disc_loss += self.hparams['grad_penalty'] * grad_penalty 357 | 358 | d_steps_per_g = self.hparams['d_steps_per_g_step'] 359 | if (self.update_count.item() % (1 + d_steps_per_g) < d_steps_per_g): 360 | 361 | self.disc_opt.zero_grad() 362 | disc_loss.backward() 363 | self.disc_opt.step() 364 | return {'disc_loss': disc_loss.item()} 365 | else: 366 | all_preds = self.classifier(all_z) 367 | classifier_loss = F.cross_entropy(all_preds, all_y) 368 | gen_loss = (classifier_loss + 369 | (self.hparams['lambda'] * -disc_loss)) 370 | self.disc_opt.zero_grad() 371 | self.gen_opt.zero_grad() 372 | gen_loss.backward() 373 | self.gen_opt.step() 374 | return {'gen_loss': gen_loss.item()} 375 | 376 | def predict(self, x): 377 | return self.classifier(self.featurizer(x)) 378 | 379 | 380 | class DANN(AbstractDANN): 381 | """Unconditional DANN""" 382 | 383 | def __init__(self, input_shape, num_classes, num_domains, hparams): 384 | super(DANN, self).__init__(input_shape, num_classes, num_domains, 385 | hparams, conditional=False, class_balance=False) 386 | 387 | 388 | # 389 | # 390 | # class CDANN(AbstractDANN): 391 | # """Conditional DANN""" 392 | # 393 | # def __init__(self, input_shape, num_classes, num_domains, hparams): 394 | # super(CDANN, self).__init__(input_shape, num_classes, num_domains, 395 | # hparams, conditional=True, class_balance=True) 396 | # 397 | # 398 | class IRM(ERM): 399 | """Invariant Risk Minimization""" 400 | 401 | def __init__(self, input_shape, num_classes, num_domains, hparams): 402 | super(IRM, self).__init__(input_shape, num_classes, num_domains, 403 | hparams) 404 | self.register_buffer('update_count', torch.tensor([0])) 405 | 406 | @staticmethod 407 | def _irm_penalty(logits, y): 408 | device = "cuda" if logits[0][0].is_cuda else "cpu" 409 | scale = torch.tensor(1.).to(device).requires_grad_() 410 | loss_1 = F.cross_entropy(logits[::2] * scale, y[::2]) 411 | loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2]) 412 | grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0] 413 | grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0] 414 | result = torch.sum(grad_1 * grad_2) 415 | return result 416 | 417 | def update(self, minibatches, unlabeled=None): 418 | device = "cuda" if minibatches[0][0].is_cuda else "cpu" 419 | penalty_weight = (self.hparams['irm_lambda'] if self.update_count 420 | >= self.hparams['irm_penalty_anneal_iters'] else 421 | 1.0) 422 | nll = 0. 423 | penalty = 0. 424 | 425 | all_x = torch.cat([x for x, y in minibatches]) 426 | all_logits = self.network(all_x) 427 | all_logits_idx = 0 428 | for i, (x, y) in enumerate(minibatches): 429 | logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]] 430 | all_logits_idx += x.shape[0] 431 | nll += F.cross_entropy(logits, y) 432 | penalty += self._irm_penalty(logits, y) 433 | nll /= len(minibatches) 434 | penalty /= len(minibatches) 435 | loss = nll + (penalty_weight * penalty) 436 | 437 | if self.update_count == self.hparams['irm_penalty_anneal_iters']: 438 | # Reset Adam, because it doesn't like the sharp jump in gradient 439 | # magnitudes that happens at this step. 440 | self.optimizer = torch.optim.Adam( 441 | self.network.parameters(), 442 | lr=self.hparams["lr"], 443 | weight_decay=self.hparams['weight_decay']) 444 | 445 | self.optimizer.zero_grad() 446 | loss.backward() 447 | self.optimizer.step() 448 | 449 | self.update_count += 1 450 | return {'loss': loss.item(), 'nll': nll.item(), 451 | 'penalty': penalty.item()} 452 | 453 | 454 | class Fishr(Algorithm): 455 | "Invariant Gradients variances for Out-of-distribution Generalization" 456 | 457 | def __init__(self, input_shape, num_classes, num_domains, hparams): 458 | assert backpack is not None, "Install backpack with: 'pip install backpack-for-pytorch==1.3.0'" 459 | super(Fishr, self).__init__(input_shape, num_classes, num_domains, hparams) 460 | self.num_domains = num_domains 461 | 462 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 463 | self.classifier = extend( 464 | networks.Classifier( 465 | self.featurizer.n_outputs, 466 | num_classes, 467 | self.hparams['nonlinear_classifier'], 468 | ) 469 | ) 470 | self.network = nn.Sequential(self.featurizer, self.classifier) 471 | 472 | self.register_buffer("update_count", torch.tensor([0])) 473 | self.bce_extended = extend(nn.CrossEntropyLoss(reduction='none')) 474 | self.ema_per_domain = [ 475 | MovingAverage(ema=self.hparams["ema"], oneminusema_correction=True) 476 | for _ in range(self.num_domains) 477 | ] 478 | self._init_optimizer() 479 | 480 | def _init_optimizer(self): 481 | self.optimizer = torch.optim.Adam( 482 | list(self.featurizer.parameters()) + list(self.classifier.parameters()), 483 | lr=self.hparams["lr"], 484 | weight_decay=self.hparams["weight_decay"], 485 | ) 486 | 487 | def update(self, minibatches, unlabeled=False): 488 | assert len(minibatches) == self.num_domains 489 | all_x = torch.cat([x for x, y in minibatches]) 490 | all_y = torch.cat([y for x, y in minibatches]) 491 | len_minibatches = [x.shape[0] for x, y in minibatches] 492 | 493 | all_z = self.featurizer(all_x) 494 | all_logits = self.classifier(all_z) 495 | 496 | penalty = self.compute_fishr_penalty(all_logits, all_y, len_minibatches) 497 | all_nll = F.cross_entropy(all_logits, all_y) 498 | 499 | penalty_weight = 0 500 | if self.update_count >= self.hparams["penalty_anneal_iters"]: 501 | penalty_weight = self.hparams["lambda"] 502 | if self.update_count == self.hparams["penalty_anneal_iters"] != 0: 503 | # Reset Adam as in IRM or V-REx, because it may not like the sharp jump in 504 | # gradient magnitudes that happens at this step. 505 | self._init_optimizer() 506 | self.update_count += 1 507 | 508 | objective = all_nll + penalty_weight * penalty 509 | self.optimizer.zero_grad() 510 | objective.backward() 511 | self.optimizer.step() 512 | 513 | return {'loss': objective.item(), 'nll': all_nll.item(), 'penalty': penalty.item()} 514 | 515 | def compute_fishr_penalty(self, all_logits, all_y, len_minibatches): 516 | dict_grads = self._get_grads(all_logits, all_y) 517 | grads_var_per_domain = self._get_grads_var_per_domain(dict_grads, len_minibatches) 518 | return self._compute_distance_grads_var(grads_var_per_domain) 519 | 520 | def _get_grads(self, logits, y): 521 | self.optimizer.zero_grad() 522 | loss = self.bce_extended(logits, y).sum() 523 | with backpack(BatchGrad()): 524 | loss.backward( 525 | inputs=list(self.classifier.parameters()), retain_graph=True, create_graph=True 526 | ) 527 | 528 | # compute individual grads for all samples across all domains simultaneously 529 | dict_grads = OrderedDict( 530 | [ 531 | (name, weights.grad_batch.clone().view(weights.grad_batch.size(0), -1)) 532 | for name, weights in self.classifier.named_parameters() 533 | ] 534 | ) 535 | return dict_grads 536 | 537 | def _get_grads_var_per_domain(self, dict_grads, len_minibatches): 538 | # grads var per domain 539 | grads_var_per_domain = [{} for _ in range(self.num_domains)] 540 | for name, _grads in dict_grads.items(): 541 | all_idx = 0 542 | for domain_id, bsize in enumerate(len_minibatches): 543 | env_grads = _grads[all_idx:all_idx + bsize] 544 | all_idx += bsize 545 | env_mean = env_grads.mean(dim=0, keepdim=True) 546 | env_grads_centered = env_grads - env_mean 547 | grads_var_per_domain[domain_id][name] = (env_grads_centered).pow(2).mean(dim=0) 548 | 549 | # moving average 550 | for domain_id in range(self.num_domains): 551 | grads_var_per_domain[domain_id] = self.ema_per_domain[domain_id].update( 552 | grads_var_per_domain[domain_id] 553 | ) 554 | 555 | return grads_var_per_domain 556 | 557 | def _compute_distance_grads_var(self, grads_var_per_domain): 558 | 559 | # compute gradient variances averaged across domains 560 | grads_var = OrderedDict( 561 | [ 562 | ( 563 | name, 564 | torch.stack( 565 | [ 566 | grads_var_per_domain[domain_id][name] 567 | for domain_id in range(self.num_domains) 568 | ], 569 | dim=0 570 | ).mean(dim=0) 571 | ) 572 | for name in grads_var_per_domain[0].keys() 573 | ] 574 | ) 575 | 576 | penalty = 0 577 | for domain_id in range(self.num_domains): 578 | penalty += l2_between_dicts(grads_var_per_domain[domain_id], grads_var) 579 | return penalty / self.num_domains 580 | 581 | def predict(self, x): 582 | return self.network(x) 583 | --------------------------------------------------------------------------------