├── LICENSE ├── README.md ├── algorithms.py ├── command_launchers.py ├── datasets.py ├── hparams_registry.py ├── lib ├── augmentations.py ├── fast_data_loader.py ├── misc.py ├── query.py ├── reporting.py └── wide_resnet.py ├── model_selection.py ├── networks.py ├── requirements.txt ├── scripts ├── __init__.py ├── __pycache__ │ ├── download.cpython-36.pyc │ └── download.cpython-37.pyc ├── collect_results.py ├── download.py ├── list_top_hparams.py ├── models.py ├── save_images.py └── utils.py ├── submit.py ├── sweep.py ├── test ├── __init__.py ├── helpers.py ├── lib │ ├── __init__.py │ ├── test_misc.py │ └── test_query.py ├── scripts │ ├── __init__.py │ ├── test_collect_results.py │ ├── test_sweep.py │ └── test_train.py ├── test_datasets.py ├── test_hparams_registry.py ├── test_model_selection.py ├── test_models.py ├── test_networks.py └── visual │ ├── mix.py │ ├── show1by1.py │ ├── show_smooth.py │ ├── show_smooth_v.py │ ├── swap.py │ └── swap_augmix.py ├── train.py └── wilds ├── __init__.py ├── common ├── grouper.py └── utils.py └── datasets ├── camelyon17_dataset.py └── wilds_dataset.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hanlin Zhang, Yi-Fan Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Principled Disentanglement for Domain Generalization, CVPR, 2022 (Oral) 2 | 3 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | 6 | DDG is a PyTorch implementation of [Towards Principled Disentanglement for Domain Generalization](https://arxiv.org/abs/2111.13839) based on [DomainBed](https://github.com/facebookresearch/DomainBed). 7 | ## Available datasets 8 | 9 | The [currently available datasets](datasets.py) are: 10 | 11 | * RotatedMNIST ([Ghifary et al., 2015](https://arxiv.org/abs/1508.07680)) 12 | * VLCS ([Fang et al., 2013](https://openaccess.thecvf.com/content_iccv_2013/papers/Fang_Unbiased_Metric_Learning_2013_ICCV_paper.pdf)) 13 | * PACS ([Li et al., 2017](https://arxiv.org/abs/1710.03077)) 14 | * WILDS ([Koh et al., 2020](https://arxiv.org/abs/2012.07421)) Camelyon17 ([Bandi et al., 2019](https://pubmed.ncbi.nlm.nih.gov/30716025/)) about tumor detection in tissues 15 | 16 | Send us a PR to add your dataset! Any custom image dataset with folder structure `dataset/domain/class/image.xyz` is readily usable. While we include some datasets from the [WILDS project](https://wilds.stanford.edu/), please use their [official code](https://github.com/p-lambda/wilds/) if you wish to participate in their leaderboard. 17 | 18 | ## Available model selection criteria 19 | 20 | [Model selection criteria](model_selection.py) differ in what data is used to choose the best hyper-parameters for a given model: 21 | 22 | * `IIDAccuracySelectionMethod`: A random subset from the data of the training domains. 23 | * `LeaveOneOutSelectionMethod`: A random subset from the data of a held-out (not training, not testing) domain. 24 | * `OracleSelectionMethod`: A random subset from the data of the test domain. 25 | 26 | ## Quick start 27 | 28 | Download the datasets: 29 | 30 | ```python 31 | python scripts/download.py \ 32 | --data-dir /my/datasets/path 33 | ``` 34 | 35 | Train a model: 36 | 37 | ```python 38 | python train.py\ 39 | --data-dir /my/datasets/path\ 40 | --algorithm ERM\ 41 | --dataset RotatedMNIST 42 | ``` 43 | 44 | Pretrain the decoder in DDG model: 45 | 46 | ```python 47 | python train.py\ 48 | --data-dir /my/datasets/path\ 49 | --algorithm DDG\ 50 | --dataset PACS\ 51 | --stage 0 52 | ``` 53 | 54 | Train the DDG model with pretrained decoder: 55 | 56 | ```python 57 | python train.py\ 58 | --data-dir /my/datasets/path\ 59 | --algorithm DDG\ 60 | --gen-dir /my/models/model.pkl 61 | --dataset PACS\ 62 | --stage 1 63 | ``` 64 | 65 | ### Citation 66 | If you find this repo useful, please consider citing: 67 | ``` 68 | @inproceedings{zhang2022DDG, 69 | title={Towards principled disentanglement for domain generalization}, 70 | author={Zhang, Hanlin and Zhang, Yi-Fan and Liu, Weiyang and Weller, Adrian and Sch{\"o}lkopf, Bernhard and Xing, Eric P}, 71 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 72 | pages={8024--8034}, 73 | year={2022} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /command_launchers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | A command launcher launches a list of commands on a cluster; implement your own 5 | launcher to add support for your cluster. We've provided an example launcher 6 | which runs all commands serially on the local machine. 7 | """ 8 | 9 | import subprocess 10 | import time 11 | import torch 12 | 13 | def local_launcher(commands): 14 | """Launch commands serially on the local machine.""" 15 | for cmd in commands: 16 | subprocess.call(cmd, shell=True) 17 | 18 | def dummy_launcher(commands): 19 | """ 20 | Doesn't run anything; instead, prints each command. 21 | Useful for testing. 22 | """ 23 | for cmd in commands: 24 | print(f'Dummy launcher: {cmd}') 25 | 26 | def multi_gpu_launcher(commands): 27 | """ 28 | Launch commands on the local machine, using all GPUs in parallel. 29 | """ 30 | print('WARNING: using experimental multi_gpu_launcher.') 31 | n_gpus = torch.cuda.device_count() 32 | procs_by_gpu = [None]*n_gpus 33 | 34 | while len(commands) > 0: 35 | for gpu_idx in range(n_gpus): 36 | proc = procs_by_gpu[gpu_idx] 37 | if (proc is None) or (proc.poll() is not None): 38 | # Nothing is running on this GPU; launch a command. 39 | cmd = commands.pop(0) 40 | new_proc = subprocess.Popen( 41 | f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', shell=True) 42 | procs_by_gpu[gpu_idx] = new_proc 43 | break 44 | time.sleep(1) 45 | 46 | # Wait for the last few tasks to finish before returning 47 | for p in procs_by_gpu: 48 | if p is not None: 49 | p.wait() 50 | 51 | REGISTRY = { 52 | 'local': local_launcher, 53 | 'dummy': dummy_launcher, 54 | 'multi_gpu': multi_gpu_launcher 55 | } 56 | 57 | try: 58 | from domainbed import facebook 59 | facebook.register_command_launchers(REGISTRY) 60 | except ImportError: 61 | pass 62 | -------------------------------------------------------------------------------- /hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import numpy as np 4 | from lib import misc 5 | 6 | def _define_hparam(hparams, hparam_name, default_val, random_val_fn): 7 | hparams[hparam_name] = (hparams, hparam_name, default_val, random_val_fn) 8 | 9 | 10 | def _hparams(algorithm, dataset, random_seed, stage): 11 | """ 12 | Global registry of hyperparams. Each entry is a (default, random) tuple. 13 | New algorithms / networks / etc. should add entries here. 14 | """ 15 | SMALL_IMAGES = ['Debug28', 'RotatedMNIST', 'ColoredMNIST'] 16 | 17 | hparams = {} 18 | def _hparam(name, default_val, random_val_fn): 19 | """Define a hyperparameter. random_val_fn takes a RandomState and 20 | returns a random hyperparameter value.""" 21 | assert(name not in hparams) 22 | random_state = np.random.RandomState( 23 | misc.seed_hash(random_seed, name) 24 | ) 25 | hparams[name] = (default_val, random_val_fn(random_state)) 26 | 27 | # Unconditional hparam definitions. 28 | 29 | _hparam('data_augmentation', True, lambda r: True) 30 | _hparam('resnet18', True, lambda r: False) 31 | _hparam('resnet_dropout', 0.0, lambda r: r.choice([0., 0.1, 0.5])) 32 | _hparam('class_balanced', False, lambda r: False) 33 | _hparam('nonlinear_classifier', False, lambda r: bool(r.choice([False, True]))) 34 | 35 | # Algorithm-specific hparam definitions. Each block of code below 36 | # corresponds to exactly one algorithm. 37 | 38 | if 'MNIST' in dataset: 39 | _hparam('is_mnist', True, lambda r: True) 40 | else: 41 | _hparam('is_mnist', False, lambda r: False) 42 | 43 | if algorithm in ['DANN', 'CDANN']: 44 | _hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2)) 45 | _hparam('weight_decay_d', 0., lambda r: 10**r.uniform(-6, -2)) 46 | _hparam('d_steps_per_g_step', 5, lambda r: int(2**r.uniform(0, 3))) 47 | _hparam('grad_penalty', 0., lambda r: 10**r.uniform(-2, 1)) 48 | _hparam('beta1', 0.5, lambda r: r.choice([0., 0.5])) 49 | _hparam('mlp_width', 256, lambda r: int(2 ** r.uniform(6, 10))) 50 | _hparam('mlp_depth', 3, lambda r: int(r.choice([3, 4, 5]))) 51 | _hparam('mlp_dropout', 0., lambda r: r.choice([0., 0.1, 0.5])) 52 | 53 | elif algorithm == "RSC": 54 | _hparam('rsc_f_drop_factor', 1/3, lambda r: r.uniform(0, 0.5)) 55 | _hparam('rsc_b_drop_factor', 1/3, lambda r: r.uniform(0, 0.5)) 56 | 57 | elif algorithm == "SagNet": 58 | _hparam('sag_w_adv', 0.1, lambda r: 10**r.uniform(-2, 1)) 59 | 60 | elif algorithm == "IRM": 61 | _hparam('irm_lambda', 1e2, lambda r: 10**r.uniform(-1, 5)) 62 | _hparam('irm_penalty_anneal_iters', 500, lambda r: int(10**r.uniform(0, 4))) 63 | 64 | elif algorithm == "Mixup": 65 | _hparam('mixup_alpha', 0.2, lambda r: 10**r.uniform(-1, -1)) 66 | 67 | elif algorithm == "GroupDRO": 68 | _hparam('groupdro_eta', 1e-2, lambda r: 10**r.uniform(-3, -1)) 69 | 70 | elif algorithm == "MMD" or algorithm == "CORAL": 71 | _hparam('mmd_gamma', 1., lambda r: 10**r.uniform(-1, 1)) 72 | 73 | elif algorithm == "MLDG": 74 | _hparam('mldg_beta', 1., lambda r: 10**r.uniform(-1, 1)) 75 | 76 | elif algorithm == "MTL": 77 | _hparam('mtl_ema', .99, lambda r: r.choice([0.5, 0.9, 0.99, 1.])) 78 | 79 | elif algorithm == "VREx": 80 | _hparam('vrex_lambda', 1e1, lambda r: 10**r.uniform(-1, 5)) 81 | _hparam('vrex_penalty_anneal_iters', 500, lambda r: int(10**r.uniform(0, 4))) 82 | 83 | elif algorithm == "SD": 84 | _hparam('sd_reg', 0.1, lambda r: 10**r.uniform(-5, -1)) 85 | 86 | if 'DDG' in algorithm: 87 | _hparam('is_ddg', True, lambda r: True) 88 | if algorithm == 'DDG_AugMix': 89 | _hparam('is_augmix', True, lambda r: True) 90 | else: 91 | _hparam('is_augmix', False, lambda r: False) 92 | if 'MNIST' in dataset: 93 | print('mnsit') 94 | _hparam('steps', 10000, lambda r: 10000) 95 | _hparam('stage', stage, lambda r: stage) 96 | _hparam('margin', 0.025, lambda r: 0.025) 97 | _hparam('recon_id_w', 0.5, lambda r: r.choice([0.1, 0.2, 0.5, 1.0])) 98 | _hparam('recon_x_w', 0.5, lambda r: r.choice([1., 2., 5., 10.])) 99 | elif stage == 0: 100 | _hparam('steps', 25000, lambda r: 25000) 101 | _hparam('stage', stage, lambda r: stage) 102 | _hparam('margin', 0.025, lambda r: 0.025) 103 | _hparam('recon_id_w', 0.5, lambda r: r.choice([0.1, 0.2, 0.5, 1.0])) 104 | _hparam('recon_x_w', 0.5, lambda r: r.choice([1., 2., 5., 10.])) 105 | else: 106 | _hparam('steps', 10000, lambda r: 10000) 107 | _hparam('stage', stage, lambda r: stage) 108 | _hparam('recon_id_w', 0.5, lambda r: r.choice([0.1, 0.2, 0.5, 1.0])) 109 | _hparam('margin', 0.25, lambda r: r.choice([0.1, 0.25, 0.5, 0.75])) 110 | _hparam('recon_xp_w', 0.5, lambda r: r.choice([1., 2., 5., 10.])) 111 | _hparam('recon_xn_w', 0.5, lambda r: r.choice([1., 2., 5., 10.])) 112 | _hparam('max_cyc_w', 2.0, lambda r: r.choice([1.0, 2.0, 4.0])) 113 | _hparam('max_w', 2.0, lambda r: r.choice([0.5, 1.0, 2.0])) 114 | _hparam('gan_w', 1.0, lambda r: r.choice([0.5, 1.0, 2.0])) 115 | _hparam('eta', 0.01, lambda r: 0.05) 116 | _hparam('recon_x_cyc_w', 0.0, lambda r: r.choice([0.1, 0.2, 0.5, 1.0])) 117 | _hparam('warm_iter_r', .2, lambda r: r.choice([.1, .2, .3, .4, .5])) 118 | _hparam('warm_scale', 5e-3, lambda r: 10**r.uniform(-5, -3)) 119 | else: 120 | _hparam('is_ddg', False, lambda r: False) 121 | 122 | # Dataset-and-algorithm-specific hparam definitions. Each block of code 123 | # below corresponds to exactly one hparam. Avoid nested conditionals. 124 | 125 | if dataset in SMALL_IMAGES: 126 | _hparam('lr', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 127 | elif 'DDG' in algorithm: 128 | _hparam('lr', 2e-5, lambda r: 2e-5) 129 | else: 130 | _hparam('lr', 5e-5, lambda r: 10**r.uniform(-5, -3.5)) 131 | 132 | if dataset in SMALL_IMAGES: 133 | _hparam('weight_decay', 0., lambda r: 0.) 134 | else: 135 | _hparam('weight_decay', 0., lambda r: 10**r.uniform(-6, -2)) 136 | 137 | if dataset in SMALL_IMAGES: 138 | _hparam('batch_size', 64, lambda r: int(2**r.uniform(3, 9)) ) 139 | elif algorithm == 'ARM': 140 | _hparam('batch_size', 8, lambda r: 8) 141 | elif 'DDG' in algorithm: 142 | _hparam('batch_size', 2, lambda r: 4) 143 | elif dataset == 'DomainNet': 144 | _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5)) ) 145 | else: 146 | _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5.5)) ) 147 | 148 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 149 | _hparam('lr_g', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5) ) 150 | elif algorithm in ['DANN', 'CDANN']: 151 | _hparam('lr_g', 5e-5, lambda r: 10**r.uniform(-5, -3.5) ) 152 | elif 'DDG' in algorithm: 153 | _hparam('lr_g', 1e-4, lambda r: 10**r.uniform(-5, -3.5) ) 154 | 155 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 156 | _hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5) ) 157 | elif algorithm in ['DANN', 'CDANN']: 158 | _hparam('lr_d', 5e-5, lambda r: 10**r.uniform(-5, -3.5) ) 159 | elif 'DDG' in algorithm: 160 | _hparam('lr_d', 1e-4, lambda r: 10**r.uniform(-5, -3.5) ) 161 | 162 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 163 | _hparam('weight_decay_g', 0., lambda r: 0.) 164 | elif algorithm in ['DANN', 'CDANN', 'DDG', 'DDG_AugMix']: 165 | _hparam('weight_decay_g', 0.0005, lambda r: 10**r.uniform(-6, -2) ) 166 | 167 | return hparams 168 | 169 | def default_hparams(algorithm, dataset, stage=0): 170 | return {a: b for a,(b,c) in 171 | _hparams(algorithm, dataset, 0, stage).items()} 172 | 173 | def random_hparams(algorithm, dataset, seed, stage=0): 174 | return {a: c for a,(b,c) in _hparams(algorithm, dataset, seed, stage).items()} 175 | -------------------------------------------------------------------------------- /lib/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base augmentations operators.""" 16 | 17 | import numpy as np 18 | from PIL import Image, ImageOps, ImageEnhance 19 | 20 | # ImageNet code should change this value 21 | IMAGE_SIZE = 32 22 | 23 | 24 | def int_parameter(level, maxval): 25 | """Helper function to scale `val` between 0 and maxval . 26 | Args: 27 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 28 | maxval: Maximum value that the operation can have. This will be scaled to 29 | level/PARAMETER_MAX. 30 | Returns: 31 | An int that results from scaling `maxval` according to `level`. 32 | """ 33 | return int(level * maxval / 10) 34 | 35 | 36 | def float_parameter(level, maxval): 37 | """Helper function to scale `val` between 0 and maxval. 38 | Args: 39 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 40 | maxval: Maximum value that the operation can have. This will be scaled to 41 | level/PARAMETER_MAX. 42 | Returns: 43 | A float that results from scaling `maxval` according to `level`. 44 | """ 45 | return float(level) * maxval / 10. 46 | 47 | 48 | def sample_level(n): 49 | return np.random.uniform(low=0.1, high=n) 50 | 51 | 52 | def autocontrast(pil_img, _): 53 | return ImageOps.autocontrast(pil_img) 54 | 55 | 56 | def equalize(pil_img, _): 57 | return ImageOps.equalize(pil_img) 58 | 59 | 60 | def posterize(pil_img, level): 61 | level = int_parameter(sample_level(level), 4) 62 | return ImageOps.posterize(pil_img, 4 - level) 63 | 64 | 65 | def rotate(pil_img, level): 66 | degrees = int_parameter(sample_level(level), 30) 67 | if np.random.uniform() > 0.5: 68 | degrees = -degrees 69 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 70 | 71 | 72 | def solarize(pil_img, level): 73 | level = int_parameter(sample_level(level), 256) 74 | return ImageOps.solarize(pil_img, 256 - level) 75 | 76 | 77 | def shear_x(pil_img, level): 78 | level = float_parameter(sample_level(level), 0.3) 79 | if np.random.uniform() > 0.5: 80 | level = -level 81 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 82 | Image.AFFINE, (1, level, 0, 0, 1, 0), 83 | resample=Image.BILINEAR) 84 | 85 | 86 | def shear_y(pil_img, level): 87 | level = float_parameter(sample_level(level), 0.3) 88 | if np.random.uniform() > 0.5: 89 | level = -level 90 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 91 | Image.AFFINE, (1, 0, 0, level, 1, 0), 92 | resample=Image.BILINEAR) 93 | 94 | 95 | def translate_x(pil_img, level): 96 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 97 | if np.random.random() > 0.5: 98 | level = -level 99 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 100 | Image.AFFINE, (1, 0, level, 0, 1, 0), 101 | resample=Image.BILINEAR) 102 | 103 | 104 | def translate_y(pil_img, level): 105 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 106 | if np.random.random() > 0.5: 107 | level = -level 108 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 109 | Image.AFFINE, (1, 0, 0, 0, 1, level), 110 | resample=Image.BILINEAR) 111 | 112 | 113 | # operation that overlaps with ImageNet-C's test set 114 | def color(pil_img, level): 115 | level = float_parameter(sample_level(level), 1.8) + 0.1 116 | return ImageEnhance.Color(pil_img).enhance(level) 117 | 118 | 119 | # operation that overlaps with ImageNet-C's test set 120 | def contrast(pil_img, level): 121 | level = float_parameter(sample_level(level), 1.8) + 0.1 122 | return ImageEnhance.Contrast(pil_img).enhance(level) 123 | 124 | 125 | # operation that overlaps with ImageNet-C's test set 126 | def brightness(pil_img, level): 127 | level = float_parameter(sample_level(level), 1.8) + 0.1 128 | return ImageEnhance.Brightness(pil_img).enhance(level) 129 | 130 | 131 | # operation that overlaps with ImageNet-C's test set 132 | def sharpness(pil_img, level): 133 | level = float_parameter(sample_level(level), 1.8) + 0.1 134 | return ImageEnhance.Sharpness(pil_img).enhance(level) 135 | 136 | 137 | augmentations = [ 138 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 139 | translate_x, translate_y 140 | ] 141 | 142 | augmentations_all = [ 143 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 144 | translate_x, translate_y, color, contrast, brightness, sharpness 145 | ] -------------------------------------------------------------------------------- /lib/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | class _InfiniteSampler(torch.utils.data.Sampler): 6 | """Wraps another Sampler to yield an infinite stream.""" 7 | def __init__(self, sampler): 8 | self.sampler = sampler 9 | 10 | def __iter__(self): 11 | while True: 12 | for batch in self.sampler: 13 | yield batch 14 | 15 | class InfiniteDataLoader: 16 | def __init__(self, dataset, weights, batch_size, num_workers): 17 | super().__init__() 18 | 19 | if weights: 20 | sampler = torch.utils.data.WeightedRandomSampler(weights, 21 | replacement=True, 22 | num_samples=batch_size) 23 | else: 24 | sampler = torch.utils.data.RandomSampler(dataset, 25 | replacement=True) 26 | 27 | if weights == None: 28 | weights = torch.ones(len(dataset)) 29 | 30 | batch_sampler = torch.utils.data.BatchSampler( 31 | sampler, 32 | batch_size=batch_size, 33 | drop_last=True) 34 | 35 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 36 | dataset, 37 | num_workers=num_workers, 38 | batch_sampler=_InfiniteSampler(batch_sampler) 39 | )) 40 | 41 | def __iter__(self): 42 | while True: 43 | yield next(self._infinite_iterator) 44 | 45 | def __len__(self): 46 | raise ValueError 47 | 48 | class FastDataLoader: 49 | """DataLoader wrapper with slightly improved speed by not respawning worker 50 | processes at every epoch.""" 51 | def __init__(self, dataset, batch_size, num_workers): 52 | super().__init__() 53 | 54 | batch_sampler = torch.utils.data.BatchSampler( 55 | torch.utils.data.RandomSampler(dataset, replacement=False), 56 | batch_size=batch_size, 57 | drop_last=False 58 | ) 59 | 60 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 61 | dataset, 62 | num_workers=num_workers, 63 | batch_sampler=_InfiniteSampler(batch_sampler) 64 | )) 65 | 66 | self._length = len(batch_sampler) 67 | 68 | def __iter__(self): 69 | for _ in range(len(self)): 70 | yield next(self._infinite_iterator) 71 | 72 | def __len__(self): 73 | return self._length 74 | -------------------------------------------------------------------------------- /lib/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Things that don't belong anywhere else 5 | """ 6 | 7 | import hashlib 8 | import json 9 | import os 10 | import sys 11 | from shutil import copyfile 12 | import lib.augmentations as augmentations 13 | 14 | import numpy as np 15 | import torch 16 | import tqdm 17 | from collections import Counter 18 | from sklearn.metrics import confusion_matrix 19 | 20 | def make_weights_for_balanced_classes(dataset): 21 | counts = Counter() 22 | classes = [] 23 | for _, y in dataset: 24 | y = int(y) 25 | counts[y] += 1 26 | classes.append(y) 27 | 28 | n_classes = len(counts) 29 | 30 | weight_per_class = {} 31 | for y in counts: 32 | weight_per_class[y] = 1 / (counts[y] * n_classes) 33 | 34 | weights = torch.zeros(len(dataset)) 35 | for i, y in enumerate(classes): 36 | weights[i] = weight_per_class[int(y)] 37 | 38 | return weights 39 | 40 | def pdb(): 41 | sys.stdout = sys.__stdout__ 42 | import pdb 43 | print("Launching PDB, enter 'n' to step to parent function.") 44 | pdb.set_trace() 45 | 46 | def seed_hash(*args): 47 | """ 48 | Derive an integer hash from all args, for use as a random seed. 49 | """ 50 | args_str = str(args) 51 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) 52 | 53 | def print_separator(): 54 | print("="*80) 55 | 56 | def print_row(row, colwidth=10, latex=False): 57 | if latex: 58 | sep = " & " 59 | end_ = "\\\\" 60 | else: 61 | sep = " " 62 | end_ = "" 63 | 64 | def format_val(x): 65 | if np.issubdtype(type(x), np.floating): 66 | x = "{:.10f}".format(x) 67 | return str(x).ljust(colwidth)[:colwidth] 68 | print(sep.join([format_val(x) for x in row]), end_) 69 | 70 | class _SplitDataset(torch.utils.data.Dataset): 71 | """Used by split_dataset""" 72 | def __init__(self, underlying_dataset, keys): 73 | super(_SplitDataset, self).__init__() 74 | self.underlying_dataset = underlying_dataset 75 | self.keys = keys 76 | def __getitem__(self, key): 77 | return self.underlying_dataset[self.keys[key]] 78 | def __len__(self): 79 | return len(self.keys) 80 | 81 | def split_dataset(dataset, n, seed=0): 82 | """ 83 | Return a pair of datasets corresponding to a random split of the given 84 | dataset, with n datapoints in the first dataset and the rest in the last, 85 | using the given random seed 86 | """ 87 | assert(n <= len(dataset)) 88 | keys = list(range(len(dataset))) 89 | np.random.RandomState(seed).shuffle(keys) 90 | keys_1 = keys[:n] 91 | keys_2 = keys[n:] 92 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 93 | 94 | def random_pairs_of_minibatches(minibatches): 95 | perm = torch.randperm(len(minibatches)).tolist() 96 | pairs = [] 97 | 98 | for i in range(len(minibatches)): 99 | j = i + 1 if i < (len(minibatches) - 1) else 0 100 | 101 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 102 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 103 | 104 | min_n = min(len(xi), len(xj)) 105 | 106 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 107 | 108 | return pairs 109 | 110 | def sample_tuple_of_minibatches(minibatches, device): 111 | disc_labels = torch.cat([ 112 | torch.full((x.shape[0], ), i, dtype=torch.int64, device=device) 113 | for i, (x, y) in enumerate(minibatches) 114 | ]) 115 | perm = torch.randperm(len(minibatches)).tolist() 116 | tuples = [] 117 | labels = np.array([minibatches[i][1] for i in range(len(minibatches))]) 118 | 119 | for i in range(len(minibatches)): 120 | 121 | x, y, d = minibatches[i][0], minibatches[i][1], disc_labels[i] 122 | x_n, y_n, d_n = minibatches[perm[i]][0], minibatches[perm[i]][1], disc_labels[perm[i]] 123 | while y_n == y: 124 | i = perm[i] 125 | x_n, y_n = minibatches[perm[i]][0], minibatches[perm[i]][1], disc_labels[perm[i]] 126 | 127 | pos_ind = np.argwhere(labels == y); pos_n_ind = np.where(labels == y_n) 128 | x_p, x_np = minibatches[pos_ind[0]][0], minibatches[pos_n_ind[0]][0] 129 | 130 | tuples.append((x, y, d, x_p), (x_n, y_n, d_n, x_np)) 131 | 132 | return tuples 133 | 134 | def plot_confusion(matrix): 135 | pass 136 | 137 | def accuracy(network, loader, weights, device, args=None, step=None, is_ddg=False): 138 | correct = 0 139 | total = 0 140 | weights_offset = 0 141 | 142 | network.eval() 143 | with torch.no_grad(): 144 | if is_ddg: 145 | for x, y, _ in loader: 146 | x = x.to(device) 147 | y = y.to(device) 148 | p = network.predict(x) 149 | if weights is None: 150 | batch_weights = torch.ones(len(x)) 151 | else: 152 | batch_weights = weights[weights_offset : weights_offset + len(x)] 153 | weights_offset += len(x) 154 | batch_weights = batch_weights.to(device) 155 | if p.size(1) == 1: 156 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() 157 | else: 158 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() 159 | total += batch_weights.sum().item() 160 | else: 161 | for x, y in loader: 162 | x = x.to(device) 163 | y = y.to(device) 164 | p = network.predict(x) 165 | #if step % 50 == 0 and args.dataset != 'WILDSCamelyon': 166 | # pass 167 | # confusion = confusion_matrix(p.gt(0).cpu().data, y.cpu().data) 168 | # with open(gen_dir + '/confusion_{}_{}_d{}/confusion{}.npy'.format(args.algorithm, args.dataset, step), 'wb') as f: 169 | # np.save(f, confusion) 170 | 171 | if weights is None: 172 | batch_weights = torch.ones(len(x)) 173 | else: 174 | batch_weights = weights[weights_offset : weights_offset + len(x)] 175 | weights_offset += len(x) 176 | batch_weights = batch_weights.to(device) 177 | if p.size(1) == 1: 178 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() 179 | else: 180 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() 181 | total += batch_weights.sum().item() 182 | 183 | network.train() 184 | 185 | return correct / total 186 | 187 | class Tee: 188 | def __init__(self, fname, mode="a"): 189 | self.stdout = sys.stdout 190 | self.file = open(fname, mode) 191 | 192 | def write(self, message): 193 | self.stdout.write(message) 194 | self.file.write(message) 195 | self.flush() 196 | 197 | def flush(self): 198 | self.stdout.flush() 199 | self.file.flush() 200 | 201 | augmentations.IMAGE_SIZE = 224 202 | def aug(image, preprocess): 203 | """Perform AugMix augmentations and compute mixture. 204 | Args: 205 | image: PIL.Image input image 206 | preprocess: Preprocessing function which should return a torch tensor. 207 | Returns: 208 | mixed: Augmented and mixed image. 209 | """ 210 | aug_list = augmentations.augmentations 211 | mixture_width = 3 212 | mixture_depth = -1 213 | aug_severity = 1 214 | ws = np.float32( 215 | np.random.dirichlet([1] * mixture_width)) 216 | m = np.float32(np.random.beta(1, 1)) 217 | 218 | mix = torch.zeros_like(preprocess(image)) 219 | for i in range(mixture_width): 220 | image_aug = image.copy() 221 | depth = mixture_depth if mixture_depth > 0 else np.random.randint( 222 | 1, 4) 223 | for _ in range(depth): 224 | op = np.random.choice(aug_list) 225 | image_aug = op(image_aug, aug_severity) 226 | # Preprocessing commutes since all coefficients are convex 227 | mix += ws[i] * preprocess(image_aug) 228 | 229 | mixed = (1 - m) * preprocess(image) + m * mix 230 | return mixed 231 | 232 | def Augmix(x, preprocess, no_jsd): 233 | if no_jsd: 234 | return aug(x, preprocess) 235 | else: 236 | return preprocess(x), aug(x, preprocess), aug(x, preprocess) -------------------------------------------------------------------------------- /lib/query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Small query library.""" 4 | 5 | import collections 6 | import inspect 7 | import json 8 | import types 9 | import unittest 10 | import warnings 11 | import math 12 | 13 | import numpy as np 14 | 15 | 16 | def make_selector_fn(selector): 17 | """ 18 | If selector is a function, return selector. 19 | Otherwise, return a function corresponding to the selector string. Examples 20 | of valid selector strings and the corresponding functions: 21 | x lambda obj: obj['x'] 22 | x.y lambda obj: obj['x']['y'] 23 | x,y lambda obj: (obj['x'], obj['y']) 24 | """ 25 | if isinstance(selector, str): 26 | if ',' in selector: 27 | parts = selector.split(',') 28 | part_selectors = [make_selector_fn(part) for part in parts] 29 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 30 | elif '.' in selector: 31 | parts = selector.split('.') 32 | part_selectors = [make_selector_fn(part) for part in parts] 33 | def f(obj): 34 | for sel in part_selectors: 35 | obj = sel(obj) 36 | return obj 37 | return f 38 | else: 39 | key = selector.strip() 40 | return lambda obj: obj[key] 41 | elif isinstance(selector, types.FunctionType): 42 | return selector 43 | else: 44 | raise TypeError 45 | 46 | def hashable(obj): 47 | try: 48 | hash(obj) 49 | return obj 50 | except TypeError: 51 | return json.dumps({'_':obj}, sort_keys=True) 52 | 53 | class Q(object): 54 | def __init__(self, list_): 55 | super(Q, self).__init__() 56 | self._list = list_ 57 | 58 | def __len__(self): 59 | return len(self._list) 60 | 61 | def __getitem__(self, key): 62 | return self._list[key] 63 | 64 | def __eq__(self, other): 65 | if isinstance(other, self.__class__): 66 | return self._list == other._list 67 | else: 68 | return self._list == other 69 | 70 | def __str__(self): 71 | return str(self._list) 72 | 73 | def __repr__(self): 74 | return repr(self._list) 75 | 76 | def _append(self, item): 77 | """Unsafe, be careful you know what you're doing.""" 78 | self._list.append(item) 79 | 80 | def group(self, selector): 81 | """ 82 | Group elements by selector and return a list of (group, group_records) 83 | tuples. 84 | """ 85 | selector = make_selector_fn(selector) 86 | groups = {} 87 | for x in self._list: 88 | group = selector(x) 89 | group_key = hashable(group) 90 | if group_key not in groups: 91 | groups[group_key] = (group, Q([])) 92 | groups[group_key][1]._append(x) 93 | results = [groups[key] for key in sorted(groups.keys())] 94 | return Q(results) 95 | 96 | def group_map(self, selector, fn): 97 | """ 98 | Group elements by selector, apply fn to each group, and return a list 99 | of the results. 100 | """ 101 | return self.group(selector).map(fn) 102 | 103 | def map(self, fn): 104 | """ 105 | map self onto fn. If fn takes multiple args, tuple-unpacking 106 | is applied. 107 | """ 108 | if len(inspect.signature(fn).parameters) > 1: 109 | return Q([fn(*x) for x in self._list]) 110 | else: 111 | return Q([fn(x) for x in self._list]) 112 | 113 | def select(self, selector): 114 | selector = make_selector_fn(selector) 115 | return Q([selector(x) for x in self._list]) 116 | 117 | def min(self): 118 | return min(self._list) 119 | 120 | def max(self): 121 | return max(self._list) 122 | 123 | def sum(self): 124 | return sum(self._list) 125 | 126 | def len(self): 127 | return len(self._list) 128 | 129 | def mean(self): 130 | with warnings.catch_warnings(): 131 | warnings.simplefilter("ignore") 132 | return float(np.mean(self._list)) 133 | 134 | def std(self): 135 | with warnings.catch_warnings(): 136 | warnings.simplefilter("ignore") 137 | return float(np.std(self._list)) 138 | 139 | def mean_std(self): 140 | return (self.mean(), self.std()) 141 | 142 | def argmax(self, selector): 143 | selector = make_selector_fn(selector) 144 | return max(self._list, key=selector) 145 | 146 | def filter(self, fn): 147 | return Q([x for x in self._list if fn(x)]) 148 | 149 | def filter_equals(self, selector, value): 150 | """like [x for x in y if x.selector == value]""" 151 | selector = make_selector_fn(selector) 152 | return self.filter(lambda r: selector(r) == value) 153 | 154 | def filter_not_none(self): 155 | return self.filter(lambda r: r is not None) 156 | 157 | def filter_not_nan(self): 158 | return self.filter(lambda r: not np.isnan(r)) 159 | 160 | def flatten(self): 161 | return Q([y for x in self._list for y in x]) 162 | 163 | def unique(self): 164 | result = [] 165 | result_set = set() 166 | for x in self._list: 167 | hashable_x = hashable(x) 168 | if hashable_x not in result_set: 169 | result_set.add(hashable_x) 170 | result.append(x) 171 | return Q(result) 172 | 173 | def sorted(self, key=None): 174 | if key is None: 175 | key = lambda x: x 176 | def key2(x): 177 | x = key(x) 178 | if isinstance(x, (np.floating, float)) and np.isnan(x): 179 | return float('-inf') 180 | else: 181 | return x 182 | return Q(sorted(self._list, key=key2)) 183 | -------------------------------------------------------------------------------- /lib/reporting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | import json 6 | import os 7 | 8 | import tqdm 9 | 10 | from domainbed.lib.query import Q 11 | 12 | def load_records(path): 13 | records = [] 14 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))), 15 | ncols=80, 16 | leave=False): 17 | results_path = os.path.join(path, subdir, "results.jsonl") 18 | try: 19 | with open(results_path, "r") as f: 20 | for line in f: 21 | records.append(json.loads(line[:-1])) 22 | except IOError: 23 | pass 24 | 25 | return Q(records) 26 | 27 | def get_grouped_records(records): 28 | """Group records by (trial_seed, dataset, algorithm, test_env). Because 29 | records can have multiple test envs, a given record may appear in more than 30 | one group.""" 31 | result = collections.defaultdict(lambda: []) 32 | for r in records: 33 | for test_env in r["args"]["test_envs"]: 34 | group = (r["args"]["trial_seed"], 35 | r["args"]["dataset"], 36 | r["args"]["algorithm"], 37 | test_env) 38 | result[group].append(r) 39 | return Q([{"trial_seed": t, "dataset": d, "algorithm": a, "test_env": e, 40 | "records": Q(r)} for (t,d,a,e),r in result.items()]) 41 | -------------------------------------------------------------------------------- /lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | From https://github.com/meliketoy/wide-resnet.pytorch 5 | """ 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d( 19 | in_planes, 20 | out_planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=1, 24 | bias=True) 25 | 26 | 27 | def conv_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 31 | init.constant_(m.bias, 0) 32 | elif classname.find('BatchNorm') != -1: 33 | init.constant_(m.weight, 1) 34 | init.constant_(m.bias, 0) 35 | 36 | 37 | class wide_basic(nn.Module): 38 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 39 | super(wide_basic, self).__init__() 40 | self.bn1 = nn.BatchNorm2d(in_planes) 41 | self.conv1 = nn.Conv2d( 42 | in_planes, planes, kernel_size=3, padding=1, bias=True) 43 | self.dropout = nn.Dropout(p=dropout_rate) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d( 46 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d( 52 | in_planes, planes, kernel_size=1, stride=stride, 53 | bias=True), ) 54 | 55 | def forward(self, x): 56 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 57 | out = self.conv2(F.relu(self.bn2(out))) 58 | out += self.shortcut(x) 59 | 60 | return out 61 | 62 | 63 | class Wide_ResNet(nn.Module): 64 | """Wide Resnet with the softmax layer chopped off""" 65 | def __init__(self, input_shape, depth, widen_factor, dropout_rate): 66 | super(Wide_ResNet, self).__init__() 67 | self.in_planes = 16 68 | 69 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 70 | n = (depth - 4) / 6 71 | k = widen_factor 72 | 73 | # print('| Wide-Resnet %dx%d' % (depth, k)) 74 | nStages = [16, 16 * k, 32 * k, 64 * k] 75 | 76 | self.conv1 = conv3x3(input_shape[0], nStages[0]) 77 | self.layer1 = self._wide_layer( 78 | wide_basic, nStages[1], n, dropout_rate, stride=1) 79 | self.layer2 = self._wide_layer( 80 | wide_basic, nStages[2], n, dropout_rate, stride=2) 81 | self.layer3 = self._wide_layer( 82 | wide_basic, nStages[3], n, dropout_rate, stride=2) 83 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 84 | 85 | self.n_outputs = nStages[3] 86 | 87 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 88 | strides = [stride] + [1] * (int(num_blocks) - 1) 89 | layers = [] 90 | 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 93 | self.in_planes = planes 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | out = self.conv1(x) 99 | out = self.layer1(out) 100 | out = self.layer2(out) 101 | out = self.layer3(out) 102 | out = F.relu(self.bn1(out)) 103 | out = F.avg_pool2d(out, 8) 104 | return out[:, :, 0, 0] 105 | -------------------------------------------------------------------------------- /model_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import itertools 4 | import numpy as np 5 | 6 | def get_test_records(records): 7 | """Given records with a common test env, get the test records (i.e. the 8 | records with *only* that single test env and no other test envs)""" 9 | return records.filter(lambda r: len(r['args']['test_envs']) == 1) 10 | 11 | class SelectionMethod: 12 | """Abstract class whose subclasses implement strategies for model 13 | selection across hparams and timesteps.""" 14 | 15 | def __init__(self): 16 | raise TypeError 17 | 18 | @classmethod 19 | def run_acc(self, run_records): 20 | """ 21 | Given records from a run, return a {val_acc, test_acc} dict representing 22 | the best val-acc and corresponding test-acc for that run. 23 | """ 24 | raise NotImplementedError 25 | 26 | @classmethod 27 | def hparams_accs(self, records): 28 | """ 29 | Given all records from a single (dataset, algorithm, test env) pair, 30 | return a sorted list of (run_acc, records) tuples. 31 | """ 32 | return (records.group('args.hparams_seed') 33 | .map(lambda _, run_records: 34 | ( 35 | self.run_acc(run_records), 36 | run_records 37 | ) 38 | ).filter(lambda x: x[0] is not None) 39 | .sorted(key=lambda x: x[0]['val_acc'])[::-1] 40 | ) 41 | 42 | @classmethod 43 | def sweep_acc(self, records): 44 | """ 45 | Given all records from a single (dataset, algorithm, test env) pair, 46 | return the mean test acc of the k runs with the top val accs. 47 | """ 48 | _hparams_accs = self.hparams_accs(records) 49 | if len(_hparams_accs): 50 | return _hparams_accs[0][0]['test_acc'] 51 | else: 52 | return None 53 | 54 | class OracleSelectionMethod(SelectionMethod): 55 | """Like Selection method which picks argmax(test_out_acc) across all hparams 56 | and checkpoints, but instead of taking the argmax over all 57 | checkpoints, we pick the last checkpoint, i.e. no early stopping.""" 58 | name = "test-domain validation set (oracle)" 59 | 60 | @classmethod 61 | def run_acc(self, run_records): 62 | run_records = run_records.filter(lambda r: 63 | len(r['args']['test_envs']) == 1) 64 | if not len(run_records): 65 | return None 66 | test_env = run_records[0]['args']['test_envs'][0] 67 | test_out_acc_key = 'env{}_out_acc'.format(test_env) 68 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 69 | chosen_record = run_records.sorted(lambda r: r['step'])[-1] 70 | return { 71 | 'val_acc': chosen_record[test_out_acc_key], 72 | 'test_acc': chosen_record[test_in_acc_key] 73 | } 74 | 75 | class IIDAccuracySelectionMethod(SelectionMethod): 76 | """Picks argmax(mean(env_out_acc for env in train_envs))""" 77 | name = "training-domain validation set" 78 | 79 | @classmethod 80 | def _step_acc(self, record): 81 | """Given a single record, return a {val_acc, test_acc} dict.""" 82 | test_env = record['args']['test_envs'][0] 83 | val_env_keys = [] 84 | for i in itertools.count(): 85 | if f'env{i}_out_acc' not in record: 86 | break 87 | if i != test_env: 88 | val_env_keys.append(f'env{i}_out_acc') 89 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 90 | return { 91 | 'val_acc': np.mean([record[key] for key in val_env_keys]), 92 | 'test_acc': record[test_in_acc_key] 93 | } 94 | 95 | @classmethod 96 | def run_acc(self, run_records): 97 | test_records = get_test_records(run_records) 98 | if not len(test_records): 99 | return None 100 | return test_records.map(self._step_acc).argmax('val_acc') 101 | 102 | class LeaveOneOutSelectionMethod(SelectionMethod): 103 | """Picks (hparams, step) by leave-one-out cross validation.""" 104 | name = "leave-one-domain-out cross-validation" 105 | 106 | @classmethod 107 | def _step_acc(self, records): 108 | """Return the {val_acc, test_acc} for a group of records corresponding 109 | to a single step.""" 110 | test_records = get_test_records(records) 111 | if len(test_records) != 1: 112 | return None 113 | 114 | test_env = test_records[0]['args']['test_envs'][0] 115 | n_envs = 0 116 | for i in itertools.count(): 117 | if f'env{i}_out_acc' not in records[0]: 118 | break 119 | n_envs += 1 120 | val_accs = np.zeros(n_envs) - 1 121 | for r in records.filter(lambda r: len(r['args']['test_envs']) == 2): 122 | val_env = (set(r['args']['test_envs']) - set([test_env])).pop() 123 | val_accs[val_env] = r['env{}_in_acc'.format(val_env)] 124 | val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:]) 125 | if any([v==-1 for v in val_accs]): 126 | return None 127 | val_acc = np.sum(val_accs) / (n_envs-1) 128 | return { 129 | 'val_acc': val_acc, 130 | 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)] 131 | } 132 | 133 | @classmethod 134 | def run_acc(self, records): 135 | step_accs = records.group('step').map(lambda step, step_records: 136 | self._step_acc(step_records) 137 | ).filter_not_none() 138 | if len(step_accs): 139 | return step_accs.argmax('val_acc') 140 | else: 141 | return None 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.3 2 | wilds==1.2.2 3 | imageio==2.9.0 4 | gdown==3.13.0 5 | torchvision==0.8.2 6 | torch==1.7.1 7 | tqdm==4.62.2 8 | backpack==0.1 9 | parameterized==0.8.1 10 | Pillow==8.3.2 -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhang109/DDG/b5cd1822f1a413ae7e263fc5a11f00b490b9c72c/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/__pycache__/download.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhang109/DDG/b5cd1822f1a413ae7e263fc5a11f00b490b9c72c/scripts/__pycache__/download.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/download.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhang109/DDG/b5cd1822f1a413ae7e263fc5a11f00b490b9c72c/scripts/__pycache__/download.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | 6 | import argparse 7 | import functools 8 | import glob 9 | import pickle 10 | import itertools 11 | import json 12 | import os 13 | import random 14 | import sys 15 | 16 | import numpy as np 17 | import tqdm 18 | 19 | from domainbed import datasets 20 | from domainbed import algorithms 21 | from domainbed.lib import misc, reporting 22 | from domainbed import model_selection 23 | from domainbed.lib.query import Q 24 | import warnings 25 | 26 | def format_mean(data, latex): 27 | """Given a list of datapoints, return a string describing their mean and 28 | standard error""" 29 | if len(data) == 0: 30 | return None, None, "X" 31 | mean = 100 * np.mean(list(data)) 32 | err = 100 * np.std(list(data) / np.sqrt(len(data))) 33 | if latex: 34 | return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err) 35 | else: 36 | return mean, err, "{:.1f} +/- {:.1f}".format(mean, err) 37 | 38 | def print_table(table, header_text, row_labels, col_labels, colwidth=10, 39 | latex=True): 40 | """Pretty-print a 2D array of data, optionally with row/col labels""" 41 | print("") 42 | 43 | if latex: 44 | num_cols = len(table[0]) 45 | print("\\begin{center}") 46 | print("\\adjustbox{max width=\\textwidth}{%") 47 | print("\\begin{tabular}{l" + "c" * num_cols + "}") 48 | print("\\toprule") 49 | else: 50 | print("--------", header_text) 51 | 52 | for row, label in zip(table, row_labels): 53 | row.insert(0, label) 54 | 55 | if latex: 56 | col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}" 57 | for col_label in col_labels] 58 | table.insert(0, col_labels) 59 | 60 | for r, row in enumerate(table): 61 | misc.print_row(row, colwidth=colwidth, latex=latex) 62 | if latex and r == 0: 63 | print("\\midrule") 64 | if latex: 65 | print("\\bottomrule") 66 | print("\\end{tabular}}") 67 | print("\\end{center}") 68 | 69 | def print_results_tables(records, selection_method, latex): 70 | """Given all records, print a results table for each dataset.""" 71 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 72 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 73 | ).filter(lambda g: g["sweep_acc"] is not None) 74 | 75 | # read algorithm names and sort (predefined order) 76 | alg_names = Q(records).select("args.algorithm").unique() 77 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 78 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 79 | 80 | # read dataset names and sort (lexicographic order) 81 | dataset_names = Q(records).select("args.dataset").unique().sorted() 82 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 83 | 84 | for dataset in dataset_names: 85 | if latex: 86 | print() 87 | print("\\subsubsection{{{}}}".format(dataset)) 88 | test_envs = range(datasets.num_environments(dataset)) 89 | 90 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 91 | for i, algorithm in enumerate(alg_names): 92 | means = [] 93 | for j, test_env in enumerate(test_envs): 94 | trial_accs = (grouped_records 95 | .filter_equals( 96 | "dataset, algorithm, test_env", 97 | (dataset, algorithm, test_env) 98 | ).select("sweep_acc")) 99 | mean, err, table[i][j] = format_mean(trial_accs, latex) 100 | means.append(mean) 101 | if None in means: 102 | table[i][-1] = "X" 103 | else: 104 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 105 | 106 | col_labels = [ 107 | "Algorithm", 108 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 109 | "Avg" 110 | ] 111 | header_text = (f"Dataset: {dataset}, " 112 | f"model selection method: {selection_method.name}") 113 | print_table(table, header_text, alg_names, list(col_labels), 114 | colwidth=20, latex=latex) 115 | 116 | # Print an "averages" table 117 | if latex: 118 | print() 119 | print("\\subsubsection{Averages}") 120 | 121 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 122 | for i, algorithm in enumerate(alg_names): 123 | means = [] 124 | for j, dataset in enumerate(dataset_names): 125 | trial_averages = (grouped_records 126 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 127 | .group("trial_seed") 128 | .map(lambda trial_seed, group: 129 | group.select("sweep_acc").mean() 130 | ) 131 | ) 132 | mean, err, table[i][j] = format_mean(trial_averages, latex) 133 | means.append(mean) 134 | if None in means: 135 | table[i][-1] = "X" 136 | else: 137 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 138 | 139 | col_labels = ["Algorithm", *dataset_names, "Avg"] 140 | header_text = f"Averages, model selection method: {selection_method.name}" 141 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 142 | latex=latex) 143 | 144 | if __name__ == "__main__": 145 | np.set_printoptions(suppress=True) 146 | 147 | parser = argparse.ArgumentParser( 148 | description="Domain generalization testbed") 149 | parser.add_argument("--input_dir", type=str, default="") 150 | parser.add_argument("--latex", action="store_true") 151 | args = parser.parse_args() 152 | 153 | results_file = "results.tex" if args.latex else "results.txt" 154 | 155 | sys.stdout = misc.Tee(os.path.join(args.input_dir, results_file), "w") 156 | 157 | records = reporting.load_records(args.input_dir) 158 | 159 | if args.latex: 160 | print("\\documentclass{article}") 161 | print("\\usepackage{booktabs}") 162 | print("\\usepackage{adjustbox}") 163 | print("\\begin{document}") 164 | print("\\section{Full DomainBed results}") 165 | print("% Total records:", len(records)) 166 | else: 167 | print("Total records:", len(records)) 168 | 169 | SELECTION_METHODS = [ 170 | model_selection.IIDAccuracySelectionMethod, 171 | model_selection.LeaveOneOutSelectionMethod, 172 | model_selection.OracleSelectionMethod, 173 | ] 174 | 175 | for selection_method in SELECTION_METHODS: 176 | if args.latex: 177 | print() 178 | print("\\subsection{{Model selection: {}}}".format( 179 | selection_method.name)) 180 | print_results_tables(records, selection_method, args.latex) 181 | 182 | if args.latex: 183 | print("\\end{document}") 184 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | import xml.etree.ElementTree as ET 3 | from zipfile import ZipFile 4 | import argparse 5 | import tarfile 6 | import shutil 7 | import gdown 8 | import uuid 9 | import json 10 | import os 11 | 12 | 13 | # utils ####################################################################### 14 | 15 | def stage_path(data_dir, name): 16 | full_path = os.path.join(data_dir, name) 17 | 18 | if not os.path.exists(full_path): 19 | os.makedirs(full_path) 20 | 21 | return full_path 22 | 23 | 24 | def download_and_extract(url, dst, remove=True): 25 | gdown.download(url, dst, quiet=False) 26 | 27 | if dst.endswith(".tar.gz"): 28 | tar = tarfile.open(dst, "r:gz") 29 | tar.extractall(os.path.dirname(dst)) 30 | tar.close() 31 | 32 | if dst.endswith(".tar"): 33 | tar = tarfile.open(dst, "r:") 34 | tar.extractall(os.path.dirname(dst)) 35 | tar.close() 36 | 37 | if dst.endswith(".zip"): 38 | zf = ZipFile(dst, "r") 39 | zf.extractall(os.path.dirname(dst)) 40 | zf.close() 41 | 42 | if remove: 43 | os.remove(dst) 44 | 45 | 46 | # VLCS ######################################################################## 47 | 48 | # Slower, but builds dataset from the original sources 49 | # 50 | # def download_vlcs(data_dir): 51 | # full_path = stage_path(data_dir, "VLCS") 52 | # 53 | # tmp_path = os.path.join(full_path, "tmp/") 54 | # if not os.path.exists(tmp_path): 55 | # os.makedirs(tmp_path) 56 | # 57 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 58 | # lines = f.readlines() 59 | # files = [line.strip().split() for line in lines] 60 | # 61 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 62 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 63 | # 64 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 65 | # os.path.join(tmp_path, "caltech101.tar.gz")) 66 | # 67 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 68 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 69 | # 70 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 71 | # tar.extractall(tmp_path) 72 | # tar.close() 73 | # 74 | # for src, dst in files: 75 | # class_folder = os.path.join(data_dir, dst) 76 | # 77 | # if not os.path.exists(class_folder): 78 | # os.makedirs(class_folder) 79 | # 80 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 81 | # 82 | # if "labelme" in src: 83 | # # download labelme from the web 84 | # gdown.download(src, dst, quiet=False) 85 | # else: 86 | # src = os.path.join(tmp_path, src) 87 | # shutil.copyfile(src, dst) 88 | # 89 | # shutil.rmtree(tmp_path) 90 | 91 | 92 | def download_vlcs(data_dir): 93 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 94 | full_path = stage_path(data_dir, "VLCS") 95 | 96 | download_and_extract("https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 97 | os.path.join(data_dir, "VLCS.tar.gz")) 98 | 99 | 100 | # MNIST ####################################################################### 101 | 102 | def download_mnist(data_dir): 103 | # Original URL: http://yann.lecun.com/exdb/mnist/ 104 | full_path = stage_path(data_dir, "MNIST") 105 | MNIST(full_path, download=True) 106 | 107 | 108 | # PACS ######################################################################## 109 | 110 | def download_pacs(data_dir): 111 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 112 | full_path = stage_path(data_dir, "PACS") 113 | 114 | download_and_extract("https://drive.google.com/uc?id=0B6x7gtvErXgfbF9CSk53UkRxVzg", 115 | os.path.join(data_dir, "PACS.zip")) 116 | 117 | os.rename(os.path.join(data_dir, "kfold"), 118 | full_path) 119 | 120 | 121 | # Office-Home ################################################################# 122 | 123 | def download_office_home(data_dir): 124 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 125 | full_path = stage_path(data_dir, "office_home") 126 | 127 | download_and_extract("https://drive.google.com/uc?id=0B81rNlvomiwed0V1YUxQdC1uOTg", 128 | os.path.join(data_dir, "office_home.zip")) 129 | 130 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), 131 | full_path) 132 | 133 | 134 | # DomainNET ################################################################### 135 | 136 | def download_domain_net(data_dir): 137 | # Original URL: http://ai.bu.edu/M3SDA/ 138 | full_path = stage_path(data_dir, "domain_net") 139 | 140 | urls = [ 141 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 142 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 143 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 144 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 145 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 146 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip" 147 | ] 148 | 149 | for url in urls: 150 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 151 | 152 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 153 | for line in f.readlines(): 154 | try: 155 | os.remove(os.path.join(full_path, line.strip())) 156 | except OSError: 157 | pass 158 | 159 | 160 | # TerraIncognita ############################################################## 161 | 162 | def download_terra_incognita(data_dir): 163 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 164 | full_path = stage_path(data_dir, "terra_incognita") 165 | 166 | download_and_extract( 167 | "http://www.vision.caltech.edu/~sbeery/datasets/caltechcameratraps18/eccv_18_all_images_sm.tar.gz", 168 | os.path.join(full_path, "terra_incognita_images.tar.gz")) 169 | 170 | download_and_extract( 171 | "http://www.vision.caltech.edu/~sbeery/datasets/caltechcameratraps18/eccv_18_all_annotations.tar.gz", 172 | os.path.join(full_path, "terra_incognita_annotations.tar.gz")) 173 | 174 | include_locations = [38, 46, 100, 43] 175 | 176 | include_categories = [ 177 | "bird", "bobcat", "cat", "coyote", "dog", "empty", "opossum", "rabbit", 178 | "raccoon", "squirrel" 179 | ] 180 | 181 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 182 | annotations_file = os.path.join(full_path, "CaltechCameraTrapsECCV18.json") 183 | destination_folder = full_path 184 | 185 | stats = {} 186 | 187 | if not os.path.exists(destination_folder): 188 | os.mkdir(destination_folder) 189 | 190 | with open(annotations_file, "r") as f: 191 | data = json.load(f) 192 | 193 | category_dict = {} 194 | for item in data['categories']: 195 | category_dict[item['id']] = item['name'] 196 | 197 | for image in data['images']: 198 | image_location = image['location'] 199 | 200 | if image_location not in include_locations: 201 | continue 202 | 203 | loc_folder = os.path.join(destination_folder, 204 | 'location_' + str(image_location) + '/') 205 | 206 | if not os.path.exists(loc_folder): 207 | os.mkdir(loc_folder) 208 | 209 | image_id = image['id'] 210 | image_fname = image['file_name'] 211 | 212 | for annotation in data['annotations']: 213 | if annotation['image_id'] == image_id: 214 | if image_location not in stats: 215 | stats[image_location] = {} 216 | 217 | category = category_dict[annotation['category_id']] 218 | 219 | if category not in include_categories: 220 | continue 221 | 222 | if category not in stats[image_location]: 223 | stats[image_location][category] = 0 224 | else: 225 | stats[image_location][category] += 1 226 | 227 | loc_cat_folder = os.path.join(loc_folder, category + '/') 228 | 229 | if not os.path.exists(loc_cat_folder): 230 | os.mkdir(loc_cat_folder) 231 | 232 | dst_path = os.path.join(loc_cat_folder, image_fname) 233 | src_path = os.path.join(images_folder, image_fname) 234 | 235 | shutil.copyfile(src_path, dst_path) 236 | 237 | shutil.rmtree(images_folder) 238 | os.remove(annotations_file) 239 | 240 | 241 | # SVIRO ################################################################# 242 | 243 | def download_sviro(data_dir): 244 | # Original URL: https://sviro.kl.dfki.de 245 | full_path = stage_path(data_dir, "sviro") 246 | 247 | download_and_extract("https://sviro.kl.dfki.de/?wpdmdl=1731", 248 | os.path.join(data_dir, "sviro_grayscale_rectangle_classification.zip")) 249 | 250 | os.rename(os.path.join(data_dir, "SVIRO_DOMAINBED"), 251 | full_path) 252 | 253 | 254 | if __name__ == "__main__": 255 | parser = argparse.ArgumentParser(description='Download datasets') 256 | parser.add_argument('--data_dir', type=str, required=True) 257 | args = parser.parse_args() 258 | 259 | # download_mnist(args.data_dir) 260 | # download_pacs(args.data_dir) 261 | # download_office_home(args.data_dir) 262 | # download_domain_net(args.data_dir) 263 | # download_vlcs(args.data_dir) 264 | download_terra_incognita(args.data_dir) 265 | download_sviro(args.data_dir) -------------------------------------------------------------------------------- /scripts/list_top_hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Example usage: 5 | python -u -m domainbed.scripts.list_top_hparams \ 6 | --input_dir domainbed/misc/test_sweep_data --algorithm ERM \ 7 | --dataset VLCS --test_env 0 8 | """ 9 | 10 | import collections 11 | 12 | 13 | import argparse 14 | import functools 15 | import glob 16 | import pickle 17 | import itertools 18 | import json 19 | import os 20 | import random 21 | import sys 22 | 23 | import numpy as np 24 | import tqdm 25 | 26 | from domainbed import datasets 27 | from domainbed import algorithms 28 | from domainbed.lib import misc, reporting 29 | from domainbed import model_selection 30 | from domainbed.lib.query import Q 31 | import warnings 32 | 33 | def todo_rename(records, selection_method, latex): 34 | 35 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 36 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 37 | ).filter(lambda g: g["sweep_acc"] is not None) 38 | 39 | # read algorithm names and sort (predefined order) 40 | alg_names = Q(records).select("args.algorithm").unique() 41 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 42 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 43 | 44 | # read dataset names and sort (lexicographic order) 45 | dataset_names = Q(records).select("args.dataset").unique().sorted() 46 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 47 | 48 | for dataset in dataset_names: 49 | if latex: 50 | print() 51 | print("\\subsubsection{{{}}}".format(dataset)) 52 | test_envs = range(datasets.num_environments(dataset)) 53 | 54 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 55 | for i, algorithm in enumerate(alg_names): 56 | means = [] 57 | for j, test_env in enumerate(test_envs): 58 | trial_accs = (grouped_records 59 | .filter_equals( 60 | "dataset, algorithm, test_env", 61 | (dataset, algorithm, test_env) 62 | ).select("sweep_acc")) 63 | mean, err, table[i][j] = format_mean(trial_accs, latex) 64 | means.append(mean) 65 | if None in means: 66 | table[i][-1] = "X" 67 | else: 68 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 69 | 70 | col_labels = [ 71 | "Algorithm", 72 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 73 | "Avg" 74 | ] 75 | header_text = (f"Dataset: {dataset}, " 76 | f"model selection method: {selection_method.name}") 77 | print_table(table, header_text, alg_names, list(col_labels), 78 | colwidth=20, latex=latex) 79 | 80 | # Print an "averages" table 81 | if latex: 82 | print() 83 | print("\\subsubsection{Averages}") 84 | 85 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 86 | for i, algorithm in enumerate(alg_names): 87 | means = [] 88 | for j, dataset in enumerate(dataset_names): 89 | trial_averages = (grouped_records 90 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 91 | .group("trial_seed") 92 | .map(lambda trial_seed, group: 93 | group.select("sweep_acc").mean() 94 | ) 95 | ) 96 | mean, err, table[i][j] = format_mean(trial_averages, latex) 97 | means.append(mean) 98 | if None in means: 99 | table[i][-1] = "X" 100 | else: 101 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 102 | 103 | col_labels = ["Algorithm", *dataset_names, "Avg"] 104 | header_text = f"Averages, model selection method: {selection_method.name}" 105 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 106 | latex=latex) 107 | 108 | if __name__ == "__main__": 109 | np.set_printoptions(suppress=True) 110 | 111 | parser = argparse.ArgumentParser( 112 | description="Domain generalization testbed") 113 | parser.add_argument("--input_dir", required=True) 114 | parser.add_argument('--dataset', required=True) 115 | parser.add_argument('--algorithm', required=True) 116 | parser.add_argument('--test_env', type=int, required=True) 117 | args = parser.parse_args() 118 | 119 | records = reporting.load_records(args.input_dir) 120 | print("Total records:", len(records)) 121 | 122 | records = reporting.get_grouped_records(records) 123 | records = records.filter( 124 | lambda r: 125 | r['dataset'] == args.dataset and 126 | r['algorithm'] == args.algorithm and 127 | r['test_env'] == args.test_env 128 | ) 129 | 130 | SELECTION_METHODS = [ 131 | model_selection.IIDAccuracySelectionMethod, 132 | model_selection.LeaveOneOutSelectionMethod, 133 | model_selection.OracleSelectionMethod, 134 | ] 135 | 136 | for selection_method in SELECTION_METHODS: 137 | print(f'Model selection: {selection_method.name}') 138 | 139 | for group in records: 140 | print(f"trial_seed: {group['trial_seed']}") 141 | best_hparams = selection_method.hparams_accs(group['records']) 142 | for run_acc, hparam_records in best_hparams: 143 | print(f"\t{run_acc}") 144 | for r in hparam_records: 145 | assert(r['hparams'] == hparam_records[0]['hparams']) 146 | print("\t\thparams:") 147 | for k, v in sorted(hparam_records[0]['hparams'].items()): 148 | print('\t\t\t{}: {}'.format(k, v)) 149 | print("\t\toutput_dirs:") 150 | output_dirs = hparam_records.select('args.output_dir').unique() 151 | for output_dir in output_dirs: 152 | print(f"\t\t\t{output_dir}") -------------------------------------------------------------------------------- /scripts/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | from torchvision import models 10 | 11 | ###################################################################### 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv') != -1: 15 | nn.init.kaiming_normal_(m.weight.data, mode='fan_in') 16 | elif classname.find('Linear') != -1: 17 | nn.init.kaiming_normal_(m.weight.data, mode='fan_out') 18 | nn.init.constant_(m.bias.data, 0.) 19 | elif classname.find('BatchNorm2d') != -1: 20 | nn.init.normal_(m.weight.data, mean=1., std=0.02) 21 | nn.init.constant_(m.bias.data, 0.0) 22 | elif classname.find('InstanceNorm1d') != -1: 23 | nn.init.normal_(m.weight.data, 1.0, 0.02) 24 | nn.init.constant_(m.bias.data, 0.0) 25 | 26 | def weights_init_classifier(m): 27 | classname = m.__class__.__name__ 28 | if classname.find('Linear') != -1: 29 | init.normal_(m.weight.data, std=0.001) 30 | init.constant_(m.bias.data, 0.0) 31 | 32 | def fix_bn(m): 33 | classname = m.__class__.__name__ 34 | if classname.find('BatchNorm') != -1: 35 | m.eval() 36 | 37 | # Defines the new fc layer and classification layer 38 | # |--Linear--|--bn--|--relu--|--Linear--| 39 | class ClassBlock(nn.Module): 40 | def __init__(self, input_dim, class_num, droprate=0.5, relu=False, num_bottleneck=512): 41 | super(ClassBlock, self).__init__() 42 | add_block = [] 43 | add_block += [nn.Linear(input_dim, num_bottleneck)] 44 | #num_bottleneck = input_dim # We remove the input_dim 45 | add_block += [nn.BatchNorm1d(num_bottleneck, affine=True)] 46 | if relu: 47 | add_block += [nn.LeakyReLU(0.1)] 48 | if droprate>0: 49 | add_block += [nn.Dropout(p=droprate)] 50 | add_block = nn.Sequential(*add_block) 51 | add_block.apply(weights_init_kaiming) 52 | 53 | classifier = [] 54 | classifier += [nn.Linear(num_bottleneck, class_num)] 55 | classifier = nn.Sequential(*classifier) 56 | classifier.apply(weights_init_classifier) 57 | 58 | self.add_block = add_block 59 | self.ave_pool = nn.AdaptiveMaxPool2d((1,1)) 60 | self.classifier = classifier 61 | def forward(self, x, ave_pool = False): 62 | if ave_pool: 63 | x = self.ave_pool(x) 64 | x = self.add_block(x)# [B, 512] 65 | x = self.classifier(x) 66 | return x 67 | 68 | class domain_discriminator(nn.Module): 69 | 70 | def __init__(self, rp_size, optimizer, lr, momentum, weight_decay,n_outputs=512): 71 | super(domain_discriminator, self).__init__() 72 | 73 | self.domain_discriminator = nn.Sequential() 74 | self.domain_discriminator.add_module('d_fc1', nn.Linear(rp_size, 512)) 75 | self.domain_discriminator.add_module('d_relu1', nn.ReLU()) 76 | self.domain_discriminator.add_module('d_drop1', nn.Dropout(0.2)) 77 | 78 | self.domain_discriminator.add_module('d_fc2', nn.Linear(512, 256)) 79 | self.domain_discriminator.add_module('d_relu2', nn.ReLU()) 80 | self.domain_discriminator.add_module('d_drop2', nn.Dropout(0.2)) 81 | self.domain_discriminator.add_module('d_fc3', nn.Linear(256, 2)) 82 | self.domain_discriminator.add_module('d_sfmax', nn.LogSoftmax(dim=1)) 83 | #self.domain_discriminator.add_module('d_relu2', nn.ReLU()) 84 | #self.domain_discriminator.add_module('d_drop2', nn.Dropout()) 85 | #self.domain_discriminator.add_module('d_fc3', nn.Linear(1024, 1)) 86 | 87 | self.optimizer = optimizer(list(self.domain_discriminator.parameters()), lr=lr, momentum=momentum, weight_decay=weight_decay) 88 | 89 | self.initialize_params() 90 | 91 | # TODO Check the RP size 92 | self.projection = nn.Linear(n_outputs, rp_size, bias=False) 93 | with torch.no_grad(): 94 | self.projection.weight.div_(torch.norm(self.projection.weight, keepdim=True)) 95 | 96 | def forward(self, input_data): 97 | #reverse_feature = ReverseLayer.apply(input_data, alpha) # Make sure there will be no problem when updating discs params 98 | feature = input_data.view(input_data.size(0), -1) 99 | feature_proj = self.projection(feature) 100 | 101 | domain_output = self.domain_discriminator(feature_proj) 102 | 103 | return domain_output 104 | 105 | def initialize_params(self): 106 | 107 | for layer in self.modules(): 108 | if isinstance(layer, torch.nn.Conv2d): 109 | init.kaiming_normal_(layer.weight, a=0, mode='fan_out') 110 | elif isinstance(layer, torch.nn.Linear): 111 | init.kaiming_uniform_(layer.weight) 112 | elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d): 113 | layer.weight.data.fill_(1) 114 | layer.bias.data.zero_() 115 | -------------------------------------------------------------------------------- /scripts/save_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Save some representative images from each dataset to disk. 5 | """ 6 | import random 7 | import torch 8 | import argparse 9 | import hparams_registry 10 | import datasets 11 | import imageio 12 | import torchvision.utils as vutils 13 | import os 14 | from tqdm import tqdm 15 | 16 | def __write_images(image_outputs, display_image_num, file_name, run): 17 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels 18 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 19 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True, scale_each=True) 20 | vutils.save_image(image_grid, file_name, nrow=1) 21 | run.log_image('images', file_name) 22 | 23 | 24 | def write_2images(image_outputs, display_image_num, image_directory, postfix, run): 25 | n = len(image_outputs) 26 | __write_images(image_outputs[0:n], display_image_num, '%s/gen_%s.jpg' % (image_directory, postfix), run) 27 | #__write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix), run) 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser(description='Domain generalization') 31 | parser.add_argument('--data_dir', type=str) 32 | parser.add_argument('--output_dir', type=str) 33 | args = parser.parse_args() 34 | 35 | os.makedirs(args.output_dir, exist_ok=True) 36 | datasets_to_save = ['OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST', 'SVIRO'] 37 | 38 | for dataset_name in tqdm(datasets_to_save): 39 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 40 | dataset = datasets.get_dataset_class(dataset_name)( 41 | args.data_dir, 42 | list(range(datasets.num_environments(dataset_name))), 43 | hparams) 44 | for env_idx, env in enumerate(tqdm(dataset)): 45 | for i in tqdm(range(50)): 46 | idx = random.choice(list(range(len(env)))) 47 | x, y = env[idx] 48 | while y > 10: 49 | idx = random.choice(list(range(len(env)))) 50 | x, y = env[idx] 51 | if x.shape[0] == 2: 52 | x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3,:,:] 53 | if x.min() < 0: 54 | mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None] 55 | std = torch.tensor([0.229, 0.224, 0.225])[:,None,None] 56 | x = (x * std) + mean 57 | assert(x.min() >= 0) 58 | assert(x.max() <= 1) 59 | x = (x * 255.99) 60 | x = x.numpy().astype('uint8').transpose(1,2,0) 61 | imageio.imwrite( 62 | os.path.join(args.output_dir, 63 | f'{dataset_name}_env{env_idx}{dataset.ENVIRONMENTS[env_idx]}_{i}_idx{idx}_class{y}.png'), 64 | x) 65 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhang109/DDG/b5cd1822f1a413ae7e263fc5a11f00b490b9c72c/scripts/utils.py -------------------------------------------------------------------------------- /submit.py: -------------------------------------------------------------------------------- 1 | from azureml.core import Workspace, Datastore 2 | from azureml.train.dnn import PyTorch 3 | from azureml.core import Experiment, Environment, ScriptRunConfig, Dataset 4 | from azureml.contrib.core.k8srunconfig import K8sComputeConfiguration 5 | from azureml.train.estimator import Estimator 6 | from sys import argv 7 | script, i, seed, dataset, algorithm = argv 8 | #setup cluster 9 | ws = Workspace.from_config() 10 | print(ws.name, ws.resource_group, ws.location, ws.subscription_id, sep='\n') 11 | 12 | from azureml.core.compute import ComputeTarget 13 | from azureml.contrib.core.compute.k8scompute import AksCompute 14 | for key, target in ws.compute_targets.items(): 15 | if type(target) is AksCompute: 16 | print('Found compute target:{}\ttype:{}\tprovisioning_state:{}\tlocation:{}'.format(target.name, target.type, target.provisioning_state, target.location)) 17 | 18 | compute_target = ComputeTarget(workspace=ws, name="itpscusv100cl")# researchvc-eus 19 | experiment_name = '%s_%s_d%s_seed%s'%(dataset, algorithm, str(i), str(seed)) 20 | #experiment_name = '%s_%s_e2e_seed%s'%(dataset, str(i), str(seed)) 21 | print(experiment_name) 22 | Datastore.register_azure_blob_container( 23 | workspace=ws, 24 | datastore_name='yifan_data', # just a name to refer the Datastore 25 | account_name="yifanzhang", 26 | container_name="data", 27 | account_key="pv0wggRvdq2Xf1hMmXWqlz0xm0hmaugghPFfqrD5G2J8BQJ7If6/9G2RAMjjv7o/21RZATGVvUfKiQ9g+Yvduw==") 28 | ds = Datastore(ws, "yifan_data") 29 | Datastore.register_azure_blob_container( 30 | workspace=ws, 31 | datastore_name='yifan_model', # just a name to refer the Datastore 32 | account_name="yifanzhang", 33 | container_name="model", 34 | account_key="pv0wggRvdq2Xf1hMmXWqlz0xm0hmaugghPFfqrD5G2J8BQJ7If6/9G2RAMjjv7o/21RZATGVvUfKiQ9g+Yvduw==") 35 | ds_model = Datastore(ws, "yifan_model") 36 | experiment = Experiment(ws, name=experiment_name) 37 | 38 | config = Estimator( 39 | compute_target=compute_target, 40 | use_gpu=False, 41 | custom_docker_image="mcr.microsoft.com/azureml/base-gpu:openmpi3.1.2-cuda10.0-cudnn7-ubuntu16.04", 42 | source_directory='./', 43 | entry_script='train_submit.py', # './DGdata/' 44 | script_params={ '--data_dir': ds.path('./DGdata') , '--gen_dir':ds_model.path('./DG-Net/mnist_gen.pkl'), '--dataset':dataset, '--test_envs':i, '--stage':1,'--algorithm': algorithm, '--seed':seed}, 45 | pip_packages=['torch','torchvision', 'pyyaml', 'tqdm', 'wilds','imageio'] 46 | ) 47 | 48 | # set up pytorch environment 49 | # env = Environment.from_conda_specification(name='disdg_yifan1',file_path='/home/v-yifanzhang/DisDG/submit/environment.yml') 50 | # config.run_config.environment = env 51 | 52 | compute_config = K8sComputeConfiguration() 53 | compute_config.configuration = { 54 | 'enable_ipython': False, 55 | 'enable_tensorboard': False, 56 | 'enable_ssh': True, 57 | 'gpu_count':1, 58 | 'preemption_allowed':False, 59 | # 'ssh_public_key' : 'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDesrvA7BerK+5Ko3yZ6Yg1kKlPGzYFT+0j5PXrrralgTq5EFDLf6UzmgwzyRq6zXYGu4IXMrJOopcacKwU68stzW78u/8jzlxzf5cKpqvZOjFyiUCUyn1rWEGuvUV0GBGmuZEYuIgEXyhBpVGV3K6nRkDbEjBISMeIPN+NXeHIUcodVTJcdly9+kFGSPtBUGTf6D/jCOL1AqS3ti0+fss9Q2n4Y6W5QpI2X+7qbIuIkg82/gbPehIN6ua54ojhejY6d5GNE+5eAw6aIV6/KejNfjWMGiAeepa7t0znQ8v6Dow+i1YdNFogq/wfe5yiGry3b4Gnwx04RX9cpHRmO9q3 cbzhang@server' 60 | 'ssh_public_key' : 'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAICIlTNE5k4t6TqtiLv17bmepthbZmldge88YpKmxfvvL yifanzhang@BPQ4BV2', 61 | } 62 | config.run_config.cmk8scompute = compute_config 63 | 64 | #setup jobs 65 | run = experiment.submit(config) 66 | # run.wait_for_completion(show_output=True) 67 | run.get_tags() -------------------------------------------------------------------------------- /sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Run sweeps 5 | """ 6 | 7 | import argparse 8 | import copy 9 | import getpass 10 | import hashlib 11 | import json 12 | import os 13 | import random 14 | import shutil 15 | import time 16 | import uuid 17 | 18 | import numpy as np 19 | import torch 20 | 21 | import datasets 22 | import hparams_registry 23 | import algorithms 24 | from lib import misc 25 | import command_launchers 26 | 27 | import tqdm 28 | import shlex 29 | 30 | class Job: 31 | NOT_LAUNCHED = 'Not launched' 32 | INCOMPLETE = 'Incomplete' 33 | DONE = 'Done' 34 | 35 | def __init__(self, train_args, sweep_output_dir): 36 | args_str = json.dumps(train_args, sort_keys=True) 37 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 38 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 39 | 40 | self.train_args = copy.deepcopy(train_args) 41 | self.train_args['output_dir'] = self.output_dir 42 | command = ['python', '-m', 'domainbed.scripts.train'] 43 | for k, v in sorted(self.train_args.items()): 44 | if isinstance(v, list): 45 | v = ' '.join([str(v_) for v_ in v]) 46 | elif isinstance(v, str): 47 | v = shlex.quote(v) 48 | command.append(f'--{k} {v}') 49 | self.command_str = ' '.join(command) 50 | 51 | if os.path.exists(os.path.join(self.output_dir, 'done')): 52 | self.state = Job.DONE 53 | elif os.path.exists(self.output_dir): 54 | self.state = Job.INCOMPLETE 55 | else: 56 | self.state = Job.NOT_LAUNCHED 57 | 58 | def __str__(self): 59 | job_info = (self.train_args['dataset'], 60 | self.train_args['algorithm'], 61 | self.train_args['test_envs'], 62 | self.train_args['hparams_seed']) 63 | return '{}: {} {}'.format( 64 | self.state, 65 | self.output_dir, 66 | job_info) 67 | 68 | @staticmethod 69 | def launch(jobs, launcher_fn): 70 | print('Launching...') 71 | jobs = jobs.copy() 72 | np.random.shuffle(jobs) 73 | print('Making job directories:') 74 | for job in tqdm.tqdm(jobs, leave=False): 75 | os.makedirs(job.output_dir, exist_ok=True) 76 | commands = [job.command_str for job in jobs] 77 | launcher_fn(commands) 78 | print(f'Launched {len(jobs)} jobs!') 79 | 80 | @staticmethod 81 | def delete(jobs): 82 | print('Deleting...') 83 | for job in jobs: 84 | shutil.rmtree(job.output_dir) 85 | print(f'Deleted {len(jobs)} jobs!') 86 | 87 | def all_test_env_combinations(n): 88 | """ 89 | For a dataset with n >= 3 envs, return all combinations of 1 and 2 test 90 | envs. 91 | """ 92 | assert(n >= 3) 93 | for i in range(n): 94 | yield [i] 95 | for j in range(i+1, n): 96 | yield [i, j] 97 | 98 | def make_args_list(n_trials, dataset_names, algorithms, n_hparams_from, n_hparams, steps, 99 | data_dir, task, holdout_fraction, single_test_envs, hparams): 100 | args_list = [] 101 | for trial_seed in range(n_trials): 102 | for dataset in dataset_names: 103 | for algorithm in algorithms: 104 | if single_test_envs: 105 | all_test_envs = [ 106 | [i] for i in range(datasets.num_environments(dataset))] 107 | else: 108 | all_test_envs = all_test_env_combinations( 109 | datasets.num_environments(dataset)) 110 | for test_envs in all_test_envs: 111 | for hparams_seed in range(n_hparams_from, n_hparams): 112 | train_args = {} 113 | train_args['dataset'] = dataset 114 | train_args['algorithm'] = algorithm 115 | train_args['test_envs'] = test_envs 116 | train_args['holdout_fraction'] = holdout_fraction 117 | train_args['hparams_seed'] = hparams_seed 118 | train_args['data_dir'] = data_dir 119 | train_args['task'] = task 120 | train_args['trial_seed'] = trial_seed 121 | train_args['seed'] = misc.seed_hash(dataset, 122 | algorithm, test_envs, hparams_seed, trial_seed) 123 | if steps is not None: 124 | train_args['steps'] = steps 125 | if hparams is not None: 126 | train_args['hparams'] = hparams 127 | args_list.append(train_args) 128 | return args_list 129 | 130 | def ask_for_confirmation(): 131 | response = input('Are you sure? (y/n) ') 132 | if not response.lower().strip()[:1] == "y": 133 | print('Nevermind!') 134 | exit(0) 135 | 136 | DATASETS = [d for d in datasets.DATASETS if "Debug" not in d] 137 | 138 | if __name__ == "__main__": 139 | parser = argparse.ArgumentParser(description='Run a sweep') 140 | parser.add_argument('command', choices=['launch', 'delete_incomplete']) 141 | parser.add_argument('--datasets', nargs='+', type=str, default=DATASETS) 142 | parser.add_argument('--algorithms', nargs='+', type=str, default=algorithms.ALGORITHMS) 143 | parser.add_argument('--task', type=str, default="domain_generalization") 144 | parser.add_argument('--n_hparams_from', type=int, default=0) 145 | parser.add_argument('--n_hparams', type=int, default=20) 146 | parser.add_argument('--output_dir', type=str, required=True) 147 | parser.add_argument('--data_dir', type=str, required=True) 148 | parser.add_argument('--seed', type=int, default=0) 149 | parser.add_argument('--n_trials', type=int, default=3) 150 | parser.add_argument('--command_launcher', type=str, required=True) 151 | parser.add_argument('--steps', type=int, default=None) 152 | parser.add_argument('--hparams', type=str, default=None) 153 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 154 | parser.add_argument('--single_test_envs', action='store_true') 155 | parser.add_argument('--skip_confirmation', action='store_true') 156 | args = parser.parse_args() 157 | 158 | args_list = make_args_list( 159 | n_trials=args.n_trials, 160 | dataset_names=args.datasets, 161 | algorithms=args.algorithms, 162 | n_hparams_from=args.n_hparams_from, 163 | n_hparams=args.n_hparams, 164 | steps=args.steps, 165 | data_dir=args.data_dir, 166 | task=args.task, 167 | holdout_fraction=args.holdout_fraction, 168 | single_test_envs=args.single_test_envs, 169 | hparams=args.hparams 170 | ) 171 | 172 | jobs = [Job(train_args, args.output_dir) for train_args in args_list] 173 | 174 | for job in jobs: 175 | print(job) 176 | print("{} jobs: {} done, {} incomplete, {} not launched.".format( 177 | len(jobs), 178 | len([j for j in jobs if j.state == Job.DONE]), 179 | len([j for j in jobs if j.state == Job.INCOMPLETE]), 180 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) 181 | ) 182 | 183 | if args.command == 'launch': 184 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] 185 | print(f'About to launch {len(to_launch)} jobs.') 186 | if not args.skip_confirmation: 187 | ask_for_confirmation() 188 | launcher_fn = command_launchers.REGISTRY[args.command_launcher] 189 | Job.launch(to_launch, launcher_fn) 190 | 191 | elif args.command == 'delete_incomplete': 192 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 193 | print(f'About to delete {len(to_delete)} jobs.') 194 | if not args.skip_confirmation: 195 | ask_for_confirmation() 196 | Job.delete(to_delete) 197 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | -------------------------------------------------------------------------------- /test/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | DEBUG_DATASETS = ['Debug28', 'Debug224'] 6 | 7 | def make_minibatches(dataset, batch_size): 8 | """Test helper to make a minibatches array like train.py""" 9 | minibatches = [] 10 | for env in dataset: 11 | X = torch.stack([env[i][0] for i in range(batch_size)]).cuda() 12 | y = torch.stack([torch.as_tensor(env[i][1]) 13 | for i in range(batch_size)]).cuda() 14 | minibatches.append((X, y)) 15 | return minibatches 16 | -------------------------------------------------------------------------------- /test/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | -------------------------------------------------------------------------------- /test/lib/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import unittest 4 | from domainbed.lib import misc 5 | 6 | class TestMisc(unittest.TestCase): 7 | 8 | def test_make_weights_for_balanced_classes(self): 9 | dataset = [('A', 0), ('B', 1), ('C', 0), ('D', 2), ('E', 3), ('F', 0)] 10 | result = misc.make_weights_for_balanced_classes(dataset) 11 | self.assertEqual(result.sum(), 1) 12 | self.assertEqual(result[0], result[2]) 13 | self.assertEqual(result[1], result[3]) 14 | self.assertEqual(3 * result[0], result[1]) 15 | -------------------------------------------------------------------------------- /test/lib/test_query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import unittest 4 | from domainbed.lib.query import Q, make_selector_fn 5 | 6 | class TestQuery(unittest.TestCase): 7 | def test_everything(self): 8 | numbers = Q([1, 4, 2]) 9 | people = Q([ 10 | {'name': 'Bob', 'age': 40}, 11 | {'name': 'Alice', 'age': 20}, 12 | {'name': 'Bob', 'age': 10} 13 | ]) 14 | 15 | self.assertEqual(numbers.select(lambda x: 2*x), [2, 8, 4]) 16 | 17 | self.assertEqual(numbers.min(), 1) 18 | self.assertEqual(numbers.max(), 4) 19 | self.assertEqual(numbers.mean(), 7/3) 20 | 21 | self.assertEqual(people.select('name'), ['Bob', 'Alice', 'Bob']) 22 | 23 | self.assertEqual( 24 | set(people.group('name').map(lambda _,g: g.select('age').mean())), 25 | set([25, 20]) 26 | ) 27 | 28 | self.assertEqual(people.argmax('age'), people[0]) 29 | 30 | def test_group_by_unhashable(self): 31 | jobs = Q([ 32 | {'hparams': {1:2}, 'score': 3}, 33 | {'hparams': {1:2}, 'score': 4}, 34 | {'hparams': {2:4}, 'score': 5} 35 | ]) 36 | grouped = jobs.group('hparams') 37 | self.assertEqual(grouped, [ 38 | ({1:2}, [jobs[0], jobs[1]]), 39 | ({2:4}, [jobs[2]]) 40 | ]) 41 | 42 | def test_comma_selector(self): 43 | struct = {'a': {'b': 1}, 'c': 2} 44 | fn = make_selector_fn('a.b,c') 45 | self.assertEqual(fn(struct), (1, 2)) 46 | 47 | def test_unique(self): 48 | numbers = Q([1,2,1,3,2,1,3,1,2,3]) 49 | self.assertEqual(numbers.unique(), [1,2,3]) 50 | -------------------------------------------------------------------------------- /test/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | -------------------------------------------------------------------------------- /test/scripts/test_collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import os 7 | import subprocess 8 | import sys 9 | import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | from domainbed import datasets 16 | from domainbed import hparams_registry 17 | from domainbed import algorithms 18 | from domainbed import networks 19 | from domainbed.test import helpers 20 | from domainbed.scripts import collect_results 21 | 22 | from parameterized import parameterized 23 | import io 24 | import textwrap 25 | 26 | class TestCollectResults(unittest.TestCase): 27 | 28 | def test_format_mean(self): 29 | self.assertEqual( 30 | collect_results.format_mean([0.1, 0.2, 0.3], False)[2], 31 | '20.0 +/- 4.7') 32 | self.assertEqual( 33 | collect_results.format_mean([0.1, 0.2, 0.3], True)[2], 34 | '20.0 $\pm$ 4.7') 35 | 36 | def test_print_table_non_latex(self): 37 | temp_out = io.StringIO() 38 | sys.stdout = temp_out 39 | table = [['1', '2'], ['3', '4']] 40 | collect_results.print_table(table, 'Header text', ['R1', 'R2'], 41 | ['C1', 'C2'], colwidth=10, latex=False) 42 | sys.stdout = sys.__stdout__ 43 | self.assertEqual( 44 | temp_out.getvalue(), 45 | textwrap.dedent(""" 46 | -------- Header text 47 | C1 C2 48 | R1 1 2 49 | R2 3 4 50 | """) 51 | ) 52 | 53 | def test_print_table_latex(self): 54 | temp_out = io.StringIO() 55 | sys.stdout = temp_out 56 | table = [['1', '2'], ['3', '4']] 57 | collect_results.print_table(table, 'Header text', ['R1', 'R2'], 58 | ['C1', 'C2'], colwidth=10, latex=True) 59 | sys.stdout = sys.__stdout__ 60 | self.assertEqual( 61 | temp_out.getvalue(), 62 | textwrap.dedent(r""" 63 | \begin{center} 64 | \adjustbox{max width=\textwidth}{% 65 | \begin{tabular}{lcc} 66 | \toprule 67 | \textbf{C1 & \textbf{C2 \\ 68 | \midrule 69 | R1 & 1 & 2 \\ 70 | R2 & 3 & 4 \\ 71 | \bottomrule 72 | \end{tabular}} 73 | \end{center} 74 | """) 75 | ) 76 | 77 | def test_get_grouped_records(self): 78 | pass # TODO 79 | 80 | def test_print_results_tables(self): 81 | pass # TODO 82 | 83 | def test_load_records(self): 84 | pass # TODO 85 | 86 | def test_end_to_end(self): 87 | """ 88 | Test that collect_results.py's output matches a manually-verified 89 | ground-truth when run on a given directory of test sweep data. 90 | 91 | If you make any changes to the output of collect_results.py, you'll need 92 | to update the ground-truth and manually verify that it's still 93 | correct. The command used to update the ground-truth is: 94 | 95 | python -m domainbed.scripts.collect_results --input_dir=domainbed/misc/test_sweep_data \ 96 | | tee domainbed/misc/test_sweep_results.txt 97 | 98 | Furthermore, if you make any changes to the data format, you'll also 99 | need to rerun the test sweep. The command used to run the test sweep is: 100 | 101 | python -m domainbed.scripts.sweep launch --data_dir=$DATA_DIR \ 102 | --output_dir=domainbed/misc/test_sweep_data --algorithms ERM \ 103 | --datasets VLCS --steps 1001 --n_hparams 2 --n_trials 2 \ 104 | --command_launcher local 105 | """ 106 | result = subprocess.run('python -m domainbed.scripts.collect_results' 107 | ' --input_dir=domainbed/misc/test_sweep_data', shell=True, 108 | stdout=subprocess.PIPE) 109 | 110 | with open('domainbed/misc/test_sweep_results.txt', 'r') as f: 111 | ground_truth = f.read() 112 | 113 | self.assertEqual(result.stdout.decode('utf8'), ground_truth) 114 | -------------------------------------------------------------------------------- /test/scripts/test_sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import os 7 | import subprocess 8 | import sys 9 | import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | from domainbed import datasets 16 | from domainbed import hparams_registry 17 | from domainbed import algorithms 18 | from domainbed import networks 19 | from domainbed.test import helpers 20 | from domainbed.scripts import sweep 21 | 22 | from parameterized import parameterized 23 | 24 | class TestSweep(unittest.TestCase): 25 | 26 | def test_job(self): 27 | """Test that a newly-created job has valid 28 | output_dir, state, and command_str properties.""" 29 | train_args = {'foo': 'bar'} 30 | sweep_output_dir = f'/tmp/{str(uuid.uuid4())}' 31 | job = sweep.Job(train_args, sweep_output_dir) 32 | self.assertTrue(job.output_dir.startswith(sweep_output_dir)) 33 | self.assertEqual(job.state, sweep.Job.NOT_LAUNCHED) 34 | self.assertEqual(job.command_str, 35 | f'python -m domainbed.scripts.train --foo bar --output_dir {job.output_dir}') 36 | 37 | def test_job_launch(self): 38 | """Test that launching a job calls the launcher_fn with appropariate 39 | arguments, and sets the job to INCOMPLETE state.""" 40 | train_args = {'foo': 'bar'} 41 | sweep_output_dir = f'/tmp/{str(uuid.uuid4())}' 42 | job = sweep.Job(train_args, sweep_output_dir) 43 | 44 | launcher_fn_called = False 45 | def launcher_fn(commands): 46 | nonlocal launcher_fn_called 47 | launcher_fn_called = True 48 | self.assertEqual(len(commands), 1) 49 | self.assertEqual(commands[0], job.command_str) 50 | 51 | sweep.Job.launch([job], launcher_fn) 52 | self.assertTrue(launcher_fn_called) 53 | 54 | job = sweep.Job(train_args, sweep_output_dir) 55 | self.assertEqual(job.state, sweep.Job.INCOMPLETE) 56 | 57 | def test_job_delete(self): 58 | """Test that deleting a launched job returns it to the NOT_LAUNCHED 59 | state""" 60 | train_args = {'foo': 'bar'} 61 | sweep_output_dir = f'/tmp/{str(uuid.uuid4())}' 62 | job = sweep.Job(train_args, sweep_output_dir) 63 | sweep.Job.launch([job], (lambda commands: None)) 64 | sweep.Job.delete([job]) 65 | 66 | job = sweep.Job(train_args, sweep_output_dir) 67 | self.assertEqual(job.state, sweep.Job.NOT_LAUNCHED) 68 | 69 | 70 | def test_make_args_list(self): 71 | """Test that, for a typical input, make_job_list returns a list 72 | of the correct length""" 73 | args_list = sweep.make_args_list( 74 | n_trials=2, 75 | dataset_names=['Debug28'], 76 | algorithms=['ERM'], 77 | n_hparams_from=0, 78 | n_hparams=3, 79 | steps=123, 80 | data_dir='/tmp/data', 81 | task='domain_generalization', 82 | holdout_fraction=0.2, 83 | single_test_envs=False, 84 | hparams=None 85 | ) 86 | assert(len(args_list) == 2*3*(3+3)) 87 | 88 | @unittest.skipIf('DATA_DIR' not in os.environ, 'needs DATA_DIR environment ' 89 | 'variable') 90 | def test_end_to_end(self): 91 | output_dir = os.path.join('/tmp', str(uuid.uuid4())) 92 | result = subprocess.run(f'python -m domainbed.scripts.sweep launch ' 93 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 94 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 95 | f'--command_launcher dummy --skip_confirmation', 96 | shell=True, capture_output=True) 97 | stdout_lines = result.stdout.decode('utf8').split("\n") 98 | dummy_launcher_lines = [l for l in stdout_lines 99 | if l.startswith('Dummy launcher:')] 100 | self.assertEqual(len(dummy_launcher_lines), 6) 101 | 102 | # Now run it again and make sure it doesn't try to relaunch those jobs 103 | result = subprocess.run(f'python -m domainbed.scripts.sweep launch ' 104 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 105 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 106 | f'--command_launcher dummy --skip_confirmation', 107 | shell=True, capture_output=True) 108 | stdout_lines = result.stdout.decode('utf8').split("\n") 109 | dummy_launcher_lines = [l for l in stdout_lines 110 | if l.startswith('Dummy launcher:')] 111 | self.assertEqual(len(dummy_launcher_lines), 0) 112 | 113 | # Delete the incomplete jobs, try launching again, and make sure they 114 | # get relaunched. 115 | subprocess.run(f'python -m domainbed.scripts.sweep delete_incomplete ' 116 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 117 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 118 | f'--command_launcher dummy --skip_confirmation', 119 | shell=True, capture_output=True) 120 | 121 | result = subprocess.run(f'python -m domainbed.scripts.sweep launch ' 122 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 123 | f'--algorithms ERM --datasets Debug28 --n_hparams 1 --n_trials 1 ' 124 | f'--command_launcher dummy --skip_confirmation', 125 | shell=True, capture_output=True) 126 | stdout_lines = result.stdout.decode('utf8').split("\n") 127 | dummy_launcher_lines = [l for l in stdout_lines 128 | if l.startswith('Dummy launcher:')] 129 | self.assertEqual(len(dummy_launcher_lines), 6) 130 | -------------------------------------------------------------------------------- /test/scripts/test_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # import argparse 4 | # import itertools 5 | import json 6 | import os 7 | import subprocess 8 | # import sys 9 | # import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | # import datasets 16 | # import hparams_registry 17 | # import algorithms 18 | # import networks 19 | # from parameterized import parameterized 20 | 21 | # import test.helpers 22 | 23 | class TestTrain(unittest.TestCase): 24 | 25 | @unittest.skipIf('DATA_DIR' not in os.environ, 'needs DATA_DIR environment ' 26 | 'variable') 27 | def test_end_to_end(self): 28 | """Test that train.py successfully completes one step""" 29 | output_dir = os.path.join('/tmp', str(uuid.uuid4())) 30 | os.makedirs(output_dir, exist_ok=True) 31 | 32 | subprocess.run(f'python -m domainbed.scripts.train --dataset RotatedMNIST ' 33 | f'--data_dir={os.environ["DATA_DIR"]} --output_dir={output_dir} ' 34 | f'--steps=501', shell=True) 35 | 36 | with open(os.path.join(output_dir, 'results.jsonl')) as f: 37 | lines = [l[:-1] for l in f] 38 | last_epoch = json.loads(lines[-1]) 39 | self.assertEqual(last_epoch['step'], 500) 40 | # Conservative values; anything lower and something's likely wrong. 41 | self.assertGreater(last_epoch['env0_in_acc'], 0.80) 42 | self.assertGreater(last_epoch['env1_in_acc'], 0.95) 43 | self.assertGreater(last_epoch['env2_in_acc'], 0.95) 44 | self.assertGreater(last_epoch['env3_in_acc'], 0.95) 45 | self.assertGreater(last_epoch['env3_in_acc'], 0.95) 46 | 47 | with open(os.path.join(output_dir, 'out.txt')) as f: 48 | text = f.read() 49 | self.assertTrue('500' in text) 50 | -------------------------------------------------------------------------------- /test/test_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Unit tests.""" 4 | 5 | import argparse 6 | import itertools 7 | import json 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import unittest 13 | import uuid 14 | 15 | import torch 16 | 17 | from domainbed import datasets 18 | from domainbed import hparams_registry 19 | from domainbed import algorithms 20 | from domainbed import networks 21 | 22 | from parameterized import parameterized 23 | 24 | from domainbed.test import helpers 25 | 26 | class TestDatasets(unittest.TestCase): 27 | 28 | @parameterized.expand(itertools.product(datasets.DATASETS)) 29 | @unittest.skipIf('DATA_DIR' not in os.environ, 'needs DATA_DIR environment ' 30 | 'variable') 31 | def test_dataset_erm(self, dataset_name): 32 | """ 33 | Test that ERM can complete one step on a given dataset without raising 34 | an error. 35 | Also test that num_environments() works correctly. 36 | """ 37 | batch_size = 8 38 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 39 | dataset = datasets.get_dataset_class(dataset_name)( 40 | os.environ['DATA_DIR'], [], hparams) 41 | self.assertEqual(datasets.num_environments(dataset_name), 42 | len(dataset)) 43 | algorithm = algorithms.get_algorithm_class('ERM')( 44 | dataset.input_shape, 45 | dataset.num_classes, 46 | len(dataset), 47 | hparams).cuda() 48 | minibatches = helpers.make_minibatches(dataset, batch_size) 49 | algorithm.update(minibatches) 50 | -------------------------------------------------------------------------------- /test/test_hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import unittest 4 | import itertools 5 | 6 | from domainbed import hparams_registry 7 | from domainbed import datasets 8 | from domainbed import algorithms 9 | 10 | from parameterized import parameterized 11 | 12 | class TestHparamsRegistry(unittest.TestCase): 13 | 14 | @parameterized.expand(itertools.product(algorithms.ALGORITHMS, datasets.DATASETS)) 15 | def test_random_hparams_deterministic(self, algorithm_name, dataset_name): 16 | """Test that hparams_registry.random_hparams is deterministic""" 17 | a = hparams_registry.random_hparams(algorithm_name, dataset_name, 0) 18 | b = hparams_registry.random_hparams(algorithm_name, dataset_name, 0) 19 | self.assertEqual(a.keys(), b.keys()) 20 | for key in a.keys(): 21 | self.assertEqual(a[key], b[key], key) 22 | -------------------------------------------------------------------------------- /test/test_model_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Unit tests.""" 4 | 5 | import argparse 6 | import itertools 7 | import json 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import unittest 13 | import uuid 14 | 15 | import torch 16 | 17 | from domainbed import model_selection 18 | from domainbed.lib.query import Q 19 | 20 | from parameterized import parameterized 21 | 22 | def make_record(step, hparams_seed, envs): 23 | """envs is a list of (in_acc, out_acc, is_test_env) tuples""" 24 | result = { 25 | 'args': {'test_envs': [], 'hparams_seed': hparams_seed}, 26 | 'step': step 27 | } 28 | for i, (in_acc, out_acc, is_test_env) in enumerate(envs): 29 | if is_test_env: 30 | result['args']['test_envs'].append(i) 31 | result[f'env{i}_in_acc'] = in_acc 32 | result[f'env{i}_out_acc'] = out_acc 33 | return result 34 | 35 | class TestSelectionMethod(unittest.TestCase): 36 | 37 | class MySelectionMethod(model_selection.SelectionMethod): 38 | @classmethod 39 | def run_acc(self, run_records): 40 | return { 41 | 'val_acc': run_records[0]['env0_out_acc'], 42 | 'test_acc': run_records[0]['env0_in_acc'] 43 | } 44 | 45 | def test_sweep_acc(self): 46 | sweep_records = Q([ 47 | make_record(0, 0, [(0.7, 0.8, True)]), 48 | make_record(0, 1, [(0.9, 0.5, True)]) 49 | ]) 50 | 51 | self.assertEqual( 52 | self.MySelectionMethod.sweep_acc(sweep_records), 53 | 0.7 54 | ) 55 | 56 | def test_sweep_acc_empty(self): 57 | self.assertEqual( 58 | self.MySelectionMethod.sweep_acc(Q([])), 59 | None 60 | ) 61 | 62 | class TestOracleSelectionMethod(unittest.TestCase): 63 | 64 | def test_run_acc_best_first(self): 65 | """Test run_acc() when the run has two records and the best one comes 66 | first""" 67 | run_records = Q([ 68 | make_record(0, 0, [(0.75, 0.70, True)]), 69 | make_record(1, 0, [(0.65, 0.60, True)]) 70 | ]) 71 | self.assertEqual( 72 | model_selection.OracleSelectionMethod.run_acc(run_records), 73 | {'val_acc': 0.60, 'test_acc': 0.65} 74 | ) 75 | 76 | def test_run_acc_best_last(self): 77 | """Test run_acc() when the run has two records and the best one comes 78 | last""" 79 | run_records = Q([ 80 | make_record(0, 0, [(0.75, 0.70, True)]), 81 | make_record(1, 0, [(0.85, 0.80, True)]) 82 | ]) 83 | self.assertEqual( 84 | model_selection.OracleSelectionMethod.run_acc(run_records), 85 | {'val_acc': 0.80, 'test_acc': 0.85} 86 | ) 87 | 88 | def test_run_acc_empty(self): 89 | """Test run_acc() when there are no valid records to choose from.""" 90 | self.assertEqual( 91 | model_selection.OracleSelectionMethod.run_acc(Q([])), 92 | None 93 | ) 94 | 95 | class TestIIDAccuracySelectionMethod(unittest.TestCase): 96 | 97 | def test_run_acc(self): 98 | run_records = Q([ 99 | make_record(0, 0, 100 | [(0.1, 0.2, True), (0.5, 0.6, False), (0.6, 0.7, False)]), 101 | make_record(1, 0, 102 | [(0.3, 0.4, True), (0.6, 0.7, False), (0.7, 0.8, False)]), 103 | ]) 104 | self.assertEqual( 105 | model_selection.IIDAccuracySelectionMethod.run_acc(run_records), 106 | {'val_acc': 0.75, 'test_acc': 0.3} 107 | ) 108 | 109 | def test_run_acc_empty(self): 110 | self.assertEqual( 111 | model_selection.IIDAccuracySelectionMethod.run_acc(Q([])), 112 | None) 113 | 114 | class TestLeaveOneOutSelectionMethod(unittest.TestCase): 115 | 116 | def test_run_acc(self): 117 | run_records = Q([ 118 | make_record(0, 0, 119 | [(0.1, 0., True), (0.0, 0., False), (0.0, 0., False)]), 120 | make_record(0, 0, 121 | [(0.0, 0., True), (0.5, 0., True), (0., 0., False)]), 122 | make_record(0, 0, 123 | [(0.0, 0., True), (0.0, 0., False), (0.6, 0., True)]), 124 | ]) 125 | self.assertEqual( 126 | model_selection.LeaveOneOutSelectionMethod.run_acc(run_records), 127 | {'val_acc': 0.55, 'test_acc': 0.1} 128 | ) 129 | 130 | def test_run_acc_empty(self): 131 | run_records = Q([ 132 | make_record(0, 0, 133 | [(0.1, 0., True), (0.0, 0., False), (0.0, 0., False)]), 134 | make_record(0, 0, 135 | [(0.0, 0., True), (0.5, 0., True), (0., 0., False)]), 136 | ]) 137 | self.assertEqual( 138 | model_selection.LeaveOneOutSelectionMethod.run_acc(run_records), 139 | None 140 | ) 141 | -------------------------------------------------------------------------------- /test/test_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Unit tests.""" 4 | 5 | import argparse 6 | import itertools 7 | import json 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import unittest 13 | import uuid 14 | 15 | import torch 16 | 17 | from domainbed import datasets 18 | from domainbed import hparams_registry 19 | from domainbed import algorithms 20 | from domainbed import networks 21 | from domainbed.test import helpers 22 | 23 | from parameterized import parameterized 24 | 25 | 26 | class TestAlgorithms(unittest.TestCase): 27 | 28 | @parameterized.expand(itertools.product(helpers.DEBUG_DATASETS, algorithms.ALGORITHMS)) 29 | def test_init_update_predict(self, dataset_name, algorithm_name): 30 | """Test that a given algorithm inits, updates and predicts without raising 31 | errors.""" 32 | batch_size = 8 33 | hparams = hparams_registry.default_hparams(algorithm_name, dataset_name) 34 | dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) 35 | minibatches = helpers.make_minibatches(dataset, batch_size) 36 | algorithm_class = algorithms.get_algorithm_class(algorithm_name) 37 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset), 38 | hparams).cuda() 39 | for _ in range(3): 40 | self.assertIsNotNone(algorithm.update(minibatches)) 41 | algorithm.eval() 42 | self.assertEqual(list(algorithm.predict(minibatches[0][0]).shape), 43 | [batch_size, dataset.num_classes]) 44 | -------------------------------------------------------------------------------- /test/test_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import itertools 5 | import json 6 | import os 7 | import subprocess 8 | import sys 9 | import time 10 | import unittest 11 | import uuid 12 | 13 | import torch 14 | 15 | from domainbed import datasets 16 | from domainbed import hparams_registry 17 | from domainbed import algorithms 18 | from domainbed import networks 19 | from domainbed.test import helpers 20 | 21 | from parameterized import parameterized 22 | 23 | 24 | class TestNetworks(unittest.TestCase): 25 | 26 | @parameterized.expand(itertools.product(helpers.DEBUG_DATASETS)) 27 | def test_featurizer(self, dataset_name): 28 | """Test that Featurizer() returns a module which can take a 29 | correctly-sized input and return a correctly-sized output.""" 30 | batch_size = 8 31 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 32 | dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) 33 | input_ = helpers.make_minibatches(dataset, batch_size)[0][0] 34 | input_shape = dataset.input_shape 35 | algorithm = networks.Featurizer(input_shape, hparams).cuda() 36 | output = algorithm(input_) 37 | self.assertEqual(list(output.shape), [batch_size, algorithm.n_outputs]) 38 | -------------------------------------------------------------------------------- /test/visual/mix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import copy 11 | import uuid 12 | import numpy as np 13 | import PIL 14 | import torch 15 | import torchvision 16 | import torch.utils.data 17 | import yaml 18 | from azureml.core import Run 19 | run = Run.get_context() 20 | import datasets 21 | import hparams_registry 22 | import algorithms_gen as algorithms 23 | import numpy.random as random 24 | from lib import misc 25 | from scripts.save_images import write_2images 26 | from lib.fast_data_loader import InfiniteDataLoader 27 | 28 | def get_config(config): 29 | with open(config, 'r') as stream: 30 | return yaml.load(stream) 31 | 32 | 33 | parser = argparse.ArgumentParser(description='Domain generalization') 34 | parser.add_argument('--data_dir', type=str, default='/home/v-yifanzhang/datasets') 35 | parser.add_argument('--dataset', type=str, default="PACS") 36 | parser.add_argument('--gen_dir', type=str, default="outputs/model.pkl", help="if not empty, the generator of DEDF will be loaded") 37 | parser.add_argument('--algorithm', type=str, default="DDG") 38 | parser.add_argument('--hparams', type=str, 39 | help='JSON-serialized hparams dict') 40 | parser.add_argument('--hparams_seed', type=int, default=0, 41 | help='Seed for random hparams (0 means "default hparams")') 42 | parser.add_argument('--trial_seed', type=int, default=0, 43 | help='Trial number (used for seeding split_dataset and ' 44 | 'random_hparams).') 45 | parser.add_argument('--seed', type=int, default=0, 46 | help='Seed for everything else') 47 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 48 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 49 | args = parser.parse_args() 50 | 51 | # If we ever want to implement checkpointing, just persist these values 52 | # every once in a while, and then load them from disk here. 53 | start_step = 0 54 | algorithm_dict = None 55 | 56 | if args.hparams_seed == 0: 57 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 58 | else: 59 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 60 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 61 | if args.hparams: 62 | hparams.update(json.loads(args.hparams)) 63 | hparams['batch_size'] = 4 64 | print('HParams:') 65 | for k, v in sorted(hparams.items()): 66 | print('\t{}: {}'.format(k, v)) 67 | 68 | random.seed(args.seed) 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | torch.backends.cudnn.deterministic = True 72 | torch.backends.cudnn.benchmark = False 73 | 74 | if torch.cuda.is_available(): 75 | device = "cuda" 76 | else: 77 | device = "cpu" 78 | 79 | if args.dataset in vars(datasets): 80 | dataset = vars(datasets)[args.dataset](args.data_dir, 81 | args.test_envs, hparams) 82 | else: 83 | raise NotImplementedError 84 | 85 | in_splits = [] 86 | out_splits = [] 87 | uda_splits = [] 88 | for env_i, env in enumerate(dataset): 89 | 90 | out, in_ = misc.split_dataset(env, 91 | int(len(env)*args.holdout_fraction), 92 | misc.seed_hash(args.trial_seed, env_i)) 93 | in_splits.append((in_, None)) 94 | out_splits.append((out, None)) 95 | train_loaders = [InfiniteDataLoader( 96 | dataset=env, 97 | weights=env_weights, 98 | batch_size=hparams['batch_size'], 99 | num_workers=dataset.N_WORKERS) 100 | for i, (env, env_weights) in enumerate(in_splits) 101 | if i not in args.test_envs] 102 | 103 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 104 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 105 | len(dataset) - len(args.test_envs), hparams) 106 | 107 | if algorithm_dict is not None: 108 | algorithm.load_state_dict(algorithm_dict) 109 | 110 | algorithm.to(device) 111 | pretext_model = torch.load(args.gen_dir)['model_dict'] 112 | alg_dict = algorithm.state_dict() 113 | ignored_keys = [] 114 | state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)} 115 | alg_dict.update(state_dict) 116 | algorithm.load_state_dict(alg_dict) 117 | 118 | train_minibatches_iterator = zip(*train_loaders) 119 | def mix_v(x_a, x_b, x_c, pretrain_model): 120 | x_mix = [] 121 | for i in range(x_a.size(0)): 122 | alpha = 0.5 123 | model = pretrain_model 124 | s_a = model.gen.encode( model.single(x_a[i].unsqueeze(0)) ) 125 | f_b, _ = model.id_featurizer(x_b[i].unsqueeze(0)) 126 | f_c, _ = model.id_featurizer(x_c[i].unsqueeze(0)) 127 | x_mix.append(model.gen.decode(s_a, (alpha*f_b+(1-alpha)*f_c))) 128 | 129 | x_mix = torch.cat(x_mix) 130 | 131 | return x_a, x_mix, x_b, x_c 132 | 133 | def mix_s(x_a, x_b, x_c, pretrain_model): 134 | x_mix = [] 135 | for i in range(x_a.size(0)): 136 | alpha = 0.75 137 | model = pretrain_model 138 | s_a = model.gen.encode( model.single(x_a[i].unsqueeze(0)) ) 139 | s_b = model.gen.encode( model.single(x_b[i].unsqueeze(0)) ) 140 | f_a, _ = model.id_featurizer(x_a[i].unsqueeze(0)) 141 | f_b, _ = model.id_featurizer(x_b[i].unsqueeze(0)) 142 | x_mix.append(model.gen.decode((alpha*s_a+(1-alpha)*s_b), f_a)) 143 | 144 | x_mix = torch.cat(x_mix) 145 | 146 | return x_a, x_mix, x_b 147 | 148 | for step in range(start_step, 20): 149 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 150 | minibatches_device_neg = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 151 | minibatches_device_neg_2 = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 152 | images_a = torch.cat([x for x, y, pos in minibatches_device]) 153 | images_b = torch.cat([x for x, y, pos in minibatches_device_neg]) 154 | images_c = torch.cat([x for x, y, pos in minibatches_device_neg_2]) 155 | image_outputs = mix_v(images_a, images_b, images_c, algorithm) 156 | write_2images(image_outputs, hparams['batch_size'], "test/visual/results/mix", 'mix_v_%08d' % (step + 1), run) 157 | image_outputs = mix_s(images_a, images_b, images_c, algorithm) 158 | write_2images(image_outputs, hparams['batch_size'], "test/visual/results/mix", 'mix_s_%08d' % (step + 1), run) -------------------------------------------------------------------------------- /test/visual/show1by1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import json 5 | import random 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import torch 9 | import torch.utils.data 10 | import yaml 11 | from azureml.core import Run 12 | run = Run.get_context() 13 | import datasets 14 | import hparams_registry 15 | import algorithms_gen as algorithms 16 | import numpy.random as random 17 | from lib import misc 18 | import imageio 19 | from PIL import Image 20 | from scripts.save_images import write_2images 21 | from lib.fast_data_loader import InfiniteDataLoader 22 | 23 | parser = argparse.ArgumentParser(description='Domain generalization') 24 | parser.add_argument('--data_dir', type=str, default='/home/v-yifanzhang/datasets') 25 | parser.add_argument('--dataset', type=str, default="PACS") 26 | parser.add_argument('--gen_dir', type=str, default="outputs/model.pkl", help="if not empty, the generator of DEDF will be loaded") 27 | parser.add_argument('--algorithm', type=str, default="DDG") 28 | parser.add_argument('--hparams', type=str, 29 | help='JSON-serialized hparams dict') 30 | parser.add_argument('--hparams_seed', type=int, default=0, 31 | help='Seed for random hparams (0 means "default hparams")') 32 | parser.add_argument('--trial_seed', type=int, default=0, 33 | help='Trial number (used for seeding split_dataset and ' 34 | 'random_hparams).') 35 | parser.add_argument('--seed', type=int, default=0, 36 | help='Seed for everything else') 37 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 38 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 39 | args = parser.parse_args() 40 | 41 | # If we ever want to implement checkpointing, just persist these values 42 | # every once in a while, and then load them from disk here. 43 | start_step = 0 44 | algorithm_dict = None 45 | 46 | if args.hparams_seed == 0: 47 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 48 | else: 49 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 50 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 51 | if args.hparams: 52 | hparams.update(json.loads(args.hparams)) 53 | hparams['batch_size'] = 1 54 | print('HParams:') 55 | for k, v in sorted(hparams.items()): 56 | print('\t{}: {}'.format(k, v)) 57 | 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.backends.cudnn.deterministic = True 62 | torch.backends.cudnn.benchmark = False 63 | 64 | if torch.cuda.is_available(): 65 | device = "cuda" 66 | else: 67 | device = "cpu" 68 | 69 | if args.dataset in vars(datasets): 70 | dataset = vars(datasets)[args.dataset](args.data_dir, 71 | args.test_envs, hparams) 72 | else: 73 | raise NotImplementedError 74 | 75 | in_splits = [] 76 | out_splits = [] 77 | uda_splits = [] 78 | for env_i, env in enumerate(dataset): 79 | 80 | out, in_ = misc.split_dataset(env, 81 | int(len(env)*args.holdout_fraction), 82 | misc.seed_hash(args.trial_seed, env_i)) 83 | in_splits.append((in_, None)) 84 | out_splits.append((out, None)) 85 | train_loaders = [InfiniteDataLoader( 86 | dataset=env, 87 | weights=env_weights, 88 | batch_size=hparams['batch_size'], 89 | num_workers=dataset.N_WORKERS) 90 | for i, (env, env_weights) in enumerate(in_splits) 91 | if i not in args.test_envs] 92 | 93 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 94 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 95 | len(dataset) - len(args.test_envs), hparams) 96 | 97 | if algorithm_dict is not None: 98 | algorithm.load_state_dict(algorithm_dict) 99 | 100 | algorithm.to(device) 101 | pretext_model = torch.load(args.gen_dir)['model_dict'] 102 | alg_dict = algorithm.state_dict() 103 | ignored_keys = [] 104 | state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)} 105 | alg_dict.update(state_dict) 106 | algorithm.load_state_dict(alg_dict) 107 | 108 | train_minibatches_iterator = zip(*train_loaders) 109 | 110 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 111 | minibatches_device_neg = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 112 | images_a = torch.cat([x for x, y, pos in minibatches_device]) 113 | images_b = torch.cat([x for x, y, pos in minibatches_device_neg]) 114 | def recover(inp): 115 | """Imshow for Tensor.""" 116 | inp = inp.cpu().numpy().transpose((1, 2, 0)) 117 | mean = np.array([0.485, 0.456, 0.406]) 118 | std = np.array([0.229, 0.224, 0.225]) 119 | inp = std * inp + mean 120 | inp = inp * 255.0 121 | inp = np.clip(inp, 0, 255) 122 | return inp 123 | def to_gray(half=False): #simple 124 | def forward(x): 125 | x = torch.mean(x, dim=1, keepdim=True) 126 | if half: 127 | x = x.half() 128 | return x 129 | return forward 130 | im = {} 131 | bg_img = images_b 132 | gray = to_gray(False) 133 | bg_img = gray(bg_img) 134 | ff = [] 135 | gif = [] 136 | count = 0 137 | encode = algorithm.gen.encode # encode function 138 | id_encode = algorithm.id_featurizer # encode function 139 | decode = algorithm.gen.decode # decode function 140 | with torch.no_grad(): 141 | for data in minibatches_device: 142 | id_img, _, _ = data 143 | id_img = Variable(id_img.cuda()) 144 | n, c, h, w = id_img.size() 145 | # Start testing 146 | s = encode(bg_img) 147 | f, _ = id_encode(id_img) 148 | input1 = recover(data[0].squeeze()) 149 | im[count] = input1 150 | for i in range(s.size(0)): 151 | s_tmp = s[i,:,:,:] 152 | outputs = decode(s_tmp.unsqueeze(0), f) 153 | tmp = recover(outputs[0].data.cpu()) 154 | pic = Image.fromarray(tmp.astype('uint8')) 155 | pic.save('%s/rainbow_%d_%d.jpg'%('test/visual/results/1by1',i,count)) 156 | count +=1 -------------------------------------------------------------------------------- /test/visual/show_smooth.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import json 5 | import random 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import torch 9 | import torch.utils.data 10 | import yaml 11 | from azureml.core import Run 12 | run = Run.get_context() 13 | import datasets 14 | import hparams_registry 15 | import algorithms_gen as algorithms 16 | import numpy.random as random 17 | from lib import misc 18 | import imageio 19 | from PIL import Image 20 | from scripts.save_images import write_2images 21 | from lib.fast_data_loader import InfiniteDataLoader 22 | 23 | parser = argparse.ArgumentParser(description='Domain generalization') 24 | parser.add_argument('--data_dir', type=str, default='/home/v-yifanzhang/datasets') 25 | parser.add_argument('--dataset', type=str, default="PACS") 26 | parser.add_argument('--gen_dir', type=str, default="outputs/model.pkl", help="if not empty, the generator of DEDF will be loaded") 27 | parser.add_argument('--algorithm', type=str, default="DDG") 28 | parser.add_argument('--hparams', type=str, 29 | help='JSON-serialized hparams dict') 30 | parser.add_argument('--hparams_seed', type=int, default=0, 31 | help='Seed for random hparams (0 means "default hparams")') 32 | parser.add_argument('--trial_seed', type=int, default=0, 33 | help='Trial number (used for seeding split_dataset and ' 34 | 'random_hparams).') 35 | parser.add_argument('--seed', type=int, default=0, 36 | help='Seed for everything else') 37 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 38 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 39 | args = parser.parse_args() 40 | 41 | # If we ever want to implement checkpointing, just persist these values 42 | # every once in a while, and then load them from disk here. 43 | start_step = 0 44 | algorithm_dict = None 45 | 46 | if args.hparams_seed == 0: 47 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 48 | else: 49 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 50 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 51 | if args.hparams: 52 | hparams.update(json.loads(args.hparams)) 53 | hparams['batch_size'] = 4 54 | print('HParams:') 55 | for k, v in sorted(hparams.items()): 56 | print('\t{}: {}'.format(k, v)) 57 | 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.backends.cudnn.deterministic = True 62 | torch.backends.cudnn.benchmark = False 63 | 64 | if torch.cuda.is_available(): 65 | device = "cuda" 66 | else: 67 | device = "cpu" 68 | 69 | if args.dataset in vars(datasets): 70 | dataset = vars(datasets)[args.dataset](args.data_dir, 71 | args.test_envs, hparams) 72 | else: 73 | raise NotImplementedError 74 | 75 | in_splits = [] 76 | out_splits = [] 77 | uda_splits = [] 78 | for env_i, env in enumerate(dataset): 79 | 80 | out, in_ = misc.split_dataset(env, 81 | int(len(env)*args.holdout_fraction), 82 | misc.seed_hash(args.trial_seed, env_i)) 83 | in_splits.append((in_, None)) 84 | out_splits.append((out, None)) 85 | train_loaders = [InfiniteDataLoader( 86 | dataset=env, 87 | weights=env_weights, 88 | batch_size=hparams['batch_size'], 89 | num_workers=dataset.N_WORKERS) 90 | for i, (env, env_weights) in enumerate(in_splits) 91 | if i not in args.test_envs] 92 | 93 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 94 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 95 | len(dataset) - len(args.test_envs), hparams) 96 | 97 | if algorithm_dict is not None: 98 | algorithm.load_state_dict(algorithm_dict) 99 | 100 | algorithm.to(device) 101 | pretext_model = torch.load(args.gen_dir)['model_dict'] 102 | alg_dict = algorithm.state_dict() 103 | ignored_keys = [] 104 | state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)} 105 | alg_dict.update(state_dict) 106 | algorithm.load_state_dict(alg_dict) 107 | 108 | train_minibatches_iterator = zip(*train_loaders) 109 | 110 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 111 | minibatches_device_neg = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 112 | def recover(inp): 113 | """Imshow for Tensor.""" 114 | inp = inp.numpy().transpose((1, 2, 0)) 115 | mean = np.array([0.485, 0.456, 0.406]) 116 | std = np.array([0.229, 0.224, 0.225]) 117 | inp = std * inp + mean 118 | inp = inp * 255.0 119 | inp = np.clip(inp, 0, 255) 120 | return inp 121 | def to_gray(half=False): #simple 122 | def forward(x): 123 | x = torch.mean(x, dim=1, keepdim=True) 124 | if half: 125 | x = x.half() 126 | return x 127 | return forward 128 | im = {} 129 | bg_img, _, _ = minibatches_device_neg[0] 130 | gray = to_gray(False) 131 | ori_img = bg_img 132 | bg_img = gray(bg_img) 133 | bg_img = Variable(bg_img.cuda()) 134 | ff = [] 135 | gif = [] 136 | encode = algorithm.gen.encode # encode function 137 | id_encode = algorithm.id_featurizer # encode function 138 | decode = algorithm.gen.decode # decode function 139 | with torch.no_grad(): 140 | for data in minibatches_device: 141 | id_img, _, _ = data 142 | id_img = Variable(id_img.cuda()) 143 | n, c, h, w = id_img.size() 144 | # Start testing 145 | s = encode(bg_img) 146 | f, _ = id_encode(id_img) 147 | for count in range(4): 148 | input1 = recover(ori_img[count].squeeze().data.cpu()) 149 | im[count] = input1 150 | gif.append(input1) 151 | for i in range(11): 152 | f_tmp = f[count,:] 153 | f_tmp = f_tmp.view(1,-1) 154 | tmp_s = 0.1*i*s[count,:,:,:] + (1-0.1*i)*s[1-count,:,:,:] 155 | tmp_s = tmp_s.unsqueeze(0) 156 | outputs = decode(tmp_s, f_tmp) 157 | tmp = recover(outputs[0].data.cpu()) 158 | im[count] = np.concatenate((im[count], tmp), axis=1) 159 | gif.append(tmp) 160 | break 161 | # save long image 162 | pic = np.concatenate( (im[0], im[1], im[2], im[3]) , axis=0) 163 | pic = Image.fromarray(pic.astype('uint8')) 164 | pic.save('smooth-v.jpg') 165 | 166 | # save gif 167 | imageio.mimsave('./smooth-v.gif', gif) -------------------------------------------------------------------------------- /test/visual/show_smooth_v.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import json 5 | import random 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import torch 9 | import torch.utils.data 10 | import yaml 11 | from azureml.core import Run 12 | run = Run.get_context() 13 | import datasets 14 | import hparams_registry 15 | import algorithms_gen as algorithms 16 | import numpy.random as random 17 | from lib import misc 18 | import imageio 19 | from PIL import Image 20 | from scripts.save_images import write_2images 21 | from lib.fast_data_loader import InfiniteDataLoader 22 | 23 | parser = argparse.ArgumentParser(description='Domain generalization') 24 | parser.add_argument('--data_dir', type=str, default='/home/v-yifanzhang/datasets') 25 | parser.add_argument('--dataset', type=str, default="RotatedMNIST") 26 | parser.add_argument('--gen_dir', type=str, default="models/mnist_gen.pkl", help="if not empty, the generator of DEDF will be loaded") 27 | parser.add_argument('--algorithm', type=str, default="DDG") 28 | parser.add_argument('--hparams', type=str, 29 | help='JSON-serialized hparams dict') 30 | parser.add_argument('--hparams_seed', type=int, default=0, 31 | help='Seed for random hparams (0 means "default hparams")') 32 | parser.add_argument('--trial_seed', type=int, default=0, 33 | help='Trial number (used for seeding split_dataset and ' 34 | 'random_hparams).') 35 | parser.add_argument('--seed', type=int, default=15, 36 | help='Seed for everything else') 37 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 38 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 39 | args = parser.parse_args() 40 | 41 | # If we ever want to implement checkpointing, just persist these values 42 | # every once in a while, and then load them from disk here. 43 | start_step = 0 44 | algorithm_dict = None 45 | 46 | if args.hparams_seed == 0: 47 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 48 | else: 49 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 50 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 51 | if args.hparams: 52 | hparams.update(json.loads(args.hparams)) 53 | hparams['batch_size'] = 4 54 | print('HParams:') 55 | for k, v in sorted(hparams.items()): 56 | print('\t{}: {}'.format(k, v)) 57 | 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.backends.cudnn.deterministic = True 62 | torch.backends.cudnn.benchmark = False 63 | 64 | if torch.cuda.is_available(): 65 | device = "cuda" 66 | else: 67 | device = "cpu" 68 | 69 | if args.dataset in vars(datasets): 70 | dataset = vars(datasets)[args.dataset](args.data_dir, 71 | args.test_envs, hparams) 72 | else: 73 | raise NotImplementedError 74 | 75 | in_splits = [] 76 | out_splits = [] 77 | uda_splits = [] 78 | for env_i, env in enumerate(dataset): 79 | 80 | out, in_ = misc.split_dataset(env, 81 | int(len(env)*args.holdout_fraction), 82 | misc.seed_hash(args.trial_seed, env_i)) 83 | in_splits.append((in_, None)) 84 | out_splits.append((out, None)) 85 | train_loaders = [InfiniteDataLoader( 86 | dataset=env, 87 | weights=env_weights, 88 | batch_size=hparams['batch_size'], 89 | num_workers=dataset.N_WORKERS) 90 | for i, (env, env_weights) in enumerate(in_splits) 91 | if i not in args.test_envs] 92 | 93 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 94 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 95 | len(dataset) - len(args.test_envs), hparams) 96 | 97 | if algorithm_dict is not None: 98 | algorithm.load_state_dict(algorithm_dict) 99 | 100 | algorithm.to(device) 101 | pretext_model = torch.load(args.gen_dir)['model_dict'] 102 | alg_dict = algorithm.state_dict() 103 | ignored_keys = [] 104 | state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)} 105 | alg_dict.update(state_dict) 106 | algorithm.load_state_dict(alg_dict) 107 | 108 | train_minibatches_iterator = zip(*train_loaders) 109 | 110 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 111 | minibatches_device_neg = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 112 | def recover(inp): 113 | """Imshow for Tensor.""" 114 | if len(inp.shape) > 2: 115 | inp = inp.numpy().transpose((1, 2, 0)) 116 | mean = np.array([0.485, 0.456, 0.406]) 117 | std = np.array([0.229, 0.224, 0.225]) 118 | inp = std * inp + mean 119 | else: 120 | inp = inp.numpy() 121 | inp = inp * 255.0 122 | inp = np.clip(inp, 0, 255) 123 | return inp 124 | def to_gray(half=False): #simple 125 | def forward(x): 126 | x = torch.mean(x, dim=1, keepdim=True) 127 | if half: 128 | x = x.half() 129 | return x 130 | return forward 131 | im = {} 132 | bg_img, _, _ = minibatches_device_neg[0] 133 | gray = to_gray(False) 134 | bg_ori = bg_img 135 | bg_img = gray(bg_img) 136 | bg_img = Variable(bg_img.cuda()) 137 | ff = [] 138 | gif = [] 139 | encode = algorithm.gen.encode # encode function 140 | id_encode = algorithm.id_featurizer # encode function 141 | decode = algorithm.gen.decode # decode function 142 | with torch.no_grad(): 143 | for data in minibatches_device: 144 | id_img, _, _ = data 145 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 146 | id_img[:2] = minibatches_device[0][0][:2] 147 | id_img = Variable(id_img.cuda()) 148 | n, c, h, w = id_img.size() 149 | # Start testing 150 | s = encode(bg_img) 151 | f, _ = id_encode(id_img) 152 | for count in range(hparams['batch_size']): 153 | input1 = recover(id_img[count].squeeze().data.cpu()) 154 | im[count] = input1 155 | gif.append(input1) 156 | for i in range(11): 157 | s_tmp = s[count,:,:,:] if len(s.shape)==4 else s[count,:] 158 | tmp_f = 0.1*i*f[count] + (1-0.1*i)*f[1-count] 159 | tmp_f = tmp_f.view(1, -1) 160 | s_tmp = torch.cat([s_tmp.unsqueeze(0), s_tmp.unsqueeze(0)]) 161 | tmp_f = torch.cat([tmp_f,tmp_f]) 162 | outputs = decode(s_tmp, tmp_f)[0] 163 | tmp = recover(outputs[0].data.cpu()) 164 | im[count] = np.concatenate((im[count], tmp), axis=1) 165 | gif.append(tmp) 166 | break 167 | 168 | # save long image 169 | pic = np.concatenate( (im[0], im[1],im[2], im[3]) , axis=0) 170 | pic = Image.fromarray(pic.astype('uint8')) 171 | pic.save('smooth_.jpg') 172 | 173 | # save gif 174 | imageio.mimsave('./smooth.gif', gif) -------------------------------------------------------------------------------- /test/visual/swap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import copy 11 | import uuid 12 | import numpy as np 13 | import PIL 14 | import torch 15 | import torchvision 16 | import torch.utils.data 17 | import yaml 18 | from azureml.core import Run 19 | run = Run.get_context() 20 | import datasets 21 | import hparams_registry 22 | import algorithms_gen as algorithms 23 | import numpy.random as random 24 | from lib import misc 25 | from scripts.save_images import write_2images 26 | from lib.fast_data_loader import InfiniteDataLoader 27 | 28 | def get_config(config): 29 | with open(config, 'r') as stream: 30 | return yaml.load(stream) 31 | 32 | 33 | parser = argparse.ArgumentParser(description='Domain generalization') 34 | parser.add_argument('--data_dir', type=str, default='/home/v-yifanzhang/datasets') 35 | parser.add_argument('--dataset', type=str, default="RotatedMNIST") 36 | parser.add_argument('--gen_dir', type=str, default="models/mnist.pkl", help="if not empty, the generator of DEDF will be loaded") 37 | parser.add_argument('--algorithm', type=str, default="DDG") 38 | parser.add_argument('--hparams', type=str, 39 | help='JSON-serialized hparams dict') 40 | parser.add_argument('--hparams_seed', type=int, default=0, 41 | help='Seed for random hparams (0 means "default hparams")') 42 | parser.add_argument('--trial_seed', type=int, default=0, 43 | help='Trial number (used for seeding split_dataset and ' 44 | 'random_hparams).') 45 | parser.add_argument('--seed', type=int, default=0, 46 | help='Seed for everything else') 47 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 48 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 49 | args = parser.parse_args() 50 | 51 | # If we ever want to implement checkpointing, just persist these values 52 | # every once in a while, and then load them from disk here. 53 | start_step = 0 54 | algorithm_dict = None 55 | 56 | if args.hparams_seed == 0: 57 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 58 | else: 59 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 60 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 61 | if args.hparams: 62 | hparams.update(json.loads(args.hparams)) 63 | hparams['batch_size'] = 2 64 | print('HParams:') 65 | for k, v in sorted(hparams.items()): 66 | print('\t{}: {}'.format(k, v)) 67 | 68 | random.seed(args.seed) 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | torch.backends.cudnn.deterministic = True 72 | torch.backends.cudnn.benchmark = False 73 | 74 | if torch.cuda.is_available(): 75 | device = "cuda" 76 | else: 77 | device = "cpu" 78 | 79 | if args.dataset in vars(datasets): 80 | dataset = vars(datasets)[args.dataset](args.data_dir, 81 | args.test_envs, hparams) 82 | else: 83 | raise NotImplementedError 84 | 85 | in_splits = [] 86 | out_splits = [] 87 | uda_splits = [] 88 | for env_i, env in enumerate(dataset): 89 | 90 | out, in_ = misc.split_dataset(env, 91 | int(len(env)*args.holdout_fraction), 92 | misc.seed_hash(args.trial_seed, env_i)) 93 | in_splits.append((in_, None)) 94 | out_splits.append((out, None)) 95 | train_loaders = [InfiniteDataLoader( 96 | dataset=env, 97 | weights=env_weights, 98 | batch_size=hparams['batch_size'], 99 | num_workers=dataset.N_WORKERS) 100 | for i, (env, env_weights) in enumerate(in_splits) 101 | if i not in args.test_envs] 102 | 103 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 104 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 105 | len(dataset) - len(args.test_envs), hparams) 106 | 107 | if algorithm_dict is not None: 108 | algorithm.load_state_dict(algorithm_dict) 109 | 110 | algorithm.to(device) 111 | pretext_model = torch.load(args.gen_dir)['model_dict'] 112 | alg_dict = algorithm.state_dict() 113 | ignored_keys = [] 114 | state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)} 115 | alg_dict.update(state_dict) 116 | algorithm.load_state_dict(alg_dict) 117 | 118 | train_minibatches_iterator = zip(*train_loaders) 119 | 120 | for step in range(start_step, 10): 121 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 122 | minibatches_device_neg = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 123 | images_a = torch.cat([x for x, y, pos in minibatches_device]) 124 | images_b = torch.cat([x for x, y, pos in minibatches_device_neg]) 125 | perm = torch.randperm(len(images_b)).tolist() 126 | image_outputs = algorithm.sample(images_a, images_b[perm]) 127 | write_2images(image_outputs, hparams['batch_size']*len(train_loaders), "test/visual/results", 'train_%08d' % (step + 1), run) -------------------------------------------------------------------------------- /test/visual/swap_augmix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import copy 11 | import uuid 12 | import numpy as np 13 | import PIL 14 | import torch 15 | import torchvision 16 | import torch.utils.data 17 | import yaml 18 | from azureml.core import Run 19 | from torchvision import transforms 20 | run = Run.get_context() 21 | import datasets 22 | import hparams_registry 23 | import algorithms_gen as algorithms 24 | import numpy.random as random 25 | from lib import misc 26 | from scripts.save_images import write_2images 27 | from lib.fast_data_loader import InfiniteDataLoader 28 | from lib.misc import Augmix 29 | 30 | def get_config(config): 31 | with open(config, 'r') as stream: 32 | return yaml.load(stream) 33 | 34 | 35 | parser = argparse.ArgumentParser(description='Domain generalization') 36 | parser.add_argument('--data_dir', type=str, default='/home/v-yifanzhang/datasets') 37 | parser.add_argument('--dataset', type=str, default="PACS") 38 | parser.add_argument('--gen_dir', type=str, default="models/mnist.pkl", help="if not empty, the generator of DEDF will be loaded") 39 | parser.add_argument('--algorithm', type=str, default="DDG_AugMix") 40 | parser.add_argument('--hparams', type=str, 41 | help='JSON-serialized hparams dict') 42 | parser.add_argument('--hparams_seed', type=int, default=0, 43 | help='Seed for random hparams (0 means "default hparams")') 44 | parser.add_argument('--trial_seed', type=int, default=0, 45 | help='Trial number (used for seeding split_dataset and ' 46 | 'random_hparams).') 47 | parser.add_argument('--seed', type=int, default=0, 48 | help='Seed for everything else') 49 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 50 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 51 | args = parser.parse_args() 52 | 53 | # If we ever want to implement checkpointing, just persist these values 54 | # every once in a while, and then load them from disk here. 55 | start_step = 0 56 | algorithm_dict = None 57 | 58 | if args.hparams_seed == 0: 59 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 60 | else: 61 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 62 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 63 | if args.hparams: 64 | hparams.update(json.loads(args.hparams)) 65 | hparams['batch_size'] = 2 66 | print('HParams:') 67 | for k, v in sorted(hparams.items()): 68 | print('\t{}: {}'.format(k, v)) 69 | 70 | random.seed(args.seed) 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | torch.backends.cudnn.deterministic = True 74 | torch.backends.cudnn.benchmark = False 75 | 76 | if torch.cuda.is_available(): 77 | device = "cuda" 78 | else: 79 | device = "cpu" 80 | 81 | if args.dataset in vars(datasets): 82 | dataset = vars(datasets)[args.dataset](args.data_dir, 83 | args.test_envs, hparams) 84 | else: 85 | raise NotImplementedError 86 | 87 | in_splits = [] 88 | out_splits = [] 89 | uda_splits = [] 90 | for env_i, env in enumerate(dataset): 91 | 92 | out, in_ = misc.split_dataset(env, 93 | int(len(env)*args.holdout_fraction), 94 | misc.seed_hash(args.trial_seed, env_i)) 95 | in_splits.append((in_, None)) 96 | out_splits.append((out, None)) 97 | train_loaders = [InfiniteDataLoader( 98 | dataset=env, 99 | weights=env_weights, 100 | batch_size=hparams['batch_size'], 101 | num_workers=dataset.N_WORKERS) 102 | for i, (env, env_weights) in enumerate(in_splits) 103 | if i not in args.test_envs] 104 | 105 | train_minibatches_iterator = zip(*train_loaders) 106 | mean = [0.485, 0.456, 0.406] 107 | std = [0.229, 0.224, 0.225] 108 | preprocess = transforms.Compose( 109 | [transforms.ToTensor(), 110 | transforms.Normalize(mean, std)]) 111 | TO_pil = transforms.ToPILImage() 112 | 113 | def sample(x_a, x_b, pretrain_model=None): 114 | device = "cuda" if x_a.is_cuda else "cpu" 115 | x_as, x_bs, x_a_aug, x_b_aug, x_a_aug1, x_b_aug1 = [], [], [], [], [], [] 116 | for image_a, image_b in zip(x_a, x_b): 117 | x_b_, x_ab1, x_ab2= Augmix(TO_pil(image_b.cpu()), preprocess, no_jsd=False) 118 | x_a_, x_ba1, x_ba2= Augmix(TO_pil(image_a.cpu()), preprocess, no_jsd=False) 119 | x_a_aug.append(x_ba1.to(device).unsqueeze(0)); x_a_aug1.append(x_ba2.to(device).unsqueeze(0)) 120 | x_b_aug.append(x_ab1.to(device).unsqueeze(0)); x_b_aug1.append(x_ab2.to(device).unsqueeze(0)) 121 | x_as.append(x_a_.to(device).unsqueeze(0)); x_bs.append(x_b_.to(device).unsqueeze(0)) 122 | x_a_aug, x_a_aug1=torch.cat(x_a_aug), torch.cat(x_a_aug1) 123 | x_b_aug, x_b_aug1 = torch.cat(x_b_aug), torch.cat(x_b_aug1) 124 | x_as, x_bs = torch.cat(x_as), torch.cat(x_bs) 125 | return x_as, x_a_aug, x_a_aug1, x_bs, x_b_aug, x_b_aug1 126 | 127 | for step in range(start_step, 5): 128 | minibatches_device = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 129 | minibatches_device_neg = [(x.to(device), y, pos) for x,y,pos in next(train_minibatches_iterator)] 130 | images_a = torch.cat([x for x, y, pos in minibatches_device]) 131 | images_b = torch.cat([x for x, y, pos in minibatches_device_neg]) 132 | perm = torch.randperm(len(images_b)).tolist() 133 | image_outputs = sample(images_a, images_b[perm]) 134 | write_2images(image_outputs, hparams['batch_size']*len(train_loaders), "test/visual/results/aug_mix", 'Augmix_%08d' % (step + 1), run) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import copy 11 | import uuid 12 | import numpy as np 13 | import PIL 14 | import torch 15 | import torchvision 16 | import torch.utils.data 17 | import yaml 18 | import datasets 19 | import hparams_registry 20 | import algorithms 21 | import numpy.random as random 22 | from lib import misc 23 | from scripts.save_images import write_2images 24 | from lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 25 | 26 | def get_config(config): 27 | with open(config, 'r') as stream: 28 | return yaml.load(stream) 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser(description='Domain generalization') 32 | parser.add_argument('--data_dir', type=str, default='/data1/yifan.zhang/datasets/DGdata/') 33 | parser.add_argument('--dataset', type=str, default="PACS") 34 | parser.add_argument('--algorithm', type=str, default="DDG") 35 | parser.add_argument('--gen_dir', type=str, default="", help="if not empty, the generator of DEDF will be loaded") 36 | parser.add_argument('--stage', type=int, default=1, 37 | help='hyperparameter for DDG, 0:train the gan, 1: train the model') 38 | parser.add_argument('--task', type=str, default="domain_generalization", 39 | help='domain_generalization | domain_adaptation') 40 | parser.add_argument('--hparams', type=str, 41 | help='JSON-serialized hparams dict') 42 | parser.add_argument('--hparams_seed', type=int, default=0, 43 | help='Seed for random hparams (0 means "default hparams")') 44 | parser.add_argument('--image_display_iter', type=int, default=500, 45 | help='Epochs interval for showing the generated images') 46 | parser.add_argument('--trial_seed', type=int, default=0, 47 | help='Trial number (used for seeding split_dataset and ' 48 | 'random_hparams).') 49 | parser.add_argument('--seed', type=int, default=7, 50 | help='Seed for everything else') 51 | parser.add_argument('--steps', type=int, default=None, 52 | help='Number of steps. Default is dataset-dependent.') 53 | parser.add_argument('--checkpoint_freq', type=int, default=None, 54 | help='Checkpoint every N steps. Default is dataset-dependent.') 55 | parser.add_argument('--test_envs', type=int, nargs='+', default=[1]) 56 | parser.add_argument('--output_dir', type=str, default="train_outputs") 57 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 58 | parser.add_argument('--uda_holdout_fraction', type=float, default=0) 59 | parser.add_argument('--skip_model_save', action='store_true') 60 | parser.add_argument('--save_model_every_checkpoint', action='store_true') 61 | args = parser.parse_args() 62 | 63 | # If we ever want to implement checkpointing, just persist these values 64 | # every once in a while, and then load them from disk here. 65 | start_step = 0 66 | algorithm_dict = None 67 | 68 | os.makedirs(args.output_dir, exist_ok=True) 69 | sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) 70 | sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) 71 | 72 | print("Environment:") 73 | print("\tPython: {}".format(sys.version.split(" ")[0])) 74 | print("\tPyTorch: {}".format(torch.__version__)) 75 | print("\tTorchvision: {}".format(torchvision.__version__)) 76 | print("\tCUDA: {}".format(torch.version.cuda)) 77 | print("\tCUDNN: {}".format(torch.backends.cudnn.version())) 78 | print("\tNumPy: {}".format(np.__version__)) 79 | print("\tPIL: {}".format(PIL.__version__)) 80 | 81 | print('Args:') 82 | for k, v in sorted(vars(args).items()): 83 | print('\t{}: {}'.format(k, v)) 84 | 85 | if args.hparams_seed == 0: 86 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset, stage=args.stage) 87 | else: 88 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 89 | misc.seed_hash(args.hparams_seed, args.trial_seed), stage=args.stage) 90 | if args.hparams: 91 | hparams.update(json.loads(args.hparams)) 92 | 93 | print('HParams:') 94 | for k, v in sorted(hparams.items()): 95 | print('\t{}: {}'.format(k, v)) 96 | 97 | random.seed(args.seed) 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.backends.cudnn.deterministic = True 101 | torch.backends.cudnn.benchmark = False 102 | 103 | if torch.cuda.is_available(): 104 | device = "cuda" 105 | else: 106 | device = "cpu" 107 | 108 | if args.dataset in vars(datasets): 109 | dataset = vars(datasets)[args.dataset](args.data_dir, 110 | args.test_envs, hparams) 111 | else: 112 | raise NotImplementedError 113 | 114 | # Split each env into an 'in-split' and an 'out-split'. We'll train on 115 | # each in-split except the test envs, and evaluate on all splits. 116 | 117 | # To allow unsupervised domain adaptation experiments, we split each test 118 | # env into 'in-split', 'uda-split' and 'out-split'. The 'in-split' is used 119 | # by collect_results.py to compute classification accuracies. The 120 | # 'out-split' is used by the Oracle model selectino method. The unlabeled 121 | # samples in 'uda-split' are passed to the algorithm at training time if 122 | # args.task == "domain_adaptation". If we are interested in comparing 123 | # domain generalization and domain adaptation results, then domain 124 | # generalization algorithms should create the same 'uda-splits', which will 125 | # be discared at training. 126 | in_splits = [] 127 | out_splits = [] 128 | uda_splits = [] 129 | for env_i, env in enumerate(dataset): 130 | uda = [] 131 | 132 | out, in_ = misc.split_dataset(env, 133 | int(len(env)*args.holdout_fraction), 134 | misc.seed_hash(args.trial_seed, env_i)) 135 | 136 | if env_i in args.test_envs: 137 | uda, in_ = misc.split_dataset(in_, 138 | int(len(in_)*args.uda_holdout_fraction), 139 | misc.seed_hash(args.trial_seed, env_i)) 140 | 141 | if hparams['class_balanced']: 142 | in_weights = misc.make_weights_for_balanced_classes(in_) 143 | out_weights = misc.make_weights_for_balanced_classes(out) 144 | if uda is not None: 145 | uda_weights = misc.make_weights_for_balanced_classes(uda) 146 | else: 147 | in_weights, out_weights, uda_weights = None, None, None 148 | in_splits.append((in_, in_weights)) 149 | out_splits.append((out, out_weights)) 150 | if len(uda): 151 | uda_splits.append((uda, uda_weights)) 152 | 153 | train_loaders = [InfiniteDataLoader( 154 | dataset=env, 155 | weights=env_weights, 156 | batch_size=hparams['batch_size'], 157 | num_workers=dataset.N_WORKERS) 158 | for i, (env, env_weights) in enumerate(in_splits) 159 | if i not in args.test_envs] 160 | 161 | uda_loaders = [InfiniteDataLoader( 162 | dataset=env, 163 | weights=env_weights, 164 | batch_size=hparams['batch_size'], 165 | num_workers=dataset.N_WORKERS) 166 | for i, (env, env_weights) in enumerate(uda_splits) 167 | if i in args.test_envs] 168 | 169 | eval_loaders = [FastDataLoader( 170 | dataset=env, 171 | batch_size=64, 172 | num_workers=dataset.N_WORKERS) 173 | for env, _ in (in_splits + out_splits + uda_splits)] 174 | eval_weights = [None for _, weights in (in_splits + out_splits + uda_splits)] 175 | eval_loader_names = ['env{}_in'.format(i) 176 | for i in range(len(in_splits))] 177 | eval_loader_names += ['env{}_out'.format(i) 178 | for i in range(len(out_splits))] 179 | eval_loader_names += ['env{}_uda'.format(i) 180 | for i in range(len(uda_splits))] 181 | 182 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 183 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 184 | len(dataset) - len(args.test_envs), hparams) 185 | 186 | if algorithm_dict is not None: 187 | algorithm.load_state_dict(algorithm_dict) 188 | 189 | algorithm.to(device) 190 | if args.algorithm == 'DDG' and args.gen_dir and hparams['stage'] == 1: 191 | pretext_model = torch.load(args.gen_dir)['model_dict'] 192 | alg_dict = algorithm.state_dict() 193 | ignored_keys = [] 194 | state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)} 195 | alg_dict.update(state_dict) 196 | algorithm.load_state_dict(alg_dict) 197 | algorithm_copy = copy.deepcopy(algorithm) 198 | algorithm_copy.eval() 199 | else: 200 | algorithm_copy = None 201 | 202 | train_minibatches_iterator = zip(*train_loaders) 203 | uda_minibatches_iterator = zip(*uda_loaders) 204 | checkpoint_vals = collections.defaultdict(lambda: []) 205 | 206 | steps_per_epoch = min([len(env)/hparams['batch_size'] for env,_ in in_splits]) if args.algorithm is not 'DDG' else min([len(env)/hparams['batch_size']/2 for env,_ in in_splits]) 207 | print("steps per epoch", steps_per_epoch) 208 | 209 | n_steps = args.steps or dataset.N_STEPS 210 | if 'DDG' in args.algorithm: 211 | n_steps = hparams['steps'] 212 | checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ 213 | 214 | def save_checkpoint(filename): 215 | save_dict = { 216 | "args": vars(args), 217 | "model_input_shape": dataset.input_shape, 218 | "model_num_classes": dataset.num_classes, 219 | "model_num_domains": len(dataset) - len(args.test_envs), 220 | "model_hparams": hparams, 221 | "model_dict": algorithm.cpu().state_dict() 222 | } 223 | torch.save(save_dict, os.path.join(args.output_dir, filename)) 224 | 225 | 226 | last_results_keys = None 227 | print("n_steps", n_steps) 228 | for step in range(start_step, n_steps): 229 | step_start_time = time.time() 230 | if args.task == "domain_adaptation": 231 | uda_device = [x.to(device) 232 | for x,_ in next(uda_minibatches_iterator)] 233 | else: 234 | uda_device = None 235 | 236 | if 'DDG' in args.algorithm: 237 | minibatches_device = [(x.to(device), y.to(device), pos.to(device)) for x,y,pos in next(train_minibatches_iterator)] 238 | minibatches_device_neg = [(x.to(device), y.to(device), pos.to(device)) for x,y,pos in next(train_minibatches_iterator)] 239 | step_vals = algorithm.update(minibatches_device, minibatches_device_neg, pretrain_model=algorithm_copy) 240 | else: 241 | minibatches_device = [(x.to(device), y.to(device)) for x,y in next(train_minibatches_iterator)] 242 | step_vals = algorithm.update(minibatches_device, uda_device) 243 | checkpoint_vals['step_time'].append(time.time() - step_start_time) 244 | 245 | for key, val in step_vals.items(): 246 | checkpoint_vals[key].append(val) 247 | if not os.path.exists('train_outputs/images'): 248 | print("Creating directory: {}".format('train_outputs/images')) 249 | os.makedirs('train_outputs/images') 250 | if (step % checkpoint_freq == 0) or (step == n_steps - 1): 251 | results = { 252 | 'step': step, 253 | 'epoch': step / steps_per_epoch, 254 | } 255 | 256 | for key, val in checkpoint_vals.items(): 257 | results[key] = np.mean(val) 258 | 259 | evals = zip(eval_loader_names, eval_loaders, eval_weights) 260 | for name, loader, weights in evals: 261 | acc = misc.accuracy(algorithm, loader, weights, device, args=args, step=step, is_ddg=hparams['is_ddg']) 262 | results[name+'_acc'] = acc 263 | 264 | results_keys = sorted(results.keys()) 265 | if results_keys != last_results_keys: 266 | misc.print_row(results_keys, colwidth=12) 267 | last_results_keys = results_keys 268 | misc.print_row([results[key] for key in results_keys], 269 | colwidth=12) 270 | 271 | results.update({ 272 | 'hparams': hparams, 273 | 'args': vars(args) 274 | }) 275 | 276 | epochs_path = os.path.join(args.output_dir, 'results.jsonl') 277 | with open(epochs_path, 'a') as f: 278 | f.write(json.dumps(results, sort_keys=True) + "\n") 279 | 280 | algorithm_dict = algorithm.state_dict() 281 | start_step = step + 1 282 | checkpoint_vals = collections.defaultdict(lambda: []) 283 | 284 | save_checkpoint('model.pkl') 285 | 286 | with open(os.path.join(args.output_dir, 'done'), 'w') as f: 287 | f.write('done') 288 | -------------------------------------------------------------------------------- /wilds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hlzhang109/DDG/b5cd1822f1a413ae7e263fc5a11f00b490b9c72c/wilds/__init__.py -------------------------------------------------------------------------------- /wilds/common/grouper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from wilds.common.utils import get_counts 4 | from wilds.datasets.wilds_dataset import WILDSSubset 5 | import warnings 6 | 7 | class Grouper: 8 | """ 9 | Groupers group data points together based on their metadata. 10 | They are used for training and evaluation, 11 | e.g., to measure the accuracies of different groups of data. 12 | """ 13 | def __init__(self): 14 | raise NotImplementedError 15 | 16 | @property 17 | def n_groups(self): 18 | """ 19 | The number of groups defined by this Grouper. 20 | """ 21 | return self._n_groups 22 | 23 | def metadata_to_group(self, metadata, return_counts=False): 24 | """ 25 | Args: 26 | - metadata (Tensor): An n x d matrix containing d metadata fields 27 | for n different points. 28 | - return_counts (bool): If True, return group counts as well. 29 | Output: 30 | - group (Tensor): An n-length vector of groups. 31 | - group_counts (Tensor): Optional, depending on return_counts. 32 | An n_group-length vector of integers containing the 33 | numbers of data points in each group in the metadata. 34 | """ 35 | raise NotImplementedError 36 | 37 | def group_str(self, group): 38 | """ 39 | Args: 40 | - group (int): A single integer representing a group. 41 | Output: 42 | - group_str (str): A string containing the pretty name of that group. 43 | """ 44 | raise NotImplementedError 45 | 46 | def group_field_str(self, group): 47 | """ 48 | Args: 49 | - group (int): A single integer representing a group. 50 | Output: 51 | - group_str (str): A string containing the name of that group. 52 | """ 53 | raise NotImplementedError 54 | 55 | class CombinatorialGrouper(Grouper): 56 | def __init__(self, dataset, groupby_fields): 57 | """ 58 | CombinatorialGroupers form groups by taking all possible combinations of the metadata 59 | fields specified in groupby_fields, in lexicographical order. 60 | For example, if: 61 | dataset.metadata_fields = ['country', 'time', 'y'] 62 | groupby_fields = ['country', 'time'] 63 | and if in dataset.metadata, country is in {0, 1} and time is in {0, 1, 2}, 64 | then the grouper will assign groups in the following way: 65 | country = 0, time = 0 -> group 0 66 | country = 1, time = 0 -> group 1 67 | country = 0, time = 1 -> group 2 68 | country = 1, time = 1 -> group 3 69 | country = 0, time = 2 -> group 4 70 | country = 1, time = 2 -> group 5 71 | 72 | If groupby_fields is None, then all data points are assigned to group 0. 73 | 74 | Args: 75 | - dataset (WILDSDataset) 76 | - groupby_fields (list of str) 77 | """ 78 | if isinstance(dataset, WILDSSubset): 79 | raise ValueError("Grouper should be defined for the full dataset, not a subset") 80 | self.groupby_fields = groupby_fields 81 | 82 | if groupby_fields is None: 83 | self._n_groups = 1 84 | else: 85 | # We assume that the metadata fields are integers, 86 | # so we can measure the cardinality of each field by taking its max + 1. 87 | # Note that this might result in some empty groups. 88 | self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if field in groupby_fields] 89 | if len(self.groupby_field_indices) != len(self.groupby_fields): 90 | raise ValueError('At least one group field not found in dataset.metadata_fields') 91 | grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices] 92 | if not isinstance(grouped_metadata, torch.LongTensor): 93 | grouped_metadata_long = grouped_metadata.long() 94 | if not torch.all(grouped_metadata == grouped_metadata_long): 95 | warnings.warn(f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long') 96 | grouped_metadata = grouped_metadata_long 97 | for idx, field in enumerate(self.groupby_fields): 98 | min_value = grouped_metadata[:,idx].min() 99 | if min_value < 0: 100 | raise ValueError(f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}") 101 | if min_value > 0: 102 | warnings.warn(f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups") 103 | self.cardinality = 1 + torch.max( 104 | grouped_metadata, dim=0)[0] 105 | cumprod = torch.cumprod(self.cardinality, dim=0) 106 | self._n_groups = cumprod[-1].item() 107 | self.factors_np = np.concatenate(([1], cumprod[:-1])) 108 | self.factors = torch.from_numpy(self.factors_np) 109 | self.metadata_map = dataset.metadata_map 110 | 111 | def metadata_to_group(self, metadata, return_counts=False): 112 | if self.groupby_fields is None: 113 | groups = torch.zeros(metadata.shape[0], dtype=torch.long) 114 | else: 115 | groups = metadata[:, self.groupby_field_indices].long() @ self.factors 116 | 117 | if return_counts: 118 | group_counts = get_counts(groups, self._n_groups) 119 | return groups, group_counts 120 | else: 121 | return groups 122 | 123 | def group_str(self, group): 124 | if self.groupby_fields is None: 125 | return 'all' 126 | 127 | # group is just an integer, not a Tensor 128 | n = len(self.factors_np) 129 | metadata = np.zeros(n) 130 | for i in range(n-1): 131 | metadata[i] = (group % self.factors_np[i+1]) // self.factors_np[i] 132 | metadata[n-1] = group // self.factors_np[n-1] 133 | group_name = '' 134 | for i in reversed(range(n)): 135 | meta_val = int(metadata[i]) 136 | if self.metadata_map is not None: 137 | if self.groupby_fields[i] in self.metadata_map: 138 | meta_val = self.metadata_map[self.groupby_fields[i]][meta_val] 139 | group_name += f'{self.groupby_fields[i]} = {meta_val}, ' 140 | group_name = group_name[:-2] 141 | return group_name 142 | 143 | # a_n = S / x_n 144 | # a_{n-1} = (S % x_n) / x_{n-1} 145 | # a_{n-2} = (S % x_{n-1}) / x_{n-2} 146 | # ... 147 | # 148 | # g = 149 | # a_1 * x_1 + 150 | # a_2 * x_2 + ... 151 | # a_n * x_n 152 | 153 | def group_field_str(self, group): 154 | return self.group_str(group).replace('=', ':').replace(',','_').replace(' ','') 155 | -------------------------------------------------------------------------------- /wilds/common/utils.py: -------------------------------------------------------------------------------- 1 | import torch #, torch_scatter 2 | import numpy as np 3 | from torch.utils.data import Subset 4 | from pandas.api.types import CategoricalDtype 5 | 6 | def minimum(numbers, empty_val=0.): 7 | if isinstance(numbers, torch.Tensor): 8 | if numbers.numel()==0: 9 | return torch.tensor(empty_val, device=numbers.device) 10 | else: 11 | return numbers[~torch.isnan(numbers)].min() 12 | elif isinstance(numbers, np.ndarray): 13 | if numbers.size==0: 14 | return np.array(empty_val) 15 | else: 16 | return np.nanmin(numbers) 17 | else: 18 | if len(numbers)==0: 19 | return empty_val 20 | else: 21 | return min(numbers) 22 | 23 | def maximum(numbers, empty_val=0.): 24 | if isinstance(numbers, torch.Tensor): 25 | if numbers.numel()==0: 26 | return torch.tensor(empty_val, device=numbers.device) 27 | else: 28 | return numbers[~torch.isnan(numbers)].max() 29 | elif isinstance(numbers, np.ndarray): 30 | if numbers.size==0: 31 | return np.array(empty_val) 32 | else: 33 | return np.nanmax(numbers) 34 | else: 35 | if len(numbers)==0: 36 | return empty_val 37 | else: 38 | return max(numbers) 39 | 40 | def split_into_groups(g): 41 | """ 42 | Args: 43 | - g (Tensor): Vector of groups 44 | Returns: 45 | - groups (Tensor): Unique groups present in g 46 | - group_indices (list): List of Tensors, where the i-th tensor is the indices of the 47 | elements of g that equal groups[i]. 48 | Has the same length as len(groups). 49 | - unique_counts (Tensor): Counts of each element in groups. 50 | Has the same length as len(groups). 51 | """ 52 | unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True) 53 | group_indices = [] 54 | for group in unique_groups: 55 | group_indices.append( 56 | torch.nonzero(g == group, as_tuple=True)[0]) 57 | return unique_groups, group_indices, unique_counts 58 | 59 | def get_counts(g, n_groups): 60 | """ 61 | This differs from split_into_groups in how it handles missing groups. 62 | get_counts always returns a count Tensor of length n_groups, 63 | whereas split_into_groups returns a unique_counts Tensor 64 | whose length is the number of unique groups present in g. 65 | Args: 66 | - g (Tensor): Vector of groups 67 | Returns: 68 | - counts (Tensor): A list of length n_groups, denoting the count of each group. 69 | """ 70 | unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True) 71 | counts = torch.zeros(n_groups, device=g.device) 72 | counts[unique_groups] = unique_counts.float() 73 | return counts 74 | 75 | def avg_over_groups(v, g, n_groups): 76 | """ 77 | Args: 78 | v (Tensor): Vector containing the quantity to average over. 79 | g (Tensor): Vector of the same length as v, containing group information. 80 | Returns: 81 | group_avgs (Tensor): Vector of length num_groups 82 | group_counts (Tensor) 83 | """ 84 | assert v.device==g.device 85 | device = v.device 86 | assert v.numel()==g.numel() 87 | group_count = get_counts(g, n_groups) 88 | #group_avgs = torch_scatter.scatter(src=v, index=g, dim_size=n_groups, reduce='mean') 89 | group_avgs = None 90 | return group_avgs, group_count 91 | 92 | def map_to_id_array(df, ordered_map={}): 93 | maps = {} 94 | array = np.zeros(df.shape) 95 | for i, c in enumerate(df.columns): 96 | if c in ordered_map: 97 | category_type = CategoricalDtype(categories=ordered_map[c], ordered=True) 98 | else: 99 | category_type = 'category' 100 | series = df[c].astype(category_type) 101 | maps[c] = series.cat.categories.values 102 | array[:,i] = series.cat.codes.values 103 | return maps, array 104 | 105 | def subsample_idxs(idxs, num=5000, take_rest=False, seed=None): 106 | seed = (seed + 541433) if seed is not None else None 107 | rng = np.random.default_rng(seed) 108 | 109 | idxs = idxs.copy() 110 | rng.shuffle(idxs) 111 | if take_rest: 112 | idxs = idxs[num:] 113 | else: 114 | idxs = idxs[:num] 115 | return idxs 116 | 117 | 118 | def shuffle_arr(arr, seed=None): 119 | seed = (seed + 548207) if seed is not None else None 120 | rng = np.random.default_rng(seed) 121 | 122 | arr = arr.copy() 123 | rng.shuffle(arr) 124 | return arr 125 | 126 | def threshold_at_recall(y_pred, y_true, global_recall=60): 127 | """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that 128 | y_true is a vector of the true binary labels.""" 129 | return np.percentile(y_pred[y_true == 1], 100-global_recall) 130 | -------------------------------------------------------------------------------- /wilds/datasets/camelyon17_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | from PIL import Image 5 | import numpy as np 6 | from wilds.datasets.wilds_dataset import WILDSDataset 7 | from wilds.common.grouper import CombinatorialGrouper 8 | #from wilds.common.metrics.all_metrics import Accuracy 9 | 10 | class Camelyon17Dataset(WILDSDataset): 11 | """ 12 | The CAMELYON17-wilds histopathology dataset. 13 | This is a modified version of the original CAMELYON17 dataset. 14 | 15 | Supported `split_scheme`: 16 | 'official' or 'in-dist' 17 | 18 | Input (x): 19 | 96x96 image patches extracted from histopathology slides. 20 | 21 | Label (y): 22 | y is binary. It is 1 if the central 32x32 region contains any tumor tissue, and 0 otherwise. 23 | 24 | Metadata: 25 | Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) 26 | and the slide it came from (integer from 0 to 49). 27 | 28 | Website: 29 | https://camelyon17.grand-challenge.org/ 30 | 31 | Original publication: 32 | @article{bandi2018detection, 33 | title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, 34 | author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, 35 | journal={IEEE transactions on medical imaging}, 36 | volume={38}, 37 | number={2}, 38 | pages={550--560}, 39 | year={2018}, 40 | publisher={IEEE} 41 | } 42 | 43 | License: 44 | This dataset is in the public domain and is distributed under CC0. 45 | https://creativecommons.org/publicdomain/zero/1.0/ 46 | """ 47 | 48 | _dataset_name = 'camelyon17' 49 | _versions_dict = { 50 | '1.0': { 51 | 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/', 52 | 'compressed_size': 10_658_709_504}} 53 | 54 | def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): 55 | self._version = version 56 | self._data_dir = self.initialize_data_dir(root_dir, download) 57 | self._original_resolution = (96,96) 58 | 59 | # Read in metadata 60 | self._metadata_df = pd.read_csv( 61 | os.path.join(self._data_dir, 'metadata.csv'), 62 | index_col=0, 63 | dtype={'patient': 'str'}) 64 | 65 | # Get the y values 66 | self._y_array = torch.LongTensor(self._metadata_df['tumor'].values) 67 | self._y_size = 1 68 | self._n_classes = 2 69 | 70 | # Get filenames 71 | self._input_array = [ 72 | f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png' 73 | for patient, node, x, y in 74 | self._metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] 75 | 76 | # Extract splits 77 | # Note that the hospital numbering here is different from what's in the paper, 78 | # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. 79 | # Here, the numbers are 0-indexed. 80 | test_center = 2 81 | val_center = 1 82 | 83 | self._split_dict = { 84 | 'train': 0, 85 | 'id_val': 1, 86 | 'test': 2, 87 | 'val': 3 88 | } 89 | self._split_names = { 90 | 'train': 'Train', 91 | 'id_val': 'Validation (ID)', 92 | 'test': 'Test', 93 | 'val': 'Validation (OOD)', 94 | } 95 | centers = self._metadata_df['center'].values.astype('long') 96 | num_centers = int(np.max(centers)) + 1 97 | val_center_mask = (self._metadata_df['center'] == val_center) 98 | test_center_mask = (self._metadata_df['center'] == test_center) 99 | self._metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] 100 | self._metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] 101 | 102 | self._split_scheme = split_scheme 103 | if self._split_scheme == 'official': 104 | pass 105 | elif self._split_scheme == 'in-dist': 106 | # For the in-distribution oracle, 107 | # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) 108 | # from the test set to the training set 109 | slide_mask = (self._metadata_df['slide'] == 23) 110 | self._metadata_df.loc[slide_mask, 'split'] = self.split_dict['train'] 111 | else: 112 | raise ValueError(f'Split scheme {self._split_scheme} not recognized') 113 | self._split_array = self._metadata_df['split'].values 114 | 115 | self._metadata_array = torch.stack( 116 | (torch.LongTensor(centers), 117 | torch.LongTensor(self._metadata_df['slide'].values), 118 | self._y_array), 119 | dim=1) 120 | self._metadata_fields = ['hospital', 'slide', 'y'] 121 | 122 | self._eval_grouper = CombinatorialGrouper( 123 | dataset=self, 124 | groupby_fields=['slide']) 125 | 126 | super().__init__(root_dir, download, split_scheme) 127 | 128 | def get_input(self, idx): 129 | """ 130 | Returns x for a given idx. 131 | """ 132 | img_filename = os.path.join( 133 | self.data_dir, 134 | self._input_array[idx]) 135 | x = Image.open(img_filename).convert('RGB') 136 | return x 137 | 138 | 139 | ''' 140 | def eval(self, y_pred, y_true, metadata, prediction_fn=None): 141 | """ 142 | Computes all evaluation metrics. 143 | Args: 144 | - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). 145 | But they can also be other model outputs such that prediction_fn(y_pred) 146 | are predicted labels. 147 | - y_true (LongTensor): Ground-truth labels 148 | - metadata (Tensor): Metadata 149 | - prediction_fn (function): A function that turns y_pred into predicted labels 150 | Output: 151 | - results (dictionary): Dictionary of evaluation metrics 152 | - results_str (str): String summarizing the evaluation metrics 153 | """ 154 | metric = Accuracy(prediction_fn=prediction_fn) 155 | return self.standard_group_eval( 156 | metric, 157 | self._eval_grouper, 158 | y_pred, y_true, metadata) 159 | ''' -------------------------------------------------------------------------------- /wilds/datasets/wilds_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import numpy as np 6 | 7 | class WILDSDataset: 8 | """ 9 | Shared dataset class for all WILDS datasets. 10 | Each data point in the dataset is an (x, y, metadata) tuple, where: 11 | - x is the input features 12 | - y is the target 13 | - metadata is a vector of relevant information, e.g., domain. 14 | For convenience, metadata also contains y. 15 | """ 16 | DEFAULT_SPLITS = {'train': 0, 'val': 1, 'test': 2} 17 | DEFAULT_SPLIT_NAMES = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} 18 | 19 | def __init__(self, root_dir, download, split_scheme): 20 | if len(self._metadata_array.shape) == 1: 21 | self._metadata_array = self._metadata_array.unsqueeze(1) 22 | self.check_init() 23 | 24 | def __len__(self): 25 | return len(self.y_array) 26 | 27 | def __getitem__(self, idx): 28 | # Any transformations are handled by the WILDSSubset 29 | # since different subsets (e.g., train vs test) might have different transforms 30 | x = self.get_input(idx) 31 | y = self.y_array[idx] 32 | metadata = self.metadata_array[idx] 33 | return x, y, metadata 34 | 35 | def get_input(self, idx): 36 | """ 37 | Args: 38 | - idx (int): Index of a data point 39 | Output: 40 | - x (Tensor): Input features of the idx-th data point 41 | """ 42 | raise NotImplementedError 43 | 44 | def eval(self, y_pred, y_true, metadata): 45 | """ 46 | Args: 47 | - y_pred (Tensor): Predicted targets 48 | - y_true (Tensor): True targets 49 | - metadata (Tensor): Metadata 50 | Output: 51 | - results (dict): Dictionary of results 52 | - results_str (str): Pretty print version of the results 53 | """ 54 | raise NotImplementedError 55 | 56 | def get_subset(self, split, frac=1.0, transform=None): 57 | """ 58 | Args: 59 | - split (str): Split identifier, e.g., 'train', 'val', 'test'. 60 | Must be in self.split_dict. 61 | - frac (float): What fraction of the split to randomly sample. 62 | Used for fast development on a small dataset. 63 | - transform (function): Any data transformations to be applied to the input x. 64 | Output: 65 | - subset (WILDSSubset): A (potentially subsampled) subset of the WILDSDataset. 66 | """ 67 | if split not in self.split_dict: 68 | raise ValueError(f"Split {split} not found in dataset's split_dict.") 69 | split_mask = self.split_array == self.split_dict[split] 70 | split_idx = np.where(split_mask)[0] 71 | if frac < 1.0: 72 | num_to_retain = int(np.round(float(len(split_idx)) * frac)) 73 | split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain]) 74 | subset = WILDSSubset(self, split_idx, transform) 75 | return subset 76 | 77 | def check_init(self): 78 | """ 79 | Convenience function to check that the WILDSDataset is properly configured. 80 | """ 81 | required_attrs = ['_dataset_name', '_data_dir', 82 | '_split_scheme', '_split_array', 83 | '_y_array', '_y_size', 84 | '_metadata_fields', '_metadata_array'] 85 | for attr_name in required_attrs: 86 | assert hasattr(self, attr_name), f'WILDSDataset is missing {attr_name}.' 87 | 88 | # Check that data directory exists 89 | if not os.path.exists(self.data_dir): 90 | raise ValueError( 91 | f'{self.data_dir} does not exist yet. Please generate the dataset first.') 92 | 93 | # Check splits 94 | assert self.split_dict.keys()==self.split_names.keys() 95 | assert 'train' in self.split_dict 96 | assert 'val' in self.split_dict 97 | 98 | # Check that required arrays are Tensors 99 | assert isinstance(self.y_array, torch.Tensor), 'y_array must be a torch.Tensor' 100 | assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor' 101 | 102 | # Check that dimensions match 103 | assert len(self.y_array) == len(self.metadata_array) 104 | assert len(self.split_array) == len(self.metadata_array) 105 | 106 | # Check metadata 107 | assert len(self.metadata_array.shape) == 2 108 | assert len(self.metadata_fields) == self.metadata_array.shape[1] 109 | # For convenience, include y in metadata_fields if y_size == 1 110 | if self.y_size == 1: 111 | assert 'y' in self.metadata_fields 112 | 113 | @property 114 | def latest_version(cls): 115 | def is_later(u, v): 116 | """Returns true if u is a later version than v.""" 117 | u_major, u_minor = tuple(map(int, u.split('.'))) 118 | v_major, v_minor = tuple(map(int, v.split('.'))) 119 | if (u_major > v_major) or ( 120 | (u_major == v_major) and (u_minor > v_minor)): 121 | return True 122 | else: 123 | return False 124 | 125 | latest_version = '0.0' 126 | for key in cls.versions_dict.keys(): 127 | if is_later(key, latest_version): 128 | latest_version = key 129 | return latest_version 130 | 131 | @property 132 | def dataset_name(self): 133 | """ 134 | A string that identifies the dataset, e.g., 'amazon', 'camelyon17'. 135 | """ 136 | return self._dataset_name 137 | 138 | @property 139 | def version(self): 140 | """ 141 | A string that identifies the dataset version, e.g., '1.0'. 142 | """ 143 | if self._version is None: 144 | return self.latest_version 145 | else: 146 | return self._version 147 | 148 | @property 149 | def versions_dict(self): 150 | """ 151 | A dictionary where each key is a version string (e.g., '1.0') 152 | and each value is a dictionary containing the 'download_url' and 153 | 'compressed_size' keys. 154 | 155 | 'download_url' is the URL for downloading the dataset archive. 156 | If None, the dataset cannot be downloaded automatically 157 | (e.g., because it first requires accepting a usage agreement). 158 | 159 | 'compressed_size' is the approximate size of the compressed dataset in bytes. 160 | """ 161 | return self._versions_dict 162 | 163 | @property 164 | def data_dir(self): 165 | """ 166 | The full path to the folder in which the dataset is stored. 167 | """ 168 | return self._data_dir 169 | 170 | @property 171 | def collate(self): 172 | """ 173 | Torch function to collate items in a batch. 174 | By default returns None -> uses default torch collate. 175 | """ 176 | return getattr(self, '_collate', None) 177 | 178 | @property 179 | def split_scheme(self): 180 | """ 181 | A string identifier of how the split is constructed, 182 | e.g., 'standard', 'in-dist', 'user', etc. 183 | """ 184 | return self._split_scheme 185 | 186 | @property 187 | def split_dict(self): 188 | """ 189 | A dictionary mapping splits to integer identifiers (used in split_array), 190 | e.g., {'train': 0, 'val': 1, 'test': 2}. 191 | Keys should match up with split_names. 192 | """ 193 | return getattr(self, '_split_dict', WILDSDataset.DEFAULT_SPLITS) 194 | 195 | @property 196 | def split_names(self): 197 | """ 198 | A dictionary mapping splits to their pretty names, 199 | e.g., {'train': 'Train', 'val': 'Validation', 'test': 'Test'}. 200 | Keys should match up with split_dict. 201 | """ 202 | return getattr(self, '_split_names', WILDSDataset.DEFAULT_SPLIT_NAMES) 203 | 204 | @property 205 | def split_array(self): 206 | """ 207 | An array of integers, with split_array[i] representing what split the i-th data point 208 | belongs to. 209 | """ 210 | return self._split_array 211 | 212 | @property 213 | def y_array(self): 214 | """ 215 | A Tensor of targets (e.g., labels for classification tasks), 216 | with y_array[i] representing the target of the i-th data point. 217 | y_array[i] can contain multiple elements. 218 | """ 219 | return self._y_array 220 | 221 | @property 222 | def y_size(self): 223 | """ 224 | The number of dimensions/elements in the target, i.e., len(y_array[i]). 225 | For standard classification/regression tasks, y_size = 1. 226 | For multi-task or structured prediction settings, y_size > 1. 227 | Used for logging and to configure models to produce appropriately-sized output. 228 | """ 229 | return self._y_size 230 | 231 | @property 232 | def n_classes(self): 233 | """ 234 | Number of classes for single-task classification datasets. 235 | Used for logging and to configure models to produce appropriately-sized output. 236 | None by default. 237 | Leave as None if not applicable (e.g., regression or multi-task classification). 238 | """ 239 | return getattr(self, '_n_classes', None) 240 | 241 | @property 242 | def is_classification(self): 243 | """ 244 | Boolean. True if the task is classification, and false otherwise. 245 | Used for logging purposes. 246 | """ 247 | return (self.n_classes is not None) 248 | 249 | @property 250 | def metadata_fields(self): 251 | """ 252 | A list of strings naming each column of the metadata table, e.g., ['hospital', 'y']. 253 | Must include 'y'. 254 | """ 255 | return self._metadata_fields 256 | 257 | @property 258 | def metadata_array(self): 259 | """ 260 | A Tensor of metadata, with the i-th row representing the metadata associated with 261 | the i-th data point. The columns correspond to the metadata_fields defined above. 262 | """ 263 | return self._metadata_array 264 | 265 | @property 266 | def metadata_map(self): 267 | """ 268 | An optional dictionary that, for each metadata field, contains a list that maps from 269 | integers (in metadata_array) to a string representing what that integer means. 270 | This is only used for logging, so that we print out more intelligible metadata values. 271 | Each key must be in metadata_fields. 272 | For example, if we have 273 | metadata_fields = ['hospital', 'y'] 274 | metadata_map = {'hospital': ['East', 'West']} 275 | then if metadata_array[i, 0] == 0, the i-th data point belongs to the 'East' hospital 276 | while if metadata_array[i, 0] == 1, it belongs to the 'West' hospital. 277 | """ 278 | return getattr(self, '_metadata_map', None) 279 | 280 | @property 281 | def original_resolution(self): 282 | """ 283 | Original image resolution for image datasets. 284 | """ 285 | return getattr(self, '_original_resolution', None) 286 | 287 | def initialize_data_dir(self, root_dir, download): 288 | """ 289 | Helper function for downloading/updating the dataset if required. 290 | Note that we only do a version check for datasets where the download_url is set. 291 | Currently, this includes all datasets except Yelp. 292 | Datasets for which we don't control the download, like Yelp, 293 | might not handle versions similarly. 294 | """ 295 | if self.version not in self.versions_dict: 296 | raise ValueError(f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.') 297 | 298 | download_url = self.versions_dict[self.version]['download_url'] 299 | compressed_size = self.versions_dict[self.version]['compressed_size'] 300 | 301 | os.makedirs(root_dir, exist_ok=True) 302 | 303 | data_dir = os.path.join(root_dir, f'{self.dataset_name}_v{self.version}') 304 | version_file = os.path.join(data_dir, f'RELEASE_v{self.version}.txt') 305 | current_major_version, current_minor_version = tuple(map(int, self.version.split('.'))) 306 | 307 | # Check if we specified the latest version. Otherwise, print a warning. 308 | latest_major_version, latest_minor_version = tuple(map(int, self.latest_version.split('.'))) 309 | if latest_major_version > current_major_version: 310 | print( 311 | f'*****************************\n' 312 | f'{self.dataset_name} has been updated to version {self.latest_version}.\n' 313 | f'You are currently using version {self.version}.\n' 314 | f'We highly recommend updating the dataset by not specifying the older version in the command-line argument or dataset constructor.\n' 315 | f'See https://wilds.stanford.edu/changelog for changes.\n' 316 | f'*****************************\n') 317 | elif latest_minor_version > current_minor_version: 318 | print( 319 | f'*****************************\n' 320 | f'{self.dataset_name} has been updated to version {self.latest_version}.\n' 321 | f'You are currently using version {self.version}.\n' 322 | f'Please consider updating the dataset.\n' 323 | f'See https://wilds.stanford.edu/changelog for changes.\n' 324 | f'*****************************\n') 325 | 326 | # If the data_dir exists and contains the right RELEASE file, 327 | # we assume the dataset is correctly set up 328 | if os.path.exists(data_dir) and os.path.exists(version_file): 329 | return data_dir 330 | 331 | # If the data_dir exists and does not contain the right RELEASE file, but it is not empty and the download_url is not set, 332 | # we assume the dataset is correctly set up 333 | if ((os.path.exists(data_dir)) and 334 | (len(os.listdir(data_dir)) > 0) and 335 | (download_url is None)): 336 | return data_dir 337 | 338 | # Otherwise, we assume the dataset needs to be downloaded. 339 | # If download == False, then return an error. 340 | if download == False: 341 | if download_url is None: 342 | raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.') 343 | else: 344 | raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.') 345 | 346 | # Otherwise, proceed with downloading. 347 | if download_url is None: 348 | raise ValueError(f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.') 349 | 350 | from wilds.datasets.download_utils import download_and_extract_archive 351 | print(f'Downloading dataset to {data_dir}...') 352 | print(f'You can also download the dataset manually at https://wilds.stanford.edu/downloads.') 353 | try: 354 | start_time = time.time() 355 | download_and_extract_archive( 356 | url=download_url, 357 | download_root=data_dir, 358 | filename='archive.tar.gz', 359 | remove_finished=True, 360 | size=compressed_size) 361 | 362 | download_time_in_minutes = (time.time() - start_time) / 60 363 | print(f"It took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.") 364 | except Exception as e: 365 | print(f"\n{os.path.join(data_dir, 'archive.tar.gz')} may be corrupted. Please try deleting it and rerunning this command.\n") 366 | print(f"Exception: ", e) 367 | 368 | return data_dir 369 | 370 | @staticmethod 371 | def standard_eval(metric, y_pred, y_true): 372 | """ 373 | Args: 374 | - metric (Metric): Metric to use for eval 375 | - y_pred (Tensor): Predicted targets 376 | - y_true (Tensor): True targets 377 | Output: 378 | - results (dict): Dictionary of results 379 | - results_str (str): Pretty print version of the results 380 | """ 381 | results = { 382 | **metric.compute(y_pred, y_true), 383 | } 384 | results_str = ( 385 | f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" 386 | ) 387 | return results, results_str 388 | 389 | @staticmethod 390 | def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=True): 391 | """ 392 | Args: 393 | - metric (Metric): Metric to use for eval 394 | - grouper (CombinatorialGrouper): Grouper object that converts metadata into groups 395 | - y_pred (Tensor): Predicted targets 396 | - y_true (Tensor): True targets 397 | - metadata (Tensor): Metadata 398 | Output: 399 | - results (dict): Dictionary of results 400 | - results_str (str): Pretty print version of the results 401 | """ 402 | results, results_str = {}, '' 403 | if aggregate: 404 | results.update(metric.compute(y_pred, y_true)) 405 | results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" 406 | g = grouper.metadata_to_group(metadata) 407 | group_results = metric.compute_group_wise(y_pred, y_true, g, grouper.n_groups) 408 | for group_idx in range(grouper.n_groups): 409 | group_str = grouper.group_field_str(group_idx) 410 | group_metric = group_results[metric.group_metric_field(group_idx)] 411 | group_counts = group_results[metric.group_count_field(group_idx)] 412 | results[f'{metric.name}_{group_str}'] = group_metric 413 | results[f'count_{group_str}'] = group_counts 414 | if group_results[metric.group_count_field(group_idx)] == 0: 415 | continue 416 | results_str += ( 417 | f' {grouper.group_str(group_idx)} ' 418 | f"[n = {group_results[metric.group_count_field(group_idx)]:6.0f}]:\t" 419 | f"{metric.name} = {group_results[metric.group_metric_field(group_idx)]:5.3f}\n") 420 | results[f'{metric.worst_group_metric_field}'] = group_results[f'{metric.worst_group_metric_field}'] 421 | results_str += f"Worst-group {metric.name}: {group_results[metric.worst_group_metric_field]:.3f}\n" 422 | return results, results_str 423 | 424 | 425 | class WILDSSubset(WILDSDataset): 426 | def __init__(self, dataset, indices, transform): 427 | """ 428 | This acts like torch.utils.data.Subset, but on WILDSDatasets. 429 | We pass in transform explicitly because it can potentially vary at 430 | training vs. test time, if we're using data augmentation. 431 | """ 432 | self.dataset = dataset 433 | self.indices = indices 434 | inherited_attrs = ['_dataset_name', '_data_dir', '_collate', 435 | '_split_scheme', '_split_dict', '_split_names', 436 | '_y_size', '_n_classes', 437 | '_metadata_fields', '_metadata_map'] 438 | for attr_name in inherited_attrs: 439 | if hasattr(dataset, attr_name): 440 | setattr(self, attr_name, getattr(dataset, attr_name)) 441 | self.transform = transform 442 | 443 | def __getitem__(self, idx): 444 | x, y, metadata = self.dataset[self.indices[idx]] 445 | if self.transform is not None: 446 | x = self.transform(x) 447 | return x, y, metadata 448 | 449 | def __len__(self): 450 | return len(self.indices) 451 | 452 | @property 453 | def split_array(self): 454 | return self.dataset._split_array[self.indices] 455 | 456 | @property 457 | def y_array(self): 458 | return self.dataset._y_array[self.indices] 459 | 460 | @property 461 | def metadata_array(self): 462 | return self.dataset.metadata_array[self.indices] 463 | 464 | def eval(self, y_pred, y_true, metadata): 465 | return self.dataset.eval(y_pred, y_true, metadata) 466 | --------------------------------------------------------------------------------