├── .gitignore ├── LICENSE.md ├── README.md ├── data ├── dataset.py └── wilds_datasets.py ├── main_erm.py ├── models ├── miro_nets.py ├── resnet.py └── timm_model_wrapper.py ├── requirements.yml ├── scripts ├── domain_net_duplicates.txt ├── download.py └── download_models.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | model_checkpoints/**/** 3 | scc 4 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Piotr Teterwak 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ERM++: An Improved Baseline for Domain Generalization 2 | 3 | Official PyTorch implementation of [ERM++: An Improved Baseline for Domain Generalizaton] 4 | 5 | Piotr Teterwak, Kuniaki Saito, Theodoros Tsiligkaridis, Kate Saenko, Bryan A. Plummer 6 | 7 | 8 | 9 | 10 | ## Installation 11 | 12 | ### Dependencies 13 | 14 | ```sh 15 | conda env create -f requirements.yml 16 | ``` 17 | 18 | ### Dataset Download 19 | 20 | 21 | ```sh 22 | python -m scripts.download --data_dir=/my/datasets/path 23 | 24 | #For PACS 25 | git clone https://github.com/MachineLearning2020/Homework3-PACS 26 | 27 | ``` 28 | 29 | ### Model Download 30 | ```sh 31 | cd scripts 32 | bash download_models.sh 33 | 34 | ``` 35 | 36 | This downloads all models except for MEAL distilled models. To download those, 37 | please see the MEAL github (repository)[https://github.com/szq0214/MEAL-V2]. 38 | 39 | ### Data Path Specification 40 | 41 | Modify the data paths in ```data/dataset.py```, at the top of the file. 42 | 43 | ## Running ERM++ 44 | 45 | An example, which splits off 20% of the training data for validation. 46 | ```sh 47 | python main_erm.py --save_name --dataset domainnet --training_data "clipart infograph real quickdraw sketch" --validation_data "clipart infograph real quickdraw sketch" --sma --save_dir --steps 60000 --train-val-split 0.8 --lr 5e-5 --save-freq 1000 --linear-steps 500 --sma-start-iter 600 --arch resnet_timm_augmix 48 | ``` 49 | 50 | Then, find the number of steps corresponding to the highest (printed in the log) validation accuracy, and retrain on the full data: 51 | 52 | ```sh 53 | python main_erm.py --save_name --dataset domainnet --training_data "clipart infograph real quickdraw sketch" --validation_data painting --sma --save_dir --steps 60000 --lr 5e-5 --save-freq 1000 --linear-steps 500 --sma-start-iter 600 --arch resnet_timm_augmix 54 | ``` 55 | 56 | 57 | To see domain names for different datasets, please see the ```data/dataset.py``` file and search for transform_dict variables for different data. 58 | 59 | 60 | 61 | ## License and Acknowledgements 62 | 63 | This project is released under the MIT license, included [here](./LICENSE). 64 | 65 | This project include some code from [facebookresearch/DomainBed](https://github.com/facebookresearch/DomainBed) (MIT license),[kakaobrain/miro](https://github.com/kakaobrain/miro) (MIT license), and [salesforce/ensemble-of-averages](https://github.com/salesforce/ensemble-of-averages). The structure and some text of the README is borrowed from the MIRO repository. 66 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | from data.wilds_datasets import get_fmow 5 | 6 | 7 | import torch 8 | import torchvision.datasets as datasets 9 | 10 | 11 | TERRAINCOGNITA_PATH = "/projectnb/ivc-ml/piotrt/data/terra_incognita/terra_incognita/" 12 | DOMAINNET_PATH = "/projectnb/ivc-ml/piotrt/data/domainnet" 13 | OFFICEHOME_PATH = "/projectnb/ivc-ml/piotrt/data/office_home" 14 | PACS_PATH = "/projectnb/ivc-ml/piotrt/data/PACS" 15 | VLCS_PATH = "/projectnb/ivc-ml/piotrt/data/VLCS" 16 | 17 | 18 | def _construct_dataset_helper(args, dataset_dict, train_transform, val_transform): 19 | train_datasets = [] 20 | val_datasets = [] 21 | train_dataset_lengths = [] 22 | test_dataset = None 23 | 24 | for d in args.training_data: 25 | train_datasets.append(dataset_dict[d]) 26 | train_dataset_lengths.append(len(dataset_dict[d])) 27 | 28 | for d in args.validation_data: 29 | val_datasets.append(dataset_dict[d]) 30 | 31 | if args.train_val_split > 0: 32 | datasets_split_train = [] 33 | datasets_split_val = [] 34 | for d in train_datasets: 35 | lengths = [int(len(d) * args.train_val_split)] 36 | lengths.append(len(d) - lengths[0]) 37 | train_split, val_split = torch.utils.data.random_split( 38 | d, lengths, torch.Generator().manual_seed(42) 39 | ) 40 | train_split.dataset = copy.copy(d) 41 | train_split.dataset.transform = train_transform 42 | datasets_split_train.append(train_split) 43 | for idx, d in enumerate(val_datasets): 44 | lengths = [int(len(d) * args.train_val_split)] 45 | lengths.append(len(d) - lengths[0]) 46 | train_split, val_split = torch.utils.data.random_split( 47 | d, lengths, torch.Generator().manual_seed(42) 48 | ) 49 | val_split.dataset.transform = val_transform 50 | datasets_split_val.append(val_split) 51 | 52 | train_datasets = datasets_split_train 53 | test_dataset = datasets_split_val 54 | 55 | else: 56 | 57 | test_dataset = val_datasets 58 | 59 | return train_datasets, test_dataset 60 | 61 | 62 | def construct_dataset(args, train_transform, val_transform): 63 | 64 | if args.dataset == "domainnet": 65 | 66 | num_classes = 345 67 | 68 | transform_dict = { 69 | "sketch": train_transform, 70 | "real": train_transform, 71 | "clipart": train_transform, 72 | "infograph": train_transform, 73 | "quickdraw": train_transform, 74 | "painting": train_transform, 75 | } 76 | 77 | for d in args.validation_data: 78 | transform_dict[d] = val_transform 79 | 80 | sketch_dataset = datasets.ImageFolder( 81 | os.path.join(DOMAINNET_PATH, "sketch"), 82 | transform=transform_dict["sketch"], 83 | ) 84 | real_dataset = datasets.ImageFolder( 85 | os.path.join(DOMAINNET_PATH, "real"), 86 | transform=transform_dict["real"], 87 | ) 88 | clipart_dataset = datasets.ImageFolder( 89 | os.path.join(DOMAINNET_PATH, "clipart"), 90 | transform=transform_dict["clipart"], 91 | ) 92 | infograph_dataset = datasets.ImageFolder( 93 | os.path.join(DOMAINNET_PATH, "infograph"), 94 | transform=transform_dict["infograph"], 95 | ) 96 | quickdraw_dataset = datasets.ImageFolder( 97 | os.path.join(DOMAINNET_PATH, "quickdraw"), 98 | transform=transform_dict["quickdraw"], 99 | ) 100 | painting_dataset = datasets.ImageFolder( 101 | os.path.join(DOMAINNET_PATH, "painting"), 102 | transform=transform_dict["painting"], 103 | ) 104 | 105 | dataset_dict = { 106 | "sketch": sketch_dataset, 107 | "real": real_dataset, 108 | "clipart": clipart_dataset, 109 | "infograph": infograph_dataset, 110 | "quickdraw": quickdraw_dataset, 111 | "painting": painting_dataset, 112 | } 113 | 114 | elif args.dataset == "terraincognita": 115 | 116 | num_classes = 10 117 | 118 | transform_dict = { 119 | "location_100": train_transform, 120 | "location_38": train_transform, 121 | "location_43": train_transform, 122 | "location_46": train_transform, 123 | } 124 | 125 | for d in args.validation_data: 126 | transform_dict[d] = val_transform 127 | 128 | location_100_dataset = datasets.ImageFolder( 129 | os.path.join(TERRAINCOGNITA_PATH, "location_100"), 130 | transform=transform_dict["location_100"], 131 | ) 132 | location_38_dataset = datasets.ImageFolder( 133 | os.path.join(TERRAINCOGNITA_PATH, "location_38"), 134 | transform=transform_dict["location_38"], 135 | ) 136 | location_43_dataset = datasets.ImageFolder( 137 | os.path.join(TERRAINCOGNITA_PATH, "location_43"), 138 | transform=transform_dict["location_43"], 139 | ) 140 | location_46_dataset = datasets.ImageFolder( 141 | os.path.join(TERRAINCOGNITA_PATH, "location_46"), 142 | transform=transform_dict["location_46"], 143 | ) 144 | 145 | dataset_dict = { 146 | "location_100": location_100_dataset, 147 | "location_38": location_38_dataset, 148 | "location_43": location_43_dataset, 149 | "location_46": location_46_dataset, 150 | } 151 | 152 | elif args.dataset == "officehome": 153 | 154 | num_classes = 65 155 | 156 | transform_dict = { 157 | "art": train_transform, 158 | "clipart": train_transform, 159 | "product": train_transform, 160 | "real": train_transform, 161 | } 162 | 163 | for d in args.validation_data: 164 | transform_dict[d] = val_transform 165 | 166 | art_dataset = datasets.ImageFolder( 167 | os.path.join(OFFICEHOME_PATH, "Art"), 168 | transform=transform_dict["art"], 169 | ) 170 | clipart_dataset = datasets.ImageFolder( 171 | os.path.join(OFFICEHOME_PATH, "Clipart"), 172 | transform=transform_dict["clipart"], 173 | ) 174 | product_dataset = datasets.ImageFolder( 175 | os.path.join(OFFICEHOME_PATH, "Product"), 176 | transform=transform_dict["product"], 177 | ) 178 | real_dataset = datasets.ImageFolder( 179 | os.path.join(OFFICEHOME_PATH, "Real"), 180 | transform=transform_dict["real"], 181 | ) 182 | 183 | dataset_dict = { 184 | "art": art_dataset, 185 | "clipart": clipart_dataset, 186 | "product": product_dataset, 187 | "real": real_dataset, 188 | } 189 | 190 | elif args.dataset == "pacs": 191 | 192 | num_classes = 7 193 | 194 | transform_dict = { 195 | "art_painting": train_transform, 196 | "cartoon": train_transform, 197 | "photo": train_transform, 198 | "sketch": train_transform, 199 | } 200 | 201 | for d in args.validation_data: 202 | transform_dict[d] = val_transform 203 | 204 | art_painting_dataset = datasets.ImageFolder( 205 | os.path.join(PACS_PATH, "art_painting"), 206 | transform=transform_dict["art_painting"], 207 | ) 208 | cartoon_dataset = datasets.ImageFolder( 209 | os.path.join(PACS_PATH, "cartoon"), 210 | transform=transform_dict["cartoon"], 211 | ) 212 | photo_dataset = datasets.ImageFolder( 213 | os.path.join(PACS_PATH, "photo"), 214 | transform=transform_dict["photo"], 215 | ) 216 | sketch_dataset = datasets.ImageFolder( 217 | os.path.join(PACS_PATH, "sketch"), 218 | transform=transform_dict["sketch"], 219 | ) 220 | 221 | dataset_dict = { 222 | "art_painting": art_painting_dataset, 223 | "cartoon": cartoon_dataset, 224 | "photo": photo_dataset, 225 | "sketch": sketch_dataset, 226 | } 227 | 228 | elif args.dataset == "vlcs": 229 | 230 | num_classes = 5 231 | 232 | transform_dict = { 233 | "caltech101": train_transform, 234 | "labelme": train_transform, 235 | "sun09": train_transform, 236 | "voc2007": train_transform, 237 | } 238 | 239 | for d in args.validation_data: 240 | transform_dict[d] = val_transform 241 | 242 | caltech101_dataset = datasets.ImageFolder( 243 | os.path.join(VLCS_PATH, "Caltech101"), 244 | transform=transform_dict["caltech101"], 245 | ) 246 | labelme_dataset = datasets.ImageFolder( 247 | os.path.join(VLCS_PATH, "LabelMe"), 248 | transform=transform_dict["labelme"], 249 | ) 250 | sun09_dataset = datasets.ImageFolder( 251 | os.path.join(VLCS_PATH, "SUN09"), 252 | transform=transform_dict["sun09"], 253 | ) 254 | voc2007_dataset = datasets.ImageFolder( 255 | os.path.join(VLCS_PATH, "VOC2007"), 256 | transform=transform_dict["voc2007"], 257 | ) 258 | 259 | dataset_dict = { 260 | "caltech101": caltech101_dataset, 261 | "labelme": labelme_dataset, 262 | "sun09": sun09_dataset, 263 | "voc2007": voc2007_dataset, 264 | } 265 | 266 | elif args.dataset == "wilds_fmow": 267 | 268 | num_classes = 62 269 | 270 | train_datasets = [] 271 | val_datasets = [] 272 | train_dataset_lengths = [] 273 | test_dataset = None 274 | 275 | datasets_list = get_fmow(train_transform, val_transform, args.validation_data) 276 | 277 | region0_dataset = datasets_list[0] 278 | region1_dataset = datasets_list[1] 279 | region2_dataset = datasets_list[2] 280 | region3_dataset = datasets_list[3] 281 | region4_dataset = datasets_list[4] 282 | region5_dataset = datasets_list[5] 283 | 284 | dataset_dict = { 285 | "region0": region0_dataset, 286 | "region1": region1_dataset, 287 | "region2": region2_dataset, 288 | "region3": region3_dataset, 289 | "region4": region4_dataset, 290 | "region5": region5_dataset, 291 | } 292 | 293 | train_dataset, test_dataset = _construct_dataset_helper( 294 | args, dataset_dict, train_transform, val_transform 295 | ) 296 | 297 | return train_dataset, test_dataset, num_classes 298 | -------------------------------------------------------------------------------- /data/wilds_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from wilds.datasets.fmow_dataset import FMoWDataset 5 | 6 | 7 | def metadata_values(wilds_dataset, metadata_name): 8 | metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 9 | metadata_vals = wilds_dataset.metadata_array[:, metadata_index] 10 | return sorted(list(set(metadata_vals.view(-1).tolist()))) 11 | 12 | 13 | class WILDSEnvironment: 14 | def __init__(self, wilds_dataset, metadata_name, metadata_value, transform=None): 15 | self.name = metadata_name + "_" + str(metadata_value) 16 | 17 | metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 18 | metadata_array = wilds_dataset.metadata_array 19 | subset_indices = torch.where( 20 | metadata_array[:, metadata_index] == metadata_value 21 | )[0] 22 | 23 | self.dataset = wilds_dataset 24 | self.indices = subset_indices 25 | self.transform = transform 26 | 27 | def __getitem__(self, i): 28 | x = self.dataset.get_input(self.indices[i]) 29 | if type(x).__name__ != "Image": 30 | x = Image.fromarray(x) 31 | 32 | y = self.dataset.y_array[self.indices[i]] 33 | if self.transform is not None: 34 | x = self.transform(x) 35 | return x, y 36 | 37 | def __len__(self): 38 | return len(self.indices) 39 | 40 | 41 | def get_fmow(train_transform, val_transform, validation_data): 42 | 43 | validation_data_dict = { 44 | "region0": 0, 45 | "region1": 1, 46 | "region2": 2, 47 | "region3": 3, 48 | "region4": 4, 49 | "region5": 5, 50 | } 51 | 52 | validation_data_list = [validation_data_dict[d] for d in validation_data] 53 | 54 | dataset = FMoWDataset(root_dir="/projectnb/ivc-ml/piotrt/data/WILDS") 55 | metadata_name = "region" 56 | datasets = [] 57 | for i, metadata_value in enumerate(metadata_values(dataset, metadata_name)): 58 | if i not in validation_data_list: 59 | env_transform = train_transform 60 | else: 61 | env_transform = val_transform 62 | 63 | env_dataset = WILDSEnvironment( 64 | dataset, metadata_name, metadata_value, env_transform 65 | ) 66 | 67 | datasets.append(env_dataset) 68 | 69 | return datasets 70 | -------------------------------------------------------------------------------- /main_erm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import argparse 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | from data.dataset import construct_dataset 19 | from models.resnet import resnet18, resnet50, wide_resnet50_2 20 | from models.timm_model_wrapper import TimmWrapper 21 | from util import MovingAvg, AverageMeter, ProgressMeter, accuracy 22 | import timm 23 | import numpy as np 24 | 25 | 26 | from PIL import ImageFile 27 | 28 | ImageFile.LOAD_TRUNCATED_IMAGES = True 29 | 30 | 31 | model_names = ["regnet", "resnet", "resnet_timm_augmix", "resnet_timm_a1", "meal_v2"] 32 | dataset_names = [ 33 | "domainnet", 34 | "terraincognita", 35 | "officehome", 36 | "pacs", 37 | "vlcs", 38 | "wilds_fmow", 39 | ] 40 | 41 | parser = argparse.ArgumentParser(description="ERM++ training") 42 | parser.add_argument( 43 | "-a", 44 | "--arch", 45 | metavar="ARCH", 46 | default="resnet50", 47 | choices=model_names, 48 | help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", 49 | ) 50 | parser.add_argument( 51 | "--dataset", default="domainnet", choices=dataset_names, help="which dataset to use" 52 | ) 53 | parser.add_argument( 54 | "-j", 55 | "--workers", 56 | default=32, 57 | type=int, 58 | metavar="N", 59 | help="number of data loading workers (default: 32)", 60 | ) 61 | parser.add_argument( 62 | "--steps", default=5000, type=int, metavar="N", help="number of total steps to run" 63 | ) 64 | parser.add_argument( 65 | "--linear-steps", 66 | default=-1, 67 | type=int, 68 | metavar="N", 69 | help="number of total steps to run", 70 | ) 71 | parser.add_argument( 72 | "--sma-start-iter", 73 | default=100, 74 | type=int, 75 | metavar="N", 76 | help="Where to start model averaging.", 77 | ) 78 | parser.add_argument( 79 | "--accum-iter", 80 | default=1, 81 | type=int, 82 | metavar="N", 83 | help="number of steps between updates", 84 | ) 85 | parser.add_argument( 86 | "--start-epoch", 87 | default=0, 88 | type=int, 89 | metavar="N", 90 | help="manual epoch number (useful on restarts)", 91 | ) 92 | parser.add_argument( 93 | "-b", 94 | "--batch-size", 95 | default=32, 96 | type=int, 97 | metavar="N", 98 | help="mini-batch size (default: 256), this is the total " 99 | "batch size of all GPUs on the current node when " 100 | "using Data Parallel or Distributed Data Parallel", 101 | ) 102 | parser.add_argument( 103 | "--lr", 104 | "--learning-rate", 105 | default=5e-5, 106 | type=float, 107 | metavar="LR", 108 | help="initial learning rate", 109 | dest="lr", 110 | ) 111 | parser.add_argument( 112 | "--miro-weight", 113 | default=0.1, 114 | type=float, 115 | metavar="MW", 116 | help="initial learning rate", 117 | dest="miro_weight", 118 | ) 119 | parser.add_argument( 120 | "--wd", 121 | "--weight-decay", 122 | default=0.0, 123 | type=float, 124 | metavar="W", 125 | help="weight decay (default: 0.)", 126 | dest="weight_decay", 127 | ) 128 | parser.add_argument( 129 | "-p", 130 | "--print-freq", 131 | default=10, 132 | type=int, 133 | metavar="N", 134 | help="print frequency (default: 10)", 135 | ) 136 | parser.add_argument( 137 | "--eval-path", 138 | default="", 139 | type=str, 140 | metavar="PATH", 141 | help="path to latest checkpoint (default: none)", 142 | ) 143 | parser.add_argument( 144 | "-e", 145 | "--evaluate", 146 | dest="evaluate", 147 | action="store_true", 148 | help="evaluate model on validation set", 149 | ) 150 | parser.add_argument( 151 | "--seed", default=None, type=int, help="seed for initializing training. " 152 | ) 153 | parser.add_argument( 154 | "--training_data", 155 | default=["sketch", "real"], 156 | type=str, 157 | nargs="*", 158 | help="training subsets", 159 | ) 160 | parser.add_argument( 161 | "--validation_data", 162 | default=["painting"], 163 | type=str, 164 | nargs="*", 165 | help="testing subsets", 166 | ) 167 | 168 | 169 | parser.add_argument( 170 | "--save_name", default="", type=str, help="name of saved checkpoint" 171 | ) 172 | parser.add_argument( 173 | "--save-dir", 174 | default="model_checkpoints/", 175 | type=str, 176 | help="name of saved checkpoint", 177 | ) 178 | parser.add_argument("--freeze-bn", dest="freeze_bn", action="store_true") 179 | parser.add_argument("--miro", dest="miro", action="store_true") 180 | parser.add_argument( 181 | "--save-freq", default=-1, type=int, help="how often to save checkpoints in steps" 182 | ) 183 | parser.add_argument( 184 | "--train-val-split", default=-1, type=float, help="how much to split train val" 185 | ) 186 | 187 | parser.add_argument("--pretrained", dest="pretrained", action="store_true") 188 | parser.add_argument("--sma", dest="sma", action="store_true") 189 | parser.set_defaults(pretrained=False) 190 | parser.set_defaults(sma=False) 191 | parser.set_defaults(freeze_bn=False) 192 | parser.set_defaults(miro=False) 193 | 194 | best_acc1 = 0 195 | best_steps = 0 196 | 197 | if torch.cuda.is_available(): 198 | device = "cuda" 199 | else: 200 | device = "cpu" 201 | 202 | 203 | def main(): 204 | args = parser.parse_args() 205 | 206 | if args.seed is not None: 207 | random.seed(args.seed) 208 | torch.manual_seed(args.seed) 209 | cudnn.deterministic = True 210 | warnings.warn( 211 | "You have chosen to seed training. " 212 | "This will turn on the CUDNN deterministic setting, " 213 | "which can slow down your training considerably! " 214 | "You may see unexpected behavior when restarting " 215 | "from checkpoints." 216 | ) 217 | 218 | global best_acc1 219 | global best_steps 220 | 221 | train_transform = transforms.Compose( 222 | [ 223 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 224 | transforms.RandomHorizontalFlip(), 225 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 226 | transforms.RandomGrayscale(), 227 | transforms.ToTensor(), 228 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 229 | ] 230 | ) 231 | 232 | val_transform = transforms.Compose( 233 | [ 234 | transforms.Resize((224, 224)), 235 | transforms.ToTensor(), 236 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 237 | ] 238 | ) 239 | 240 | train_datasets, test_dataset, num_classes = construct_dataset( 241 | args, train_transform, val_transform 242 | ) 243 | 244 | batch_size_list = [args.batch_size] * len(train_datasets) 245 | 246 | train_loader = [ 247 | torch.utils.data.DataLoader( 248 | train_dataset, 249 | batch_size=bs, 250 | shuffle=True, 251 | num_workers=args.workers, 252 | pin_memory=True, 253 | drop_last=True, 254 | ) 255 | for bs, train_dataset in zip(batch_size_list, train_datasets) 256 | ] 257 | 258 | val_loader = torch.utils.data.DataLoader( 259 | torch.utils.data.ConcatDataset(test_dataset), 260 | batch_size=args.batch_size, 261 | shuffle=True, 262 | num_workers=args.workers, 263 | pin_memory=True, 264 | drop_last=False, 265 | ) 266 | 267 | # create model 268 | print("=> creating model '{}'".format(args.arch)) 269 | if args.arch == "resnet50": 270 | model = resnet50( 271 | pretrained=args.pretrained, 272 | num_classes=num_classes, 273 | freeze_bn=args.freeze_bn, 274 | projection_head=args.projection_head, 275 | ) 276 | elif args.arch == "resnet_timm_a1": 277 | model = TimmWrapper( 278 | timm.create_model( 279 | "resnet50", 280 | pretrained=False, 281 | num_classes=num_classes, 282 | features_only=args.miro, 283 | ), 284 | freeze_bn=args.freeze_bn, 285 | miro=args.miro, 286 | num_classes=num_classes, 287 | ) 288 | state_dict = { 289 | "model." + k.split("module.")[-1]: v 290 | for (k, v) in torch.load( 291 | "model_checkpoints/resnet_a1/resnet50_a1_0-14fe96d1.pth" 292 | ).items() 293 | } 294 | del state_dict["model.fc.weight"] 295 | del state_dict["model.fc.bias"] 296 | msg = model.load_state_dict(state_dict, strict=False) 297 | assert set(msg.missing_keys) == {"model.fc.weight", "model.fc.bias"} 298 | elif args.arch == "resnet_timm_augmix": 299 | model = TimmWrapper( 300 | timm.create_model( 301 | "resnet50", 302 | pretrained=False, 303 | num_classes=num_classes, 304 | features_only=args.miro, 305 | ), 306 | freeze_bn=args.freeze_bn, 307 | miro=args.miro, 308 | num_classes=num_classes, 309 | ) 310 | state_dict = { 311 | "model." + k.split("module.")[-1]: v 312 | for (k, v) in torch.load( 313 | "model_checkpoints/augmix/resnet50_ram-a26f946b.pth" 314 | ).items() 315 | } 316 | del state_dict["model.fc.weight"] 317 | del state_dict["model.fc.bias"] 318 | msg = model.load_state_dict(state_dict, strict=False) 319 | assert set(msg.missing_keys) == {"model.fc.weight", "model.fc.bias"} 320 | elif args.arch == "meal_v2": 321 | model = timm.create_model("resnet50", pretrained=False, num_classes=num_classes) 322 | state_dict = { 323 | k.split("module.")[-1]: v 324 | for (k, v) in torch.load( 325 | "model_checkpoints/meal_v2/MEALV2_ResNet50_224.pth" 326 | ).items() 327 | } 328 | del state_dict["fc.weight"] 329 | del state_dict["fc.bias"] 330 | msg = model.load_state_dict(state_dict, strict=False) 331 | assert set(msg.missing_keys) == {"model.fc.weight", "model.fc.bias"} 332 | else: 333 | raise RuntimeError("Invalid architechture specified") 334 | 335 | if args.miro: 336 | featurizer = TimmWrapper( 337 | timm.create_model( 338 | "resnet50", pretrained=True, num_classes=num_classes, features_only=True 339 | ), 340 | freeze_bn=True, 341 | miro=args.miro, 342 | num_classes=num_classes, 343 | freeze_all=True, 344 | ).to(device) 345 | shapes = miro_nets.get_shapes(featurizer, (3, 224, 224)) 346 | mean_encoders = nn.ModuleList( 347 | [miro_nets.MeanEncoder(shape).to(device) for shape in shapes] 348 | ) 349 | var_encoders = nn.ModuleList( 350 | [miro_nets.VarianceEncoder(shape).to(device) for shape in shapes] 351 | ) 352 | else: 353 | featurizer = None 354 | mean_encoders = None 355 | var_encoders = None 356 | 357 | print(model) 358 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 359 | params = sum([np.prod(p.size()) for p in model_parameters]) 360 | print("Num params:{}".format(params)) 361 | 362 | model = torch.nn.DataParallel(model).to(device) 363 | 364 | criterion = nn.CrossEntropyLoss().to(device) 365 | 366 | if args.miro: 367 | backbone_parameters = [ 368 | {"params": model.parameters()}, 369 | {"params": mean_encoders.parameters(), "lr": args.lr * 10}, 370 | {"params": var_encoders.parameters(), "lr": args.lr * 10}, 371 | ] 372 | else: 373 | backbone_parameters = model.parameters() 374 | 375 | optimizer = torch.optim.Adam( 376 | backbone_parameters, args.lr, weight_decay=args.weight_decay 377 | ) 378 | 379 | linear_parameters = [] 380 | 381 | for n, p in model.named_parameters(): 382 | if "fc" in n: 383 | linear_parameters.append(p) 384 | 385 | linear_optimizer = torch.optim.Adam( 386 | linear_parameters, args.lr, weight_decay=args.weight_decay 387 | ) 388 | 389 | # Load model for eval 390 | if args.eval_path: 391 | if os.path.isfile(args.eval_path): 392 | print("=> loading checkpoint '{}'".format(args.eval_path)) 393 | checkpoint = torch.load(args.eval_path) 394 | model.load_state_dict(checkpoint["state_dict"]) 395 | print( 396 | "=> loaded checkpoint '{}' (epoch {})".format( 397 | args.eval_path, checkpoint["epoch"] 398 | ) 399 | ) 400 | else: 401 | print("=> no checkpoint found at '{}'".format(args.eval_path)) 402 | 403 | cudnn.benchmark = True 404 | 405 | model = MovingAvg(model, args.sma, args.sma_start_iter) 406 | 407 | if args.evaluate: 408 | validate(val_loader, model, criterion, args, steps) 409 | return 410 | 411 | int(args.steps / len(train_loader)) 412 | epoch = 0 413 | steps = 0 414 | save_iterate = 0 415 | 416 | if args.evaluate: 417 | validate(val_loader, model, criterion, args, steps) 418 | return 419 | 420 | 421 | while True: 422 | if steps > args.steps: 423 | break 424 | 425 | steps, save_iterate = train( 426 | train_loader, 427 | val_loader, 428 | model, 429 | criterion, 430 | optimizer, 431 | linear_optimizer, 432 | epoch, 433 | args, 434 | steps, 435 | save_iterate, 436 | featurizer=featurizer, 437 | mean_encoders=mean_encoders, 438 | var_encoders=var_encoders, 439 | ) 440 | epoch = epoch + 1 441 | 442 | acc1 = validate(val_loader, model, criterion, args, steps) 443 | 444 | # remember best acc@1 and save checkpoint 445 | is_best = acc1 > best_acc1 446 | best_acc1 = max(acc1, best_acc1) 447 | 448 | save_name = "{}_{}".format(args.save_name, save_iterate) 449 | save_checkpoint( 450 | { 451 | "epoch": epoch, 452 | "arch": args.arch, 453 | "state_dict": model.network_sma.state_dict(), 454 | "best_acc1": best_acc1, 455 | "optimizer": optimizer.state_dict(), 456 | }, 457 | is_best, 458 | save_name=save_name, 459 | save_dir=args.save_dir, 460 | save_iterate=save_iterate, 461 | args=args, 462 | ) 463 | 464 | 465 | def train( 466 | train_loader, 467 | val_loader, 468 | model, 469 | criterion, 470 | optimizer, 471 | linear_optimizer, 472 | epoch, 473 | args, 474 | steps, 475 | save_iterate, 476 | featurizer, 477 | mean_encoders, 478 | var_encoders, 479 | ): 480 | global best_acc1 481 | global best_steps 482 | batch_time = AverageMeter("Time", ":6.3f") 483 | data_time = AverageMeter("Data", ":6.3f") 484 | losses = AverageMeter("Loss", ":.4e") 485 | 486 | top1 = AverageMeter("Acc@1", ":6.2f") 487 | top5 = AverageMeter("Acc@5", ":6.2f") 488 | 489 | # switch to train mode 490 | model.network.train() 491 | model.network_sma.train() 492 | 493 | train_loader_epoch = train_loader.copy() 494 | 495 | train_loader_main_idx = np.argmax([len(d) for d in train_loader_epoch]) 496 | train_loader_main = train_loader_epoch.pop(train_loader_main_idx) 497 | 498 | aux_iter_list = [iter(aux_loader) for aux_loader in train_loader_epoch] 499 | end = time.time() 500 | 501 | progress = ProgressMeter( 502 | len(train_loader_main), 503 | [batch_time, data_time, losses, top1, top5], 504 | prefix="Epoch: [{}]".format(epoch), 505 | ) 506 | 507 | for i, (images, target) in enumerate(train_loader_main): 508 | if steps > args.steps: 509 | return steps, save_iterate 510 | if steps > args.linear_steps: 511 | selected_optimizer = optimizer 512 | else: 513 | selected_optimizer = linear_optimizer 514 | 515 | steps = steps + 1 516 | # measure data loading time 517 | aux_images_list = [] 518 | aux_target_list = [] 519 | for idx, aux_iter in enumerate(aux_iter_list): 520 | try: 521 | aux_images, aux_target = next(aux_iter) 522 | except StopIteration: 523 | aux_iter_list[idx] = iter(train_loader_epoch[idx]) 524 | aux_images, aux_target = next(aux_iter_list[idx]) 525 | aux_images_list.append(aux_images) 526 | aux_target_list.append(aux_target) 527 | 528 | images = images.to(device, non_blocking=True) 529 | for idx in range(len(aux_iter_list)): 530 | aux_images_list[idx] = aux_images_list[idx].to(device, non_blocking=True) 531 | target = target.cuda(non_blocking=True) 532 | for idx in range(len(aux_iter_list)): 533 | aux_target_list[idx] = aux_target_list[idx].to(device, non_blocking=True) 534 | 535 | images = torch.concat([images] + aux_images_list, dim=0) 536 | target = torch.concat([target] + aux_target_list) 537 | 538 | data_time.update(time.time() - end) 539 | 540 | # compute output 541 | if args.miro: 542 | with torch.no_grad(): 543 | _, pre_feats = featurizer.forward_features(images) 544 | output, inter_feats = model.network.module.forward_features(images) 545 | loss = criterion(output, target) 546 | reg_loss = 0.0 547 | for f, pre_f, mean_enc, var_enc in zip( 548 | inter_feats, pre_feats, mean_encoders, var_encoders 549 | ): 550 | # mutual information regularization 551 | mean = mean_enc(f) 552 | var = var_enc(f) 553 | vlb = (mean - pre_f).pow(2).div(var) + var.log() 554 | reg_loss += vlb.mean() / 2.0 555 | 556 | loss += reg_loss * args.miro_weight 557 | 558 | else: 559 | output = model.network(images) 560 | loss = criterion(output, target) 561 | # measure accuracy and record loss 562 | max_k = 5 563 | acc1, acc5 = accuracy(output, target, topk=(1, max_k)) 564 | losses.update(loss.item(), images.size(0)) 565 | top1.update(acc1[0], images.size(0)) 566 | top5.update(acc5[0], images.size(0)) 567 | 568 | # compute gradient and do SGD step 569 | loss = loss / args.accum_iter 570 | 571 | loss.backward() 572 | 573 | if ((i + 1) % args.accum_iter == 0) or (i + 1 == len(train_loader_main)): 574 | selected_optimizer.step() 575 | selected_optimizer.zero_grad() 576 | model.update_sma() 577 | if not args.freeze_bn: 578 | model.network_sma(images) 579 | 580 | # measure elapsed time 581 | batch_time.update(time.time() - end) 582 | end = time.time() 583 | 584 | if i % args.print_freq == 0: 585 | progress.display(i) 586 | 587 | if args.save_freq > 0 and (steps % args.save_freq == 0): 588 | 589 | # evaluate on validation set 590 | acc1 = validate(val_loader, model, criterion, args,steps) 591 | # switch to train mode 592 | model.network.train() 593 | model.network_sma.train() 594 | 595 | # remember best acc@1 and save checkpoint 596 | is_best = acc1 > best_acc1 597 | best_acc1 = max(acc1, best_acc1) 598 | if is_best: 599 | best_steps = steps 600 | save_checkpoint( 601 | { 602 | "epoch": epoch, 603 | "arch": args.arch, 604 | "state_dict": model.network_sma.state_dict(), 605 | "best_acc1": best_acc1, 606 | "optimizer": optimizer.state_dict(), 607 | }, 608 | is_best, 609 | save_name=args.save_name, 610 | save_dir=args.save_dir, 611 | args=args, 612 | ) 613 | 614 | 615 | return steps, save_iterate 616 | 617 | 618 | def validate(val_loader, model, criterion, args, steps): 619 | global best_steps 620 | global best_acc1 621 | batch_time = AverageMeter("Time", ":6.3f") 622 | losses = AverageMeter("Loss", ":.4e") 623 | 624 | top1 = AverageMeter("Acc@1", ":6.2f") 625 | top5 = AverageMeter("Acc@5", ":6.2f") 626 | progress = ProgressMeter( 627 | len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 628 | ) 629 | 630 | # switch to evaluate mode 631 | model.network.eval() 632 | model.network_sma.eval() 633 | end = time.time() 634 | for i, (images, target) in enumerate(val_loader): 635 | images = images.to(device, non_blocking=True) 636 | target = target.to(device, non_blocking=True) 637 | 638 | # compute output 639 | output = model.network_sma(images) 640 | loss = criterion(output, target) 641 | 642 | # measure accuracy and record loss 643 | max_k = 5 644 | acc1, acc5 = accuracy(output, target, topk=(1, max_k)) 645 | 646 | losses.update(loss.item(), images.size(0)) 647 | 648 | top1.update(acc1[0], images.size(0)) 649 | top5.update(acc5[0], images.size(0)) 650 | 651 | # measure elapsed time 652 | batch_time.update(time.time() - end) 653 | end = time.time() 654 | 655 | if i % args.print_freq == 0: 656 | progress.display(i) 657 | 658 | # remember best acc@1 and save checkpoint 659 | is_best = top1.avg > best_acc1 660 | best_acc1 = max(top1.avg, best_acc1) 661 | if is_best: 662 | best_steps = steps 663 | 664 | 665 | progress.display_summary() 666 | print("Best acc steps: {}".format(best_steps)) 667 | 668 | return top1.avg 669 | 670 | 671 | def save_checkpoint(state, is_best, save_name, save_dir, args=None): 672 | filename = save_dir + "checkpoint_" + str(save_name) + ".pth.tar" 673 | torch.save(state, filename) 674 | if is_best: 675 | best_filename = save_dir + "model_best_" + str(args.save_name) + ".pth.tar" 676 | shutil.copyfile(filename, best_filename) 677 | 678 | 679 | if __name__ == "__main__": 680 | main() 681 | -------------------------------------------------------------------------------- /models/miro_nets.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kakaobrain/miro/blob/52384553e9745b1dbd974aaf8db3908e5757f245/domainbed/algorithms/miro.py#L141 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MeanEncoder(nn.Module): 9 | """Identity function""" 10 | 11 | def __init__(self, shape): 12 | super().__init__() 13 | self.shape = shape 14 | 15 | def forward(self, x): 16 | return x 17 | 18 | 19 | class VarianceEncoder(nn.Module): 20 | """Bias-only model with diagonal covariance""" 21 | 22 | def __init__(self, shape, init=0.1, channelwise=True, eps=1e-5): 23 | super().__init__() 24 | self.shape = shape 25 | self.eps = eps 26 | 27 | init = (torch.as_tensor(init - eps).exp() - 1.0).log() 28 | b_shape = shape 29 | if channelwise: 30 | if len(shape) == 4: 31 | # [B, C, H, W] 32 | b_shape = (1, shape[1], 1, 1) 33 | elif len(shape) == 3: 34 | # CLIP-ViT: [H*W+1, B, C] 35 | b_shape = (1, 1, shape[2]) 36 | else: 37 | raise ValueError() 38 | 39 | self.b = nn.Parameter(torch.full(b_shape, init)) 40 | 41 | def forward(self, x): 42 | return F.softplus(self.b) + self.eps 43 | 44 | 45 | def get_shapes(model, input_shape): 46 | # get shape of intermediate features 47 | with torch.no_grad(): 48 | dummy = torch.rand(1, *input_shape).to(next(model.parameters()).device) 49 | _, feats = model.forward_features(dummy) 50 | shapes = [f.shape for f in feats] 51 | 52 | return shapes 53 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | # from .utils import load_state_dict_from_url 8 | from torch.hub import load_state_dict_from_url 9 | from typing import Type, Any, Callable, Union, List, Optional 10 | 11 | 12 | __all__ = [ 13 | "ResNet", 14 | "resnet18", 15 | "resnet34", 16 | "resnet50", 17 | "resnet101", 18 | "resnet152", 19 | "resnext50_32x4d", 20 | "resnext101_32x8d", 21 | "wide_resnet50_2", 22 | "wide_resnet101_2", 23 | ] 24 | 25 | 26 | model_urls = { 27 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 28 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 29 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 30 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 31 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 32 | "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", 33 | "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", 34 | "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", 35 | "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", 36 | } 37 | 38 | 39 | def conv3x3( 40 | in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 41 | ) -> nn.Conv2d: 42 | """3x3 convolution with padding""" 43 | return nn.Conv2d( 44 | in_planes, 45 | out_planes, 46 | kernel_size=3, 47 | stride=stride, 48 | padding=dilation, 49 | groups=groups, 50 | bias=False, 51 | dilation=dilation, 52 | ) 53 | 54 | 55 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 56 | """1x1 convolution""" 57 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 58 | 59 | 60 | class BasicBlock(nn.Module): 61 | expansion: int = 1 62 | 63 | def __init__( 64 | self, 65 | inplanes: int, 66 | planes: int, 67 | stride: int = 1, 68 | downsample: Optional[nn.Module] = None, 69 | groups: int = 1, 70 | base_width: int = 64, 71 | dilation: int = 1, 72 | norm_layer: Optional[Callable[..., nn.Module]] = None, 73 | ) -> None: 74 | super(BasicBlock, self).__init__() 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | if groups != 1 or base_width != 64: 78 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 79 | if dilation > 1: 80 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 81 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 82 | self.conv1 = conv3x3(inplanes, planes, stride) 83 | self.bn1 = norm_layer(planes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.conv2 = conv3x3(planes, planes) 86 | self.bn2 = norm_layer(planes) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x: Tensor) -> Tensor: 91 | identity = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | 100 | if self.downsample is not None: 101 | identity = self.downsample(x) 102 | 103 | out += identity 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class Bottleneck(nn.Module): 110 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 111 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 112 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 113 | # This variant is also known as ResNet V1.5 and improves accuracy according to 114 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 115 | 116 | expansion: int = 4 117 | 118 | def __init__( 119 | self, 120 | inplanes: int, 121 | planes: int, 122 | stride: int = 1, 123 | downsample: Optional[nn.Module] = None, 124 | groups: int = 1, 125 | base_width: int = 64, 126 | dilation: int = 1, 127 | norm_layer: Optional[Callable[..., nn.Module]] = None, 128 | ) -> None: 129 | super(Bottleneck, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | width = int(planes * (base_width / 64.0)) * groups 133 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 134 | self.conv1 = conv1x1(inplanes, width) 135 | self.bn1 = norm_layer(width) 136 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 137 | self.bn2 = norm_layer(width) 138 | self.conv3 = conv1x1(width, planes * self.expansion) 139 | self.bn3 = norm_layer(planes * self.expansion) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.downsample = downsample 142 | self.stride = stride 143 | 144 | def forward(self, x: Tensor) -> Tensor: 145 | identity = x 146 | 147 | out = self.conv1(x) 148 | out = self.bn1(out) 149 | out = self.relu(out) 150 | 151 | out = self.conv2(out) 152 | out = self.bn2(out) 153 | out = self.relu(out) 154 | 155 | out = self.conv3(out) 156 | out = self.bn3(out) 157 | 158 | if self.downsample is not None: 159 | identity = self.downsample(x) 160 | 161 | out += identity 162 | out = self.relu(out) 163 | 164 | return out 165 | 166 | 167 | # https://github.com/yaoxufeng/PCL-Proxy-based-Contrastive-Learning-for-Domain-Generalization/blob/86b1d1bdddb2e93f234faff128d805a2f1931d7c/domainbed/networks.py 168 | class Identity(nn.Module): 169 | """An identity layer""" 170 | 171 | def __init__(self): 172 | super(Identity, self).__init__() 173 | 174 | def forward(self, x): 175 | return x 176 | 177 | 178 | class PCLResNet(nn.Module): 179 | def __init__(self, network, proj_dim=512, n_outputs=2048): 180 | super(PCLResNet, self).__init__() 181 | self.network = network 182 | self.network.fc = Identity() 183 | self.fea_proj = nn.Sequential( 184 | nn.Linear(proj_dim, proj_dim), 185 | ) 186 | self.fc_proj = nn.Parameter(torch.FloatTensor(proj_dim, proj_dim)) 187 | 188 | dropout = nn.Dropout(0.25) 189 | self.encoder = nn.Sequential( 190 | nn.Linear(n_outputs, proj_dim), 191 | nn.BatchNorm1d(proj_dim), 192 | nn.ReLU(inplace=True), 193 | dropout, 194 | nn.Linear(proj_dim, proj_dim), 195 | nn.BatchNorm1d(proj_dim), 196 | nn.ReLU(inplace=True), 197 | dropout, 198 | nn.Linear(proj_dim, proj_dim), 199 | ) 200 | self.classifier = nn.Parameter(torch.FloatTensor(network.num_classes, proj_dim)) 201 | nn.init.kaiming_uniform_(self.fc_proj, mode="fan_out", a=math.sqrt(5)) 202 | nn.init.kaiming_uniform_(self.classifier, mode="fan_out", a=math.sqrt(5)) 203 | self._initialize_weights(self.encoder) 204 | 205 | def forward(self, x): 206 | x = self.network(x) 207 | x = self.encoder(x) 208 | rep = self.fea_proj(x) 209 | pred = F.linear(x, self.classifier) 210 | 211 | return rep, pred 212 | 213 | def _initialize_weights(self, modules): 214 | for m in modules: 215 | if isinstance(m, nn.Conv2d): 216 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 217 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 218 | if m.bias is not None: 219 | m.bias.data.zero_() 220 | elif isinstance(m, nn.BatchNorm2d): 221 | m.weight.data.fill_(1) 222 | m.bias.data.zero_() 223 | elif isinstance(m, nn.Linear): 224 | n = m.weight.size(1) 225 | m.weight.data.normal_(0, 0.01) 226 | m.bias.data.zero_() 227 | 228 | 229 | class ResNet(nn.Module): 230 | def __init__( 231 | self, 232 | block: Type[Union[BasicBlock, Bottleneck]], 233 | layers: List[int], 234 | num_classes: int = 1000, 235 | zero_init_residual: bool = False, 236 | groups: int = 1, 237 | width_per_group: int = 64, 238 | width_multiplier: int = 1, 239 | second_order: bool = False, 240 | replace_stride_with_dilation: Optional[List[bool]] = None, 241 | norm_layer: Optional[Callable[..., nn.Module]] = None, 242 | freeze_bn: bool = False, 243 | projection_head: bool = False, 244 | ) -> None: 245 | super(ResNet, self).__init__() 246 | if norm_layer is None: 247 | norm_layer = nn.BatchNorm2d 248 | self._norm_layer = norm_layer 249 | self.num_classes = num_classes 250 | self.second_order = second_order 251 | self.freeze_bn = freeze_bn 252 | self.projection_head = projection_head 253 | 254 | self.inplanes = 64 * width_multiplier 255 | self.dilation = 1 256 | if replace_stride_with_dilation is None: 257 | # each element in the tuple indicates if we should replace 258 | # the 2x2 stride with a dilated convolution instead 259 | replace_stride_with_dilation = [False, False, False] 260 | if len(replace_stride_with_dilation) != 3: 261 | raise ValueError( 262 | "replace_stride_with_dilation should be None " 263 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 264 | ) 265 | self.groups = groups 266 | self.base_width = width_per_group 267 | self.conv1 = nn.Conv2d( 268 | 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 269 | ) 270 | self.bn1 = norm_layer(self.inplanes) 271 | self.relu = nn.ReLU(inplace=True) 272 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 273 | self.layer1 = self._make_layer(block, 64 * width_multiplier, layers[0]) 274 | self.layer2 = self._make_layer( 275 | block, 276 | 128 * width_multiplier, 277 | layers[1], 278 | stride=2, 279 | dilate=replace_stride_with_dilation[0], 280 | ) 281 | self.layer3 = self._make_layer( 282 | block, 283 | 256 * width_multiplier, 284 | layers[2], 285 | stride=2, 286 | dilate=replace_stride_with_dilation[1], 287 | ) 288 | self.layer4 = self._make_layer( 289 | block, 290 | 512 * width_multiplier, 291 | layers[3], 292 | stride=2, 293 | dilate=replace_stride_with_dilation[2], 294 | ) 295 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 296 | if projection_head: 297 | self.fc1 = nn.Linear( 298 | 512 * width_multiplier * block.expansion, 299 | 512 * width_multiplier * block.expansion, 300 | ) 301 | self.bn_last = nn.BatchNorm1d(512 * width_multiplier * block.expansion) 302 | self.fc2 = nn.Linear( 303 | 512 * width_multiplier * block.expansion, 304 | 512 * width_multiplier * block.expansion, 305 | ) 306 | self.fc_final = nn.Linear( 307 | 512 * width_multiplier * block.expansion, self.num_classes 308 | ) 309 | else: 310 | self.fc = nn.Linear( 311 | 512 * width_multiplier * block.expansion, self.num_classes 312 | ) 313 | 314 | if self.freeze_bn: 315 | self.freeze_batchnorm() 316 | 317 | for m in self.modules(): 318 | if isinstance(m, nn.Conv2d): 319 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 320 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 321 | nn.init.constant_(m.weight, 1) 322 | nn.init.constant_(m.bias, 0) 323 | 324 | # Zero-initialize the last BN in each residual branch, 325 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 326 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 327 | if zero_init_residual: 328 | for m in self.modules(): 329 | if isinstance(m, Bottleneck): 330 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 331 | elif isinstance(m, BasicBlock): 332 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 333 | 334 | def _make_layer( 335 | self, 336 | block: Type[Union[BasicBlock, Bottleneck]], 337 | planes: int, 338 | blocks: int, 339 | stride: int = 1, 340 | dilate: bool = False, 341 | ) -> nn.Sequential: 342 | norm_layer = self._norm_layer 343 | downsample = None 344 | previous_dilation = self.dilation 345 | if dilate: 346 | self.dilation *= stride 347 | stride = 1 348 | if stride != 1 or self.inplanes != planes * block.expansion: 349 | downsample = nn.Sequential( 350 | conv1x1(self.inplanes, planes * block.expansion, stride), 351 | norm_layer(planes * block.expansion), 352 | ) 353 | 354 | layers = [] 355 | layers.append( 356 | block( 357 | self.inplanes, 358 | planes, 359 | stride, 360 | downsample, 361 | self.groups, 362 | self.base_width, 363 | previous_dilation, 364 | norm_layer, 365 | ) 366 | ) 367 | self.inplanes = planes * block.expansion 368 | for _ in range(1, blocks): 369 | layers.append( 370 | block( 371 | self.inplanes, 372 | planes, 373 | groups=self.groups, 374 | base_width=self.base_width, 375 | dilation=self.dilation, 376 | norm_layer=norm_layer, 377 | ) 378 | ) 379 | 380 | return nn.Sequential(*layers) 381 | 382 | def _forward_impl(self, x: Tensor) -> Tensor: 383 | # See note [TorchScript super()] 384 | x = self.conv1(x) 385 | x = self.bn1(x) 386 | x = self.relu(x) 387 | x = self.maxpool(x) 388 | 389 | x = self.layer1(x) 390 | x = self.layer2(x) 391 | x = self.layer3(x) 392 | x = self.layer4(x) 393 | 394 | x = self.avgpool(x) 395 | x = torch.flatten(x, 1) 396 | if self.projection_head: 397 | x = self.fc1(x) 398 | x = self.bn_last(x) 399 | x = self.relu(x) 400 | x = self.fc2(x) 401 | return self.fc_final(x) 402 | else: 403 | return self.fc(x) 404 | 405 | def forward(self, x: Tensor) -> Tensor: 406 | return self._forward_impl(x) 407 | 408 | def train(self, mode=True): 409 | """ 410 | Override the default train() to freeze the BN parameters 411 | """ 412 | super().train(mode) 413 | if self.freeze_bn: 414 | self.freeze_batchnorm() 415 | 416 | def freeze_batchnorm(self): 417 | for m in self.modules(): 418 | if isinstance(m, nn.BatchNorm2d): 419 | m.eval() 420 | 421 | 422 | def _resnet( 423 | arch: str, 424 | block: Type[Union[BasicBlock, Bottleneck]], 425 | layers: List[int], 426 | pretrained: bool, 427 | progress: bool, 428 | **kwargs: Any 429 | ) -> ResNet: 430 | model = ResNet(block, layers, **kwargs) 431 | if pretrained: 432 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 433 | del state_dict["fc.weight"] 434 | del state_dict["fc.bias"] 435 | msg = model.load_state_dict(state_dict, strict=False) 436 | # assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 437 | return model 438 | 439 | 440 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 441 | r"""ResNet-18 model from 442 | `"Deep Residual Learning for Image Recognition" `_. 443 | Args: 444 | pretrained (bool): If True, returns a model pre-trained on ImageNet 445 | progress (bool): If True, displays a progress bar of the download to stderr 446 | """ 447 | return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 448 | 449 | 450 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 451 | r"""ResNet-34 model from 452 | `"Deep Residual Learning for Image Recognition" `_. 453 | Args: 454 | pretrained (bool): If True, returns a model pre-trained on ImageNet 455 | progress (bool): If True, displays a progress bar of the download to stderr 456 | """ 457 | return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 458 | 459 | 460 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 461 | r"""ResNet-50 model from 462 | `"Deep Residual Learning for Image Recognition" `_. 463 | Args: 464 | pretrained (bool): If True, returns a model pre-trained on ImageNet 465 | progress (bool): If True, displays a progress bar of the download to stderr 466 | """ 467 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 468 | 469 | 470 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 471 | r"""ResNet-101 model from 472 | `"Deep Residual Learning for Image Recognition" `_. 473 | Args: 474 | pretrained (bool): If True, returns a model pre-trained on ImageNet 475 | progress (bool): If True, displays a progress bar of the download to stderr 476 | """ 477 | return _resnet( 478 | "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs 479 | ) 480 | 481 | 482 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 483 | r"""ResNet-152 model from 484 | `"Deep Residual Learning for Image Recognition" `_. 485 | Args: 486 | pretrained (bool): If True, returns a model pre-trained on ImageNet 487 | progress (bool): If True, displays a progress bar of the download to stderr 488 | """ 489 | return _resnet( 490 | "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs 491 | ) 492 | 493 | 494 | def resnext50_32x4d( 495 | pretrained: bool = False, progress: bool = True, **kwargs: Any 496 | ) -> ResNet: 497 | r"""ResNeXt-50 32x4d model from 498 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 499 | Args: 500 | pretrained (bool): If True, returns a model pre-trained on ImageNet 501 | progress (bool): If True, displays a progress bar of the download to stderr 502 | """ 503 | kwargs["groups"] = 32 504 | kwargs["width_per_group"] = 4 505 | return _resnet( 506 | "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs 507 | ) 508 | 509 | 510 | def resnext101_32x8d( 511 | pretrained: bool = False, progress: bool = True, **kwargs: Any 512 | ) -> ResNet: 513 | r"""ResNeXt-101 32x8d model from 514 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 515 | Args: 516 | pretrained (bool): If True, returns a model pre-trained on ImageNet 517 | progress (bool): If True, displays a progress bar of the download to stderr 518 | """ 519 | kwargs["groups"] = 32 520 | kwargs["width_per_group"] = 8 521 | return _resnet( 522 | "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs 523 | ) 524 | 525 | 526 | def wide_resnet50_2( 527 | pretrained: bool = False, progress: bool = True, **kwargs: Any 528 | ) -> ResNet: 529 | r"""Wide ResNet-50-2 model from 530 | `"Wide Residual Networks" `_. 531 | The model is the same as ResNet except for the bottleneck number of channels 532 | which is twice larger in every block. The number of channels in outer 1x1 533 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 534 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 535 | Args: 536 | pretrained (bool): If True, returns a model pre-trained on ImageNet 537 | progress (bool): If True, displays a progress bar of the download to stderr 538 | """ 539 | kwargs["width_per_group"] = 64 * 2 540 | return _resnet( 541 | "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs 542 | ) 543 | 544 | 545 | def wide_resnet101_2( 546 | pretrained: bool = False, progress: bool = True, **kwargs: Any 547 | ) -> ResNet: 548 | r"""Wide ResNet-101-2 model from 549 | `"Wide Residual Networks" `_. 550 | The model is the same as ResNet except for the bottleneck number of channels 551 | which is twice larger in every block. The number of channels in outer 1x1 552 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 553 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 554 | Args: 555 | pretrained (bool): If True, returns a model pre-trained on ImageNet 556 | progress (bool): If True, displays a progress bar of the download to stderr 557 | """ 558 | kwargs["width_per_group"] = 64 * 2 559 | return _resnet( 560 | "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs 561 | ) 562 | -------------------------------------------------------------------------------- /models/timm_model_wrapper.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TimmWrapper(nn.Module): 7 | def __init__( 8 | self, model, freeze_bn=False, miro=False, num_classes=100, freeze_all=False 9 | ): 10 | super(TimmWrapper, self).__init__() 11 | self.model = model 12 | self.freeze_bn = freeze_bn 13 | self.miro = miro 14 | self.num_classes = num_classes 15 | self.freeze_all = freeze_all 16 | 17 | if self.freeze_all: 18 | for p in model.parameters(): 19 | p.requires_grad_(False) 20 | 21 | if self.freeze_bn: 22 | self.freeze_batchnorm() 23 | 24 | if self.miro: 25 | self.global_pool, self.fc = timm.models.resnet.create_classifier( 26 | timm.models.resnet.Bottleneck.expansion * 512, self.num_classes, "avg" 27 | ) 28 | 29 | def forward(self, x): 30 | if self.miro: 31 | x = self.model(x)[-1] 32 | x = self.global_pool(x) 33 | return self.fc(x) 34 | else: 35 | return self.model(x) 36 | 37 | def forward_features(self, x): 38 | if self.miro: 39 | x = self.model(x) 40 | x_pool = self.global_pool(x[-1]) 41 | out = self.fc(x_pool) 42 | x[0] = self.model.maxpool(x[0]) 43 | return (out, x) 44 | else: 45 | return self.model(x) 46 | 47 | def train(self, mode=True): 48 | """ 49 | Override the default train() to freeze the BN parameters 50 | """ 51 | super().train(mode) 52 | if self.freeze_bn: 53 | self.freeze_batchnorm() 54 | 55 | def freeze_batchnorm(self): 56 | for m in self.model.modules(): 57 | if isinstance(m, nn.BatchNorm2d): 58 | m.eval() 59 | -------------------------------------------------------------------------------- /requirements.yml: -------------------------------------------------------------------------------- 1 | name: erm_plus_plus 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - aom=3.4.0=h27087fc_1 10 | - backports=1.0=py_2 11 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 12 | - blas=1.0=mkl 13 | - brotli=1.0.9=h166bdaf_8 14 | - brotli-bin=1.0.9=h166bdaf_8 15 | - bzip2=1.0.8=h7f98852_4 16 | - ca-certificates=2022.12.7=ha878542_0 17 | - cudatoolkit=11.3.1=h2bc3f7f_2 18 | - dbus=1.13.18=hb2f20db_0 19 | - expat=2.4.8=h27087fc_0 20 | - ffmpeg=5.1.0=gpl_hb2553f0_100 21 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 22 | - font-ttf-inconsolata=3.000=h77eed37_0 23 | - font-ttf-source-code-pro=2.038=h77eed37_0 24 | - font-ttf-ubuntu=0.83=hab24e00_0 25 | - fontconfig=2.14.0=h8e229c2_0 26 | - fonts-conda-ecosystem=1=0 27 | - fonts-conda-forge=1=0 28 | - freetype=2.10.4=hca18f0e_2 29 | - gettext=0.19.8.1=h73d1719_1008 30 | - giflib=5.2.1=h36c2ea0_2 31 | - glib=2.69.1=he621ea3_2 32 | - gmp=6.2.1=h58526e2_0 33 | - gnutls=3.7.6=hf3e180e_5 34 | - gst-plugins-base=1.14.0=h8213a91_2 35 | - gstreamer=1.14.0=h28cd5cc_2 36 | - icu=58.2=hf484d3e_1000 37 | - importlib_resources=5.10.1=pyhd8ed1ab_0 38 | - intel-openmp=2021.4.0=h06a4308_3561 39 | - ipython_genutils=0.2.0=py_1 40 | - jpeg=9e=h166bdaf_2 41 | - jupyter_client=7.4.8=pyhd8ed1ab_0 42 | - jupyter_core=5.1.0=py310hff52083_0 43 | - jupyter_server=1.23.3=pyhd8ed1ab_0 44 | - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 45 | - keyutils=1.6.1=h166bdaf_0 46 | - krb5=1.19.3=h3790be6_0 47 | - lame=3.100=h7f98852_1001 48 | - lcms2=2.12=hddcbb42_0 49 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 50 | - lerc=4.0.0=h27087fc_0 51 | - libblas=3.9.0=12_linux64_mkl 52 | - libbrotlicommon=1.0.9=h166bdaf_8 53 | - libbrotlidec=1.0.9=h166bdaf_8 54 | - libbrotlienc=1.0.9=h166bdaf_8 55 | - libcblas=3.9.0=12_linux64_mkl 56 | - libclang=10.0.1=default_hb85057a_2 57 | - libdeflate=1.13=h166bdaf_0 58 | - libdrm=2.4.112=h166bdaf_0 59 | - libedit=3.1.20191231=he28a2e2_2 60 | - libevent=2.1.12=h8f2d780_0 61 | - libffi=3.4.2=h7f98852_5 62 | - libgcc-ng=12.1.0=h8d9b700_16 63 | - libgfortran-ng=12.1.0=h69a702a_16 64 | - libgfortran5=12.1.0=hdcd56e2_16 65 | - libgomp=12.1.0=h8d9b700_16 66 | - libiconv=1.16=h516909a_0 67 | - libidn2=2.3.3=h166bdaf_0 68 | - liblapack=3.9.0=12_linux64_mkl 69 | - libllvm10=10.0.1=he513fc3_3 70 | - libnsl=2.0.0=h7f98852_0 71 | - libpciaccess=0.16=h516909a_0 72 | - libpng=1.6.37=h753d276_3 73 | - libpq=12.9=h16c4e8d_3 74 | - libsodium=1.0.18=h36c2ea0_1 75 | - libsqlite=3.40.0=h753d276_0 76 | - libstdcxx-ng=12.1.0=ha89aaad_16 77 | - libtasn1=4.18.0=h166bdaf_1 78 | - libtiff=4.4.0=h0e0dad5_3 79 | - libunistring=0.9.10=h7f98852_0 80 | - libuuid=2.32.1=h7f98852_1000 81 | - libva=2.15.0=h166bdaf_0 82 | - libvpx=1.11.0=h9c3ff4c_3 83 | - libwebp=1.2.4=h522a892_0 84 | - libwebp-base=1.2.4=h166bdaf_0 85 | - libxcb=1.15=h7f8727e_0 86 | - libxkbcommon=1.0.3=he3ba5ed_0 87 | - libxml2=2.9.14=h74e7548_0 88 | - libxslt=1.1.35=h4e12654_0 89 | - libzlib=1.2.13=h166bdaf_4 90 | - lz4-c=1.9.3=h9c3ff4c_1 91 | - matplotlib-base=3.6.2=py310h945d387_0 92 | - mkl=2021.4.0=h06a4308_640 93 | - mkl_fft=1.3.1=py310h2b4bcf5_1 94 | - mkl_random=1.2.2=py310h00e6091_0 95 | - nbconvert-core=7.2.7=pyhd8ed1ab_0 96 | - nbconvert-pandoc=7.2.7=pyhd8ed1ab_0 97 | - ncurses=6.3=h27087fc_1 98 | - nettle=3.8.1=hc379101_1 99 | - nspr=4.35=h27087fc_0 100 | - nss=3.82=he02c5a1_0 101 | - numpy-base=1.23.1=py310hcba007f_0 102 | - openh264=2.2.0=h27087fc_2 103 | - openjpeg=2.4.0=hb52868f_1 104 | - openssl=1.1.1s=h0b41bf4_1 105 | - p11-kit=0.24.1=hc5aa10d_0 106 | - pandoc=2.19.2=ha770c72_0 107 | - pcre=8.45=h9c3ff4c_0 108 | - prometheus_client=0.15.0=pyhd8ed1ab_0 109 | - pthread-stubs=0.4=h36c2ea0_1001 110 | - pure_eval=0.2.2=pyhd8ed1ab_0 111 | - pyqt=5.15.7=py310h6a678d5_1 112 | - python=3.10.8=h257c98d_0_cpython 113 | - python-fastjsonschema=2.16.2=pyhd8ed1ab_0 114 | - python_abi=3.10=2_cp310 115 | - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0 116 | - pytorch-mutex=1.0=cuda 117 | - qt-main=5.15.2=h327a75a_7 118 | - qt-webengine=5.15.9=hd2b0992_4 119 | - qtwebkit=5.212=h4eab89a_4 120 | - readline=8.1.2=h0f457ee_0 121 | - sqlite=3.39.2=h4ff8645_0 122 | - stack_data=0.3.0=pyhd8ed1ab_0 123 | - svt-av1=1.2.0=h27087fc_0 124 | - tk=8.6.12=h27826a3_0 125 | - typing_extensions=4.3.0=pyha770c72_0 126 | - tzdata=2022b=h191b570_0 127 | - x264=1!164.3095=h166bdaf_2 128 | - x265=3.5=h924138e_3 129 | - xorg-fixesproto=5.0=h7f98852_1002 130 | - xorg-kbproto=1.0.7=h7f98852_1002 131 | - xorg-libx11=1.7.2=h7f98852_0 132 | - xorg-libxau=1.0.9=h7f98852_0 133 | - xorg-libxdmcp=1.1.3=h7f98852_0 134 | - xorg-libxext=1.3.4=h7f98852_1 135 | - xorg-libxfixes=5.0.3=h7f98852_1004 136 | - xorg-xextproto=7.3.0=h7f98852_1002 137 | - xorg-xproto=7.0.31=h7f98852_1007 138 | - xz=5.2.6=h166bdaf_0 139 | - zeromq=4.3.4=h9c3ff4c_1 140 | - zlib=1.2.13=h166bdaf_4 141 | - zstd=1.5.2=h8a70e8d_4 142 | - pip: 143 | - anyio==3.6.2 144 | - argon2-cffi==21.3.0 145 | - argon2-cffi-bindings==21.2.0 146 | - asttokens==2.0.5 147 | - attrs==22.1.0 148 | - backcall==0.2.0 149 | - backports-functools-lru-cache==1.6.4 150 | - beautifulsoup4==4.11.1 151 | - black==22.12.0 152 | - bleach==5.0.1 153 | - brotlipy==0.7.0 154 | - certifi==2022.12.7 155 | - cffi==1.15.1 156 | - charset-normalizer==2.1.0 157 | - click==8.1.3 158 | - comm==0.1.2 159 | - contourpy==1.0.6 160 | - cryptography==37.0.1 161 | - cycler==0.11.0 162 | - debugpy==1.6.4 163 | - decorator==5.1.1 164 | - defusedxml==0.7.1 165 | - easydict==1.9 166 | - entrypoints==0.4 167 | - executing==0.9.1 168 | - fastjsonschema==2.16.2 169 | - filelock==3.8.0 170 | - flit-core==3.8.0 171 | - fonttools==4.38.0 172 | - gdown==4.5.1 173 | - huggingface-hub==0.12.0 174 | - idna==3.3 175 | - importlib-metadata==5.2.0 176 | - importlib-resources==5.10.1 177 | - ipykernel==6.19.3 178 | - ipython==8.4.0 179 | - ipython-genutils==0.2.0 180 | - jedi==0.18.1 181 | - jinja2==3.1.2 182 | - joblib==1.1.0 183 | - jsonschema==4.17.3 184 | - jupyter-client==7.4.8 185 | - jupyter-core==5.1.0 186 | - jupyter-server==1.23.3 187 | - jupyterlab-pygments==0.2.2 188 | - kiwisolver==1.4.4 189 | - littleutils==0.2.2 190 | - markupsafe==2.1.1 191 | - matplotlib==3.6.2 192 | - matplotlib-inline==0.1.3 193 | - mistune==2.0.4 194 | - mkl-fft==1.3.1 195 | - mkl-random==1.2.2 196 | - mkl-service==2.4.0 197 | - munkres==1.1.4 198 | - mypy-extensions==0.4.3 199 | - nbclassic==0.4.8 200 | - nbclient==0.7.2 201 | - nbconvert==7.2.7 202 | - nbformat==5.7.1 203 | - nest-asyncio==1.5.6 204 | - notebook==6.5.2 205 | - notebook-shim==0.2.2 206 | - numpy==1.23.1 207 | - ogb==1.3.5 208 | - opencv-python==4.6.0.66 209 | - outdated==0.2.2 210 | - packaging==22.0 211 | - pandas==1.5.2 212 | - pandocfilters==1.5.0 213 | - parso==0.8.3 214 | - pathspec==0.10.3 215 | - pexpect==4.8.0 216 | - pickleshare==0.7.5 217 | - pillow==9.3.0 218 | - pip==22.2.2 219 | - pkgutil-resolve-name==1.3.10 220 | - platformdirs==2.6.0 221 | - ply==3.11 222 | - prometheus-client==0.15.0 223 | - prompt-toolkit==3.0.30 224 | - psutil==5.9.4 225 | - ptflops==0.6.9 226 | - ptyprocess==0.7.0 227 | - pure-eval==0.2.2 228 | - pycls==0.1.1 229 | - pycparser==2.21 230 | - pygments==2.12.0 231 | - pyopenssl==22.0.0 232 | - pyparsing==3.0.9 233 | - pyqt5-sip==12.11.0 234 | - pyrsistent==0.19.2 235 | - pysocks==1.7.1 236 | - python-dateutil==2.8.2 237 | - pytz==2022.6 238 | - pyyaml==6.0 239 | - pyzmq==24.0.1 240 | - requests==2.28.1 241 | - scikit-learn==1.1.2 242 | - scipy==1.9.0 243 | - send2trash==1.8.0 244 | - setuptools==64.0.2 245 | - simplejson==3.17.6 246 | - sip==6.6.2 247 | - six==1.16.0 248 | - sniffio==1.3.0 249 | - soupsieve==2.3.2.post1 250 | - stack-data==0.3.0 251 | - terminado==0.17.1 252 | - threadpoolctl==3.1.0 253 | - timm==0.6.12 254 | - tinycss2==1.2.1 255 | - toml==0.10.2 256 | - tomli==2.0.1 257 | - torch==1.12.1 258 | - torchaudio==0.12.1 259 | - torchvision==0.13.1 260 | - tornado==6.2 261 | - tqdm==4.64.0 262 | - traitlets==5.4.0 263 | - typing-extensions==4.3.0 264 | - unicodedata2==15.0.0 265 | - urllib3==1.26.11 266 | - wcwidth==0.2.5 267 | - webencodings==0.5.1 268 | - websocket-client==1.4.2 269 | - wheel==0.37.1 270 | - wilds==2.0.0 271 | - yacs==0.1.8 272 | - zipp==3.11.0 273 | 274 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torchvision.datasets import MNIST 4 | import xml.etree.ElementTree as ET 5 | from zipfile import ZipFile 6 | import argparse 7 | import tarfile 8 | import shutil 9 | import gdown 10 | import uuid 11 | import json 12 | import os 13 | 14 | from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 15 | from wilds.datasets.fmow_dataset import FMoWDataset 16 | 17 | 18 | # utils ####################################################################### 19 | 20 | 21 | def stage_path(data_dir, name): 22 | full_path = os.path.join(data_dir, name) 23 | 24 | if not os.path.exists(full_path): 25 | os.makedirs(full_path) 26 | 27 | return full_path 28 | 29 | 30 | def download_and_extract(url, dst, remove=True): 31 | gdown.download(url, dst, quiet=False) 32 | 33 | if dst.endswith(".tar.gz"): 34 | tar = tarfile.open(dst, "r:gz") 35 | tar.extractall(os.path.dirname(dst)) 36 | tar.close() 37 | 38 | if dst.endswith(".tar"): 39 | tar = tarfile.open(dst, "r:") 40 | tar.extractall(os.path.dirname(dst)) 41 | tar.close() 42 | 43 | if dst.endswith(".zip"): 44 | zf = ZipFile(dst, "r") 45 | zf.extractall(os.path.dirname(dst)) 46 | zf.close() 47 | 48 | if remove: 49 | os.remove(dst) 50 | 51 | 52 | # VLCS ######################################################################## 53 | 54 | # Slower, but builds dataset from the original sources 55 | # 56 | # def download_vlcs(data_dir): 57 | # full_path = stage_path(data_dir, "VLCS") 58 | # 59 | # tmp_path = os.path.join(full_path, "tmp/") 60 | # if not os.path.exists(tmp_path): 61 | # os.makedirs(tmp_path) 62 | # 63 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 64 | # lines = f.readlines() 65 | # files = [line.strip().split() for line in lines] 66 | # 67 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 68 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 69 | # 70 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 71 | # os.path.join(tmp_path, "caltech101.tar.gz")) 72 | # 73 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 74 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 75 | # 76 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 77 | # tar.extractall(tmp_path) 78 | # tar.close() 79 | # 80 | # for src, dst in files: 81 | # class_folder = os.path.join(data_dir, dst) 82 | # 83 | # if not os.path.exists(class_folder): 84 | # os.makedirs(class_folder) 85 | # 86 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 87 | # 88 | # if "labelme" in src: 89 | # # download labelme from the web 90 | # gdown.download(src, dst, quiet=False) 91 | # else: 92 | # src = os.path.join(tmp_path, src) 93 | # shutil.copyfile(src, dst) 94 | # 95 | # shutil.rmtree(tmp_path) 96 | 97 | 98 | def download_vlcs(data_dir): 99 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 100 | full_path = stage_path(data_dir, "VLCS") 101 | 102 | download_and_extract( 103 | "https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 104 | os.path.join(data_dir, "VLCS.tar.gz"), 105 | ) 106 | 107 | 108 | 109 | 110 | # Office-Home ################################################################# 111 | 112 | 113 | def download_office_home(data_dir): 114 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 115 | full_path = stage_path(data_dir, "office_home") 116 | 117 | download_and_extract( 118 | "https://wjdcloud.blob.core.windows.net/dataset/OfficeHome.zip", 119 | os.path.join(data_dir, "OfficeHome.zip"), 120 | ) 121 | 122 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), full_path) 123 | 124 | 125 | # DomainNET ################################################################### 126 | 127 | 128 | def download_domain_net(data_dir): 129 | # Original URL: http://ai.bu.edu/M3SDA/ 130 | full_path = stage_path(data_dir, "domain_net") 131 | 132 | urls = [ 133 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 134 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 135 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 136 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 137 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 138 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip", 139 | ] 140 | 141 | for url in urls: 142 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 143 | 144 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 145 | for line in f.readlines(): 146 | try: 147 | os.remove(os.path.join(full_path, line.strip())) 148 | except OSError: 149 | pass 150 | 151 | 152 | # TerraIncognita ############################################################## 153 | 154 | 155 | def download_terra_incognita(data_dir): 156 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 157 | # New URL: http://lila.science/datasets/caltech-camera-traps 158 | 159 | full_path = stage_path(data_dir, "terra_incognita") 160 | 161 | download_and_extract( 162 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz", 163 | os.path.join(full_path, "terra_incognita_images.tar.gz"), 164 | ) 165 | 166 | download_and_extract( 167 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip", 168 | os.path.join(full_path, "caltech_camera_traps.json.zip"), 169 | ) 170 | 171 | include_locations = ["38", "46", "100", "43"] 172 | 173 | include_categories = [ 174 | "bird", 175 | "bobcat", 176 | "cat", 177 | "coyote", 178 | "dog", 179 | "empty", 180 | "opossum", 181 | "rabbit", 182 | "raccoon", 183 | "squirrel", 184 | ] 185 | 186 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 187 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json") 188 | destination_folder = full_path 189 | 190 | stats = {} 191 | 192 | if not os.path.exists(destination_folder): 193 | os.mkdir(destination_folder) 194 | 195 | with open(annotations_file, "r") as f: 196 | data = json.load(f) 197 | 198 | category_dict = {} 199 | for item in data["categories"]: 200 | category_dict[item["id"]] = item["name"] 201 | 202 | for image in data["images"]: 203 | image_location = image["location"] 204 | 205 | if image_location not in include_locations: 206 | continue 207 | 208 | loc_folder = os.path.join( 209 | destination_folder, "location_" + str(image_location) + "/" 210 | ) 211 | 212 | if not os.path.exists(loc_folder): 213 | os.mkdir(loc_folder) 214 | 215 | image_id = image["id"] 216 | image_fname = image["file_name"] 217 | 218 | for annotation in data["annotations"]: 219 | if annotation["image_id"] == image_id: 220 | if image_location not in stats: 221 | stats[image_location] = {} 222 | 223 | category = category_dict[annotation["category_id"]] 224 | 225 | if category not in include_categories: 226 | continue 227 | 228 | if category not in stats[image_location]: 229 | stats[image_location][category] = 0 230 | else: 231 | stats[image_location][category] += 1 232 | 233 | loc_cat_folder = os.path.join(loc_folder, category + "/") 234 | 235 | if not os.path.exists(loc_cat_folder): 236 | os.mkdir(loc_cat_folder) 237 | 238 | dst_path = os.path.join(loc_cat_folder, image_fname) 239 | src_path = os.path.join(images_folder, image_fname) 240 | 241 | shutil.copyfile(src_path, dst_path) 242 | 243 | shutil.rmtree(images_folder) 244 | os.remove(annotations_file) 245 | 246 | 247 | if __name__ == "__main__": 248 | parser = argparse.ArgumentParser(description="Download datasets") 249 | parser.add_argument("--data_dir", type=str, required=True) 250 | args = parser.parse_args() 251 | 252 | download_office_home(args.data_dir) 253 | download_domain_net(args.data_dir) 254 | download_vlcs(args.data_dir) 255 | download_terra_incognita(args.data_dir) 256 | FMoWDataset(root_dir=args.data_dir, download=True) 257 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -P ../model_checkpoints2/augmix https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth 3 | wget -P ../model_checkpoints2/resnet_a1 https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth 4 | 5 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from enum import Enum 4 | 5 | 6 | class MovingAvg: 7 | def __init__(self, network, ema=False, sma_start_iter=100): 8 | self.network = network 9 | self.network_sma = copy.deepcopy(network) 10 | self.sma_start_iter = sma_start_iter 11 | self.global_iter = 0 12 | self.sma_count = 0 13 | self.ema = ema 14 | 15 | def update_sma(self): 16 | self.global_iter += 1 17 | if self.global_iter >= self.sma_start_iter and self.ema: 18 | # if False: 19 | self.sma_count += 1 20 | for param_q, param_k in zip( 21 | self.network.parameters(), self.network_sma.parameters() 22 | ): 23 | param_k.data = (param_k.data * self.sma_count + param_q.data) / ( 24 | 1.0 + self.sma_count 25 | ) 26 | else: 27 | for param_q, param_k in zip( 28 | self.network.parameters(), self.network_sma.parameters() 29 | ): 30 | param_k.data = param_q.data 31 | 32 | 33 | class Summary(Enum): 34 | NONE = 0 35 | AVERAGE = 1 36 | SUM = 2 37 | COUNT = 3 38 | 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | 43 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 44 | self.name = name 45 | self.fmt = fmt 46 | self.summary_type = summary_type 47 | self.reset() 48 | 49 | def reset(self): 50 | self.val = 0 51 | self.avg = 0 52 | self.sum = 0 53 | self.count = 0 54 | 55 | def update(self, val, n=1): 56 | self.val = val 57 | self.sum += val * n 58 | self.count += n 59 | self.avg = self.sum / self.count 60 | 61 | def __str__(self): 62 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 63 | return fmtstr.format(**self.__dict__) 64 | 65 | def summary(self): 66 | fmtstr = "" 67 | if self.summary_type is Summary.NONE: 68 | fmtstr = "" 69 | elif self.summary_type is Summary.AVERAGE: 70 | fmtstr = "{name} {avg:.3f}" 71 | elif self.summary_type is Summary.SUM: 72 | fmtstr = "{name} {sum:.3f}" 73 | elif self.summary_type is Summary.COUNT: 74 | fmtstr = "{name} {count:.3f}" 75 | else: 76 | raise ValueError("invalid summary type %r" % self.summary_type) 77 | 78 | return fmtstr.format(**self.__dict__) 79 | 80 | 81 | class ProgressMeter(object): 82 | def __init__(self, num_batches, meters, prefix=""): 83 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 84 | self.meters = meters 85 | self.prefix = prefix 86 | 87 | def display(self, batch): 88 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 89 | entries += [str(meter) for meter in self.meters] 90 | print("\t".join(entries)) 91 | 92 | def display_summary(self): 93 | entries = [" *"] 94 | entries += [meter.summary() for meter in self.meters] 95 | print(" ".join(entries)) 96 | 97 | def _get_batch_fmtstr(self, num_batches): 98 | num_digits = len(str(num_batches // 1)) 99 | fmt = "{:" + str(num_digits) + "d}" 100 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 101 | 102 | 103 | def accuracy(output, target, topk=(1,)): 104 | """Computes the accuracy over the k top predictions for the specified values of k""" 105 | with torch.no_grad(): 106 | maxk = max(topk) 107 | batch_size = target.size(0) 108 | 109 | _, pred = output.topk(maxk, 1, True, True) 110 | pred = pred.t() 111 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 112 | 113 | res = [] 114 | for k in topk: 115 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 116 | res.append(correct_k.mul_(100.0 / batch_size)) 117 | return res 118 | --------------------------------------------------------------------------------