├── LICENSE ├── README.md ├── config.yaml ├── domainbed ├── algorithms │ ├── __init__.py │ ├── algorithms.py │ └── miro.py ├── datasets │ ├── __init__.py │ ├── datasets.py │ └── transforms.py ├── evaluator.py ├── hparams_registry.py ├── lib │ ├── fast_data_loader.py │ ├── logger.py │ ├── misc.py │ ├── query.py │ ├── swa_utils.py │ ├── wide_resnet.py │ └── writers.py ├── misc │ └── domain_net_duplicates.txt ├── models │ ├── mixstyle.py │ ├── resnet_mixstyle.py │ └── resnet_mixstyle2.py ├── networks │ ├── __init__.py │ ├── backbones.py │ ├── networks.py │ └── ur_networks.py ├── optimizers.py ├── scripts │ └── download.py ├── swad.py ├── trainer.py └── trainer_DN.py ├── media ├── DART_pic.png ├── DG_combined_results.png ├── DG_main_results.png ├── ID_results.png └── model_optimization_trajectory.gif ├── requirements.txt └── train_all.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Video Analytics Lab -- IISc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DART: Diversify-Aggregate-Repeat Training 2 | This repository contains codes for the training and evaluation of our CVPR-23 paper DART:Diversify-Aggregate-Repeat Training Improves Generalization of Neural Networks [main](https://openaccess.thecvf.com/content/CVPR2023/papers/Jain_DART_Diversify-Aggregate-Repeat_Training_Improves_Generalization_of_Neural_Networks_CVPR_2023_paper.pdf) and [supplementary](https://openaccess.thecvf.com/content/CVPR2023/supplemental/Jain_DART_Diversify-Aggregate-Repeat_Training_CVPR_2023_supplemental.pdf). The arxiv link for the paper is also [available](https://arxiv.org/pdf/2302.14685.pdf). 3 |

4 | 5 | 6 |

7 | 8 | 9 | # Environment Settings 10 | * Python 3.6.9 11 | * PyTorch 1.8 12 | * Torchvision 0.8.0 13 | * Numpy 1.19.2 14 | 15 | 16 | 17 | # Training 18 | For training DART on Domain Generalization task: 19 | ``` 20 | python train_all.py [name_of_exp] --data_dir ./path/to/data --algorithm ERM --dataset PACS --inter_freq 1000 --steps 10001 21 | ``` 22 | ### Combine with SWAD 23 | set `swad: True` in config.yaml file or pass `--swad True` in the python command. 24 | ### Changing Model & Hyperparams 25 | Similarly, to change the model (eg- VIT), swad hyperparameters or MIRO hyperparams, you can update ```config.yaml``` file or pass it as argument in the python command. 26 | ``` 27 | python train_all.py [name_of_exp] --data_dir ./path/to/data \ 28 | --lr 3e-5 \ 29 | --inter_freq 600 \ 30 | --steps 8001 \ 31 | --dataset OfficeHome \ 32 | --algorithm MIRO \ 33 | --ld 0.1 \ 34 | --weight_decay 1e-6 \ 35 | --swad True \ 36 | --model clip_vit-b16 37 | ``` 38 | 39 | # Results 40 | ## In-Domain Generalization of DART: 41 |

42 | 43 |

44 | 45 | ## Domain Generalization of DART: 46 |

47 | 48 |

49 | 50 | ## Combining DART with other DG methods on Office-Home: 51 |

52 | 53 |

54 | 55 | 56 | 57 | # Citing this work 58 | ``` 59 | @inproceedings{jain2023dart, 60 | title={DART: Diversify-Aggregate-Repeat Training Improves Generalization of Neural Networks}, 61 | author={Jain, Samyak and Addepalli, Sravanti and Sahu, Pawan Kumar and Dey, Priyam and Babu, R Venkatesh}, 62 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 63 | pages={16048--16059}, 64 | year={2023} 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # Default update config 2 | # Config order: hparams_registry -> config.yaml -> CLI 3 | swad: True # True / False 4 | swad_kwargs: 5 | n_converge: 3 6 | n_tolerance: 6 7 | tolerance_ratio: 0.3 8 | test_batchsize: 128 9 | 10 | # resnet50, resnet50_barlowtwins, resnet50_moco, clip_resnet50, clip_vit-b16, swag_regnety_16gf 11 | model: resnet50 12 | feat_layers: stem_block 13 | 14 | # MIRO params 15 | ld: 0.1 # lambda 16 | lr_mult: 10. 17 | -------------------------------------------------------------------------------- /domainbed/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .algorithms import * 2 | from .miro import MIRO 3 | 4 | 5 | def get_algorithm_class(algorithm_name): 6 | """Return the algorithm class with the given name.""" 7 | if algorithm_name not in globals(): 8 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 9 | return globals()[algorithm_name] 10 | -------------------------------------------------------------------------------- /domainbed/algorithms/algorithms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import copy 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.autograd as autograd 10 | import numpy as np 11 | 12 | # import higher 13 | 14 | from domainbed import networks 15 | from domainbed.lib.misc import random_pairs_of_minibatches 16 | from domainbed.optimizers import get_optimizer 17 | 18 | from domainbed.models.resnet_mixstyle import ( 19 | resnet18_mixstyle_L234_p0d5_a0d1, 20 | resnet50_mixstyle_L234_p0d5_a0d1, 21 | ) 22 | from domainbed.models.resnet_mixstyle2 import ( 23 | resnet18_mixstyle2_L234_p0d5_a0d1, 24 | resnet50_mixstyle2_L234_p0d5_a0d1, 25 | ) 26 | 27 | 28 | def to_minibatch(x, y): 29 | minibatches = list(zip(x, y)) 30 | return minibatches 31 | 32 | 33 | class Algorithm(torch.nn.Module): 34 | """ 35 | A subclass of Algorithm implements a domain generalization algorithm. 36 | Subclasses should implement the following: 37 | - update() 38 | - predict() 39 | """ 40 | 41 | transforms = {} 42 | 43 | def __init__(self, input_shape, num_classes, num_domains, hparams): 44 | super(Algorithm, self).__init__() 45 | self.input_shape = input_shape 46 | self.num_classes = num_classes 47 | self.num_domains = num_domains 48 | self.hparams = hparams 49 | 50 | def update(self, x, y, **kwargs): 51 | """ 52 | Perform one update step, given a list of (x, y) tuples for all 53 | environments. 54 | """ 55 | raise NotImplementedError 56 | 57 | def predict(self, x): 58 | raise NotImplementedError 59 | 60 | def forward(self, x): 61 | return self.predict(x) 62 | 63 | def new_optimizer(self, parameters): 64 | optimizer = get_optimizer( 65 | self.hparams["optimizer"], 66 | parameters, 67 | lr=self.hparams["lr"], 68 | weight_decay=self.hparams["weight_decay"], 69 | ) 70 | return optimizer 71 | 72 | def clone(self): 73 | clone = copy.deepcopy(self) 74 | clone.optimizer = self.new_optimizer(clone.network.parameters()) 75 | clone.optimizer.load_state_dict(self.optimizer.state_dict()) 76 | 77 | return clone 78 | 79 | 80 | class ERM(Algorithm): 81 | """ 82 | Empirical Risk Minimization (ERM) 83 | """ 84 | 85 | def __init__(self, input_shape, num_classes, num_domains, hparams): 86 | super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams) 87 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 88 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes) 89 | self.network = nn.Sequential(self.featurizer, self.classifier) 90 | self.optimizer = get_optimizer( 91 | hparams["optimizer"], 92 | self.network.parameters(), 93 | lr=self.hparams["lr"], 94 | weight_decay=self.hparams["weight_decay"], 95 | ) 96 | 97 | def update(self, x, y, **kwargs): 98 | all_x = torch.cat(x) 99 | all_y = torch.cat(y) 100 | loss = F.cross_entropy(self.predict(all_x), all_y) 101 | 102 | self.optimizer.zero_grad() 103 | loss.backward() 104 | self.optimizer.step() 105 | 106 | return {"loss": loss.item()} 107 | 108 | def predict(self, x): 109 | return self.network(x) 110 | 111 | 112 | class Mixstyle(Algorithm): 113 | """MixStyle w/o domain label (random shuffle)""" 114 | 115 | def __init__(self, input_shape, num_classes, num_domains, hparams): 116 | assert input_shape[1:3] == (224, 224), "Mixstyle support R18 and R50 only" 117 | super().__init__(input_shape, num_classes, num_domains, hparams) 118 | if hparams["resnet18"]: 119 | network = resnet18_mixstyle_L234_p0d5_a0d1() 120 | else: 121 | network = resnet50_mixstyle_L234_p0d5_a0d1() 122 | self.featurizer = networks.ResNet(input_shape, self.hparams, network) 123 | 124 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes) 125 | self.network = nn.Sequential(self.featurizer, self.classifier) 126 | self.optimizer = self.new_optimizer(self.network.parameters()) 127 | 128 | def update(self, x, y, **kwargs): 129 | all_x = torch.cat(x) 130 | all_y = torch.cat(y) 131 | loss = F.cross_entropy(self.predict(all_x), all_y) 132 | 133 | self.optimizer.zero_grad() 134 | loss.backward() 135 | self.optimizer.step() 136 | 137 | return {"loss": loss.item()} 138 | 139 | def predict(self, x): 140 | return self.network(x) 141 | 142 | 143 | class ARM(ERM): 144 | """Adaptive Risk Minimization (ARM)""" 145 | 146 | def __init__(self, input_shape, num_classes, num_domains, hparams): 147 | original_input_shape = input_shape 148 | input_shape = (1 + original_input_shape[0],) + original_input_shape[1:] 149 | super(ARM, self).__init__(input_shape, num_classes, num_domains, hparams) 150 | self.context_net = networks.ContextNet(original_input_shape) 151 | self.support_size = hparams["batch_size"] 152 | 153 | def predict(self, x): 154 | batch_size, c, h, w = x.shape 155 | if batch_size % self.support_size == 0: 156 | meta_batch_size = batch_size // self.support_size 157 | support_size = self.support_size 158 | else: 159 | meta_batch_size, support_size = 1, batch_size 160 | context = self.context_net(x) 161 | context = context.reshape((meta_batch_size, support_size, 1, h, w)) 162 | context = context.mean(dim=1) 163 | context = torch.repeat_interleave(context, repeats=support_size, dim=0) 164 | x = torch.cat([x, context], dim=1) 165 | return self.network(x) 166 | 167 | 168 | class SAM(ERM): 169 | """Sharpness-Aware Minimization 170 | """ 171 | @staticmethod 172 | def norm(tensor_list: List[torch.tensor], p=2): 173 | """Compute p-norm for tensor list""" 174 | return torch.cat([x.flatten() for x in tensor_list]).norm(p) 175 | 176 | def update(self, x, y, **kwargs): 177 | all_x = torch.cat([xi for xi in x]) 178 | all_y = torch.cat([yi for yi in y]) 179 | loss = F.cross_entropy(self.predict(all_x), all_y) 180 | 181 | # 1. eps(w) = rho * g(w) / g(w).norm(2) 182 | # = (rho / g(w).norm(2)) * g(w) 183 | grad_w = autograd.grad(loss, self.network.parameters()) 184 | scale = self.hparams["rho"] / self.norm(grad_w) 185 | eps = [g * scale for g in grad_w] 186 | 187 | # 2. w' = w + eps(w) 188 | with torch.no_grad(): 189 | for p, v in zip(self.network.parameters(), eps): 190 | p.add_(v) 191 | 192 | # 3. w = w - lr * g(w') 193 | loss = F.cross_entropy(self.predict(all_x), all_y) 194 | 195 | self.optimizer.zero_grad() 196 | loss.backward() 197 | # restore original network params 198 | with torch.no_grad(): 199 | for p, v in zip(self.network.parameters(), eps): 200 | p.sub_(v) 201 | self.optimizer.step() 202 | 203 | return {"loss": loss.item()} 204 | 205 | 206 | class AbstractDANN(Algorithm): 207 | """Domain-Adversarial Neural Networks (abstract class)""" 208 | 209 | def __init__(self, input_shape, num_classes, num_domains, hparams, conditional, class_balance): 210 | 211 | super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains, hparams) 212 | 213 | self.register_buffer("update_count", torch.tensor([0])) 214 | self.conditional = conditional 215 | self.class_balance = class_balance 216 | 217 | # Algorithms 218 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 219 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes) 220 | self.discriminator = networks.MLP(self.featurizer.n_outputs, num_domains, self.hparams) 221 | self.class_embeddings = nn.Embedding(num_classes, self.featurizer.n_outputs) 222 | 223 | # Optimizers 224 | self.disc_opt = get_optimizer( 225 | hparams["optimizer"], 226 | (list(self.discriminator.parameters()) + list(self.class_embeddings.parameters())), 227 | lr=self.hparams["lr_d"], 228 | weight_decay=self.hparams["weight_decay_d"], 229 | betas=(self.hparams["beta1"], 0.9), 230 | ) 231 | 232 | self.gen_opt = get_optimizer( 233 | hparams["optimizer"], 234 | (list(self.featurizer.parameters()) + list(self.classifier.parameters())), 235 | lr=self.hparams["lr_g"], 236 | weight_decay=self.hparams["weight_decay_g"], 237 | betas=(self.hparams["beta1"], 0.9), 238 | ) 239 | 240 | def update(self, x, y, **kwargs): 241 | self.update_count += 1 242 | all_x = torch.cat([xi for xi in x]) 243 | all_y = torch.cat([yi for yi in y]) 244 | minibatches = to_minibatch(x, y) 245 | all_z = self.featurizer(all_x) 246 | if self.conditional: 247 | disc_input = all_z + self.class_embeddings(all_y) 248 | else: 249 | disc_input = all_z 250 | disc_out = self.discriminator(disc_input) 251 | disc_labels = torch.cat( 252 | [ 253 | torch.full((x.shape[0],), i, dtype=torch.int64, device="cuda") 254 | for i, (x, y) in enumerate(minibatches) 255 | ] 256 | ) 257 | 258 | if self.class_balance: 259 | y_counts = F.one_hot(all_y).sum(dim=0) 260 | weights = 1.0 / (y_counts[all_y] * y_counts.shape[0]).float() 261 | disc_loss = F.cross_entropy(disc_out, disc_labels, reduction="none") 262 | disc_loss = (weights * disc_loss).sum() 263 | else: 264 | disc_loss = F.cross_entropy(disc_out, disc_labels) 265 | 266 | disc_softmax = F.softmax(disc_out, dim=1) 267 | input_grad = autograd.grad( 268 | disc_softmax[:, disc_labels].sum(), [disc_input], create_graph=True 269 | )[0] 270 | grad_penalty = (input_grad ** 2).sum(dim=1).mean(dim=0) 271 | disc_loss += self.hparams["grad_penalty"] * grad_penalty 272 | 273 | d_steps_per_g = self.hparams["d_steps_per_g_step"] 274 | if self.update_count.item() % (1 + d_steps_per_g) < d_steps_per_g: 275 | 276 | self.disc_opt.zero_grad() 277 | disc_loss.backward() 278 | self.disc_opt.step() 279 | return {"disc_loss": disc_loss.item()} 280 | else: 281 | all_preds = self.classifier(all_z) 282 | classifier_loss = F.cross_entropy(all_preds, all_y) 283 | gen_loss = classifier_loss + (self.hparams["lambda"] * -disc_loss) 284 | self.disc_opt.zero_grad() 285 | self.gen_opt.zero_grad() 286 | gen_loss.backward() 287 | self.gen_opt.step() 288 | return {"gen_loss": gen_loss.item()} 289 | 290 | def predict(self, x): 291 | return self.classifier(self.featurizer(x)) 292 | 293 | 294 | class DANN(AbstractDANN): 295 | """Unconditional DANN""" 296 | 297 | def __init__(self, input_shape, num_classes, num_domains, hparams): 298 | super(DANN, self).__init__( 299 | input_shape, 300 | num_classes, 301 | num_domains, 302 | hparams, 303 | conditional=False, 304 | class_balance=False, 305 | ) 306 | 307 | 308 | class CDANN(AbstractDANN): 309 | """Conditional DANN""" 310 | 311 | def __init__(self, input_shape, num_classes, num_domains, hparams): 312 | super(CDANN, self).__init__( 313 | input_shape, 314 | num_classes, 315 | num_domains, 316 | hparams, 317 | conditional=True, 318 | class_balance=True, 319 | ) 320 | 321 | 322 | class OrgMixup(ERM): 323 | """ 324 | Original Mixup independent with domains 325 | """ 326 | 327 | def update(self, x, y, **kwargs): 328 | x = torch.cat(x) 329 | y = torch.cat(y) 330 | 331 | indices = torch.randperm(x.size(0)) 332 | x2 = x[indices] 333 | y2 = y[indices] 334 | 335 | lam = np.random.beta(self.hparams["mixup_alpha"], self.hparams["mixup_alpha"]) 336 | 337 | x = lam * x + (1 - lam) * x2 338 | predictions = self.predict(x) 339 | 340 | objective = lam * F.cross_entropy(predictions, y) 341 | objective += (1 - lam) * F.cross_entropy(predictions, y2) 342 | 343 | self.optimizer.zero_grad() 344 | objective.backward() 345 | self.optimizer.step() 346 | 347 | return {"loss": objective.item()} 348 | 349 | 350 | class CutMix(ERM): 351 | @staticmethod 352 | def rand_bbox(size, lam): 353 | W = size[2] 354 | H = size[3] 355 | cut_rat = np.sqrt(1.0 - lam) 356 | cut_w = np.int(W * cut_rat) 357 | cut_h = np.int(H * cut_rat) 358 | 359 | # uniform 360 | cx = np.random.randint(W) 361 | cy = np.random.randint(H) 362 | 363 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 364 | bby1 = np.clip(cy - cut_h // 2, 0, H) 365 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 366 | bby2 = np.clip(cy + cut_h // 2, 0, H) 367 | 368 | return bbx1, bby1, bbx2, bby2 369 | 370 | def update(self, x, y, **kwargs): 371 | # cutmix_prob is set to 1.0 for ImageNet and 0.5 for CIFAR100 in the original paper. 372 | x = torch.cat(x) 373 | y = torch.cat(y) 374 | 375 | r = np.random.rand(1) 376 | if self.hparams["beta"] > 0 and r < self.hparams["cutmix_prob"]: 377 | # generate mixed sample 378 | beta = self.hparams["beta"] 379 | lam = np.random.beta(beta, beta) 380 | rand_index = torch.randperm(x.size()[0]).cuda() 381 | target_a = y 382 | target_b = y[rand_index] 383 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.size(), lam) 384 | x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2] 385 | # adjust lambda to exactly match pixel ratio 386 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 387 | # compute output 388 | output = self.predict(x) 389 | objective = F.cross_entropy(output, target_a) * lam + F.cross_entropy( 390 | output, target_b 391 | ) * (1.0 - lam) 392 | else: 393 | output = self.predict(x) 394 | objective = F.cross_entropy(output, y) 395 | 396 | self.optimizer.zero_grad() 397 | objective.backward() 398 | self.optimizer.step() 399 | 400 | return {"loss": objective.item()} 401 | 402 | 403 | class SagNet(Algorithm): 404 | """ 405 | Style Agnostic Network 406 | Algorithm 1 from: https://arxiv.org/abs/1910.11645 407 | """ 408 | 409 | def __init__(self, input_shape, num_classes, num_domains, hparams): 410 | super(SagNet, self).__init__(input_shape, num_classes, num_domains, hparams) 411 | # featurizer network 412 | self.network_f = networks.Featurizer(input_shape, self.hparams) 413 | # content network 414 | self.network_c = nn.Linear(self.network_f.n_outputs, num_classes) 415 | # style network 416 | self.network_s = nn.Linear(self.network_f.n_outputs, num_classes) 417 | 418 | # # This commented block of code implements something closer to the 419 | # # original paper, but is specific to ResNet and puts in disadvantage 420 | # # the other algorithms. 421 | # resnet_c = networks.Featurizer(input_shape, self.hparams) 422 | # resnet_s = networks.Featurizer(input_shape, self.hparams) 423 | # # featurizer network 424 | # self.network_f = torch.nn.Sequential( 425 | # resnet_c.network.conv1, 426 | # resnet_c.network.bn1, 427 | # resnet_c.network.relu, 428 | # resnet_c.network.maxpool, 429 | # resnet_c.network.layer1, 430 | # resnet_c.network.layer2, 431 | # resnet_c.network.layer3) 432 | # # content network 433 | # self.network_c = torch.nn.Sequential( 434 | # resnet_c.network.layer4, 435 | # resnet_c.network.avgpool, 436 | # networks.Flatten(), 437 | # resnet_c.network.fc) 438 | # # style network 439 | # self.network_s = torch.nn.Sequential( 440 | # resnet_s.network.layer4, 441 | # resnet_s.network.avgpool, 442 | # networks.Flatten(), 443 | # resnet_s.network.fc) 444 | 445 | def opt(p): 446 | return get_optimizer( 447 | hparams["optimizer"], p, lr=hparams["lr"], weight_decay=hparams["weight_decay"] 448 | ) 449 | 450 | self.optimizer_f = opt(self.network_f.parameters()) 451 | self.optimizer_c = opt(self.network_c.parameters()) 452 | self.optimizer_s = opt(self.network_s.parameters()) 453 | self.weight_adv = hparams["sag_w_adv"] 454 | 455 | def forward_c(self, x): 456 | # learning content network on randomized style 457 | return self.network_c(self.randomize(self.network_f(x), "style")) 458 | 459 | def forward_s(self, x): 460 | # learning style network on randomized content 461 | return self.network_s(self.randomize(self.network_f(x), "content")) 462 | 463 | def randomize(self, x, what="style", eps=1e-5): 464 | sizes = x.size() 465 | alpha = torch.rand(sizes[0], 1).cuda() 466 | 467 | if len(sizes) == 4: 468 | x = x.view(sizes[0], sizes[1], -1) 469 | alpha = alpha.unsqueeze(-1) 470 | 471 | mean = x.mean(-1, keepdim=True) 472 | var = x.var(-1, keepdim=True) 473 | 474 | x = (x - mean) / (var + eps).sqrt() 475 | 476 | idx_swap = torch.randperm(sizes[0]) 477 | if what == "style": 478 | mean = alpha * mean + (1 - alpha) * mean[idx_swap] 479 | var = alpha * var + (1 - alpha) * var[idx_swap] 480 | else: 481 | x = x[idx_swap].detach() 482 | 483 | x = x * (var + eps).sqrt() + mean 484 | return x.view(*sizes) 485 | 486 | def update(self, x, y, **kwargs): 487 | all_x = torch.cat([xi for xi in x]) 488 | all_y = torch.cat([yi for yi in y]) 489 | 490 | # learn content 491 | self.optimizer_f.zero_grad() 492 | self.optimizer_c.zero_grad() 493 | loss_c = F.cross_entropy(self.forward_c(all_x), all_y) 494 | loss_c.backward() 495 | self.optimizer_f.step() 496 | self.optimizer_c.step() 497 | 498 | # learn style 499 | self.optimizer_s.zero_grad() 500 | loss_s = F.cross_entropy(self.forward_s(all_x), all_y) 501 | loss_s.backward() 502 | self.optimizer_s.step() 503 | 504 | # learn adversary 505 | self.optimizer_f.zero_grad() 506 | loss_adv = -F.log_softmax(self.forward_s(all_x), dim=1).mean(1).mean() 507 | loss_adv = loss_adv * self.weight_adv 508 | loss_adv.backward() 509 | self.optimizer_f.step() 510 | 511 | return { 512 | "loss_c": loss_c.item(), 513 | "loss_s": loss_s.item(), 514 | "loss_adv": loss_adv.item(), 515 | } 516 | 517 | def predict(self, x): 518 | return self.network_c(self.network_f(x)) 519 | 520 | 521 | class RSC(ERM): 522 | def __init__(self, input_shape, num_classes, num_domains, hparams): 523 | super(RSC, self).__init__(input_shape, num_classes, num_domains, hparams) 524 | self.drop_f = (1 - hparams["rsc_f_drop_factor"]) * 100 525 | self.drop_b = (1 - hparams["rsc_b_drop_factor"]) * 100 526 | self.num_classes = num_classes 527 | 528 | def update(self, x, y, **kwargs): 529 | # inputs 530 | all_x = torch.cat([xi for xi in x]) 531 | # labels 532 | all_y = torch.cat([yi for yi in y]) 533 | # one-hot labels 534 | all_o = torch.nn.functional.one_hot(all_y, self.num_classes) 535 | # features 536 | all_f = self.featurizer(all_x) 537 | # predictions 538 | all_p = self.classifier(all_f) 539 | 540 | # Equation (1): compute gradients with respect to representation 541 | all_g = autograd.grad((all_p * all_o).sum(), all_f)[0] 542 | 543 | # Equation (2): compute top-gradient-percentile mask 544 | percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1) 545 | percentiles = torch.Tensor(percentiles) 546 | percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1)) 547 | mask_f = all_g.lt(percentiles.cuda()).float() 548 | 549 | # Equation (3): mute top-gradient-percentile activations 550 | all_f_muted = all_f * mask_f 551 | 552 | # Equation (4): compute muted predictions 553 | all_p_muted = self.classifier(all_f_muted) 554 | 555 | # Section 3.3: Batch Percentage 556 | all_s = F.softmax(all_p, dim=1) 557 | all_s_muted = F.softmax(all_p_muted, dim=1) 558 | changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1) 559 | percentile = np.percentile(changes.detach().cpu(), self.drop_b) 560 | mask_b = changes.lt(percentile).float().view(-1, 1) 561 | mask = torch.logical_or(mask_f, mask_b).float() 562 | 563 | # Equations (3) and (4) again, this time mutting over examples 564 | all_p_muted_again = self.classifier(all_f * mask) 565 | 566 | # Equation (5): update 567 | loss = F.cross_entropy(all_p_muted_again, all_y) 568 | self.optimizer.zero_grad() 569 | loss.backward() 570 | self.optimizer.step() 571 | 572 | return {"loss": loss.item()} 573 | -------------------------------------------------------------------------------- /domainbed/algorithms/miro.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kakao Brain. All Rights Reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from domainbed.optimizers import get_optimizer 8 | from domainbed.networks.ur_networks import URFeaturizer 9 | from domainbed.lib import misc 10 | from domainbed.algorithms import Algorithm 11 | 12 | 13 | class ForwardModel(nn.Module): 14 | """Forward model is used to reduce gpu memory usage of SWAD. 15 | """ 16 | def __init__(self, network): 17 | super().__init__() 18 | self.network = network 19 | 20 | def forward(self, x): 21 | return self.predict(x) 22 | 23 | def predict(self, x): 24 | return self.network(x) 25 | 26 | 27 | class MeanEncoder(nn.Module): 28 | """Identity function""" 29 | def __init__(self, shape): 30 | super().__init__() 31 | self.shape = shape 32 | 33 | def forward(self, x): 34 | return x 35 | 36 | 37 | class VarianceEncoder(nn.Module): 38 | """Bias-only model with diagonal covariance""" 39 | def __init__(self, shape, init=0.1, channelwise=True, eps=1e-5): 40 | super().__init__() 41 | self.shape = shape 42 | self.eps = eps 43 | 44 | init = (torch.as_tensor(init - eps).exp() - 1.0).log() 45 | b_shape = shape 46 | if channelwise: 47 | if len(shape) == 4: 48 | # [B, C, H, W] 49 | b_shape = (1, shape[1], 1, 1) 50 | elif len(shape ) == 3: 51 | # CLIP-ViT: [H*W+1, B, C] 52 | b_shape = (1, 1, shape[2]) 53 | else: 54 | raise ValueError() 55 | 56 | self.b = nn.Parameter(torch.full(b_shape, init)) 57 | 58 | def forward(self, x): 59 | return F.softplus(self.b) + self.eps 60 | 61 | 62 | def get_shapes(model, input_shape): 63 | # get shape of intermediate features 64 | with torch.no_grad(): 65 | dummy = torch.rand(1, *input_shape).to(next(model.parameters()).device) 66 | _, feats = model(dummy, ret_feats=True) 67 | shapes = [f.shape for f in feats] 68 | 69 | return shapes 70 | 71 | 72 | class MIRO(Algorithm): 73 | """Mutual-Information Regularization with Oracle""" 74 | def __init__(self, input_shape, num_classes, num_domains, hparams, **kwargs): 75 | super().__init__(input_shape, num_classes, num_domains, hparams) 76 | self.pre_featurizer = URFeaturizer( 77 | input_shape, self.hparams, freeze="all", feat_layers=hparams.feat_layers 78 | ) 79 | self.featurizer = URFeaturizer( 80 | input_shape, self.hparams, feat_layers=hparams.feat_layers 81 | ) 82 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes) 83 | self.network = nn.Sequential(self.featurizer, self.classifier) 84 | self.ld = hparams.ld 85 | 86 | # build mean/var encoders 87 | shapes = get_shapes(self.pre_featurizer, self.input_shape) 88 | self.mean_encoders = nn.ModuleList([ 89 | MeanEncoder(shape) for shape in shapes 90 | ]) 91 | self.var_encoders = nn.ModuleList([ 92 | VarianceEncoder(shape) for shape in shapes 93 | ]) 94 | 95 | # optimizer 96 | parameters = [ 97 | {"params": self.network.parameters()}, 98 | {"params": self.mean_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult}, 99 | {"params": self.var_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult}, 100 | ] 101 | self.optimizer = get_optimizer( 102 | hparams["optimizer"], 103 | parameters, 104 | lr=self.hparams["lr"], 105 | weight_decay=self.hparams["weight_decay"], 106 | ) 107 | 108 | def update(self, x, y, **kwargs): 109 | all_x = torch.cat(x) 110 | all_y = torch.cat(y) 111 | feat, inter_feats = self.featurizer(all_x, ret_feats=True) 112 | logit = self.classifier(feat) 113 | loss = F.cross_entropy(logit, all_y) 114 | 115 | # MIRO 116 | with torch.no_grad(): 117 | _, pre_feats = self.pre_featurizer(all_x, ret_feats=True) 118 | 119 | reg_loss = 0. 120 | for f, pre_f, mean_enc, var_enc in misc.zip_strict( 121 | inter_feats, pre_feats, self.mean_encoders, self.var_encoders 122 | ): 123 | # mutual information regularization 124 | mean = mean_enc(f) 125 | var = var_enc(f) 126 | vlb = (mean - pre_f).pow(2).div(var) + var.log() 127 | reg_loss += vlb.mean() / 2. 128 | 129 | loss += reg_loss * self.ld 130 | 131 | self.optimizer.zero_grad() 132 | loss.backward() 133 | self.optimizer.step() 134 | 135 | return {"loss": loss.item(), "reg_loss": reg_loss.item()} 136 | 137 | def predict(self, x): 138 | return self.network(x) 139 | 140 | def get_forward_model(self): 141 | forward_model = ForwardModel(self.network) 142 | return forward_model 143 | -------------------------------------------------------------------------------- /domainbed/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from domainbed.datasets import datasets 5 | from domainbed.lib import misc 6 | from domainbed.datasets import transforms as DBT 7 | 8 | 9 | def set_transfroms(dset, data_type, hparams, algorithm_class=None): 10 | """ 11 | Args: 12 | data_type: ['train', 'valid', 'test', 'mnist'] 13 | """ 14 | assert hparams["data_augmentation"] 15 | 16 | additional_data = False 17 | if data_type == "train": 18 | dset.transforms = {"x": DBT.aug} 19 | additional_data = True 20 | elif data_type == "valid": 21 | if hparams["val_augment"] is False: 22 | dset.transforms = {"x": DBT.basic} 23 | else: 24 | # Originally, DomainBed use same training augmentation policy to validation. 25 | # We turn off the augmentation for validation as default, 26 | # but left the option to reproducibility. 27 | dset.transforms = {"x": DBT.aug} 28 | elif data_type == "test": 29 | dset.transforms = {"x": DBT.basic} 30 | elif data_type == "mnist": 31 | # No augmentation for mnist 32 | dset.transforms = {"x": lambda x: x} 33 | else: 34 | raise ValueError(data_type) 35 | 36 | if additional_data and algorithm_class is not None: 37 | for key, transform in algorithm_class.transforms.items(): 38 | dset.transforms[key] = transform 39 | 40 | 41 | def get_dataset(test_envs, args, hparams, algorithm_class=None): 42 | """Get dataset and split.""" 43 | is_mnist = "MNIST" in args.dataset 44 | dataset = vars(datasets)[args.dataset](args.data_dir) 45 | # if not isinstance(dataset, MultipleEnvironmentImageFolder): 46 | # raise ValueError("SMALL image datasets are not implemented (corrupted), for transform.") 47 | 48 | in_splits = [] 49 | out_splits = [] 50 | for env_i, env in enumerate(dataset): 51 | # The split only depends on seed_hash (= trial_seed). 52 | # It means that the split is always identical only if use same trial_seed, 53 | # independent to run the code where, when, or how many times. 54 | out, in_ = split_dataset( 55 | env, 56 | int(len(env) * args.holdout_fraction), 57 | misc.seed_hash(args.trial_seed, env_i), 58 | ) 59 | if env_i in test_envs: 60 | in_type = "test" 61 | out_type = "test" 62 | else: 63 | in_type = "train" 64 | out_type = "valid" 65 | 66 | if is_mnist: 67 | in_type = "mnist" 68 | out_type = "mnist" 69 | 70 | set_transfroms(in_, in_type, hparams, algorithm_class) 71 | set_transfroms(out, out_type, hparams, algorithm_class) 72 | 73 | if hparams["class_balanced"]: 74 | in_weights = misc.make_weights_for_balanced_classes(in_) 75 | out_weights = misc.make_weights_for_balanced_classes(out) 76 | else: 77 | in_weights, out_weights = None, None 78 | in_splits.append((in_, in_weights)) 79 | out_splits.append((out, out_weights)) 80 | 81 | return dataset, in_splits, out_splits 82 | 83 | 84 | class _SplitDataset(torch.utils.data.Dataset): 85 | """Used by split_dataset""" 86 | 87 | def __init__(self, underlying_dataset, keys): 88 | super(_SplitDataset, self).__init__() 89 | self.underlying_dataset = underlying_dataset 90 | self.keys = keys 91 | self.transforms = {} 92 | 93 | self.direct_return = isinstance(underlying_dataset, _SplitDataset) 94 | 95 | def __getitem__(self, key): 96 | if self.direct_return: 97 | return self.underlying_dataset[self.keys[key]] 98 | 99 | x, y = self.underlying_dataset[self.keys[key]] 100 | ret = {"y": y} 101 | 102 | for key, transform in self.transforms.items(): 103 | ret[key] = transform(x) 104 | 105 | return ret 106 | 107 | def __len__(self): 108 | return len(self.keys) 109 | 110 | 111 | def split_dataset(dataset, n, seed=0): 112 | """ 113 | Return a pair of datasets corresponding to a random split of the given 114 | dataset, with n datapoints in the first dataset and the rest in the last, 115 | using the given random seed 116 | """ 117 | assert n <= len(dataset) 118 | keys = list(range(len(dataset))) 119 | np.random.RandomState(seed).shuffle(keys) 120 | keys_1 = keys[:n] 121 | keys_2 = keys[n:] 122 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 123 | -------------------------------------------------------------------------------- /domainbed/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | from PIL import Image, ImageFile 6 | from torchvision import transforms as T 7 | from torch.utils.data import TensorDataset 8 | from torchvision.datasets import MNIST, ImageFolder 9 | from torchvision.transforms.functional import rotate 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | DATASETS = [ 14 | # Debug 15 | "Debug28", 16 | "Debug224", 17 | # Small images 18 | "ColoredMNIST", 19 | "RotatedMNIST", 20 | # Big images 21 | "VLCS", 22 | "PACS", 23 | "OfficeHome", 24 | "TerraIncognita", 25 | "DomainNet", 26 | ] 27 | 28 | 29 | def get_dataset_class(dataset_name): 30 | """Return the dataset class with the given name.""" 31 | if dataset_name not in globals(): 32 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 33 | return globals()[dataset_name] 34 | 35 | 36 | def num_environments(dataset_name): 37 | return len(get_dataset_class(dataset_name).ENVIRONMENTS) 38 | 39 | 40 | class MultipleDomainDataset: 41 | N_STEPS = 5001 # Default, subclasses may override 42 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 43 | N_WORKERS = 4 # Default, subclasses may override 44 | ENVIRONMENTS = None # Subclasses should override 45 | INPUT_SHAPE = None # Subclasses should override 46 | 47 | def __getitem__(self, index): 48 | """ 49 | Return: sub-dataset for specific domain 50 | """ 51 | return self.datasets[index] 52 | 53 | def __len__(self): 54 | """ 55 | Return: # of sub-datasets 56 | """ 57 | return len(self.datasets) 58 | 59 | 60 | class Debug(MultipleDomainDataset): 61 | def __init__(self, root): 62 | super().__init__() 63 | self.input_shape = self.INPUT_SHAPE 64 | self.num_classes = 2 65 | self.datasets = [] 66 | for _ in [0, 1, 2]: 67 | self.datasets.append( 68 | TensorDataset( 69 | torch.randn(16, *self.INPUT_SHAPE), 70 | torch.randint(0, self.num_classes, (16,)), 71 | ) 72 | ) 73 | 74 | 75 | class Debug28(Debug): 76 | INPUT_SHAPE = (3, 28, 28) 77 | ENVIRONMENTS = ["0", "1", "2"] 78 | 79 | 80 | class Debug224(Debug): 81 | INPUT_SHAPE = (3, 224, 224) 82 | ENVIRONMENTS = ["0", "1", "2"] 83 | 84 | 85 | class MultipleEnvironmentMNIST(MultipleDomainDataset): 86 | def __init__(self, root, environments, dataset_transform, input_shape, num_classes): 87 | """ 88 | Args: 89 | root: root dir for saving MNIST dataset 90 | environments: env properties for each dataset 91 | dataset_transform: dataset generator function 92 | """ 93 | super().__init__() 94 | if root is None: 95 | raise ValueError("Data directory not specified!") 96 | 97 | original_dataset_tr = MNIST(root, train=True, download=True) 98 | original_dataset_te = MNIST(root, train=False, download=True) 99 | 100 | original_images = torch.cat((original_dataset_tr.data, original_dataset_te.data)) 101 | 102 | original_labels = torch.cat((original_dataset_tr.targets, original_dataset_te.targets)) 103 | 104 | shuffle = torch.randperm(len(original_images)) 105 | 106 | original_images = original_images[shuffle] 107 | original_labels = original_labels[shuffle] 108 | 109 | self.datasets = [] 110 | self.environments = environments 111 | 112 | for i in range(len(environments)): 113 | images = original_images[i :: len(environments)] 114 | labels = original_labels[i :: len(environments)] 115 | self.datasets.append(dataset_transform(images, labels, environments[i])) 116 | 117 | self.input_shape = input_shape 118 | self.num_classes = num_classes 119 | 120 | 121 | class ColoredMNIST(MultipleEnvironmentMNIST): 122 | ENVIRONMENTS = ["+90%", "+80%", "-90%"] 123 | 124 | def __init__(self, root): 125 | super(ColoredMNIST, self).__init__( 126 | root, 127 | [0.1, 0.2, 0.9], 128 | self.color_dataset, 129 | (2, 28, 28), 130 | 2, 131 | ) 132 | 133 | def color_dataset(self, images, labels, environment): 134 | # # Subsample 2x for computational convenience 135 | # images = images.reshape((-1, 28, 28))[:, ::2, ::2] 136 | # Assign a binary label based on the digit 137 | labels = (labels < 5).float() 138 | # Flip label with probability 0.25 139 | labels = self.torch_xor_(labels, self.torch_bernoulli_(0.25, len(labels))) 140 | 141 | # Assign a color based on the label; flip the color with probability e 142 | colors = self.torch_xor_(labels, self.torch_bernoulli_(environment, len(labels))) 143 | images = torch.stack([images, images], dim=1) 144 | # Apply the color to the image by zeroing out the other color channel 145 | images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0 146 | 147 | x = images.float().div_(255.0) 148 | y = labels.view(-1).long() 149 | 150 | return TensorDataset(x, y) 151 | 152 | def torch_bernoulli_(self, p, size): 153 | return (torch.rand(size) < p).float() 154 | 155 | def torch_xor_(self, a, b): 156 | return (a - b).abs() 157 | 158 | 159 | class RotatedMNIST(MultipleEnvironmentMNIST): 160 | ENVIRONMENTS = ["0", "15", "30", "45", "60", "75"] 161 | 162 | def __init__(self, root): 163 | super(RotatedMNIST, self).__init__( 164 | root, 165 | [0, 15, 30, 45, 60, 75], 166 | self.rotate_dataset, 167 | (1, 28, 28), 168 | 10, 169 | ) 170 | 171 | def rotate_dataset(self, images, labels, angle): 172 | rotation = T.Compose( 173 | [ 174 | T.ToPILImage(), 175 | T.Lambda(lambda x: rotate(x, angle, fill=(0,), resample=Image.BICUBIC)), 176 | T.ToTensor(), 177 | ] 178 | ) 179 | 180 | x = torch.zeros(len(images), 1, 28, 28) 181 | for i in range(len(images)): 182 | x[i] = rotation(images[i]) 183 | 184 | y = labels.view(-1) 185 | 186 | return TensorDataset(x, y) 187 | 188 | 189 | class MultipleEnvironmentImageFolder(MultipleDomainDataset): 190 | def __init__(self, root): 191 | super().__init__() 192 | environments = [f.name for f in os.scandir(root) if f.is_dir()] 193 | environments = sorted(environments) 194 | self.environments = environments 195 | 196 | self.datasets = [] 197 | for environment in environments: 198 | path = os.path.join(root, environment) 199 | env_dataset = ImageFolder(path) 200 | 201 | self.datasets.append(env_dataset) 202 | 203 | self.input_shape = (3, 224, 224) 204 | self.num_classes = len(self.datasets[-1].classes) 205 | 206 | 207 | class VLCS(MultipleEnvironmentImageFolder): 208 | CHECKPOINT_FREQ = 200 209 | ENVIRONMENTS = ["C", "L", "S", "V"] 210 | 211 | def __init__(self, root): 212 | self.dir = os.path.join(root, "VLCS/") 213 | super().__init__(self.dir) 214 | 215 | 216 | class PACS(MultipleEnvironmentImageFolder): 217 | CHECKPOINT_FREQ = 200 218 | ENVIRONMENTS = ["A", "C", "P", "S"] 219 | 220 | def __init__(self, root): 221 | self.dir = os.path.join(root, "PACS/") 222 | super().__init__(self.dir) 223 | 224 | 225 | class DomainNet(MultipleEnvironmentImageFolder): 226 | CHECKPOINT_FREQ = 1000 227 | N_STEPS = 15001 228 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"] 229 | 230 | def __init__(self, root): 231 | self.dir = os.path.join(root, "domain_net/") 232 | super().__init__(self.dir) 233 | 234 | 235 | class OfficeHome(MultipleEnvironmentImageFolder): 236 | CHECKPOINT_FREQ = 200 237 | ENVIRONMENTS = ["A", "C", "P", "R"] 238 | 239 | def __init__(self, root): 240 | self.dir = os.path.join(root, "office_home/") 241 | super().__init__(self.dir) 242 | 243 | 244 | class TerraIncognita(MultipleEnvironmentImageFolder): 245 | CHECKPOINT_FREQ = 200 246 | ENVIRONMENTS = ["L100", "L38", "L43", "L46"] 247 | 248 | def __init__(self, root): 249 | self.dir = os.path.join(root, "terra_incognita/") 250 | super().__init__(self.dir) 251 | -------------------------------------------------------------------------------- /domainbed/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as T 2 | 3 | 4 | basic = T.Compose( 5 | [ 6 | T.Resize((224, 224)), 7 | T.ToTensor(), 8 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 9 | ] 10 | ) 11 | aug = T.Compose( 12 | [ 13 | T.RandomResizedCrop(224, scale=(0.7, 1.0)), 14 | T.RandomHorizontalFlip(), 15 | T.ColorJitter(0.3, 0.3, 0.3, 0.3), 16 | T.RandomGrayscale(p=0.1), 17 | T.ToTensor(), 18 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 19 | ] 20 | ) 21 | -------------------------------------------------------------------------------- /domainbed/evaluator.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from domainbed.lib.fast_data_loader import FastDataLoader 6 | 7 | if torch.cuda.is_available(): 8 | device = "cuda" 9 | else: 10 | device = "cpu" 11 | 12 | 13 | def accuracy_from_loader(algorithm, loader, weights, debug=False): 14 | correct = 0 15 | total = 0 16 | losssum = 0.0 17 | weights_offset = 0 18 | 19 | algorithm.eval() 20 | 21 | for i, batch in enumerate(loader): 22 | x = batch["x"].to(device) 23 | y = batch["y"].to(device) 24 | 25 | with torch.no_grad(): 26 | logits = algorithm.predict(x) 27 | loss = F.cross_entropy(logits, y).item() 28 | 29 | B = len(x) 30 | losssum += loss * B 31 | 32 | if weights is None: 33 | batch_weights = torch.ones(len(x)) 34 | else: 35 | batch_weights = weights[weights_offset : weights_offset + len(x)] 36 | weights_offset += len(x) 37 | batch_weights = batch_weights.to(device) 38 | if logits.size(1) == 1: 39 | correct += (logits.gt(0).eq(y).float() * batch_weights).sum().item() 40 | else: 41 | correct += (logits.argmax(1).eq(y).float() * batch_weights).sum().item() 42 | total += batch_weights.sum().item() 43 | 44 | if debug: 45 | break 46 | 47 | algorithm.train() 48 | 49 | acc = correct / total 50 | loss = losssum / total 51 | return acc, loss 52 | 53 | 54 | def accuracy(algorithm, loader_kwargs, weights, **kwargs): 55 | if isinstance(loader_kwargs, dict): 56 | loader = FastDataLoader(**loader_kwargs) 57 | elif isinstance(loader_kwargs, FastDataLoader): 58 | loader = loader_kwargs 59 | else: 60 | raise ValueError(loader_kwargs) 61 | return accuracy_from_loader(algorithm, loader, weights, **kwargs) 62 | 63 | 64 | class Evaluator: 65 | def __init__( 66 | self, test_envs, eval_meta, n_envs, logger, evalmode="fast", debug=False, target_env=None 67 | ): 68 | all_envs = list(range(n_envs)) 69 | train_envs = sorted(set(all_envs) - set(test_envs)) 70 | self.test_envs = test_envs 71 | self.train_envs = train_envs 72 | self.eval_meta = eval_meta 73 | self.n_envs = n_envs 74 | self.logger = logger 75 | self.evalmode = evalmode 76 | self.debug = debug 77 | 78 | if target_env is not None: 79 | self.set_target_env(target_env) 80 | 81 | def set_target_env(self, target_env): 82 | """When len(test_envs) == 2, you can specify target env for computing exact test acc.""" 83 | self.test_envs = [target_env] 84 | 85 | 86 | def evaluate(self, algorithm, suffix): 87 | n_train_envs = len(self.train_envs) 88 | n_test_envs = len(self.test_envs) 89 | assert n_test_envs == 1 90 | summaries = collections.defaultdict(float) 91 | # for key order 92 | summaries["test_in"+suffix] = 0.0 93 | summaries["test_out"+suffix] = 0.0 94 | summaries["comb_val"+suffix] = 0.0 95 | # order: in_splits + out_splits. 96 | for name, loader_kwargs, weights in self.eval_meta: 97 | # env\d_[in|out] 98 | env_name, inout = name.split("_") 99 | env_num = int(env_name[3:]) 100 | 101 | skip_eval = self.evalmode == "fast" and inout == "in" and env_num not in self.test_envs 102 | if skip_eval: 103 | continue ## removing env_in of train envs 104 | 105 | is_test = env_num in self.test_envs 106 | acc, loss = accuracy(algorithm, loader_kwargs, weights, debug=self.debug) 107 | 108 | if env_num in self.train_envs: 109 | summaries["comb_val" + suffix] += acc / n_train_envs 110 | if inout == "out": 111 | summaries["comb_val_loss"+suffix] += loss / n_train_envs 112 | elif is_test: 113 | summaries["test_" + inout + suffix] += acc / n_test_envs 114 | 115 | return summaries -------------------------------------------------------------------------------- /domainbed/hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import numpy as np 4 | 5 | 6 | def _hparams(algorithm, dataset, random_state): 7 | """ 8 | Global registry of hyperparams. Each entry is a (default, random) tuple. 9 | New algorithms / networks / etc. should add entries here. 10 | """ 11 | SMALL_IMAGES = ["Debug28", "RotatedMNIST", "ColoredMNIST"] 12 | 13 | hparams = {} 14 | 15 | hparams["data_augmentation"] = (True, True) 16 | hparams["val_augment"] = (False, False) # augmentation for in-domain validation set 17 | hparams["resnet18"] = (False, False) 18 | hparams["resnet_dropout"] = (0.0, random_state.choice([0.0, 0.1, 0.5])) 19 | hparams["class_balanced"] = (False, False) 20 | hparams["optimizer"] = ("adam", "adam") 21 | 22 | hparams["freeze_bn"] = (True, True) 23 | hparams["pretrained"] = (True, True) # only for ResNet 24 | 25 | if dataset not in SMALL_IMAGES: 26 | hparams["lr"] = (5e-5, 10 ** random_state.uniform(-5, -3.5)) 27 | if dataset == "DomainNet": 28 | hparams["batch_size"] = (32, int(2 ** random_state.uniform(3, 5))) 29 | else: 30 | hparams["batch_size"] = (32, int(2 ** random_state.uniform(3, 5.5))) 31 | if algorithm == "ARM": 32 | hparams["batch_size"] = (8, 8) 33 | else: 34 | hparams["lr"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5)) 35 | hparams["batch_size"] = (64, int(2 ** random_state.uniform(3, 9))) 36 | 37 | if dataset in SMALL_IMAGES: 38 | hparams["weight_decay"] = (0.0, 0.0) 39 | else: 40 | hparams["weight_decay"] = (0.0, 10 ** random_state.uniform(-6, -2)) 41 | 42 | if algorithm in ["DANN", "CDANN"]: 43 | if dataset not in SMALL_IMAGES: 44 | hparams["lr_g"] = (5e-5, 10 ** random_state.uniform(-5, -3.5)) 45 | hparams["lr_d"] = (5e-5, 10 ** random_state.uniform(-5, -3.5)) 46 | else: 47 | hparams["lr_g"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5)) 48 | hparams["lr_d"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5)) 49 | 50 | if dataset in SMALL_IMAGES: 51 | hparams["weight_decay_g"] = (0.0, 0.0) 52 | else: 53 | hparams["weight_decay_g"] = (0.0, 10 ** random_state.uniform(-6, -2)) 54 | 55 | hparams["lambda"] = (1.0, 10 ** random_state.uniform(-2, 2)) 56 | hparams["weight_decay_d"] = (0.0, 10 ** random_state.uniform(-6, -2)) 57 | hparams["d_steps_per_g_step"] = (1, int(2 ** random_state.uniform(0, 3))) 58 | hparams["grad_penalty"] = (0.0, 10 ** random_state.uniform(-2, 1)) 59 | hparams["beta1"] = (0.5, random_state.choice([0.0, 0.5])) 60 | hparams["mlp_width"] = (256, int(2 ** random_state.uniform(6, 10))) 61 | hparams["mlp_depth"] = (3, int(random_state.choice([3, 4, 5]))) 62 | hparams["mlp_dropout"] = (0.0, random_state.choice([0.0, 0.1, 0.5])) 63 | elif algorithm == "RSC": 64 | hparams["rsc_f_drop_factor"] = (1 / 3, random_state.uniform(0, 0.5)) 65 | hparams["rsc_b_drop_factor"] = (1 / 3, random_state.uniform(0, 0.5)) 66 | elif algorithm == "SagNet": 67 | hparams["sag_w_adv"] = (0.1, 10 ** random_state.uniform(-2, 1)) 68 | elif algorithm == "IRM": 69 | hparams["irm_lambda"] = (1e2, 10 ** random_state.uniform(-1, 5)) 70 | hparams["irm_penalty_anneal_iters"] = ( 71 | 500, 72 | int(10 ** random_state.uniform(0, 4)), 73 | ) 74 | elif algorithm in ["Mixup", "OrgMixup"]: 75 | hparams["mixup_alpha"] = (0.2, 10 ** random_state.uniform(-1, -1)) 76 | elif algorithm == "GroupDRO": 77 | hparams["groupdro_eta"] = (1e-2, 10 ** random_state.uniform(-3, -1)) 78 | elif algorithm in ("MMD", "CORAL"): 79 | hparams["mmd_gamma"] = (1.0, 10 ** random_state.uniform(-1, 1)) 80 | elif algorithm in ("MLDG", "SOMLDG"): 81 | hparams["mldg_beta"] = (1.0, 10 ** random_state.uniform(-1, 1)) 82 | elif algorithm == "MTL": 83 | hparams["mtl_ema"] = (0.99, random_state.choice([0.5, 0.9, 0.99, 1.0])) 84 | elif algorithm == "VREx": 85 | hparams["vrex_lambda"] = (1e1, 10 ** random_state.uniform(-1, 5)) 86 | hparams["vrex_penalty_anneal_iters"] = ( 87 | 500, 88 | int(10 ** random_state.uniform(0, 4)), 89 | ) 90 | elif algorithm == "SAM": 91 | hparams["rho"] = (0.05, random_state.choice([0.01, 0.02, 0.05, 0.1])) 92 | elif algorithm == "CutMix": 93 | hparams["beta"] = (1.0, 1.0) 94 | # cutmix_prob is set to 1.0 for ImageNet and 0.5 for CIFAR100 in the original paper. 95 | hparams["cutmix_prob"] = (1.0, 1.0) 96 | 97 | return hparams 98 | 99 | 100 | def default_hparams(algorithm, dataset): 101 | dummy_random_state = np.random.RandomState(0) 102 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, dummy_random_state).items()} 103 | 104 | 105 | def random_hparams(algorithm, dataset, seed): 106 | random_state = np.random.RandomState(seed) 107 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, random_state).items()} 108 | -------------------------------------------------------------------------------- /domainbed/lib/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | 6 | class _InfiniteSampler(torch.utils.data.Sampler): 7 | """Wraps another Sampler to yield an infinite stream.""" 8 | 9 | def __init__(self, sampler): 10 | self.sampler = sampler 11 | 12 | def __iter__(self): 13 | while True: 14 | for batch in self.sampler: 15 | yield batch 16 | 17 | 18 | class InfiniteDataLoader: 19 | def __init__(self, dataset, weights, batch_size, num_workers): 20 | super().__init__() 21 | 22 | if weights: 23 | sampler = torch.utils.data.WeightedRandomSampler( 24 | weights, replacement=True, num_samples=batch_size 25 | ) 26 | else: 27 | sampler = torch.utils.data.RandomSampler(dataset, replacement=True) 28 | 29 | batch_sampler = torch.utils.data.BatchSampler( 30 | sampler, batch_size=batch_size, drop_last=True 31 | ) 32 | 33 | self._infinite_iterator = iter( 34 | torch.utils.data.DataLoader( 35 | dataset, 36 | num_workers=num_workers, 37 | batch_sampler=_InfiniteSampler(batch_sampler), 38 | ) 39 | ) 40 | 41 | def __iter__(self): 42 | while True: 43 | yield next(self._infinite_iterator) 44 | 45 | def __len__(self): 46 | raise ValueError 47 | 48 | 49 | class FastDataLoader: 50 | """ 51 | DataLoader wrapper with slightly improved speed by not respawning worker 52 | processes at every epoch. 53 | """ 54 | 55 | def __init__(self, dataset, batch_size, num_workers, shuffle=False): 56 | super().__init__() 57 | 58 | if shuffle: 59 | sampler = torch.utils.data.RandomSampler(dataset, replacement=False) 60 | else: 61 | sampler = torch.utils.data.SequentialSampler(dataset) 62 | 63 | batch_sampler = torch.utils.data.BatchSampler( 64 | sampler, 65 | batch_size=batch_size, 66 | drop_last=False, 67 | ) 68 | 69 | self._infinite_iterator = iter( 70 | torch.utils.data.DataLoader( 71 | dataset, 72 | num_workers=num_workers, 73 | batch_sampler=_InfiniteSampler(batch_sampler), 74 | ) 75 | ) 76 | 77 | self._length = len(batch_sampler) 78 | 79 | def __iter__(self): 80 | for _ in range(len(self)): 81 | yield next(self._infinite_iterator) 82 | 83 | def __len__(self): 84 | return self._length 85 | -------------------------------------------------------------------------------- /domainbed/lib/logger.py: -------------------------------------------------------------------------------- 1 | """ Singleton Logger """ 2 | import sys 3 | import logging 4 | 5 | 6 | def levelize(levelname): 7 | """Convert levelname to level only if it is levelname""" 8 | if isinstance(levelname, str): 9 | return logging.getLevelName(levelname) 10 | else: 11 | return levelname # already level 12 | 13 | 14 | class ColorFormatter(logging.Formatter): 15 | color_dic = { 16 | "DEBUG": 37, # white 17 | "INFO": 36, # cyan 18 | "WARNING": 33, # yellow 19 | "ERROR": 31, # red 20 | "CRITICAL": 41, # white on red bg 21 | } 22 | 23 | def format(self, record): 24 | color = self.color_dic.get(record.levelname, 37) # default white 25 | record.levelname = "\033[{}m{}\033[0m".format(color, record.levelname) 26 | return logging.Formatter.format(self, record) 27 | 28 | 29 | class Logger(logging.Logger): 30 | NAME = "SingletonLogger" 31 | 32 | @classmethod 33 | def get(cls, file_path=None, level="INFO", colorize=True, track_code=False): 34 | logging.setLoggerClass(cls) 35 | logger = logging.getLogger(cls.NAME) 36 | logging.setLoggerClass(logging.Logger) # restore 37 | logger.setLevel(level) 38 | 39 | if logger.hasHandlers(): 40 | # If logger already got all handlers (# handlers == 2), use the logger. 41 | # else, re-set handlers. 42 | if len(logger.handlers) == 2: 43 | return logger 44 | 45 | logger.handlers.clear() 46 | 47 | log_format = "%(levelname)s %(asctime)s | %(message)s" 48 | # log_format = '%(asctime)s | %(message)s' 49 | if track_code: 50 | log_format = ( 51 | "%(levelname)s::%(asctime)s | [%(filename)s] [%(funcName)s:%(lineno)d] " 52 | "%(message)s" 53 | ) 54 | date_format = "%m/%d %H:%M:%S" 55 | if colorize: 56 | formatter = ColorFormatter(log_format, date_format) 57 | else: 58 | formatter = logging.Formatter(log_format, date_format) 59 | 60 | # standard output handler 61 | # NOTE as default, StreamHandler use stderr stream instead of stdout stream. 62 | # Use StreamHandler(sys.stdout) for stdout stream. 63 | stream_handler = logging.StreamHandler(sys.stdout) 64 | stream_handler.setFormatter(formatter) 65 | logger.addHandler(stream_handler) 66 | 67 | if file_path: 68 | # file output handler 69 | file_handler = logging.FileHandler(file_path) 70 | file_handler.setFormatter(formatter) 71 | logger.addHandler(file_handler) 72 | 73 | logger.propagate = False 74 | 75 | return logger 76 | 77 | def nofmt(self, msg, *args, level="INFO", **kwargs): 78 | level = levelize(level) 79 | formatters = self.remove_formats() 80 | super().log(level, msg, *args, **kwargs) 81 | self.set_formats(formatters) 82 | 83 | def remove_formats(self): 84 | """Remove all formats from logger""" 85 | formatters = [] 86 | for handler in self.handlers: 87 | formatters.append(handler.formatter) 88 | handler.setFormatter(logging.Formatter("%(message)s")) 89 | 90 | return formatters 91 | 92 | def set_formats(self, formatters): 93 | """Set formats to every handler of logger""" 94 | for handler, formatter in zip(self.handlers, formatters): 95 | handler.setFormatter(formatter) 96 | 97 | def set_file_handler(self, file_path): 98 | file_handler = logging.FileHandler(file_path) 99 | formatter = self.handlers[0].formatter 100 | file_handler.setFormatter(formatter) 101 | self.addHandler(file_handler) 102 | -------------------------------------------------------------------------------- /domainbed/lib/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Things that don't belong anywhere else 5 | """ 6 | 7 | import hashlib 8 | import sys 9 | import random 10 | import os 11 | import shutil 12 | import errno 13 | from itertools import chain 14 | from datetime import datetime 15 | from collections import Counter 16 | from typing import List 17 | from contextlib import contextmanager 18 | from subprocess import call 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | 26 | def make_weights_for_balanced_classes(dataset): 27 | counts = Counter() 28 | classes = [] 29 | for _, y in dataset: 30 | y = int(y) 31 | counts[y] += 1 32 | classes.append(y) 33 | 34 | n_classes = len(counts) 35 | 36 | weight_per_class = {} 37 | for y in counts: 38 | weight_per_class[y] = 1 / (counts[y] * n_classes) 39 | 40 | weights = torch.zeros(len(dataset)) 41 | for i, y in enumerate(classes): 42 | weights[i] = weight_per_class[int(y)] 43 | 44 | return weights 45 | 46 | 47 | def seed_hash(*args): 48 | """ 49 | Derive an integer hash from all args, for use as a random seed. 50 | """ 51 | args_str = str(args) 52 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2 ** 31) 53 | 54 | 55 | def to_row(row, colwidth=10, latex=False): 56 | """Convert value list to row string""" 57 | if latex: 58 | sep = " & " 59 | end_ = "\\\\" 60 | else: 61 | sep = " " 62 | end_ = "" 63 | 64 | def format_val(x): 65 | if np.issubdtype(type(x), np.floating): 66 | x = "{:.6f}".format(x) 67 | return str(x).ljust(colwidth)[:colwidth] 68 | 69 | return sep.join([format_val(x) for x in row]) + " " + end_ 70 | 71 | 72 | def random_pairs_of_minibatches(minibatches): 73 | # n_tr_envs = len(minibatches) 74 | perm = torch.randperm(len(minibatches)).tolist() 75 | pairs = [] 76 | 77 | for i in range(len(minibatches)): 78 | # j = cyclic(i + 1) 79 | j = i + 1 if i < (len(minibatches) - 1) else 0 80 | 81 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 82 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 83 | 84 | min_n = min(len(xi), len(xj)) 85 | 86 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 87 | 88 | return pairs 89 | 90 | 91 | ########################################################### 92 | # Custom utils 93 | ########################################################### 94 | 95 | 96 | def index_conditional_iterate(skip_condition, iterable, index): 97 | for i, x in enumerate(iterable): 98 | if skip_condition(i): 99 | continue 100 | 101 | if index: 102 | yield i, x 103 | else: 104 | yield x 105 | 106 | 107 | class SplitIterator: 108 | def __init__(self, test_envs): 109 | self.test_envs = test_envs 110 | 111 | def train(self, iterable, index=False): 112 | return index_conditional_iterate(lambda idx: idx in self.test_envs, iterable, index) 113 | 114 | def test(self, iterable, index=False): 115 | return index_conditional_iterate(lambda idx: idx not in self.test_envs, iterable, index) 116 | 117 | 118 | class AverageMeter(): 119 | """ Computes and stores the average and current value """ 120 | def __init__(self): 121 | self.reset() 122 | 123 | def reset(self): 124 | """ Reset all statistics """ 125 | self.val = 0 126 | self.avg = 0 127 | self.sum = 0 128 | self.count = 0 129 | 130 | def update(self, val, n=1): 131 | """ Update statistics """ 132 | self.val = val 133 | self.sum += val * n 134 | self.count += n 135 | self.avg = self.sum / self.count 136 | 137 | def __repr__(self): 138 | return "{:.3f} (val={:.3f}, count={})".format(self.avg, self.val, self.count) 139 | 140 | 141 | class AverageMeters(): 142 | def __init__(self, *keys): 143 | self.keys = keys 144 | for k in keys: 145 | setattr(self, k, AverageMeter()) 146 | 147 | def resets(self): 148 | for k in self.keys: 149 | getattr(self, k).reset() 150 | 151 | def updates(self, dic, n=1): 152 | for k, v in dic.items(): 153 | getattr(self, k).update(v, n) 154 | 155 | def __repr__(self): 156 | return " ".join(["{}: {}".format(k, str(getattr(self, k))) for k in self.keys]) 157 | 158 | def get_averages(self): 159 | dic = {k: getattr(self, k).avg for k in self.keys} 160 | return dic 161 | 162 | 163 | def timestamp(fmt="%y%m%d_%H-%M-%S"): 164 | return datetime.now().strftime(fmt) 165 | 166 | 167 | def makedirs(path): 168 | if not os.path.exists(path): 169 | try: 170 | os.makedirs(path) 171 | except OSError as exc: 172 | if exc.errno != errno.EEXIST: 173 | raise 174 | 175 | 176 | def rm(path): 177 | """ remove dir recursively """ 178 | if os.path.isdir(path): 179 | shutil.rmtree(path, ignore_errors=True) 180 | elif os.path.exists(path): 181 | os.remove(path) 182 | 183 | 184 | def cp(src, dst): 185 | shutil.copy2(src, dst) 186 | 187 | 188 | def set_seed(seed): 189 | random.seed(seed) 190 | # os.environ['PYTHONHASHSEED'] = str(seed) 191 | np.random.seed(seed) 192 | torch.manual_seed(seed) 193 | # torch.backends.cudnn.deterministic = True 194 | torch.backends.cudnn.benchmark = True 195 | 196 | 197 | def get_lr(optimizer): 198 | """Assume that the optimizer has single lr""" 199 | lr = optimizer.param_groups[0]['lr'] 200 | 201 | return lr 202 | 203 | 204 | def entropy(logits): 205 | ent = F.softmax(logits, -1) * F.log_softmax(logits, -1) 206 | ent = -ent.sum(1) # batch-wise 207 | return ent.mean() 208 | 209 | 210 | @torch.no_grad() 211 | def hash_bn(module): 212 | summary = [] 213 | for m in module.modules(): 214 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 215 | w = m.weight.detach().mean().item() 216 | b = m.bias.detach().mean().item() 217 | rm = m.running_mean.detach().mean().item() 218 | rv = m.running_var.detach().mean().item() 219 | summary.append((w, b, rm, rv)) 220 | 221 | if not summary: 222 | return 0., 0. 223 | 224 | w, b, rm, rv = [np.mean(col) for col in zip(*summary)] 225 | p = np.mean([w, b]) 226 | s = np.mean([rm, rv]) 227 | 228 | return p, s 229 | 230 | 231 | @torch.no_grad() 232 | def hash_params(module): 233 | return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item() 234 | 235 | 236 | @torch.no_grad() 237 | def hash_module(module): 238 | p = hash_params(module) 239 | _, s = hash_bn(module) 240 | 241 | return p, s 242 | 243 | 244 | def merge_dictlist(dictlist): 245 | """Merge list of dicts into dict of lists, by grouping same key. 246 | """ 247 | ret = { 248 | k: [] 249 | for k in dictlist[0].keys() 250 | } 251 | for dic in dictlist: 252 | for data_key, v in dic.items(): 253 | ret[data_key].append(v) 254 | return ret 255 | 256 | 257 | def zip_strict(*iterables): 258 | """strict version of zip. The length of iterables should be same. 259 | 260 | NOTE yield looks non-reachable, but they are required. 261 | """ 262 | # For trivial cases, use pure zip. 263 | if len(iterables) < 2: 264 | return zip(*iterables) 265 | 266 | # Tail for the first iterable 267 | first_stopped = False 268 | def first_tail(): 269 | nonlocal first_stopped 270 | first_stopped = True 271 | return 272 | yield 273 | 274 | # Tail for the zip 275 | def zip_tail(): 276 | if not first_stopped: 277 | raise ValueError('zip_equal: first iterable is longer') 278 | for _ in chain.from_iterable(rest): 279 | raise ValueError('zip_equal: first iterable is shorter') 280 | yield 281 | 282 | # Put the pieces together 283 | iterables = iter(iterables) 284 | first = chain(next(iterables), first_tail()) 285 | rest = list(map(iter, iterables)) 286 | return chain(zip(first, *rest), zip_tail()) 287 | 288 | 289 | def freeze_(module): 290 | for p in module.parameters(): 291 | p.requires_grad_(False) 292 | module.eval() 293 | 294 | 295 | def unfreeze_(module): 296 | for p in module.parameters(): 297 | p.requires_grad_(True) 298 | module.train() 299 | -------------------------------------------------------------------------------- /domainbed/lib/query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Small query library.""" 4 | 5 | import inspect 6 | import json 7 | import types 8 | import warnings 9 | 10 | import numpy as np 11 | 12 | 13 | def make_selector_fn(selector): 14 | """ 15 | If selector is a function, return selector. 16 | Otherwise, return a function corresponding to the selector string. Examples 17 | of valid selector strings and the corresponding functions: 18 | x lambda obj: obj['x'] 19 | x.y lambda obj: obj['x']['y'] 20 | x,y lambda obj: (obj['x'], obj['y']) 21 | """ 22 | if isinstance(selector, str): 23 | if "," in selector: 24 | parts = selector.split(",") 25 | part_selectors = [make_selector_fn(part) for part in parts] 26 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 27 | elif "." in selector: 28 | parts = selector.split(".") 29 | part_selectors = [make_selector_fn(part) for part in parts] 30 | 31 | def f(obj): 32 | for sel in part_selectors: 33 | obj = sel(obj) 34 | return obj 35 | 36 | return f 37 | else: 38 | key = selector.strip() 39 | return lambda obj: obj[key] 40 | elif isinstance(selector, types.FunctionType): 41 | return selector 42 | else: 43 | raise TypeError 44 | 45 | 46 | def hashable(obj): 47 | try: 48 | hash(obj) 49 | return obj 50 | except TypeError: 51 | return json.dumps({"_": obj}, sort_keys=True) 52 | 53 | 54 | class Q(object): 55 | def __init__(self, list_): 56 | super(Q, self).__init__() 57 | self._list = list_ 58 | 59 | def __len__(self): 60 | return len(self._list) 61 | 62 | def __getitem__(self, key): 63 | return self._list[key] 64 | 65 | def __eq__(self, other): 66 | if isinstance(other, self.__class__): 67 | return self._list == other._list 68 | else: 69 | return self._list == other 70 | 71 | def __str__(self): 72 | return str(self._list) 73 | 74 | def __repr__(self): 75 | return repr(self._list) 76 | 77 | def _append(self, item): 78 | """Unsafe, be careful you know what you're doing.""" 79 | self._list.append(item) 80 | 81 | def group(self, selector): 82 | """ 83 | Group elements by selector and return a list of (group, group_records) 84 | tuples. 85 | """ 86 | selector = make_selector_fn(selector) 87 | groups = {} 88 | for x in self._list: 89 | group = selector(x) 90 | group_key = hashable(group) 91 | if group_key not in groups: 92 | groups[group_key] = (group, Q([])) 93 | groups[group_key][1]._append(x) 94 | results = [groups[key] for key in sorted(groups.keys())] 95 | return Q(results) 96 | 97 | def group_map(self, selector, fn): 98 | """ 99 | Group elements by selector, apply fn to each group, and return a list 100 | of the results. 101 | """ 102 | return self.group(selector).map(fn) 103 | 104 | def map(self, fn): 105 | """ 106 | map self onto fn. If fn takes multiple args, tuple-unpacking 107 | is applied. 108 | """ 109 | if len(inspect.signature(fn).parameters) > 1: 110 | return Q([fn(*x) for x in self._list]) 111 | else: 112 | return Q([fn(x) for x in self._list]) 113 | 114 | def select(self, selector): 115 | selector = make_selector_fn(selector) 116 | return Q([selector(x) for x in self._list]) 117 | 118 | def min(self): 119 | return min(self._list) 120 | 121 | def max(self): 122 | return max(self._list) 123 | 124 | def sum(self): 125 | return sum(self._list) 126 | 127 | def len(self): 128 | return len(self._list) 129 | 130 | def mean(self): 131 | with warnings.catch_warnings(): 132 | warnings.simplefilter("ignore") 133 | return float(np.mean(self._list)) 134 | 135 | def std(self): 136 | with warnings.catch_warnings(): 137 | warnings.simplefilter("ignore") 138 | return float(np.std(self._list)) 139 | 140 | def mean_std(self): 141 | return (self.mean(), self.std()) 142 | 143 | def argmax(self, selector): 144 | selector = make_selector_fn(selector) 145 | return max(self._list, key=selector) 146 | 147 | def filter(self, fn): 148 | return Q([x for x in self._list if fn(x)]) 149 | 150 | def filter_equals(self, selector, value): 151 | """like [x for x in y if x.selector == value]""" 152 | selector = make_selector_fn(selector) 153 | return self.filter(lambda r: selector(r) == value) 154 | 155 | def filter_not_none(self): 156 | return self.filter(lambda r: r is not None) 157 | 158 | def filter_not_nan(self): 159 | return self.filter(lambda r: not np.isnan(r)) 160 | 161 | def flatten(self): 162 | return Q([y for x in self._list for y in x]) 163 | 164 | def unique(self): 165 | result = [] 166 | result_set = set() 167 | for x in self._list: 168 | hashable_x = hashable(x) 169 | if hashable_x not in result_set: 170 | result_set.add(hashable_x) 171 | result.append(x) 172 | return Q(result) 173 | 174 | def sorted(self, key=None, reverse=False): 175 | if key is None: 176 | key = lambda x: x 177 | 178 | def key2(x): 179 | x = key(x) 180 | if isinstance(x, (np.floating, float)) and np.isnan(x): 181 | return float("-inf") 182 | else: 183 | return x 184 | 185 | return Q(sorted(self._list, key=key2, reverse=reverse)) 186 | -------------------------------------------------------------------------------- /domainbed/lib/swa_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py 2 | import copy 3 | import warnings 4 | import math 5 | from copy import deepcopy 6 | 7 | import torch 8 | from torch.nn import Module 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | 11 | from domainbed.networks.ur_networks import URResNet 12 | 13 | 14 | class AveragedModel(Module): 15 | def filter(self, model): 16 | if isinstance(model, AveragedModel): 17 | # prevent nested averagedmodel 18 | model = model.module 19 | 20 | if hasattr(model, "get_forward_model"): 21 | model = model.get_forward_model() 22 | # URERM models use URNetwork, which manages features internally. 23 | for m in model.modules(): 24 | if isinstance(m, URResNet): 25 | m.clear_features() 26 | 27 | return model 28 | 29 | def __init__(self, model, device=None, avg_fn=None, rm_optimizer=False): 30 | super(AveragedModel, self).__init__() 31 | self.start_step = -1 32 | self.end_step = -1 33 | model = self.filter(model) 34 | self.module = deepcopy(model) 35 | self.module.zero_grad() 36 | if rm_optimizer: 37 | for k, v in vars(self.module).items(): 38 | if isinstance(v, torch.optim.Optimizer): 39 | setattr(self.module, k, None) 40 | # print(f"{k} -> {getattr(self.module, k)}") 41 | if device is not None: 42 | self.module = self.module.to(device) 43 | self.register_buffer('n_averaged', torch.tensor(0, dtype=torch.long, device=device)) 44 | if avg_fn is None: 45 | def avg_fn(averaged_model_parameter, model_parameter, num_averaged): 46 | return averaged_model_parameter + \ 47 | (model_parameter - averaged_model_parameter) / (num_averaged + 1) 48 | self.avg_fn = avg_fn 49 | 50 | def forward(self, *args, **kwargs): 51 | # return self.predict(*args, **kwargs) 52 | return self.module(*args, **kwargs) 53 | 54 | def predict(self, *args, **kwargs): 55 | return self.module.predict(*args, **kwargs) 56 | 57 | @property 58 | def network(self): 59 | return self.module.network 60 | 61 | def update_parameters(self, model, step=None, start_step=None, end_step=None): 62 | model = self.filter(model) 63 | for p_swa, p_model in zip(self.parameters(), model.parameters()): 64 | device = p_swa.device 65 | p_model_ = p_model.detach().to(device) 66 | if self.n_averaged == 0: 67 | p_swa.detach().copy_(p_model_) 68 | else: 69 | p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, 70 | self.n_averaged.to(device))) 71 | self.n_averaged += 1 72 | 73 | if step is not None: 74 | if start_step is None: 75 | start_step = step 76 | if end_step is None: 77 | end_step = step 78 | 79 | if start_step is not None: 80 | if self.n_averaged == 1: 81 | self.start_step = start_step 82 | 83 | if end_step is not None: 84 | self.end_step = end_step 85 | 86 | def clone(self): 87 | clone = copy.deepcopy(self.module) 88 | clone.optimizer = clone.new_optimizer(clone.network.parameters()) 89 | return clone 90 | 91 | 92 | @torch.no_grad() 93 | def update_bn(iterator, model, n_steps, device='cuda'): 94 | momenta = {} 95 | for module in model.modules(): 96 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 97 | module.running_mean = torch.zeros_like(module.running_mean) 98 | module.running_var = torch.ones_like(module.running_var) 99 | momenta[module] = module.momentum 100 | 101 | if not momenta: 102 | return 103 | 104 | was_training = model.training 105 | model.train() 106 | for module in momenta.keys(): 107 | module.momentum = None 108 | module.num_batches_tracked *= 0 109 | 110 | # for input in loader: 111 | for i in range(n_steps): 112 | # batches_dictlist: [{env0_data_key: tensor, env0_...}, env1_..., ...] 113 | batches_dictlist = next(iterator) 114 | x = torch.cat([ 115 | dic["x"] for dic in batches_dictlist 116 | ]) 117 | x = x.to(device) 118 | 119 | model(x) 120 | 121 | for bn_module in momenta.keys(): 122 | bn_module.momentum = momenta[bn_module] 123 | model.train(was_training) 124 | 125 | 126 | class SWALR(_LRScheduler): 127 | r"""Anneals the learning rate in each parameter group to a fixed value. 128 | This learning rate scheduler is meant to be used with Stochastic Weight 129 | Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). 130 | Arguments: 131 | optimizer (torch.optim.Optimizer): wrapped optimizer 132 | swa_lrs (float or list): the learning rate value for all param groups 133 | together or separately for each group. 134 | annealing_epochs (int): number of epochs in the annealing phase 135 | (default: 10) 136 | annealing_strategy (str): "cos" or "linear"; specifies the annealing 137 | strategy: "cos" for cosine annealing, "linear" for linear annealing 138 | (default: "cos") 139 | last_epoch (int): the index of the last epoch (default: 'cos') 140 | The :class:`SWALR` scheduler is can be used together with other 141 | schedulers to switch to a constant learning rate late in the training 142 | as in the example below. 143 | Example: 144 | >>> loader, optimizer, model = ... 145 | >>> lr_lambda = lambda epoch: 0.9 146 | >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 147 | >>> lr_lambda=lr_lambda) 148 | >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, 149 | >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) 150 | >>> swa_start = 160 151 | >>> for i in range(300): 152 | >>> for input, target in loader: 153 | >>> optimizer.zero_grad() 154 | >>> loss_fn(model(input), target).backward() 155 | >>> optimizer.step() 156 | >>> if i > swa_start: 157 | >>> swa_scheduler.step() 158 | >>> else: 159 | >>> scheduler.step() 160 | .. _Averaging Weights Leads to Wider Optima and Better Generalization: 161 | https://arxiv.org/abs/1803.05407 162 | """ 163 | def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1): 164 | swa_lrs = self._format_param(optimizer, swa_lr) 165 | for swa_lr, group in zip(swa_lrs, optimizer.param_groups): 166 | group['swa_lr'] = swa_lr 167 | if anneal_strategy not in ['cos', 'linear']: 168 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', " 169 | "instead got {}".format(anneal_strategy)) 170 | elif anneal_strategy == 'cos': 171 | self.anneal_func = self._cosine_anneal 172 | elif anneal_strategy == 'linear': 173 | self.anneal_func = self._linear_anneal 174 | if not isinstance(anneal_epochs, int) or anneal_epochs < 1: 175 | raise ValueError("anneal_epochs must be a positive integer, got {}".format( 176 | anneal_epochs)) 177 | self.anneal_epochs = anneal_epochs 178 | 179 | super(SWALR, self).__init__(optimizer, last_epoch) 180 | 181 | @staticmethod 182 | def _format_param(optimizer, swa_lrs): 183 | if isinstance(swa_lrs, (list, tuple)): 184 | if len(swa_lrs) != len(optimizer.param_groups): 185 | raise ValueError("swa_lr must have the same length as " 186 | "optimizer.param_groups: swa_lr has {}, " 187 | "optimizer.param_groups has {}".format( 188 | len(swa_lrs), len(optimizer.param_groups))) 189 | return swa_lrs 190 | else: 191 | return [swa_lrs] * len(optimizer.param_groups) 192 | 193 | @staticmethod 194 | def _linear_anneal(t): 195 | return t 196 | 197 | @staticmethod 198 | def _cosine_anneal(t): 199 | return (1 - math.cos(math.pi * t)) / 2 200 | 201 | @staticmethod 202 | def _get_initial_lr(lr, swa_lr, alpha): 203 | if alpha == 1: 204 | return swa_lr 205 | return (lr - alpha * swa_lr) / (1 - alpha) 206 | 207 | def get_lr(self): 208 | if not self._get_lr_called_within_step: 209 | warnings.warn("To get the last learning rate computed by the scheduler, " 210 | "please use `get_last_lr()`.", UserWarning) 211 | step = self._step_count - 1 212 | prev_t = max(0, min(1, (step - 1) / self.anneal_epochs)) 213 | prev_alpha = self.anneal_func(prev_t) 214 | prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha) 215 | for group in self.optimizer.param_groups] 216 | t = max(0, min(1, step / self.anneal_epochs)) 217 | alpha = self.anneal_func(t) 218 | return [group['swa_lr'] * alpha + lr * (1 - alpha) 219 | for group, lr in zip(self.optimizer.param_groups, prev_lrs)] 220 | -------------------------------------------------------------------------------- /domainbed/lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | From https://github.com/meliketoy/wide-resnet.pytorch 5 | """ 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 19 | 20 | 21 | def conv_init(m): 22 | classname = m.__class__.__name__ 23 | if classname.find("Conv") != -1: 24 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 25 | init.constant_(m.bias, 0) 26 | elif classname.find("BatchNorm") != -1: 27 | init.constant_(m.weight, 1) 28 | init.constant_(m.bias, 0) 29 | 30 | 31 | class wide_basic(nn.Module): 32 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 33 | super(wide_basic, self).__init__() 34 | self.bn1 = nn.BatchNorm2d(in_planes) 35 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 36 | self.dropout = nn.Dropout(p=dropout_rate) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 44 | ) 45 | 46 | def forward(self, x): 47 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 48 | out = self.conv2(F.relu(self.bn2(out))) 49 | out += self.shortcut(x) 50 | 51 | return out 52 | 53 | 54 | class Wide_ResNet(nn.Module): 55 | """Wide Resnet with the softmax layer chopped off""" 56 | 57 | def __init__(self, input_shape, depth, widen_factor, dropout_rate): 58 | super(Wide_ResNet, self).__init__() 59 | self.in_planes = 16 60 | 61 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4" 62 | n = (depth - 4) / 6 63 | k = widen_factor 64 | 65 | # print('| Wide-Resnet %dx%d' % (depth, k)) 66 | nStages = [16, 16 * k, 32 * k, 64 * k] 67 | 68 | self.conv1 = conv3x3(input_shape[0], nStages[0]) 69 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 70 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 71 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 72 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 73 | 74 | self.n_outputs = nStages[3] 75 | 76 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 77 | strides = [stride] + [1] * (int(num_blocks) - 1) 78 | layers = [] 79 | 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 82 | self.in_planes = planes 83 | 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = F.relu(self.bn1(out)) 92 | out = F.avg_pool2d(out, 8) 93 | return out[:, :, 0, 0] 94 | -------------------------------------------------------------------------------- /domainbed/lib/writers.py: -------------------------------------------------------------------------------- 1 | class Writer: 2 | def add_scalars(self, tag_scalar_dic, global_step): 3 | raise NotImplementedError() 4 | 5 | def add_scalars_with_prefix(self, tag_scalar_dic, global_step, prefix): 6 | tag_scalar_dic = {prefix + k: v for k, v in tag_scalar_dic.items()} 7 | self.add_scalars(tag_scalar_dic, global_step) 8 | 9 | 10 | class TBWriter(Writer): 11 | def __init__(self, dir_path): 12 | from tensorboardX import SummaryWriter 13 | 14 | self.writer = SummaryWriter(dir_path, flush_secs=30) 15 | 16 | def add_scalars(self, tag_scalar_dic, global_step): 17 | for tag, scalar in tag_scalar_dic.items(): 18 | self.writer.add_scalar(tag, scalar, global_step) 19 | 20 | 21 | def get_writer(dir_path): 22 | """ 23 | Args: 24 | dir_path: tb dir 25 | """ 26 | writer = TBWriter(dir_path) 27 | 28 | return writer 29 | -------------------------------------------------------------------------------- /domainbed/models/mixstyle.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/mixstyle.py 3 | """ 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MixStyle(nn.Module): 10 | """MixStyle. 11 | Reference: 12 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 13 | """ 14 | 15 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6): 16 | """ 17 | Args: 18 | p (float): probability of using MixStyle. 19 | alpha (float): parameter of the Beta distribution. 20 | eps (float): scaling parameter to avoid numerical issues. 21 | """ 22 | super().__init__() 23 | self.p = p 24 | self.beta = torch.distributions.Beta(alpha, alpha) 25 | self.eps = eps 26 | self.alpha = alpha 27 | 28 | print("* MixStyle params") 29 | print(f"- p: {p}") 30 | print(f"- alpha: {alpha}") 31 | 32 | def __repr__(self): 33 | return f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})" 34 | 35 | def forward(self, x): 36 | if not self.training: 37 | return x 38 | 39 | if random.random() > self.p: 40 | return x 41 | 42 | B = x.size(0) 43 | 44 | mu = x.mean(dim=[2, 3], keepdim=True) 45 | var = x.var(dim=[2, 3], keepdim=True) 46 | sig = (var + self.eps).sqrt() 47 | mu, sig = mu.detach(), sig.detach() 48 | x_normed = (x - mu) / sig 49 | 50 | lmda = self.beta.sample((B, 1, 1, 1)) 51 | lmda = lmda.to(x.device) 52 | 53 | perm = torch.randperm(B) 54 | mu2, sig2 = mu[perm], sig[perm] 55 | mu_mix = mu * lmda + mu2 * (1 - lmda) 56 | sig_mix = sig * lmda + sig2 * (1 - lmda) 57 | 58 | return x_normed * sig_mix + mu_mix 59 | 60 | 61 | class MixStyle2(nn.Module): 62 | """MixStyle (w/ domain prior). 63 | The input should contain two equal-sized mini-batches from two distinct domains. 64 | Reference: 65 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 66 | """ 67 | 68 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6): 69 | """ 70 | Args: 71 | p (float): probability of using MixStyle. 72 | alpha (float): parameter of the Beta distribution. 73 | eps (float): scaling parameter to avoid numerical issues. 74 | """ 75 | super().__init__() 76 | self.p = p 77 | self.beta = torch.distributions.Beta(alpha, alpha) 78 | self.eps = eps 79 | self.alpha = alpha 80 | 81 | print("* MixStyle params") 82 | print(f"- p: {p}") 83 | print(f"- alpha: {alpha}") 84 | 85 | def __repr__(self): 86 | return f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})" 87 | 88 | def forward(self, x): 89 | """ 90 | For the input x, the first half comes from one domain, 91 | while the second half comes from the other domain. 92 | """ 93 | if not self.training: 94 | return x 95 | 96 | if random.random() > self.p: 97 | return x 98 | 99 | B = x.size(0) 100 | 101 | mu = x.mean(dim=[2, 3], keepdim=True) 102 | var = x.var(dim=[2, 3], keepdim=True) 103 | sig = (var + self.eps).sqrt() 104 | mu, sig = mu.detach(), sig.detach() 105 | x_normed = (x - mu) / sig 106 | 107 | lmda = self.beta.sample((B, 1, 1, 1)) 108 | lmda = lmda.to(x.device) 109 | 110 | perm = torch.arange(B - 1, -1, -1) # inverse index 111 | perm_b, perm_a = perm.chunk(2) 112 | perm_b = perm_b[torch.randperm(B // 2)] 113 | perm_a = perm_a[torch.randperm(B // 2)] 114 | perm = torch.cat([perm_b, perm_a], 0) 115 | 116 | mu2, sig2 = mu[perm], sig[perm] 117 | mu_mix = mu * lmda + mu2 * (1 - lmda) 118 | sig_mix = sig * lmda + sig2 * (1 - lmda) 119 | 120 | return x_normed * sig_mix + mu_mix 121 | -------------------------------------------------------------------------------- /domainbed/models/resnet_mixstyle.py: -------------------------------------------------------------------------------- 1 | """MixStyle w/ random shuffle 2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/resnet_mixstyle.py 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | from .mixstyle import MixStyle 9 | 10 | model_urls = { 11 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 12 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 13 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 14 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 15 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super().__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super().__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | def __init__( 96 | self, block, layers, mixstyle_layers=[], mixstyle_p=0.5, mixstyle_alpha=0.3, **kwargs 97 | ): 98 | self.inplanes = 64 99 | super().__init__() 100 | 101 | # backbone network 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 111 | 112 | self.mixstyle = None 113 | if mixstyle_layers: 114 | self.mixstyle = MixStyle(p=mixstyle_p, alpha=mixstyle_alpha) 115 | for layer_name in mixstyle_layers: 116 | assert layer_name in ["conv1", "conv2_x", "conv3_x", "conv4_x", "conv5_x"] 117 | print("Insert MixStyle after the following layers: {}".format(mixstyle_layers)) 118 | self.mixstyle_layers = mixstyle_layers 119 | 120 | self._out_features = 512 * block.expansion 121 | self.fc = nn.Identity() # for DomainBed compatibility 122 | 123 | self._init_params() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d( 130 | self.inplanes, 131 | planes * block.expansion, 132 | kernel_size=1, 133 | stride=stride, 134 | bias=False, 135 | ), 136 | nn.BatchNorm2d(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, downsample)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def _init_params(self): 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d): 150 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 151 | if m.bias is not None: 152 | nn.init.constant_(m.bias, 0) 153 | elif isinstance(m, nn.BatchNorm2d): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | elif isinstance(m, nn.BatchNorm1d): 157 | nn.init.constant_(m.weight, 1) 158 | nn.init.constant_(m.bias, 0) 159 | elif isinstance(m, nn.Linear): 160 | nn.init.normal_(m.weight, 0, 0.01) 161 | if m.bias is not None: 162 | nn.init.constant_(m.bias, 0) 163 | 164 | def compute_style(self, x): 165 | mu = x.mean(dim=[2, 3]) 166 | sig = x.std(dim=[2, 3]) 167 | return torch.cat([mu, sig], 1) 168 | 169 | def featuremaps(self, x): 170 | x = self.conv1(x) 171 | x = self.bn1(x) 172 | x = self.relu(x) 173 | x = self.maxpool(x) 174 | 175 | x = self.layer1(x) 176 | if "conv2_x" in self.mixstyle_layers: 177 | x = self.mixstyle(x) 178 | 179 | x = self.layer2(x) 180 | if "conv3_x" in self.mixstyle_layers: 181 | x = self.mixstyle(x) 182 | 183 | x = self.layer3(x) 184 | if "conv4_x" in self.mixstyle_layers: 185 | x = self.mixstyle(x) 186 | 187 | x = self.layer4(x) 188 | if "conv5_x" in self.mixstyle_layers: 189 | x = self.mixstyle(x) 190 | 191 | return x 192 | 193 | def forward(self, x): 194 | f = self.featuremaps(x) 195 | v = self.global_avgpool(f) 196 | return v.view(v.size(0), -1) 197 | 198 | 199 | def init_pretrained_weights(model, model_url): 200 | pretrain_dict = model_zoo.load_url(model_url) 201 | model.load_state_dict(pretrain_dict, strict=False) 202 | 203 | 204 | def resnet18_mixstyle_L234_p0d5_a0d1(pretrained=True, **kwargs): 205 | model = ResNet( 206 | block=BasicBlock, 207 | layers=[2, 2, 2, 2], 208 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"], 209 | mixstyle_p=0.5, 210 | mixstyle_alpha=0.1, 211 | ) 212 | 213 | if pretrained: 214 | init_pretrained_weights(model, model_urls["resnet18"]) 215 | 216 | return model 217 | 218 | 219 | def resnet50_mixstyle_L234_p0d5_a0d1(pretrained=True, **kwargs): 220 | model = ResNet( 221 | block=Bottleneck, 222 | layers=[3, 4, 6, 3], 223 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"], 224 | mixstyle_p=0.5, 225 | mixstyle_alpha=0.1, 226 | ) 227 | 228 | if pretrained: 229 | init_pretrained_weights(model, model_urls["resnet50"]) 230 | 231 | return model 232 | -------------------------------------------------------------------------------- /domainbed/models/resnet_mixstyle2.py: -------------------------------------------------------------------------------- 1 | """MixStyle w/ domain label 2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/resnet_mixstyle2.py 3 | """ 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | from .mixstyle import MixStyle2 as MixStyle 10 | 11 | model_urls = { 12 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 13 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 14 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 15 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 16 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super().__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super().__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | def __init__( 97 | self, block, layers, mixstyle_layers=[], mixstyle_p=0.5, mixstyle_alpha=0.3, **kwargs 98 | ): 99 | self.inplanes = 64 100 | super().__init__() 101 | 102 | # backbone network 103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 112 | 113 | self.mixstyle = None 114 | if mixstyle_layers: 115 | self.mixstyle = MixStyle(p=mixstyle_p, alpha=mixstyle_alpha) 116 | for layer_name in mixstyle_layers: 117 | assert layer_name in ["conv1", "conv2_x", "conv3_x", "conv4_x", "conv5_x"] 118 | print("Insert MixStyle after the following layers: {}".format(mixstyle_layers)) 119 | self.mixstyle_layers = mixstyle_layers 120 | 121 | self._out_features = 512 * block.expansion 122 | self.fc = nn.Identity() # for DomainBed compatibility 123 | 124 | self._init_params() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d( 131 | self.inplanes, 132 | planes * block.expansion, 133 | kernel_size=1, 134 | stride=stride, 135 | bias=False, 136 | ), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def _init_params(self): 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 152 | if m.bias is not None: 153 | nn.init.constant_(m.bias, 0) 154 | elif isinstance(m, nn.BatchNorm2d): 155 | nn.init.constant_(m.weight, 1) 156 | nn.init.constant_(m.bias, 0) 157 | elif isinstance(m, nn.BatchNorm1d): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | elif isinstance(m, nn.Linear): 161 | nn.init.normal_(m.weight, 0, 0.01) 162 | if m.bias is not None: 163 | nn.init.constant_(m.bias, 0) 164 | 165 | def compute_style(self, x): 166 | mu = x.mean(dim=[2, 3]) 167 | sig = x.std(dim=[2, 3]) 168 | return torch.cat([mu, sig], 1) 169 | 170 | def featuremaps(self, x): 171 | x = self.conv1(x) 172 | x = self.bn1(x) 173 | x = self.relu(x) 174 | x = self.maxpool(x) 175 | 176 | x = self.layer1(x) 177 | if "conv2_x" in self.mixstyle_layers: 178 | x = self.mixstyle(x) 179 | 180 | x = self.layer2(x) 181 | if "conv3_x" in self.mixstyle_layers: 182 | x = self.mixstyle(x) 183 | 184 | x = self.layer3(x) 185 | if "conv4_x" in self.mixstyle_layers: 186 | x = self.mixstyle(x) 187 | 188 | x = self.layer4(x) 189 | if "conv5_x" in self.mixstyle_layers: 190 | x = self.mixstyle(x) 191 | 192 | return x 193 | 194 | def forward(self, x): 195 | f = self.featuremaps(x) 196 | v = self.global_avgpool(f) 197 | return v.view(v.size(0), -1) 198 | 199 | 200 | def init_pretrained_weights(model, model_url): 201 | pretrain_dict = model_zoo.load_url(model_url) 202 | model.load_state_dict(pretrain_dict, strict=False) 203 | 204 | 205 | """ 206 | Residual network configurations: 207 | -- 208 | resnet18: block=BasicBlock, layers=[2, 2, 2, 2] 209 | resnet34: block=BasicBlock, layers=[3, 4, 6, 3] 210 | resnet50: block=Bottleneck, layers=[3, 4, 6, 3] 211 | resnet101: block=Bottleneck, layers=[3, 4, 23, 3] 212 | resnet152: block=Bottleneck, layers=[3, 8, 36, 3] 213 | """ 214 | 215 | 216 | def resnet18_mixstyle2_L234_p0d5_a0d1(pretrained=True, **kwargs): 217 | model = ResNet( 218 | block=BasicBlock, 219 | layers=[2, 2, 2, 2], 220 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"], 221 | mixstyle_p=0.5, 222 | mixstyle_alpha=0.1, 223 | ) 224 | 225 | if pretrained: 226 | init_pretrained_weights(model, model_urls["resnet18"]) 227 | 228 | return model 229 | 230 | 231 | def resnet50_mixstyle2_L234_p0d5_a0d1(pretrained=True, **kwargs): 232 | model = ResNet( 233 | block=Bottleneck, 234 | layers=[3, 4, 6, 3], 235 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"], 236 | mixstyle_p=0.5, 237 | mixstyle_alpha=0.1, 238 | ) 239 | 240 | if pretrained: 241 | init_pretrained_weights(model, model_urls["resnet50"]) 242 | 243 | return model 244 | -------------------------------------------------------------------------------- /domainbed/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import * 2 | from .ur_networks import URFeaturizer, URResNet 3 | -------------------------------------------------------------------------------- /domainbed/networks/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kakao Brain. All Rights Reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models 6 | import clip 7 | 8 | 9 | def clip_imageencoder(name): 10 | model, _preprocess = clip.load(name, device="cpu") 11 | imageencoder = model.visual 12 | 13 | return imageencoder 14 | 15 | 16 | class Identity(nn.Module): 17 | """An identity layer""" 18 | 19 | def __init__(self): 20 | super(Identity, self).__init__() 21 | 22 | def forward(self, x): 23 | return x 24 | 25 | 26 | def torchhub_load(repo, model, **kwargs): 27 | try: 28 | # torch >= 1.10 29 | network = torch.hub.load(repo, model=model, skip_validation=True, **kwargs) 30 | except TypeError: 31 | # torch 1.7.1 32 | network = torch.hub.load(repo, model=model, **kwargs) 33 | 34 | return network 35 | 36 | 37 | def get_backbone(name, preserve_readout, pretrained): 38 | if not pretrained: 39 | assert name in ["resnet50", "swag_regnety_16gf"], "Only RN50/RegNet supports non-pretrained network" 40 | 41 | if name == "resnet18": 42 | network = torchvision.models.resnet18(pretrained=True) 43 | n_outputs = 512 44 | elif name == "resnet50": 45 | network = torchvision.models.resnet50(pretrained=pretrained) 46 | n_outputs = 2048 47 | elif name == "resnet50_barlowtwins": 48 | network = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50') 49 | n_outputs = 2048 50 | elif name == "resnet50_moco": 51 | network = torchvision.models.resnet50() 52 | 53 | # download pretrained model of MoCo v3: https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/r-50-1000ep.pth.tar 54 | ckpt_path = "./r-50-1000ep.pth.tar" 55 | 56 | # https://github.com/facebookresearch/moco-v3/blob/main/main_lincls.py#L172 57 | print("=> loading checkpoint '{}'".format(ckpt_path)) 58 | checkpoint = torch.load(ckpt_path, map_location="cpu") 59 | 60 | # rename moco pre-trained keys 61 | state_dict = checkpoint['state_dict'] 62 | linear_keyword = "fc" # resnet linear keyword 63 | for k in list(state_dict.keys()): 64 | # retain only base_encoder up to before the embedding layer 65 | if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword): 66 | # remove prefix 67 | state_dict[k[len("module.base_encoder."):]] = state_dict[k] 68 | # delete renamed or unused k 69 | del state_dict[k] 70 | 71 | msg = network.load_state_dict(state_dict, strict=False) 72 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 73 | 74 | print("=> loaded pre-trained model '{}'".format(ckpt_path)) 75 | 76 | n_outputs = 2048 77 | elif name.startswith("clip_resnet"): 78 | name = "RN" + name[11:] 79 | network = clip_imageencoder(name) 80 | n_outputs = network.output_dim 81 | elif name == "clip_vit-b16": 82 | network = clip_imageencoder("ViT-B/16") 83 | n_outputs = network.output_dim 84 | elif name == "swag_regnety_16gf": 85 | # No readout layer as default 86 | network = torchhub_load("facebookresearch/swag", model="regnety_16gf", pretrained=pretrained) 87 | 88 | network.head = nn.Sequential( 89 | nn.AdaptiveAvgPool2d(1), 90 | nn.Flatten(1), 91 | ) 92 | n_outputs = 3024 93 | else: 94 | raise ValueError(name) 95 | 96 | if not preserve_readout: 97 | # remove readout layer (but left GAP and flatten) 98 | # final output shape: [B, n_outputs] 99 | if name.startswith("resnet"): 100 | del network.fc 101 | network.fc = Identity() 102 | 103 | return network, n_outputs 104 | -------------------------------------------------------------------------------- /domainbed/networks/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from domainbed.lib import wide_resnet 8 | from domainbed.networks.backbones import get_backbone 9 | 10 | 11 | class SqueezeLastTwo(nn.Module): 12 | """ 13 | A module which squeezes the last two dimensions, 14 | ordinary squeeze can be a problem for batch size 1 15 | """ 16 | 17 | def __init__(self): 18 | super(SqueezeLastTwo, self).__init__() 19 | 20 | def forward(self, x): 21 | return x.view(x.shape[0], x.shape[1]) 22 | 23 | 24 | class MLP(nn.Module): 25 | """Just an MLP""" 26 | 27 | def __init__(self, n_inputs, n_outputs, hparams): 28 | super(MLP, self).__init__() 29 | self.input = nn.Linear(n_inputs, hparams["mlp_width"]) 30 | self.dropout = nn.Dropout(hparams["mlp_dropout"]) 31 | self.hiddens = nn.ModuleList( 32 | [ 33 | nn.Linear(hparams["mlp_width"], hparams["mlp_width"]) 34 | for _ in range(hparams["mlp_depth"] - 2) 35 | ] 36 | ) 37 | self.output = nn.Linear(hparams["mlp_width"], n_outputs) 38 | self.n_outputs = n_outputs 39 | 40 | def forward(self, x): 41 | x = self.input(x) 42 | x = self.dropout(x) 43 | x = F.relu(x) 44 | for hidden in self.hiddens: 45 | x = hidden(x) 46 | x = self.dropout(x) 47 | x = F.relu(x) 48 | x = self.output(x) 49 | return x 50 | 51 | 52 | class ResNet(torch.nn.Module): 53 | """ResNet with the softmax chopped off and the batchnorm frozen""" 54 | 55 | def __init__(self, input_shape, hparams): 56 | super(ResNet, self).__init__() 57 | self.network, self.n_outputs = get_backbone( 58 | hparams.model, 59 | preserve_readout=False, 60 | pretrained=hparams.pretrained 61 | ) 62 | 63 | # adapt number of channels 64 | nc = input_shape[0] 65 | if nc != 3: 66 | tmp = self.network.conv1.weight.data.clone() 67 | 68 | self.network.conv1 = nn.Conv2d( 69 | nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 70 | ) 71 | 72 | for i in range(nc): 73 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :] 74 | 75 | self.hparams = hparams 76 | self.dropout = nn.Dropout(hparams["resnet_dropout"]) 77 | self.freeze_bn() 78 | 79 | def forward(self, x): 80 | """Encode x into a feature vector of size n_outputs.""" 81 | return self.dropout(self.network(x)) 82 | 83 | def train(self, mode=True): 84 | """ 85 | Override the default train() to freeze the BN parameters 86 | """ 87 | super().train(mode) 88 | self.freeze_bn() 89 | 90 | def freeze_bn(self): 91 | for m in self.network.modules(): 92 | if isinstance(m, nn.BatchNorm2d): 93 | m.eval() 94 | 95 | 96 | class MNIST_CNN(nn.Module): 97 | """ 98 | Hand-tuned architecture for MNIST. 99 | Weirdness I've noticed so far with this architecture: 100 | - adding a linear layer after the mean-pool in features hurts 101 | RotatedMNIST-100 generalization severely. 102 | """ 103 | 104 | n_outputs = 128 105 | 106 | def __init__(self, input_shape): 107 | super(MNIST_CNN, self).__init__() 108 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1) 109 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 110 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1) 111 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1) 112 | 113 | self.bn0 = nn.GroupNorm(8, 64) 114 | self.bn1 = nn.GroupNorm(8, 128) 115 | self.bn2 = nn.GroupNorm(8, 128) 116 | self.bn3 = nn.GroupNorm(8, 128) 117 | 118 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 119 | self.squeezeLastTwo = SqueezeLastTwo() 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = F.relu(x) 124 | x = self.bn0(x) 125 | 126 | x = self.conv2(x) 127 | x = F.relu(x) 128 | x = self.bn1(x) 129 | 130 | x = self.conv3(x) 131 | x = F.relu(x) 132 | x = self.bn2(x) 133 | 134 | x = self.conv4(x) 135 | x = F.relu(x) 136 | x = self.bn3(x) 137 | 138 | x = self.avgpool(x) 139 | x = self.squeezeLastTwo(x) 140 | return x 141 | 142 | 143 | class ContextNet(nn.Module): 144 | def __init__(self, input_shape): 145 | super(ContextNet, self).__init__() 146 | 147 | # Keep same dimensions 148 | padding = (5 - 1) // 2 149 | self.context_net = nn.Sequential( 150 | nn.Conv2d(input_shape[0], 64, 5, padding=padding), 151 | nn.BatchNorm2d(64), 152 | nn.ReLU(), 153 | nn.Conv2d(64, 64, 5, padding=padding), 154 | nn.BatchNorm2d(64), 155 | nn.ReLU(), 156 | nn.Conv2d(64, 1, 5, padding=padding), 157 | ) 158 | 159 | def forward(self, x): 160 | return self.context_net(x) 161 | 162 | 163 | def Featurizer(input_shape, hparams): 164 | """Auto-select an appropriate featurizer for the given input shape.""" 165 | if len(input_shape) == 1: 166 | return MLP(input_shape[0], 128, hparams) 167 | elif input_shape[1:3] == (28, 28): 168 | return MNIST_CNN(input_shape) 169 | elif input_shape[1:3] == (32, 32): 170 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.0) 171 | elif input_shape[1:3] == (224, 224): 172 | return ResNet(input_shape, hparams) 173 | else: 174 | raise NotImplementedError(f"Input shape {input_shape} is not supported") 175 | -------------------------------------------------------------------------------- /domainbed/networks/ur_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kakao Brain. All Rights Reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | from .backbones import get_backbone 6 | 7 | 8 | BLOCKNAMES = { 9 | "resnet": { 10 | "stem": ["conv1", "bn1", "relu", "maxpool"], 11 | "block1": ["layer1"], 12 | "block2": ["layer2"], 13 | "block3": ["layer3"], 14 | "block4": ["layer4"], 15 | }, 16 | "clipresnet": { 17 | "stem": ["conv1", "bn1", "conv2", "bn2", "conv3", "bn3", "relu", "avgpool"], 18 | "block1": ["layer1"], 19 | "block2": ["layer2"], 20 | "block3": ["layer3"], 21 | "block4": ["layer4"], 22 | }, 23 | "clipvit": { # vit-base 24 | "stem": ["conv1"], 25 | "block1": ["transformer.resblocks.0", "transformer.resblocks.1", "transformer.resblocks.2"], 26 | "block2": ["transformer.resblocks.3", "transformer.resblocks.4", "transformer.resblocks.5"], 27 | "block3": ["transformer.resblocks.6", "transformer.resblocks.7", "transformer.resblocks.8"], 28 | "block4": ["transformer.resblocks.9", "transformer.resblocks.10", "transformer.resblocks.11"], 29 | }, 30 | "regnety": { 31 | "stem": ["stem"], 32 | "block1": ["trunk_output.block1"], 33 | "block2": ["trunk_output.block2"], 34 | "block3": ["trunk_output.block3"], 35 | "block4": ["trunk_output.block4"] 36 | }, 37 | } 38 | 39 | 40 | def get_module(module, name): 41 | for n, m in module.named_modules(): 42 | if n == name: 43 | return m 44 | 45 | 46 | def build_blocks(model, block_name_dict): 47 | # blocks = nn.ModuleList() 48 | blocks = [] # saved model can be broken... 49 | for _key, name_list in block_name_dict.items(): 50 | block = nn.ModuleList() 51 | for module_name in name_list: 52 | module = get_module(model, module_name) 53 | block.append(module) 54 | blocks.append(block) 55 | 56 | return blocks 57 | 58 | 59 | def freeze_(model): 60 | """Freeze model 61 | Note that this function does not control BN 62 | """ 63 | for p in model.parameters(): 64 | p.requires_grad_(False) 65 | 66 | 67 | class URResNet(torch.nn.Module): 68 | """ResNet + FrozenBN + IntermediateFeatures 69 | """ 70 | 71 | def __init__(self, input_shape, hparams, preserve_readout=False, freeze=None, feat_layers=None): 72 | assert input_shape == (3, 224, 224), input_shape 73 | super().__init__() 74 | 75 | self.network, self.n_outputs = get_backbone(hparams.model, preserve_readout, hparams.pretrained) 76 | 77 | if hparams.model == "resnet18": 78 | block_names = BLOCKNAMES["resnet"] 79 | elif hparams.model.startswith("resnet50"): 80 | block_names = BLOCKNAMES["resnet"] 81 | elif hparams.model.startswith("clip_resnet"): 82 | block_names = BLOCKNAMES["clipresnet"] 83 | elif hparams.model.startswith("clip_vit"): 84 | block_names = BLOCKNAMES["clipvit"] 85 | elif hparams.model == "swag_regnety_16gf": 86 | block_names = BLOCKNAMES["regnety"] 87 | elif hparams.model.startswith("vit"): 88 | block_names = BLOCKNAMES["vit"] 89 | else: 90 | raise ValueError(hparams.model) 91 | 92 | self._features = [] 93 | self.feat_layers = self.build_feature_hooks(feat_layers, block_names) 94 | self.blocks = build_blocks(self.network, block_names) 95 | 96 | self.freeze(freeze) 97 | 98 | if not preserve_readout: 99 | self.dropout = nn.Dropout(hparams["resnet_dropout"]) 100 | else: 101 | self.dropout = nn.Identity() 102 | assert hparams["resnet_dropout"] == 0.0 103 | 104 | self.hparams = hparams 105 | self.freeze_bn() 106 | 107 | def freeze(self, freeze): 108 | if freeze is not None: 109 | if freeze == "all": 110 | freeze_(self.network) 111 | else: 112 | for block in self.blocks[:freeze+1]: 113 | freeze_(block) 114 | 115 | def hook(self, module, input, output): 116 | self._features.append(output) 117 | 118 | def build_feature_hooks(self, feats, block_names): 119 | assert feats in ["stem_block", "block"] 120 | 121 | if feats is None: 122 | return [] 123 | 124 | # build feat layers 125 | if feats.startswith("stem"): 126 | last_stem_name = block_names["stem"][-1] 127 | feat_layers = [last_stem_name] 128 | else: 129 | feat_layers = [] 130 | 131 | for name, module_names in block_names.items(): 132 | if name == "stem": 133 | continue 134 | 135 | module_name = module_names[-1] 136 | feat_layers.append(module_name) 137 | 138 | # print(f"feat layers = {feat_layers}") 139 | 140 | for n, m in self.network.named_modules(): 141 | if n in feat_layers: 142 | m.register_forward_hook(self.hook) 143 | 144 | return feat_layers 145 | 146 | def forward(self, x, ret_feats=False): 147 | """Encode x into a feature vector of size n_outputs.""" 148 | self.clear_features() 149 | out = self.dropout(self.network(x)) 150 | if ret_feats: 151 | return out, self._features 152 | else: 153 | return out 154 | 155 | def clear_features(self): 156 | self._features.clear() 157 | 158 | def train(self, mode=True): 159 | """ 160 | Override the default train() to freeze the BN parameters 161 | """ 162 | super().train(mode) 163 | self.freeze_bn() 164 | 165 | def freeze_bn(self): 166 | for m in self.network.modules(): 167 | if isinstance(m, nn.BatchNorm2d): 168 | m.eval() 169 | 170 | 171 | def URFeaturizer(input_shape, hparams, **kwargs): 172 | """Auto-select an appropriate featurizer for the given input shape.""" 173 | if input_shape[1:3] == (224, 224): 174 | return URResNet(input_shape, hparams, **kwargs) 175 | else: 176 | raise NotImplementedError(f"Input shape {input_shape} is not supported") 177 | -------------------------------------------------------------------------------- /domainbed/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_optimizer(name, params, **kwargs): 5 | name = name.lower() 6 | optimizers = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD, "adamw": torch.optim.AdamW} 7 | optim_cls = optimizers[name] 8 | 9 | return optim_cls(params, **kwargs) 10 | -------------------------------------------------------------------------------- /domainbed/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torchvision.datasets import MNIST 4 | import xml.etree.ElementTree as ET 5 | from zipfile import ZipFile 6 | import argparse 7 | import tarfile 8 | import shutil 9 | import gdown 10 | import uuid 11 | import json 12 | import os 13 | 14 | 15 | # utils ####################################################################### 16 | 17 | 18 | def stage_path(data_dir, name): 19 | full_path = os.path.join(data_dir, name) 20 | 21 | if not os.path.exists(full_path): 22 | os.makedirs(full_path) 23 | 24 | return full_path 25 | 26 | 27 | def download_and_extract(url, dst, remove=True): 28 | gdown.download(url, dst, quiet=False) 29 | 30 | if dst.endswith(".tar.gz"): 31 | tar = tarfile.open(dst, "r:gz") 32 | tar.extractall(os.path.dirname(dst)) 33 | tar.close() 34 | 35 | if dst.endswith(".tar"): 36 | tar = tarfile.open(dst, "r:") 37 | tar.extractall(os.path.dirname(dst)) 38 | tar.close() 39 | 40 | if dst.endswith(".zip"): 41 | zf = ZipFile(dst, "r") 42 | zf.extractall(os.path.dirname(dst)) 43 | zf.close() 44 | 45 | if remove: 46 | os.remove(dst) 47 | 48 | 49 | # VLCS ######################################################################## 50 | 51 | # Slower, but builds dataset from the original sources 52 | # 53 | # def download_vlcs(data_dir): 54 | # full_path = stage_path(data_dir, "VLCS") 55 | # 56 | # tmp_path = os.path.join(full_path, "tmp/") 57 | # if not os.path.exists(tmp_path): 58 | # os.makedirs(tmp_path) 59 | # 60 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 61 | # lines = f.readlines() 62 | # files = [line.strip().split() for line in lines] 63 | # 64 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 65 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 66 | # 67 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 68 | # os.path.join(tmp_path, "caltech101.tar.gz")) 69 | # 70 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 71 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 72 | # 73 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 74 | # tar.extractall(tmp_path) 75 | # tar.close() 76 | # 77 | # for src, dst in files: 78 | # class_folder = os.path.join(data_dir, dst) 79 | # 80 | # if not os.path.exists(class_folder): 81 | # os.makedirs(class_folder) 82 | # 83 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 84 | # 85 | # if "labelme" in src: 86 | # # download labelme from the web 87 | # gdown.download(src, dst, quiet=False) 88 | # else: 89 | # src = os.path.join(tmp_path, src) 90 | # shutil.copyfile(src, dst) 91 | # 92 | # shutil.rmtree(tmp_path) 93 | 94 | 95 | def download_vlcs(data_dir): 96 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 97 | full_path = stage_path(data_dir, "VLCS") 98 | 99 | download_and_extract( 100 | "https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 101 | os.path.join(data_dir, "VLCS.tar.gz"), 102 | ) 103 | 104 | 105 | # MNIST ####################################################################### 106 | 107 | 108 | def download_mnist(data_dir): 109 | # Original URL: http://yann.lecun.com/exdb/mnist/ 110 | full_path = stage_path(data_dir, "MNIST") 111 | MNIST(full_path, download=True) 112 | 113 | 114 | # PACS ######################################################################## 115 | 116 | 117 | def download_pacs(data_dir): 118 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 119 | full_path = stage_path(data_dir, "PACS") 120 | 121 | download_and_extract( 122 | "https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", 123 | os.path.join(data_dir, "PACS.zip"), 124 | ) 125 | 126 | os.rename(os.path.join(data_dir, "kfold"), full_path) 127 | 128 | 129 | # Office-Home ################################################################# 130 | 131 | 132 | def download_office_home(data_dir): 133 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 134 | full_path = stage_path(data_dir, "office_home") 135 | 136 | download_and_extract( 137 | "https://drive.google.com/uc?id=1uY0pj7oFsjMxRwaD3Sxy0jgel0fsYXLC", 138 | os.path.join(data_dir, "office_home.zip"), 139 | ) 140 | 141 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), full_path) 142 | 143 | 144 | # DomainNET ################################################################### 145 | 146 | 147 | def download_domain_net(data_dir): 148 | # Original URL: http://ai.bu.edu/M3SDA/ 149 | full_path = stage_path(data_dir, "domain_net") 150 | 151 | urls = [ 152 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 153 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 154 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 155 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 156 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 157 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip", 158 | ] 159 | 160 | for url in urls: 161 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 162 | 163 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 164 | for line in f.readlines(): 165 | try: 166 | os.remove(os.path.join(full_path, line.strip())) 167 | except OSError: 168 | pass 169 | 170 | 171 | # TerraIncognita ############################################################## 172 | 173 | 174 | def download_terra_incognita(data_dir): 175 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 176 | # New URL: http://lila.science/datasets/caltech-camera-traps 177 | 178 | full_path = stage_path(data_dir, "terra_incognita") 179 | 180 | download_and_extract( 181 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz", 182 | os.path.join(full_path, "terra_incognita_images.tar.gz"), 183 | ) 184 | 185 | download_and_extract( 186 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip", 187 | os.path.join(full_path, "caltech_camera_traps.json.zip"), 188 | ) 189 | 190 | include_locations = ["38", "46", "100", "43"] 191 | 192 | include_categories = [ 193 | "bird", 194 | "bobcat", 195 | "cat", 196 | "coyote", 197 | "dog", 198 | "empty", 199 | "opossum", 200 | "rabbit", 201 | "raccoon", 202 | "squirrel", 203 | ] 204 | 205 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 206 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json") 207 | destination_folder = full_path 208 | 209 | stats = {} 210 | 211 | if not os.path.exists(destination_folder): 212 | os.mkdir(destination_folder) 213 | 214 | with open(annotations_file, "r") as f: 215 | data = json.load(f) 216 | 217 | category_dict = {} 218 | for item in data["categories"]: 219 | category_dict[item["id"]] = item["name"] 220 | 221 | for image in data["images"]: 222 | image_location = image["location"] 223 | 224 | if image_location not in include_locations: 225 | continue 226 | 227 | loc_folder = os.path.join(destination_folder, "location_" + str(image_location) + "/") 228 | 229 | if not os.path.exists(loc_folder): 230 | os.mkdir(loc_folder) 231 | 232 | image_id = image["id"] 233 | image_fname = image["file_name"] 234 | 235 | for annotation in data["annotations"]: 236 | if annotation["image_id"] == image_id: 237 | if image_location not in stats: 238 | stats[image_location] = {} 239 | 240 | category = category_dict[annotation["category_id"]] 241 | 242 | if category not in include_categories: 243 | continue 244 | 245 | if category not in stats[image_location]: 246 | stats[image_location][category] = 0 247 | else: 248 | stats[image_location][category] += 1 249 | 250 | loc_cat_folder = os.path.join(loc_folder, category + "/") 251 | 252 | if not os.path.exists(loc_cat_folder): 253 | os.mkdir(loc_cat_folder) 254 | 255 | dst_path = os.path.join(loc_cat_folder, image_fname) 256 | src_path = os.path.join(images_folder, image_fname) 257 | 258 | shutil.copyfile(src_path, dst_path) 259 | 260 | shutil.rmtree(images_folder) 261 | os.remove(annotations_file) 262 | 263 | 264 | if __name__ == "__main__": 265 | parser = argparse.ArgumentParser(description="Download datasets") 266 | parser.add_argument("--data_dir", type=str, required=True) 267 | args = parser.parse_args() 268 | 269 | download_mnist(args.data_dir) 270 | download_pacs(args.data_dir) 271 | download_vlcs(args.data_dir) 272 | download_domain_net(args.data_dir) 273 | download_office_home(args.data_dir) 274 | download_terra_incognita(args.data_dir) 275 | -------------------------------------------------------------------------------- /domainbed/swad.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import deque 3 | import numpy as np 4 | from domainbed.lib import swa_utils 5 | 6 | 7 | class SWADBase: 8 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn): 9 | raise NotImplementedError() 10 | 11 | def get_final_model(self): 12 | raise NotImplementedError() 13 | 14 | 15 | class IIDMax(SWADBase): 16 | """SWAD start from iid max acc and select last by iid max swa acc""" 17 | 18 | def __init__(self, evaluator, **kwargs): 19 | self.iid_max_acc = 0.0 20 | self.swa_max_acc = 0.0 21 | self.avgmodel = None 22 | self.final_model = None 23 | self.evaluator = evaluator 24 | 25 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn): 26 | if self.iid_max_acc < val_acc: 27 | self.iid_max_acc = val_acc 28 | self.avgmodel = swa_utils.AveragedModel(segment_swa.module, rm_optimizer=True) 29 | self.avgmodel.start_step = segment_swa.start_step 30 | 31 | self.avgmodel.update_parameters(segment_swa.module) 32 | self.avgmodel.end_step = segment_swa.end_step 33 | 34 | # evaluate 35 | accuracies, summaries = self.evaluator.evaluate(self.avgmodel) 36 | results = {**summaries, **accuracies} 37 | prt_fn(results, self.avgmodel) 38 | 39 | swa_val_acc = results["train_out"] 40 | if swa_val_acc > self.swa_max_acc: 41 | self.swa_max_acc = swa_val_acc 42 | self.final_model = copy.deepcopy(self.avgmodel) 43 | 44 | def get_final_model(self): 45 | return self.final_model 46 | 47 | 48 | class LossValley(SWADBase): 49 | """IIDMax has a potential problem that bias to validation dataset. 50 | LossValley choose SWAD range by detecting loss valley. 51 | """ 52 | 53 | def __init__(self, evaluator, n_converge, n_tolerance, tolerance_ratio, **kwargs): 54 | """ 55 | Args: 56 | evaluator 57 | n_converge: converge detector window size. 58 | n_tolerance: loss min smoothing window size 59 | tolerance_ratio: decision ratio for dead loss valley 60 | """ 61 | self.evaluator = evaluator 62 | self.n_converge = n_converge 63 | self.n_tolerance = n_tolerance 64 | self.tolerance_ratio = tolerance_ratio 65 | 66 | self.converge_Q = deque(maxlen=n_converge) 67 | self.smooth_Q = deque(maxlen=n_tolerance) 68 | 69 | self.final_model = None 70 | 71 | self.converge_step = None 72 | self.dead_valley = False 73 | self.threshold = None 74 | 75 | def get_smooth_loss(self, idx): 76 | smooth_loss = min([model.end_loss for model in list(self.smooth_Q)[idx:]]) 77 | return smooth_loss 78 | 79 | @property 80 | def is_converged(self): 81 | return self.converge_step is not None 82 | 83 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn): 84 | if self.dead_valley: 85 | return 86 | 87 | frozen = copy.deepcopy(segment_swa.cpu()) 88 | frozen.end_loss = val_loss 89 | self.converge_Q.append(frozen) 90 | self.smooth_Q.append(frozen) 91 | 92 | if not self.is_converged: 93 | if len(self.converge_Q) < self.n_converge: 94 | return 95 | 96 | min_idx = np.argmin([model.end_loss for model in self.converge_Q]) 97 | untilmin_segment_swa = self.converge_Q[min_idx] # until-min segment swa. 98 | if min_idx == 0: 99 | self.converge_step = self.converge_Q[0].end_step 100 | self.final_model = swa_utils.AveragedModel(untilmin_segment_swa) 101 | 102 | th_base = np.mean([model.end_loss for model in self.converge_Q]) 103 | self.threshold = th_base * (1.0 + self.tolerance_ratio) 104 | 105 | if self.n_tolerance < self.n_converge: 106 | for i in range(self.n_converge - self.n_tolerance): 107 | model = self.converge_Q[1 + i] 108 | self.final_model.update_parameters( 109 | model, start_step=model.start_step, end_step=model.end_step 110 | ) 111 | elif self.n_tolerance > self.n_converge: 112 | converge_idx = self.n_tolerance - self.n_converge 113 | Q = list(self.smooth_Q)[: converge_idx + 1] 114 | start_idx = 0 115 | for i in reversed(range(len(Q))): 116 | model = Q[i] 117 | if model.end_loss > self.threshold: 118 | start_idx = i + 1 119 | break 120 | for model in Q[start_idx + 1 :]: 121 | self.final_model.update_parameters( 122 | model, start_step=model.start_step, end_step=model.end_step 123 | ) 124 | print( 125 | f"Model converged at step {self.converge_step}, " 126 | f"Start step = {self.final_model.start_step}; " 127 | f"Threshold = {self.threshold:.6f}, " 128 | ) 129 | return 130 | 131 | if self.smooth_Q[0].end_step < self.converge_step: 132 | return 133 | 134 | # converged -> loss valley 135 | min_vloss = self.get_smooth_loss(0) 136 | if min_vloss > self.threshold: 137 | self.dead_valley = True 138 | print(f"Valley is dead at step {self.final_model.end_step}") 139 | return 140 | 141 | model = self.smooth_Q[0] 142 | self.final_model.update_parameters( 143 | model, start_step=model.start_step, end_step=model.end_step 144 | ) 145 | 146 | def get_final_model(self): 147 | if not self.is_converged: 148 | self.evaluator.logger.error( 149 | "Requested final model, but model is not yet converged; return last model instead" 150 | ) 151 | return self.converge_Q[-1].cuda() 152 | 153 | if not self.dead_valley: 154 | self.smooth_Q.popleft() 155 | while self.smooth_Q: 156 | smooth_loss = self.get_smooth_loss(0) 157 | if smooth_loss > self.threshold: 158 | break 159 | segment_swa = self.smooth_Q.popleft() 160 | self.final_model.update_parameters(segment_swa, step=segment_swa.end_step) 161 | 162 | return self.final_model.cuda() 163 | -------------------------------------------------------------------------------- /domainbed/trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import time 4 | import copy 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | 11 | from domainbed.datasets import get_dataset, split_dataset 12 | from domainbed import algorithms 13 | from domainbed.evaluator import Evaluator 14 | from domainbed.lib import misc 15 | from domainbed.lib import swa_utils 16 | from domainbed.lib.query import Q 17 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 18 | from domainbed import swad as swad_module 19 | 20 | # if torch.cuda.is_available(): 21 | # device = "cuda" 22 | # else: 23 | # device = "cpu" 24 | 25 | 26 | def json_handler(v): 27 | if isinstance(v, (Path, range)): 28 | return str(v) 29 | raise TypeError(f"`{type(v)}` is not JSON Serializable") 30 | 31 | def interpolate_algos(sd1, sd2, sd3, sd4): 32 | return {key: (sd1[key] + sd2[key] + sd3[key] +sd4[key])/4 for key in sd1.keys()} 33 | 34 | def train(test_envs, args, hparams, n_steps, checkpoint_freq, logger, writer, target_env=None): 35 | logger.info("") 36 | # n_steps = 1 37 | ####################################################### 38 | # setup dataset & loader 39 | ####################################################### 40 | args.real_test_envs = test_envs # for log 41 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 42 | dataset, in_splits, out_splits = get_dataset(test_envs, args, hparams, algorithm_class) 43 | test_splits = [] 44 | # if hparams.indomain_test > 0.0: 45 | # logger.info("!!! In-domain test mode On !!!") 46 | # assert hparams["val_augment"] is False, ( 47 | # "indomain_test split the val set into val/test sets. " 48 | # "Therefore, the val set should be not augmented." 49 | # ) 50 | # val_splits = [] 51 | # for env_i, (out_split, _weights) in enumerate(out_splits): 52 | # n = len(out_split) // 2 53 | # seed = misc.seed_hash(args.trial_seed, env_i) 54 | # val_split, test_split = split_dataset(out_split, n, seed=seed) 55 | # val_splits.append((val_split, None)) 56 | # test_splits.append((test_split, None)) 57 | # logger.info( 58 | # "env %d: out (#%d) -> val (#%d) / test (#%d)" 59 | # % (env_i, len(out_split), len(val_split), len(test_split)) 60 | # ) 61 | # out_splits = val_splits 62 | 63 | if target_env is not None: 64 | testenv_name = f"te_{dataset.environments[target_env]}" 65 | logger.info(f"Target env = {target_env}") 66 | else: 67 | testenv_properties = [str(dataset.environments[i]) for i in test_envs] 68 | testenv_name = "te_" + "_".join(testenv_properties) 69 | 70 | logger.info( 71 | "Testenv name escaping {} -> {}".format(testenv_name, testenv_name.replace(".", "")) 72 | ) 73 | testenv_name = testenv_name.replace(".", "") 74 | logger.info(f"Test envs = {test_envs}, name = {testenv_name}") 75 | 76 | n_envs = len(dataset) 77 | train_envs = sorted(set(range(n_envs)) - set(test_envs)) 78 | iterator = misc.SplitIterator(test_envs) 79 | batch_sizes = np.full([n_envs], hparams["batch_size"], dtype=np.int) 80 | batch_sizes50 = np.full([n_envs], int(hparams["batch_size"]*3*0.5), dtype=np.int) 81 | batch_sizes25 = np.full([n_envs], int(hparams["batch_size"]*3*0.25), dtype=np.int) 82 | 83 | 84 | batch_sizes[test_envs] = 0 85 | batch_sizes = batch_sizes.tolist() 86 | batch_sizes50[test_envs] = 0 87 | batch_sizes50 = batch_sizes50.tolist() 88 | batch_sizes25[test_envs] = 0 89 | batch_sizes25 = batch_sizes25.tolist() 90 | 91 | logger.info(f"Batch sizes for CombERM branch: {batch_sizes} (total={sum(batch_sizes)})") 92 | logger.info(f"Own domain Batch sizes for each domain: {batch_sizes50} (total={sum(batch_sizes50)})") 93 | logger.info(f"Other domain Batch sizes for each domain: {batch_sizes25} (total={sum(batch_sizes25)})") 94 | 95 | # calculate steps per epoch 96 | steps_per_epochs = [ 97 | len(env) / batch_size for (env, _), batch_size in iterator.train(zip(in_splits, batch_sizes)) 98 | ] 99 | steps_per_epochs50 = [ 100 | len(env) / batch_size50 for (env, _), batch_size50 in iterator.train(zip(in_splits, batch_sizes50)) 101 | ] 102 | steps_per_epochs25 = [ 103 | len(env) / batch_size25 for (env, _), batch_size25 in iterator.train(zip(in_splits, batch_sizes25)) 104 | ] 105 | steps_per_epoch = min(steps_per_epochs) 106 | steps_per_epoch50 = min(steps_per_epochs50) 107 | steps_per_epoch25 = min(steps_per_epochs25) 108 | # epoch is computed by steps_per_epoch 109 | prt_steps = ", ".join([f"{step:.2f}" for step in steps_per_epochs]) 110 | prt_steps50 = ", ".join([f"{step:.2f}" for step in steps_per_epochs50]) 111 | prt_steps25 = ", ".join([f"{step:.2f}" for step in steps_per_epochs25]) 112 | logger.info(f"steps-per-epoch for CombERM : {prt_steps} -> min = {steps_per_epoch:.2f}") 113 | logger.info(f"steps-per-epoch for own domain: {prt_steps50} -> min = {steps_per_epoch50:.2f}") 114 | logger.info(f"steps-per-epoch for other domain: {prt_steps25} -> min = {steps_per_epoch25:.2f}") 115 | 116 | # setup loaders 117 | train_loaders = [ 118 | InfiniteDataLoader( 119 | dataset=env, 120 | weights=env_weights, 121 | batch_size=batch_size, 122 | num_workers=dataset.N_WORKERS 123 | ) 124 | for (env, env_weights), batch_size in iterator.train(zip(in_splits, batch_sizes)) 125 | ] 126 | train_loaders50 = [ 127 | InfiniteDataLoader( 128 | dataset=env, 129 | weights=env_weights, 130 | batch_size=batch_size50, 131 | num_workers=dataset.N_WORKERS 132 | ) 133 | for (env, env_weights), batch_size50 in iterator.train(zip(in_splits, batch_sizes50)) 134 | ] 135 | train_loaders25a = [ 136 | InfiniteDataLoader( 137 | dataset=env, 138 | weights=env_weights, 139 | batch_size=batch_size25, 140 | num_workers=dataset.N_WORKERS 141 | ) 142 | for (env, env_weights), batch_size25 in iterator.train(zip(in_splits, batch_sizes25)) 143 | ] 144 | train_loaders25b = [ 145 | InfiniteDataLoader( 146 | dataset=env, 147 | weights=env_weights, 148 | batch_size=batch_size25, 149 | num_workers=dataset.N_WORKERS 150 | ) 151 | for (env, env_weights), batch_size25 in iterator.train(zip(in_splits, batch_sizes25)) 152 | ] 153 | 154 | # setup eval loaders 155 | eval_loaders_kwargs = [] 156 | for i, (env, _) in enumerate(in_splits + out_splits + test_splits): 157 | batchsize = hparams["test_batchsize"] 158 | loader_kwargs = {"dataset": env, "batch_size": batchsize, "num_workers": dataset.N_WORKERS} 159 | if args.prebuild_loader: 160 | loader_kwargs = FastDataLoader(**loader_kwargs) 161 | eval_loaders_kwargs.append(loader_kwargs) 162 | 163 | eval_weights = [None for _, weights in (in_splits + out_splits + test_splits)] 164 | eval_loader_names = ["env{}_in".format(i) for i in range(len(in_splits))] 165 | eval_loader_names += ["env{}_out".format(i) for i in range(len(out_splits))] 166 | eval_loader_names += ["env{}_inTE".format(i) for i in range(len(test_splits))] 167 | eval_meta = list(zip(eval_loader_names, eval_loaders_kwargs, eval_weights)) 168 | 169 | ####################################################### 170 | # setup algorithm (model) 171 | ####################################################### 172 | algorithmCE1 = algorithm_class( 173 | dataset.input_shape, 174 | dataset.num_classes, 175 | len(dataset) - len(test_envs), 176 | hparams, 177 | ) 178 | algorithmCE2 = algorithm_class( 179 | dataset.input_shape, 180 | dataset.num_classes, 181 | len(dataset) - len(test_envs), 182 | hparams, 183 | ) 184 | algorithmCE3 = algorithm_class( 185 | dataset.input_shape, 186 | dataset.num_classes, 187 | len(dataset) - len(test_envs), 188 | hparams, 189 | ) 190 | algorithmCE4 = algorithm_class( 191 | dataset.input_shape, 192 | dataset.num_classes, 193 | len(dataset) - len(test_envs), 194 | hparams, 195 | ) 196 | 197 | # algorithmCE1.to(device) 198 | # algorithmCE2.to(device) 199 | # algorithmCE3.to(device) 200 | # algorithmCE4.to(device) 201 | algorithmCE1.cuda() 202 | algorithmCE2.cuda() 203 | algorithmCE3.cuda() 204 | algorithmCE4.cuda() 205 | 206 | n_params = sum([p.numel() for p in algorithmCE1.parameters()]) 207 | logger.info("# of params = %d" % n_params) 208 | 209 | train_minibatches_iterator = zip(*train_loaders) 210 | train_minibatches_iterator50 = zip(*train_loaders50) 211 | train_minibatches_iterator25a = zip(*train_loaders25a) 212 | train_minibatches_iterator25b = zip(*train_loaders25b) 213 | 214 | checkpoint_vals = collections.defaultdict(lambda: []) 215 | 216 | ####################################################### 217 | # start training loop 218 | ####################################################### 219 | evaluator = Evaluator( 220 | test_envs, 221 | eval_meta, 222 | n_envs, 223 | logger, 224 | evalmode=args.evalmode, 225 | debug=args.debug, 226 | target_env=target_env, 227 | ) 228 | 229 | # swad = None 230 | # if hparams["swad"]: 231 | # swad_algorithm = swa_utils.AveragedModel(algorithm) 232 | # swad_cls = getattr(swad_module, hparams["swad"]) 233 | # swad = swad_cls(evaluator, **hparams.swad_kwargs) 234 | 235 | swad1 = None 236 | if hparams["swad"]: 237 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1) 238 | swad_cls1 = getattr(swad_module, "LossValley") 239 | swad1 = swad_cls1(evaluator, **hparams.swad_kwargs) 240 | swad2 = None 241 | if hparams["swad"]: 242 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2) 243 | swad_cls2 = getattr(swad_module, "LossValley") 244 | swad2 = swad_cls2(evaluator, **hparams.swad_kwargs) 245 | swad3 = None 246 | if hparams["swad"]: 247 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3) 248 | swad_cls3 = getattr(swad_module, "LossValley") 249 | swad3 = swad_cls3(evaluator, **hparams.swad_kwargs) 250 | swad4 = None 251 | if hparams["swad"]: 252 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4) 253 | swad_cls4 = getattr(swad_module, "LossValley") 254 | swad4 = swad_cls4(evaluator, **hparams.swad_kwargs) 255 | 256 | last_results_keys = None 257 | records = [] 258 | records_inter = [] 259 | epochs_path = args.out_dir / "results.jsonl" 260 | 261 | for step in range(n_steps): 262 | step_start_time = time.time() 263 | 264 | # batches_dictlist: [ {x: ,y: }, {x: ,y: }, {x: ,y: } ] 265 | batches_dictlist = next(train_minibatches_iterator) 266 | batches_dictlist50 = next(train_minibatches_iterator50) 267 | batches_dictlist25a = next(train_minibatches_iterator25a) 268 | batches_dictlist25b = next(train_minibatches_iterator25b) 269 | 270 | # batches: {x: [ ,, ] ,y: [ ,, ] } 271 | batchesCE = misc.merge_dictlist(batches_dictlist) 272 | batches1 = {'x': [ batches_dictlist50[0]['x'],batches_dictlist25a[1]['x'],batches_dictlist25b[2]['x'] ], 273 | 'y': [ batches_dictlist50[0]['y'],batches_dictlist25a[1]['y'],batches_dictlist25b[2]['y'] ] } 274 | batches2 = {'x': [ batches_dictlist25b[0]['x'],batches_dictlist50[1]['x'],batches_dictlist25a[2]['x'] ], 275 | 'y': [ batches_dictlist25b[0]['y'],batches_dictlist50[1]['y'],batches_dictlist25a[2]['y'] ] } 276 | batches3 = {'x': [ batches_dictlist25a[0]['x'],batches_dictlist25b[1]['x'],batches_dictlist50[2]['x'] ], 277 | 'y': [ batches_dictlist25a[0]['y'],batches_dictlist25b[1]['y'],batches_dictlist50[2]['y'] ] } 278 | 279 | # to device 280 | batchesCE = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batchesCE.items()} 281 | batches1 = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batches1.items()} 282 | batches2 = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batches2.items()} 283 | batches3 = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batches3.items()} 284 | 285 | inputsCE = {**batchesCE, "step":step} 286 | inputs1 = {**batches1, "step": step} 287 | inputs2 = {**batches2, "step": step} 288 | inputs3 = {**batches3, "step": step} 289 | 290 | step_valsCE1 = algorithmCE1.update(**inputs1) 291 | step_valsCE2 = algorithmCE2.update(**inputs2) 292 | step_valsCE3 = algorithmCE3.update(**inputs3) 293 | step_valsCE4 = algorithmCE4.update(**inputsCE) 294 | 295 | 296 | for key, val in step_valsCE1.items(): 297 | checkpoint_vals['1_'+key].append(val) 298 | for key, val in step_valsCE2.items(): 299 | checkpoint_vals['2_'+key].append(val) 300 | for key, val in step_valsCE3.items(): 301 | checkpoint_vals['3_'+key].append(val) 302 | for key, val in step_valsCE4.items(): 303 | checkpoint_vals['4_'+key].append(val) 304 | checkpoint_vals["step_time"].append(time.time() - step_start_time) 305 | 306 | if swad1: 307 | # swad_algorithm is segment_swa for swad 308 | swad_algorithm1.update_parameters(algorithmCE1, step=step) 309 | swad_algorithm2.update_parameters(algorithmCE2, step=step) 310 | swad_algorithm3.update_parameters(algorithmCE3, step=step) 311 | swad_algorithm4.update_parameters(algorithmCE4, step=step) 312 | 313 | if step % checkpoint_freq == 0: 314 | results = { 315 | "step": step, 316 | "epoch": step / steps_per_epoch, 317 | } 318 | 319 | for key, val in checkpoint_vals.items(): 320 | results[key] = np.mean(val) 321 | 322 | eval_start_time = time.time() 323 | summaries1 = evaluator.evaluate(algorithmCE1, suffix='_1') 324 | summaries2 = evaluator.evaluate(algorithmCE2, suffix='_2') 325 | summaries3 = evaluator.evaluate(algorithmCE3, suffix='_3') 326 | summaries4 = evaluator.evaluate(algorithmCE4, suffix='_4') 327 | results["eval_time"] = time.time() - eval_start_time 328 | 329 | # results = (epochs, loss, step, step_time) 330 | results_keys = list(summaries1.keys()) + list(summaries2.keys()) + list(summaries3.keys()) + list(summaries4.keys()) + list(results.keys()) 331 | # merge results 332 | results.update(summaries1) 333 | results.update(summaries2) 334 | results.update(summaries3) 335 | results.update(summaries4) 336 | 337 | # print 338 | if results_keys != last_results_keys: 339 | logger.info(misc.to_row(results_keys)) 340 | last_results_keys = results_keys 341 | logger.info(misc.to_row([results[key] for key in results_keys])) 342 | records.append(copy.deepcopy(results)) 343 | 344 | # update results to record 345 | results.update({"hparams": dict(hparams), "args": vars(args)}) 346 | 347 | with open(epochs_path, "a") as f: 348 | f.write(json.dumps(results, sort_keys=True, default=json_handler) + "\n") 349 | 350 | checkpoint_vals = collections.defaultdict(lambda: []) 351 | 352 | writer.add_scalars_with_prefix(summaries1, step, f"{testenv_name}/summary1/") 353 | writer.add_scalars_with_prefix(summaries2, step, f"{testenv_name}/summary2/") 354 | writer.add_scalars_with_prefix(summaries3, step, f"{testenv_name}/summary3/") 355 | writer.add_scalars_with_prefix(summaries4, step, f"{testenv_name}/summary4/") 356 | # writer.add_scalars_with_prefix(accuracies, step, f"{testenv_name}/all/") 357 | 358 | if args.model_save and step >= args.model_save: 359 | ckpt_dir = args.out_dir / "checkpoints" 360 | ckpt_dir.mkdir(exist_ok=True) 361 | 362 | test_env_str = ",".join(map(str, test_envs)) 363 | filename = "TE{}_{}.pth".format(test_env_str, step) 364 | if len(test_envs) > 1 and target_env is not None: 365 | train_env_str = ",".join(map(str, train_envs)) 366 | filename = f"TE{target_env}_TR{train_env_str}_{step}.pth" 367 | path = ckpt_dir / filename 368 | 369 | save_dict = { 370 | "args": vars(args), 371 | "model_hparams": dict(hparams), 372 | "test_envs": test_envs, 373 | "model_dict1": algorithmCE1.cpu().state_dict(), 374 | "model_dict2": algorithmCE2.cpu().state_dict(), 375 | "model_dict3": algorithmCE3.cpu().state_dict(), 376 | "model_dict4": algorithmCE4.cpu().state_dict(), 377 | } 378 | algorithmCE1.cuda() 379 | algorithmCE2.cuda() 380 | algorithmCE3.cuda() 381 | algorithmCE4.cuda() 382 | if not args.debug: 383 | torch.save(save_dict, path) 384 | else: 385 | logger.debug("DEBUG Mode -> no save (org path: %s)" % path) 386 | 387 | # swad 388 | if swad1: 389 | def prt_results_fn(results, avgmodel): 390 | step_str = f" [{avgmodel.start_step}-{avgmodel.end_step}]" 391 | row = misc.to_row([results[key] for key in results_keys if key in results]) 392 | logger.info(row + step_str) 393 | 394 | swad1.update_and_evaluate( 395 | swad_algorithm1, results["comb_val_1"], results["comb_val_loss_1"], prt_results_fn 396 | ) 397 | swad2.update_and_evaluate( 398 | swad_algorithm2, results["comb_val_2"], results["comb_val_loss_2"], prt_results_fn 399 | ) 400 | swad3.update_and_evaluate( 401 | swad_algorithm3, results["comb_val_3"], results["comb_val_loss_3"], prt_results_fn 402 | ) 403 | swad4.update_and_evaluate( 404 | swad_algorithm4, results["comb_val_4"], results["comb_val_loss_4"], prt_results_fn 405 | ) 406 | 407 | # if hasattr(swad, "dead_valley") and swad.dead_valley: 408 | # logger.info("SWAD valley is dead -> early stop !") 409 | # break 410 | if hasattr(swad1, "dead_valley") and swad1.dead_valley: 411 | logger.info("SWAD valley is dead for 1 -> early stop !") 412 | if hasattr(swad2, "dead_valley") and swad2.dead_valley: 413 | logger.info("SWAD valley is dead for 2 -> early stop !") 414 | if hasattr(swad3, "dead_valley") and swad3.dead_valley: 415 | logger.info("SWAD valley is dead for 3 -> early stop !") 416 | if hasattr(swad4, "dead_valley") and swad4.dead_valley: 417 | logger.info("SWAD valley is dead for 4 -> early stop !") 418 | 419 | if (hparams["model"]=='clip_vit-b16') and (step % 1500 == 0): 420 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1) # reset 421 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2) 422 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3) 423 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4) 424 | 425 | if step % args.tb_freq == 0: 426 | # add step values only for tb log 427 | writer.add_scalars_with_prefix(step_valsCE1, step, f"{testenv_name}/summary1/") 428 | writer.add_scalars_with_prefix(step_valsCE2, step, f"{testenv_name}/summary2/") 429 | writer.add_scalars_with_prefix(step_valsCE3, step, f"{testenv_name}/summary3/") 430 | writer.add_scalars_with_prefix(step_valsCE4, step, f"{testenv_name}/summary4/") 431 | 432 | if step%args.inter_freq==0 and step!=0: 433 | if args.algorithm in ['DANN', 'CDANN']: 434 | inter_state_dict = interpolate_algos(algorithmCE1.featurizer.state_dict(), algorithmCE2.featurizer.state_dict(), algorithmCE3.featurizer.state_dict(), algorithmCE4.featurizer.state_dict()) 435 | algorithmCE1.featurizer.load_state_dict(inter_state_dict) 436 | algorithmCE2.featurizer.load_state_dict(inter_state_dict) 437 | algorithmCE3.featurizer.load_state_dict(inter_state_dict) 438 | algorithmCE4.featurizer.load_state_dict(inter_state_dict) 439 | inter_state_dict2 = interpolate_algos(algorithmCE1.classifier.state_dict(), algorithmCE2.classifier.state_dict(), algorithmCE3.classifier.state_dict(), algorithmCE4.classifier.state_dict()) 440 | algorithmCE1.classifier.load_state_dict(inter_state_dict2) 441 | algorithmCE2.classifier.load_state_dict(inter_state_dict2) 442 | algorithmCE3.classifier.load_state_dict(inter_state_dict2) 443 | algorithmCE4.classifier.load_state_dict(inter_state_dict2) 444 | inter_state_dict3 = interpolate_algos(algorithmCE1.discriminator.state_dict(), algorithmCE2.discriminator.state_dict(), algorithmCE3.discriminator.state_dict(), algorithmCE4.discriminator.state_dict()) 445 | algorithmCE1.discriminator.load_state_dict(inter_state_dict3) 446 | algorithmCE2.discriminator.load_state_dict(inter_state_dict3) 447 | algorithmCE3.discriminator.load_state_dict(inter_state_dict3) 448 | algorithmCE4.discriminator.load_state_dict(inter_state_dict3) 449 | 450 | elif args.algorithm in ['SagNet']: 451 | inter_state_dict = interpolate_algos(algorithmCE1.network_f.state_dict(), algorithmCE2.network_f.state_dict(), algorithmCE3.network_f.state_dict(), algorithmCE4.network_f.state_dict()) 452 | algorithmCE1.network_f.load_state_dict(inter_state_dict) 453 | algorithmCE2.network_f.load_state_dict(inter_state_dict) 454 | algorithmCE3.network_f.load_state_dict(inter_state_dict) 455 | algorithmCE4.network_f.load_state_dict(inter_state_dict) 456 | inter_state_dict2 = interpolate_algos(algorithmCE1.network_c.state_dict(), algorithmCE2.network_c.state_dict(), algorithmCE3.network_c.state_dict(), algorithmCE4.network_c.state_dict()) 457 | algorithmCE1.network_c.load_state_dict(inter_state_dict2) 458 | algorithmCE2.network_c.load_state_dict(inter_state_dict2) 459 | algorithmCE3.network_c.load_state_dict(inter_state_dict2) 460 | algorithmCE4.network_c.load_state_dict(inter_state_dict2) 461 | inter_state_dict3 = interpolate_algos(algorithmCE1.network_s.state_dict(), algorithmCE2.network_s.state_dict(), algorithmCE3.network_s.state_dict(), algorithmCE4.network_s.state_dict()) 462 | algorithmCE1.network_s.load_state_dict(inter_state_dict3) 463 | algorithmCE2.network_s.load_state_dict(inter_state_dict3) 464 | algorithmCE3.network_s.load_state_dict(inter_state_dict3) 465 | algorithmCE4.network_s.load_state_dict(inter_state_dict3) 466 | 467 | else: 468 | inter_state_dict = interpolate_algos(algorithmCE1.network.state_dict(), algorithmCE2.network.state_dict(), algorithmCE3.network.state_dict(), algorithmCE4.network.state_dict()) 469 | algorithmCE1.network.load_state_dict(inter_state_dict) 470 | algorithmCE2.network.load_state_dict(inter_state_dict) 471 | algorithmCE3.network.load_state_dict(inter_state_dict) 472 | algorithmCE4.network.load_state_dict(inter_state_dict) 473 | 474 | logger.info(f"Evaluating interpolated model at {step} step") 475 | summaries_inter = evaluator.evaluate(algorithmCE1, suffix='_from_inter') 476 | inter_results = {"inter_step": step, "inter_epoch": step / steps_per_epoch} 477 | inter_results_keys = list(summaries_inter.keys()) + list(inter_results.keys()) 478 | inter_results.update(summaries_inter) 479 | logger.info(misc.to_row([inter_results[key] for key in inter_results_keys])) 480 | records_inter.append(copy.deepcopy(inter_results)) 481 | writer.add_scalars_with_prefix(summaries_inter, step, f"{testenv_name}/summary_inter/") 482 | 483 | # find best 484 | logger.info("---") 485 | # print(records) 486 | records = Q(records) 487 | records_inter = Q(records_inter) 488 | 489 | # print(len(records)) 490 | # print(records) 491 | 492 | # 1 493 | oracle_best1 = records.argmax("test_out_1")["test_in_1"] 494 | iid_best1 = records.argmax("comb_val_1")["test_in_1"] 495 | inDom1 = records.argmax("comb_val_1")["comb_val_1"] 496 | # own_best1 = records.argmax("own_val_from_first")["test_in_from_first"] 497 | last1 = records[-1]["test_in_1"] 498 | # 2 499 | oracle_best2 = records.argmax("test_out_2")["test_in_2"] 500 | iid_best2 = records.argmax("comb_val_2")["test_in_2"] 501 | inDom2 = records.argmax("comb_val_2")["comb_val_2"] 502 | # own_best2 = records.argmax("own_val_from_second")["test_in_from_second"] 503 | last2 = records[-1]["test_in_2"] 504 | # 3 505 | oracle_best3 = records.argmax("test_out_3")["test_in_3"] 506 | iid_best3 = records.argmax("comb_val_3")["test_in_3"] 507 | inDom3 = records.argmax("comb_val_3")["comb_val_3"] 508 | # own_best3 = records.argmax("own_val_from_third")["test_in_from_third"] 509 | last3 = records[-1]["test_in_3"] 510 | # CE 511 | oracle_best4 = records.argmax("test_out_4")["test_in_4"] 512 | iid_best4 = records.argmax("comb_val_4")["test_in_4"] 513 | inDom4 = records.argmax("comb_val_4")["comb_val_4"] 514 | last4 = records[-1]["test_in_4"] 515 | # inter 516 | oracle_best_inter = records_inter.argmax("test_out_from_inter")["test_in_from_inter"] 517 | iid_best_inter = records_inter.argmax("comb_val_from_inter")["test_in_from_inter"] 518 | inDom_inter = records_inter.argmax("comb_val_from_inter")["comb_val_from_inter"] 519 | 520 | # if hparams.indomain_test: 521 | # # if test set exist, use test set for indomain results 522 | # in_key = "train_inTE" 523 | # else: 524 | # in_key = "train_out" 525 | 526 | # iid_best_indomain = records.argmax("train_out")[in_key] 527 | # last_indomain = records[-1][in_key] 528 | 529 | ret = { 530 | "oracle_1": oracle_best1, 531 | "iid_1": iid_best1, 532 | # "own_1": own_best1, 533 | "inDom1": inDom1, 534 | "last_1": last1, 535 | "oracle_2": oracle_best2, 536 | "iid_2": iid_best2, 537 | # "own_2": own_best2, 538 | "inDom2":inDom2, 539 | "last_2": last2, 540 | "oracle_3": oracle_best3, 541 | "iid_3": iid_best3, 542 | # "own_3": own_best3, 543 | "inDom3":inDom3, 544 | "last_3": last3, 545 | "oracle_4": oracle_best4, 546 | "iid_4": iid_best4, 547 | "inDom4": inDom4, 548 | "last_4": last4, 549 | "oracle_inter": oracle_best_inter, 550 | "iid_inter": iid_best_inter, 551 | "inDom_inter":inDom_inter, 552 | } 553 | 554 | # Evaluate SWAD 555 | if swad1: 556 | swad_algorithm1 = swad1.get_final_model() 557 | swad_algorithm2 = swad2.get_final_model() 558 | swad_algorithm3 = swad3.get_final_model() 559 | swad_algorithm4 = swad4.get_final_model() 560 | if hparams["freeze_bn"] is False: 561 | n_steps = 500 if not args.debug else 10 562 | logger.warning(f"Update SWAD BN statistics for {n_steps} steps ...") 563 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm1, n_steps) 564 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm2, n_steps) 565 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm3, n_steps) 566 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm4, n_steps) 567 | 568 | logger.warning("Evaluate SWAD ...") 569 | summaries_swad1 = evaluator.evaluate(swad_algorithm1, suffix='_s1') 570 | summaries_swad2 = evaluator.evaluate(swad_algorithm2, suffix='_s2') 571 | summaries_swad3 = evaluator.evaluate(swad_algorithm3, suffix='_s3') 572 | summaries_swad4 = evaluator.evaluate(swad_algorithm4, suffix='_s4') 573 | 574 | swad_results = {**summaries_swad1, **summaries_swad2, **summaries_swad3, **summaries_swad4} 575 | step_str = f" [{swad_algorithm1.start_step}-{swad_algorithm1.end_step}] (N={swad_algorithm1.n_averaged}) || [{swad_algorithm2.start_step}-{swad_algorithm2.end_step}] (N={swad_algorithm2.n_averaged}) || [{swad_algorithm3.start_step}-{swad_algorithm3.end_step}] (N={swad_algorithm3.n_averaged}) || [{swad_algorithm4.start_step}-{swad_algorithm4.end_step}] (N={swad_algorithm4.n_averaged})" 576 | row = misc.to_row([swad_results[key] for key in list(swad_results.keys())]) + step_str 577 | logger.info(row) 578 | 579 | ret["SWAD 1"] = swad_results["test_in_s1"] 580 | ret["SWAD 1 (inDom)"] = swad_results["comb_val_s1"] 581 | ret["SWAD 2"] = swad_results["test_in_s2"] 582 | ret["SWAD 2 (inDom)"] = swad_results["comb_val_s2"] 583 | ret["SWAD 3"] = swad_results["test_in_s3"] 584 | ret["SWAD 3 (inDom)"] = swad_results["comb_val_s3"] 585 | ret["SWAD 4"] = swad_results["test_in_s4"] 586 | ret["SWAD 4 (inDom)"] = swad_results["comb_val_s4"] 587 | 588 | save_dict = { 589 | "args": vars(args), 590 | "model_hparams": dict(hparams), 591 | "test_envs": test_envs, 592 | "SWAD_1": swad_algorithm1.state_dict(), 593 | "SWAD_2": swad_algorithm2.state_dict(), 594 | "SWAD_3": swad_algorithm3.state_dict(), 595 | "SWAD_4": swad_algorithm4.state_dict(), 596 | } 597 | 598 | if args.algorithm in ['DANN', 'CDANN']: 599 | inter_state_dict = interpolate_algos(swad_algorithm1.module.featurizer.state_dict(), swad_algorithm2.module.featurizer.state_dict(), swad_algorithm3.module.featurizer.state_dict(), swad_algorithm4.module.featurizer.state_dict()) 600 | swad_algorithm1.module.featurizer.load_state_dict(inter_state_dict) 601 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.classifier.state_dict(), swad_algorithm2.module.classifier.state_dict(), swad_algorithm3.module.classifier.state_dict(), swad_algorithm4.module.classifier.state_dict()) 602 | swad_algorithm1.module.classifier.load_state_dict(inter_state_dict2) 603 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.discriminator.state_dict(), swad_algorithm2.module.discriminator.state_dict(), swad_algorithm3.module.discriminator.state_dict(), swad_algorithm4.module.discriminator.state_dict()) 604 | swad_algorithm1.module.discriminator.load_state_dict(inter_state_dict3) 605 | 606 | elif args.algorithm in ['SagNet']: 607 | inter_state_dict = interpolate_algos(swad_algorithm1.module.network_f.state_dict(), swad_algorithm2.module.network_f.state_dict(), swad_algorithm3.module.network_f.state_dict(), swad_algorithm4.module.network_f.state_dict()) 608 | swad_algorithm1.module.network_f.load_state_dict(inter_state_dict) 609 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.network_c.state_dict(), swad_algorithm2.module.network_c.state_dict(), swad_algorithm3.module.network_c.state_dict(), swad_algorithm4.module.network_c.state_dict()) 610 | swad_algorithm1.module.network_c.load_state_dict(inter_state_dict2) 611 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.network_s.state_dict(), swad_algorithm2.module.network_s.state_dict(), swad_algorithm3.module.network_s.state_dict(), swad_algorithm4.module.network_s.state_dict()) 612 | swad_algorithm1.module.network_s.load_state_dict(inter_state_dict3) 613 | 614 | else: 615 | inter_state_dict = interpolate_algos(swad_algorithm1.network.state_dict(), swad_algorithm2.network.state_dict(), swad_algorithm3.network.state_dict(), swad_algorithm4.network.state_dict()) 616 | swad_algorithm1.network.load_state_dict(inter_state_dict) 617 | 618 | logger.info(f"Evaluating interpolated model of SWAD models") 619 | summaries_swadinter = evaluator.evaluate(swad_algorithm1, suffix='_from_swadinter') 620 | swadinter_results = {**summaries_swadinter} 621 | logger.info(misc.to_row([swadinter_results[key] for key in list(swadinter_results.keys())])) 622 | ret["SWAD INTER"] = swadinter_results["test_in_from_swadinter"] 623 | ret["SWAD INTER (inDom)"] = swadinter_results["comb_val_from_swadinter"] 624 | save_dict["SWAD_INTER"] = inter_state_dict 625 | 626 | 627 | ckpt_dir = args.out_dir / "checkpoints" 628 | ckpt_dir.mkdir(exist_ok=True) 629 | test_env_str = ",".join(map(str, test_envs)) 630 | filename = f"TE{test_env_str}.pth" 631 | if len(test_envs) > 1 and target_env is not None: 632 | train_env_str = ",".join(map(str, train_envs)) 633 | filename = f"TE{target_env}_TR{train_env_str}.pth" 634 | path = ckpt_dir / filename 635 | if swad1: 636 | torch.save(save_dict, path) 637 | 638 | 639 | for k, acc in ret.items(): 640 | logger.info(f"{k} = {acc:.3%}") 641 | 642 | return ret, records 643 | 644 | -------------------------------------------------------------------------------- /domainbed/trainer_DN.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import time 4 | import copy 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | 11 | from domainbed.datasets import get_dataset, split_dataset 12 | from domainbed import algorithms 13 | from domainbed.evaluator import Evaluator 14 | from domainbed.lib import misc 15 | from domainbed.lib import swa_utils 16 | from domainbed.lib.query import Q 17 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 18 | from domainbed import swad as swad_module 19 | 20 | # if torch.cuda.is_available(): 21 | # device = "cuda" 22 | # else: 23 | # device = "cpu" 24 | 25 | 26 | def json_handler(v): 27 | if isinstance(v, (Path, range)): 28 | return str(v) 29 | raise TypeError(f"`{type(v)}` is not JSON Serializable") 30 | 31 | 32 | def interpolate_algos(sd1, sd2, sd3, sd4, sd5, sd6): 33 | return {key: (sd1[key] + sd2[key] + sd3[key] +sd4[key]+ sd5[key] +sd6[key])/6 for key in sd1.keys()} 34 | 35 | def train(test_envs, args, hparams, n_steps, checkpoint_freq, logger, writer, target_env=None): 36 | logger.info("") 37 | # n_steps = 1 38 | ####################################################### 39 | # setup dataset & loader 40 | ####################################################### 41 | args.real_test_envs = test_envs # for log 42 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 43 | dataset, in_splits, out_splits = get_dataset(test_envs, args, hparams, algorithm_class) 44 | test_splits = [] 45 | # if hparams.indomain_test > 0.0: 46 | # logger.info("!!! In-domain test mode On !!!") 47 | # assert hparams["val_augment"] is False, ( 48 | # "indomain_test split the val set into val/test sets. " 49 | # "Therefore, the val set should be not augmented." 50 | # ) 51 | # val_splits = [] 52 | # for env_i, (out_split, _weights) in enumerate(out_splits): 53 | # n = len(out_split) // 2 54 | # seed = misc.seed_hash(args.trial_seed, env_i) 55 | # val_split, test_split = split_dataset(out_split, n, seed=seed) 56 | # val_splits.append((val_split, None)) 57 | # test_splits.append((test_split, None)) 58 | # logger.info( 59 | # "env %d: out (#%d) -> val (#%d) / test (#%d)" 60 | # % (env_i, len(out_split), len(val_split), len(test_split)) 61 | # ) 62 | # out_splits = val_splits 63 | 64 | if target_env is not None: 65 | testenv_name = f"te_{dataset.environments[target_env]}" 66 | logger.info(f"Target env = {target_env}") 67 | else: 68 | testenv_properties = [str(dataset.environments[i]) for i in test_envs] 69 | testenv_name = "te_" + "_".join(testenv_properties) 70 | 71 | logger.info( 72 | "Testenv name escaping {} -> {}".format(testenv_name, testenv_name.replace(".", "")) 73 | ) 74 | testenv_name = testenv_name.replace(".", "") 75 | logger.info(f"Test envs = {test_envs}, name = {testenv_name}") 76 | 77 | n_envs = len(dataset) 78 | train_envs = sorted(set(range(n_envs)) - set(test_envs)) 79 | iterator = misc.SplitIterator(test_envs) 80 | batch_sizes = np.full([n_envs], hparams["batch_size"], dtype=np.int) 81 | batch_sizes50 = np.full([n_envs], int(hparams["batch_size"]*5*0.4), dtype=np.int) 82 | batch_sizes25 = np.full([n_envs], int(hparams["batch_size"]*5*0.15), dtype=np.int) 83 | 84 | 85 | batch_sizes[test_envs] = 0 86 | batch_sizes = batch_sizes.tolist() 87 | batch_sizes50[test_envs] = 0 88 | batch_sizes50 = batch_sizes50.tolist() 89 | batch_sizes25[test_envs] = 0 90 | batch_sizes25 = batch_sizes25.tolist() 91 | 92 | logger.info(f"Batch sizes for CombERM branch: {batch_sizes} (total={sum(batch_sizes)})") 93 | logger.info(f"Own domain Batch sizes for each domain: {batch_sizes50} (total={sum(batch_sizes50)})") 94 | logger.info(f"Other domain Batch sizes for each domain: {batch_sizes25} (total={sum(batch_sizes25)})") 95 | 96 | # calculate steps per epoch 97 | steps_per_epochs = [ 98 | len(env) / batch_size for (env, _), batch_size in iterator.train(zip(in_splits, batch_sizes)) 99 | ] 100 | steps_per_epochs50 = [ 101 | len(env) / batch_size50 for (env, _), batch_size50 in iterator.train(zip(in_splits, batch_sizes50)) 102 | ] 103 | steps_per_epochs25 = [ 104 | len(env) / batch_size25 for (env, _), batch_size25 in iterator.train(zip(in_splits, batch_sizes25)) 105 | ] 106 | steps_per_epoch = min(steps_per_epochs) 107 | steps_per_epoch50 = min(steps_per_epochs50) 108 | steps_per_epoch25 = min(steps_per_epochs25) 109 | 110 | # epoch is computed by steps_per_epoch 111 | prt_steps = ", ".join([f"{step:.2f}" for step in steps_per_epochs]) 112 | prt_steps50 = ", ".join([f"{step:.2f}" for step in steps_per_epochs50]) 113 | prt_steps25 = ", ".join([f"{step:.2f}" for step in steps_per_epochs25]) 114 | logger.info(f"steps-per-epoch for CombERM : {prt_steps} -> min = {steps_per_epoch:.2f}") 115 | logger.info(f"steps-per-epoch for own domain: {prt_steps50} -> min = {steps_per_epoch50:.2f}") 116 | logger.info(f"steps-per-epoch for other domain: {prt_steps25} -> min = {steps_per_epoch25:.2f}") 117 | 118 | # setup loaders 119 | train_loaders = [ 120 | InfiniteDataLoader( 121 | dataset=env, 122 | weights=env_weights, 123 | batch_size=batch_size, 124 | num_workers=dataset.N_WORKERS 125 | ) 126 | for (env, env_weights), batch_size in iterator.train(zip(in_splits, batch_sizes)) 127 | ] 128 | train_loaders50 = [ 129 | InfiniteDataLoader( 130 | dataset=env, 131 | weights=env_weights, 132 | batch_size=batch_size50, 133 | num_workers=dataset.N_WORKERS 134 | ) 135 | for (env, env_weights), batch_size50 in iterator.train(zip(in_splits, batch_sizes50)) 136 | ] 137 | train_loaders25 = [ 138 | InfiniteDataLoader( 139 | dataset=env, 140 | weights=env_weights, 141 | batch_size=batch_size25, 142 | num_workers=dataset.N_WORKERS 143 | ) 144 | for (env, env_weights), batch_size25 in iterator.train(zip(in_splits, batch_sizes25)) 145 | ] 146 | 147 | # setup eval loaders 148 | eval_loaders_kwargs = [] 149 | for i, (env, _) in enumerate(in_splits + out_splits + test_splits): 150 | batchsize = hparams["test_batchsize"] 151 | loader_kwargs = {"dataset": env, "batch_size": batchsize, "num_workers": dataset.N_WORKERS} 152 | if args.prebuild_loader: 153 | loader_kwargs = FastDataLoader(**loader_kwargs) 154 | eval_loaders_kwargs.append(loader_kwargs) 155 | 156 | eval_weights = [None for _, weights in (in_splits + out_splits + test_splits)] 157 | eval_loader_names = ["env{}_in".format(i) for i in range(len(in_splits))] 158 | eval_loader_names += ["env{}_out".format(i) for i in range(len(out_splits))] 159 | eval_loader_names += ["env{}_inTE".format(i) for i in range(len(test_splits))] 160 | eval_meta = list(zip(eval_loader_names, eval_loaders_kwargs, eval_weights)) 161 | 162 | ####################################################### 163 | # setup algorithm (model) 164 | ####################################################### 165 | algorithmCE1 = algorithm_class( 166 | dataset.input_shape, 167 | dataset.num_classes, 168 | len(dataset) - len(test_envs), 169 | hparams, 170 | ) 171 | algorithmCE2 = algorithm_class( 172 | dataset.input_shape, 173 | dataset.num_classes, 174 | len(dataset) - len(test_envs), 175 | hparams, 176 | ) 177 | algorithmCE3 = algorithm_class( 178 | dataset.input_shape, 179 | dataset.num_classes, 180 | len(dataset) - len(test_envs), 181 | hparams, 182 | ) 183 | algorithmCE4 = algorithm_class( 184 | dataset.input_shape, 185 | dataset.num_classes, 186 | len(dataset) - len(test_envs), 187 | hparams, 188 | ) 189 | algorithmCE5 = algorithm_class( 190 | dataset.input_shape, 191 | dataset.num_classes, 192 | len(dataset) - len(test_envs), 193 | hparams, 194 | ) 195 | algorithmCE6 = algorithm_class( 196 | dataset.input_shape, 197 | dataset.num_classes, 198 | len(dataset) - len(test_envs), 199 | hparams, 200 | ) 201 | 202 | algorithmCE1.cuda() 203 | algorithmCE2.cuda() 204 | algorithmCE3.cuda() 205 | algorithmCE4.cuda() 206 | algorithmCE5.cuda() 207 | algorithmCE6.cuda() 208 | 209 | n_params = sum([p.numel() for p in algorithmCE1.parameters()]) 210 | logger.info("# of params = %d" % n_params) 211 | 212 | train_minibatches_iterator = zip(*train_loaders) 213 | train_minibatches_iterator50 = zip(*train_loaders50) 214 | train_minibatches_iterator25 = zip(*train_loaders25) 215 | # train_minibatches_iterator25b = zip(*train_loaders25b) 216 | 217 | checkpoint_vals = collections.defaultdict(lambda: []) 218 | 219 | ####################################################### 220 | # start training loop 221 | ####################################################### 222 | evaluator = Evaluator( 223 | test_envs, 224 | eval_meta, 225 | n_envs, 226 | logger, 227 | evalmode=args.evalmode, 228 | debug=args.debug, 229 | target_env=target_env, 230 | ) 231 | 232 | # swad = None 233 | # if hparams["swad"]: 234 | # swad_algorithm = swa_utils.AveragedModel(algorithm) 235 | # swad_cls = getattr(swad_module, hparams["swad"]) 236 | # swad = swad_cls(evaluator, **hparams.swad_kwargs) 237 | 238 | swad1 = None 239 | if hparams["swad"]: 240 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1) 241 | swad_cls1 = getattr(swad_module, "LossValley") 242 | swad1 = swad_cls1(evaluator, **hparams.swad_kwargs) 243 | swad2 = None 244 | if hparams["swad"]: 245 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2) 246 | swad_cls2 = getattr(swad_module, "LossValley") 247 | swad2 = swad_cls2(evaluator, **hparams.swad_kwargs) 248 | swad3 = None 249 | if hparams["swad"]: 250 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3) 251 | swad_cls3 = getattr(swad_module, "LossValley") 252 | swad3 = swad_cls3(evaluator, **hparams.swad_kwargs) 253 | swad4 = None 254 | if hparams["swad"]: 255 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4) 256 | swad_cls4 = getattr(swad_module, "LossValley") 257 | swad4 = swad_cls4(evaluator, **hparams.swad_kwargs) 258 | swad5 = None 259 | if hparams["swad"]: 260 | swad_algorithm5 = swa_utils.AveragedModel(algorithmCE5) 261 | swad_cls5 = getattr(swad_module, "LossValley") 262 | swad5 = swad_cls5(evaluator, **hparams.swad_kwargs) 263 | swad6 = None 264 | if hparams["swad"]: 265 | swad_algorithm6 = swa_utils.AveragedModel(algorithmCE6) 266 | swad_cls6 = getattr(swad_module,"LossValley") 267 | swad6 = swad_cls6(evaluator, **hparams.swad_kwargs) 268 | 269 | last_results_keys = None 270 | records = [] 271 | records_inter = [] 272 | epochs_path = args.out_dir / "results.jsonl" 273 | 274 | for step in range(n_steps): 275 | step_start_time = time.time() 276 | 277 | # batches_dictlist: [ {x: ,y: }, {x: ,y: }, {x: ,y: } ] 278 | batches_dictlist = next(train_minibatches_iterator) 279 | batches_dictlist50 = next(train_minibatches_iterator50) 280 | batches_dictlist25 = next(train_minibatches_iterator25) 281 | 282 | # batches: {x: [ ,, ] ,y: [ ,, ] } 283 | batchesCE = misc.merge_dictlist(batches_dictlist) 284 | batches1 = {'x': [ batches_dictlist50[0]['x'],batches_dictlist25[1]['x'],batches_dictlist25[2]['x'],batches_dictlist25[3]['x'],batches_dictlist25[4]['x'] ], 285 | 'y': [ batches_dictlist50[0]['y'],batches_dictlist25[1]['y'],batches_dictlist25[2]['y'],batches_dictlist25[3]['y'],batches_dictlist25[4]['y'] ] } 286 | batches2 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist50[1]['x'],batches_dictlist25[2]['x'],batches_dictlist25[3]['x'],batches_dictlist25[4]['x'] ], 287 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist50[1]['y'],batches_dictlist25[2]['y'],batches_dictlist25[3]['y'],batches_dictlist25[4]['y'] ] } 288 | batches3 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist25[1]['x'],batches_dictlist50[2]['x'],batches_dictlist25[3]['x'],batches_dictlist25[4]['x'] ], 289 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist25[1]['y'],batches_dictlist50[2]['y'],batches_dictlist25[3]['y'],batches_dictlist25[4]['y'] ] } 290 | batches4 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist25[1]['x'],batches_dictlist25[2]['x'],batches_dictlist50[3]['x'],batches_dictlist25[4]['x'] ], 291 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist25[1]['y'],batches_dictlist25[2]['y'],batches_dictlist50[3]['y'],batches_dictlist25[4]['y'] ] } 292 | batches5 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist25[1]['x'],batches_dictlist25[2]['x'],batches_dictlist25[3]['x'],batches_dictlist50[4]['x'] ], 293 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist25[1]['y'],batches_dictlist25[2]['y'],batches_dictlist25[3]['y'],batches_dictlist50[4]['y'] ] } 294 | 295 | # to device 296 | batchesCE = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batchesCE.items()} 297 | batches1 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches1.items()} 298 | batches2 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches2.items()} 299 | batches3 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches3.items()} 300 | batches4 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches4.items()} 301 | batches5 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches5.items()} 302 | 303 | inputsCE = {**batchesCE, "step":step} 304 | inputs1 = {**batches1, "step": step} 305 | inputs2 = {**batches2, "step": step} 306 | inputs3 = {**batches3, "step": step} 307 | inputs4 = {**batches4, "step": step} 308 | inputs5 = {**batches5, "step": step} 309 | 310 | step_valsCE1 = algorithmCE1.update(**inputs1) 311 | step_valsCE2 = algorithmCE2.update(**inputs2) 312 | step_valsCE3 = algorithmCE3.update(**inputs3) 313 | step_valsCE4 = algorithmCE4.update(**inputs4) 314 | step_valsCE5 = algorithmCE5.update(**inputs5) 315 | step_valsCE6 = algorithmCE6.update(**inputsCE) 316 | 317 | 318 | for key, val in step_valsCE1.items(): 319 | checkpoint_vals['1_'+key].append(val) 320 | for key, val in step_valsCE2.items(): 321 | checkpoint_vals['2_'+key].append(val) 322 | for key, val in step_valsCE3.items(): 323 | checkpoint_vals['3_'+key].append(val) 324 | for key, val in step_valsCE4.items(): 325 | checkpoint_vals['4_'+key].append(val) 326 | for key, val in step_valsCE5.items(): 327 | checkpoint_vals['5_'+key].append(val) 328 | for key, val in step_valsCE6.items(): 329 | checkpoint_vals['6_'+key].append(val) 330 | checkpoint_vals["step_time"].append(time.time() - step_start_time) 331 | 332 | if swad1: 333 | # swad_algorithm is segment_swa for swad 334 | swad_algorithm1.update_parameters(algorithmCE1, step=step) 335 | swad_algorithm2.update_parameters(algorithmCE2, step=step) 336 | swad_algorithm3.update_parameters(algorithmCE3, step=step) 337 | swad_algorithm4.update_parameters(algorithmCE4, step=step) 338 | swad_algorithm5.update_parameters(algorithmCE5, step=step) 339 | swad_algorithm6.update_parameters(algorithmCE6, step=step) 340 | 341 | if step % checkpoint_freq == 0: 342 | results = { 343 | "step": step, 344 | "epoch": step / steps_per_epoch, 345 | } 346 | 347 | for key, val in checkpoint_vals.items(): 348 | results[key] = np.mean(val) 349 | 350 | eval_start_time = time.time() 351 | summaries1 = evaluator.evaluate(algorithmCE1, suffix='_1') 352 | summaries2 = evaluator.evaluate(algorithmCE2, suffix='_2') 353 | summaries3 = evaluator.evaluate(algorithmCE3, suffix='_3') 354 | summaries4 = evaluator.evaluate(algorithmCE4, suffix='_4') 355 | summaries5 = evaluator.evaluate(algorithmCE5, suffix='_5') 356 | summaries6 = evaluator.evaluate(algorithmCE6, suffix='_6') 357 | results["eval_time"] = time.time() - eval_start_time 358 | 359 | # results = (epochs, loss, step, step_time) 360 | results_keys = list(summaries1.keys()) + list(summaries2.keys()) + list(summaries3.keys()) + list(summaries4.keys()) + list(summaries5.keys()) + list(summaries6.keys()) + list(results.keys()) 361 | # merge results 362 | results.update(summaries1) 363 | results.update(summaries2) 364 | results.update(summaries3) 365 | results.update(summaries4) 366 | results.update(summaries5) 367 | results.update(summaries6) 368 | 369 | # print 370 | if results_keys != last_results_keys: 371 | logger.info(misc.to_row(results_keys)) 372 | last_results_keys = results_keys 373 | logger.info(misc.to_row([results[key] for key in results_keys])) 374 | records.append(copy.deepcopy(results)) 375 | 376 | # update results to record 377 | results.update({"hparams": dict(hparams), "args": vars(args)}) 378 | 379 | with open(epochs_path, "a") as f: 380 | f.write(json.dumps(results, sort_keys=True, default=json_handler) + "\n") 381 | 382 | checkpoint_vals = collections.defaultdict(lambda: []) 383 | 384 | writer.add_scalars_with_prefix(summaries1, step, f"{testenv_name}/summary1/") 385 | writer.add_scalars_with_prefix(summaries2, step, f"{testenv_name}/summary2/") 386 | writer.add_scalars_with_prefix(summaries3, step, f"{testenv_name}/summary3/") 387 | writer.add_scalars_with_prefix(summaries4, step, f"{testenv_name}/summary4/") 388 | writer.add_scalars_with_prefix(summaries5, step, f"{testenv_name}/summary5/") 389 | writer.add_scalars_with_prefix(summaries6, step, f"{testenv_name}/summary6/") 390 | # writer.add_scalars_with_prefix(accuracies, step, f"{testenv_name}/all/") 391 | 392 | if args.model_save and step >= args.model_save: 393 | ckpt_dir = args.out_dir / "checkpoints" 394 | ckpt_dir.mkdir(exist_ok=True) 395 | 396 | test_env_str = ",".join(map(str, test_envs)) 397 | filename = "TE{}_{}.pth".format(test_env_str, step) 398 | if len(test_envs) > 1 and target_env is not None: 399 | train_env_str = ",".join(map(str, train_envs)) 400 | filename = f"TE{target_env}_TR{train_env_str}_{step}.pth" 401 | path = ckpt_dir / filename 402 | 403 | save_dict = { 404 | "args": vars(args), 405 | "model_hparams": dict(hparams), 406 | "test_envs": test_envs, 407 | "model_dict1": algorithmCE1.cpu().state_dict(), 408 | "model_dict2": algorithmCE2.cpu().state_dict(), 409 | "model_dict3": algorithmCE3.cpu().state_dict(), 410 | "model_dict4": algorithmCE4.cpu().state_dict(), 411 | "model_dict5": algorithmCE5.cpu().state_dict(), 412 | "model_dict6": algorithmCE6.cpu().state_dict(), 413 | } 414 | algorithmCE1.cuda() 415 | algorithmCE2.cuda() 416 | algorithmCE3.cuda() 417 | algorithmCE4.cuda() 418 | algorithmCE5.cuda() 419 | algorithmCE6.cuda() 420 | if not args.debug: 421 | torch.save(save_dict, path) 422 | else: 423 | logger.debug("DEBUG Mode -> no save (org path: %s)" % path) 424 | 425 | # swad 426 | if swad1: 427 | def prt_results_fn(results, avgmodel): 428 | step_str = f" [{avgmodel.start_step}-{avgmodel.end_step}]" 429 | row = misc.to_row([results[key] for key in results_keys if key in results]) 430 | logger.info(row + step_str) 431 | 432 | swad1.update_and_evaluate( 433 | swad_algorithm1, results["comb_val_1"], results["comb_val_loss_1"], prt_results_fn 434 | ) 435 | swad2.update_and_evaluate( 436 | swad_algorithm2, results["comb_val_2"], results["comb_val_loss_2"], prt_results_fn 437 | ) 438 | swad3.update_and_evaluate( 439 | swad_algorithm3, results["comb_val_3"], results["comb_val_loss_3"], prt_results_fn 440 | ) 441 | swad4.update_and_evaluate( 442 | swad_algorithm4, results["comb_val_4"], results["comb_val_loss_4"], prt_results_fn 443 | ) 444 | swad5.update_and_evaluate( 445 | swad_algorithm5, results["comb_val_5"], results["comb_val_loss_5"], prt_results_fn 446 | ) 447 | swad6.update_and_evaluate( 448 | swad_algorithm6, results["comb_val_6"], results["comb_val_loss_6"], prt_results_fn 449 | ) 450 | 451 | # if hasattr(swad, "dead_valley") and swad.dead_valley: 452 | # logger.info("SWAD valley is dead -> early stop !") 453 | # break 454 | if hasattr(swad1, "dead_valley") and swad1.dead_valley: 455 | logger.info("SWAD valley is dead for 1 -> early stop !") 456 | if hasattr(swad2, "dead_valley") and swad2.dead_valley: 457 | logger.info("SWAD valley is dead for 2 -> early stop !") 458 | if hasattr(swad3, "dead_valley") and swad3.dead_valley: 459 | logger.info("SWAD valley is dead for 3 -> early stop !") 460 | if hasattr(swad4, "dead_valley") and swad4.dead_valley: 461 | logger.info("SWAD valley is dead for 4 -> early stop !") 462 | if hasattr(swad5, "dead_valley") and swad5.dead_valley: 463 | logger.info("SWAD valley is dead for 5 -> early stop !") 464 | if hasattr(swad6, "dead_valley") and swad6.dead_valley: 465 | logger.info("SWAD valley is dead for 6 -> early stop !") 466 | 467 | 468 | if (hparams["model"]=='clip_vit-b16') and (step % 2000 == 0): 469 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1) # reset 470 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2) 471 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3) 472 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4) 473 | swad_algorithm5 = swa_utils.AveragedModel(algorithmCE5) 474 | swad_algorithm6 = swa_utils.AveragedModel(algorithmCE6) 475 | 476 | if step % args.tb_freq == 0: 477 | # add step values only for tb log 478 | writer.add_scalars_with_prefix(step_valsCE1, step, f"{testenv_name}/summary1/") 479 | writer.add_scalars_with_prefix(step_valsCE2, step, f"{testenv_name}/summary2/") 480 | writer.add_scalars_with_prefix(step_valsCE3, step, f"{testenv_name}/summary3/") 481 | writer.add_scalars_with_prefix(step_valsCE4, step, f"{testenv_name}/summary4/") 482 | writer.add_scalars_with_prefix(step_valsCE5, step, f"{testenv_name}/summary5/") 483 | writer.add_scalars_with_prefix(step_valsCE6, step, f"{testenv_name}/summary6/") 484 | 485 | if step%args.inter_freq==0 and step!=0: 486 | if args.algorithm in ['DANN', 'CDANN']: 487 | inter_state_dict = interpolate_algos(algorithmCE1.featurizer.state_dict(), algorithCE2.featurizer.state_dict(), algorithmCE3.featurizer.state_dict(), algorithmCE4.featurizer.state_dict(), algorithmCE5.featurizer.state_dict(), algorithCE6.featurizer.state_dict()) 488 | algorithmCE1.featurizer.load_state_dict(inter_state_dict) 489 | algorithmCE2.featurizer.load_state_dict(inter_state_dict) 490 | algorithmCE3.featurizer.load_state_dict(inter_state_dict) 491 | algorithmCE4.featurizer.load_state_dict(inter_state_dict) 492 | algorithmCE5.featurizer.load_state_dict(inter_state_dict) 493 | algorithmCE6.featurizer.load_state_dict(inter_state_dict) 494 | inter_state_dict2 = interpolate_algos(algorithmCE1.classifier.state_dict(), algorithCE2.classifier.state_dict(), algorithmCE3.classifier.state_dict(), algorithmCE4.classifier.state_dict(), algorithmCE5.classifier.state_dict(), algorithCE6.classifier.state_dict()) 495 | algorithmCE1.classifier.load_state_dict(inter_state_dict) 496 | algorithmCE2.classifier.load_state_dict(inter_state_dict) 497 | algorithmCE3.classifier.load_state_dict(inter_state_dict) 498 | algorithmCE4.classifier.load_state_dict(inter_state_dict) 499 | algorithmCE5.classifier.load_state_dict(inter_state_dict) 500 | algorithmCE6.classifier.load_state_dict(inter_state_dict) 501 | inter_state_dict3 = interpolate_algos(algorithmCE1.discriminator.state_dict(), algorithCE2.discriminator.state_dict(), algorithmCE3.discriminator.state_dict(), algorithmCE4.discriminator.state_dict(), algorithmCE5.discriminator.state_dict(), algorithCE6.discriminator.state_dict()) 502 | algorithmCE1.discriminator.load_state_dict(inter_state_dict) 503 | algorithmCE2.discriminator.load_state_dict(inter_state_dict) 504 | algorithmCE3.discriminator.load_state_dict(inter_state_dict) 505 | algorithmCE4.discriminator.load_state_dict(inter_state_dict) 506 | algorithmCE5.discriminator.load_state_dict(inter_state_dict) 507 | algorithmCE6.discriminator.load_state_dict(inter_state_dict) 508 | elif args.algorithm in ['SagNet']: 509 | inter_state_dict = interpolate_algos(algorithmCE1.network_f.state_dict(), algorithmCE2.network_f.state_dict(), algorithmCE3.network_f.state_dict(), algorithmCE4.network_f.state_dict(), algorithmCE5.network_f.state_dict(), algorithmCE6.network_f.state_dict()) 510 | algorithmCE1.network_f.load_state_dict(inter_state_dict) 511 | algorithmCE2.network_f.load_state_dict(inter_state_dict) 512 | algorithmCE3.network_f.load_state_dict(inter_state_dict) 513 | algorithmCE4.network_f.load_state_dict(inter_state_dict) 514 | algorithmCE5.network_f.load_state_dict(inter_state_dict) 515 | algorithmCE6.network_f.load_state_dict(inter_state_dict) 516 | inter_state_dict2 = interpolate_algos(algorithmCE1.network_c.state_dict(), algorithmCE2.network_c.state_dict(), algorithmCE3.network_c.state_dict(), algorithmCE4.network_c.state_dict(), algorithmCE5.network_c.state_dict(), algorithmCE6.network_c.state_dict()) 517 | algorithmCE1.network_c.load_state_dict(inter_state_dict) 518 | algorithmCE2.network_c.load_state_dict(inter_state_dict) 519 | algorithmCE3.network_c.load_state_dict(inter_state_dict) 520 | algorithmCE4.network_c.load_state_dict(inter_state_dict) 521 | algorithmCE5.network_c.load_state_dict(inter_state_dict) 522 | algorithmCE6.network_c.load_state_dict(inter_state_dict) 523 | inter_state_dict3 = interpolate_algos(algorithmCE1.network_s.state_dict(), algorithmCE2.network_s.state_dict(), algorithmCE3.network_s.state_dict(), algorithmCE4.network_s.state_dict(), algorithmCE5.network_s.state_dict(), algorithmCE6.network_s.state_dict()) 524 | algorithmCE1.network_s.load_state_dict(inter_state_dict) 525 | algorithmCE2.network_s.load_state_dict(inter_state_dict) 526 | algorithmCE3.network_s.load_state_dict(inter_state_dict) 527 | algorithmCE4.network_s.load_state_dict(inter_state_dict) 528 | algorithmCE5.network_s.load_state_dict(inter_state_dict) 529 | algorithmCE6.network_s.load_state_dict(inter_state_dict) 530 | else: 531 | inter_state_dict = interpolate_algos(algorithmCE1.network.state_dict(), algorithmCE2.network.state_dict(), algorithmCE3.network.state_dict(), algorithmCE4.network.state_dict(), algorithmCE5.network.state_dict(), algorithmCE6.network.state_dict()) 532 | algorithmCE1.network.load_state_dict(inter_state_dict) 533 | algorithmCE2.network.load_state_dict(inter_state_dict) 534 | algorithmCE3.network.load_state_dict(inter_state_dict) 535 | algorithmCE4.network.load_state_dict(inter_state_dict) 536 | algorithmCE5.network.load_state_dict(inter_state_dict) 537 | algorithmCE6.network.load_state_dict(inter_state_dict) 538 | 539 | logger.info(f"Evaluating interpolated model at {step} step") 540 | summaries_inter = evaluator.evaluate(algorithmCE1, suffix='_from_inter') 541 | inter_results = {"inter_step": step, "inter_epoch": step / steps_per_epoch} 542 | inter_results_keys = list(summaries_inter.keys()) + list(inter_results.keys()) 543 | inter_results.update(summaries_inter) 544 | logger.info(misc.to_row([inter_results[key] for key in inter_results_keys])) 545 | records_inter.append(copy.deepcopy(inter_results)) 546 | writer.add_scalars_with_prefix(summaries_inter, step, f"{testenv_name}/summary_inter/") 547 | 548 | # find best 549 | logger.info("---") 550 | # print(records) 551 | records = Q(records) 552 | records_inter = Q(records_inter) 553 | 554 | # print(len(records)) 555 | # print(records) 556 | 557 | # 1 558 | oracle_best1 = records.argmax("test_out_1")["test_in_1"] 559 | iid_best1 = records.argmax("comb_val_1")["test_in_1"] 560 | inDom1 = records.argmax("comb_val_1")["comb_val_1"] 561 | # own_best1 = records.argmax("own_val_from_first")["test_in_from_first"] 562 | last1 = records[-1]["test_in_1"] 563 | # 2 564 | oracle_best2 = records.argmax("test_out_2")["test_in_2"] 565 | iid_best2 = records.argmax("comb_val_2")["test_in_2"] 566 | inDom2 = records.argmax("comb_val_2")["comb_val_2"] 567 | # own_best2 = records.argmax("own_val_from_second")["test_in_from_second"] 568 | last2 = records[-1]["test_in_2"] 569 | # 3 570 | oracle_best3 = records.argmax("test_out_3")["test_in_3"] 571 | iid_best3 = records.argmax("comb_val_3")["test_in_3"] 572 | inDom3 = records.argmax("comb_val_3")["comb_val_3"] 573 | # own_best3 = records.argmax("own_val_from_third")["test_in_from_third"] 574 | last3 = records[-1]["test_in_3"] 575 | # CE 576 | oracle_best4 = records.argmax("test_out_4")["test_in_4"] 577 | iid_best4 = records.argmax("comb_val_4")["test_in_4"] 578 | inDom4 = records.argmax("comb_val_4")["comb_val_4"] 579 | last4 = records[-1]["test_in_4"] 580 | 581 | oracle_best5 = records.argmax("test_out_5")["test_in_5"] 582 | iid_best5 = records.argmax("comb_val_5")["test_in_5"] 583 | inDom5 = records.argmax("comb_val_5")["comb_val_5"] 584 | last5 = records[-1]["test_in_5"] 585 | 586 | oracle_best6 = records.argmax("test_out_6")["test_in_6"] 587 | iid_best6 = records.argmax("comb_val_6")["test_in_6"] 588 | inDom6 = records.argmax("comb_val_6")["comb_val_6"] 589 | last6 = records[-1]["test_in_6"] 590 | # inter 591 | oracle_best_inter = records_inter.argmax("test_out_from_inter")["test_in_from_inter"] 592 | iid_best_inter = records_inter.argmax("comb_val_from_inter")["test_in_from_inter"] 593 | inDom_inter = records_inter.argmax("comb_val_from_inter")["comb_val_from_inter"] 594 | 595 | # if hparams.indomain_test: 596 | # # if test set exist, use test set for indomain results 597 | # in_key = "train_inTE" 598 | # else: 599 | # in_key = "train_out" 600 | 601 | # iid_best_indomain = records.argmax("train_out")[in_key] 602 | # last_indomain = records[-1][in_key] 603 | 604 | ret = { 605 | "oracle_1": oracle_best1, 606 | "iid_1": iid_best1, 607 | # "own_1": own_best1, 608 | "inDom1": inDom1, 609 | "last_1": last1, 610 | "oracle_2": oracle_best2, 611 | "iid_2": iid_best2, 612 | # "own_2": own_best2, 613 | "inDom2":inDom2, 614 | "last_2": last2, 615 | "oracle_3": oracle_best3, 616 | "iid_3": iid_best3, 617 | # "own_3": own_best3, 618 | "inDom3":inDom3, 619 | "last_3": last3, 620 | "oracle_4": oracle_best4, 621 | "iid_4": iid_best4, 622 | "inDom4": inDom4, 623 | "last_4": last4, 624 | 625 | "oracle_5": oracle_best5, 626 | "iid_5": iid_best5, 627 | "inDom5": inDom5, 628 | "last_5": last5, 629 | 630 | "oracle_6": oracle_best6, 631 | "iid_6": iid_best6, 632 | "inDom6": inDom6, 633 | "last_6": last6, 634 | # "last (inD)": last_indomain, 635 | # "iid (inD)": iid_best_indomain, 636 | "oracle_inter": oracle_best_inter, 637 | "iid_inter": iid_best_inter, 638 | "inDom_inter":inDom_inter, 639 | } 640 | 641 | # Evaluate SWAD 642 | if swad1: 643 | swad_algorithm1 = swad1.get_final_model() 644 | swad_algorithm2 = swad2.get_final_model() 645 | swad_algorithm3 = swad3.get_final_model() 646 | swad_algorithm4 = swad4.get_final_model() 647 | swad_algorithm5 = swad5.get_final_model() 648 | swad_algorithm6 = swad6.get_final_model() 649 | if hparams["freeze_bn"] is False: 650 | n_steps = 500 if not args.debug else 10 651 | logger.warning(f"Update SWAD BN statistics for {n_steps} steps ...") 652 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm1, n_steps) 653 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm2, n_steps) 654 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm3, n_steps) 655 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm4, n_steps) 656 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm5, n_steps) 657 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm6, n_steps) 658 | 659 | logger.warning("Evaluate SWAD ...") 660 | summaries_swad1 = evaluator.evaluate(swad_algorithm1, suffix='_s1') 661 | summaries_swad2 = evaluator.evaluate(swad_algorithm2, suffix='_s2') 662 | summaries_swad3 = evaluator.evaluate(swad_algorithm3, suffix='_s3') 663 | summaries_swad4 = evaluator.evaluate(swad_algorithm4, suffix='_s4') 664 | summaries_swad5 = evaluator.evaluate(swad_algorithm5, suffix='_s5') 665 | summaries_swad6 = evaluator.evaluate(swad_algorithm6, suffix='_s6') 666 | # accuracies, summaries = evaluator.evaluate(swad_algorithm) 667 | 668 | # results = {**summaries, **accuracies} 669 | # start = swad_algorithm.start_step 670 | # end = swad_algorithm.end_step 671 | # step_str = f" [{start}-{end}] (N={swad_algorithm.n_averaged})" 672 | # row = misc.to_row([results[key] for key in results_keys if key in results]) + step_str 673 | # logger.info(row) 674 | 675 | swad_results = {**summaries_swad1, **summaries_swad2, **summaries_swad3, **summaries_swad4, **summaries_swad5, **summaries_swad6} 676 | step_str = f" [{swad_algorithm1.start_step}-{swad_algorithm1.end_step}] (N={swad_algorithm1.n_averaged}) || [{swad_algorithm2.start_step}-{swad_algorithm2.end_step}] (N={swad_algorithm2.n_averaged}) || [{swad_algorithm3.start_step}-{swad_algorithm3.end_step}] (N={swad_algorithm3.n_averaged}) || [{swad_algorithm4.start_step}-{swad_algorithm4.end_step}] (N={swad_algorithm4.n_averaged}) || [{swad_algorithm5.start_step}-{swad_algorithm5.end_step}] (N={swad_algorithm5.n_averaged}) || [{swad_algorithm6.start_step}-{swad_algorithm6.end_step}] (N={swad_algorithm6.n_averaged})" 677 | row = misc.to_row([swad_results[key] for key in list(swad_results.keys())]) + step_str 678 | logger.info(row) 679 | 680 | ret["SWAD 1"] = swad_results["test_in_s1"] 681 | ret["SWAD 1 (inDom)"] = swad_results["comb_val_s1"] 682 | ret["SWAD 2"] = swad_results["test_in_s2"] 683 | ret["SWAD 2 (inDom)"] = swad_results["comb_val_s2"] 684 | ret["SWAD 3"] = swad_results["test_in_s3"] 685 | ret["SWAD 3 (inDom)"] = swad_results["comb_val_s3"] 686 | ret["SWAD 4"] = swad_results["test_in_s4"] 687 | ret["SWAD 4 (inDom)"] = swad_results["comb_val_s4"] 688 | ret["SWAD 5"] = swad_results["test_in_s5"] 689 | ret["SWAD 5 (inDom)"] = swad_results["comb_val_s5"] 690 | ret["SWAD 6"] = swad_results["test_in_s6"] 691 | ret["SWAD 6 (inDom)"] = swad_results["comb_val_s6"] 692 | 693 | save_dict = { 694 | "args": vars(args), 695 | "model_hparams": dict(hparams), 696 | "test_envs": test_envs, 697 | "SWAD_1": swad_algorithm1.network.state_dict(), 698 | "SWAD_2": swad_algorithm2.network.state_dict(), 699 | "SWAD_3": swad_algorithm3.network.state_dict(), 700 | "SWAD_4": swad_algorithm4.network.state_dict(), 701 | "SWAD_5": swad_algorithm5.network.state_dict(), 702 | "SWAD_6": swad_algorithm6.network.state_dict(), 703 | } 704 | 705 | if args.algorithm in ['DANN', 'CDANN']: 706 | inter_state_dict = interpolate_algos(swad_algorithm1.module.featurizer.state_dict(), swad_algorithm2.module.featurizer.state_dict(), swad_algorithm3.module.featurizer.state_dict(), swad_algorithm4.module.featurizer.state_dict(), swad_algorithm5.module.featurizer.state_dict(), swad_algorithm6.module.featurizer.state_dict()) 707 | swad_algorithm1.module.featurizer.load_state_dict(inter_state_dict) 708 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.classifier.state_dict(), swad_algorithm2.module.classifier.state_dict(), swad_algorithm3.module.classifier.state_dict(), swad_algorithm4.module.classifier.state_dict(), swad_algorithm5.module.classifier.state_dict(), swad_algorithm6.module.classifier.state_dict()) 709 | swad_algorithm1.module.classifier.load_state_dict(inter_state_dict2) 710 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.discriminator.state_dict(), swad_algorithm2.module.discriminator.state_dict(), swad_algorithm3.module.discriminator.state_dict(), swad_algorithm4.module.discriminator.state_dict(), swad_algorithm5.module.discriminator.state_dict(), swad_algorithm6.module.discriminator.state_dict()) 711 | swad_algorithm1.module.discriminator.load_state_dict(inter_state_dict3) 712 | 713 | elif args.algorithm in ['SagNet']: 714 | inter_state_dict = interpolate_algos(swad_algorithm1.module.network_f.state_dict(), swad_algorithm2.module.network_f.state_dict(), swad_algorithm3.module.network_f.state_dict(), swad_algorithm4.module.network_f.state_dict(), swad_algorithm5.module.network_f.state_dict(), swad_algorithm6.module.network_f.state_dict()) 715 | swad_algorithm1.module.network_f.load_state_dict(inter_state_dict) 716 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.network_c.state_dict(), swad_algorithm2.module.network_c.state_dict(), swad_algorithm3.module.network_c.state_dict(), swad_algorithm4.module.network_c.state_dict(), swad_algorithm5.module.network_c.state_dict(), swad_algorithm6.module.network_c.state_dict()) 717 | swad_algorithm1.module.network_c.load_state_dict(inter_state_dict2) 718 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.network_s.state_dict(), swad_algorithm2.module.network_s.state_dict(), swad_algorithm3.module.network_s.state_dict(), swad_algorithm4.module.network_s.state_dict(), swad_algorithm5.module.network_s.state_dict(), swad_algorithm6.module.network_s.state_dict()) 719 | swad_algorithm1.module.network_s.load_state_dict(inter_state_dict3) 720 | else: 721 | inter_state_dict = interpolate_algos(swad_algorithm1.network.state_dict(), swad_algorithm2.network.state_dict(), swad_algorithm3.network.state_dict(), swad_algorithm4.network.state_dict(), swad_algorithm5.network.state_dict(), swad_algorithm6.network.state_dict()) 722 | swad_algorithm1.network.load_state_dict(inter_state_dict) 723 | 724 | logger.info(f"Evaluating interpolated model of SWAD models") 725 | summaries_swadinter = evaluator.evaluate(swad_algorithm1, suffix='_from_swadinter') 726 | swadinter_results = {**summaries_swadinter} 727 | logger.info(misc.to_row([swadinter_results[key] for key in list(swadinter_results.keys())])) 728 | ret["SWAD INTER"] = swadinter_results["test_in_from_swadinter"] 729 | ret["SWAD INTER (inDom)"] = swadinter_results["comb_val_from_swadinter"] 730 | save_dict["SWAD_INTER"] = inter_state_dict 731 | 732 | 733 | ckpt_dir = args.out_dir / "checkpoints" 734 | ckpt_dir.mkdir(exist_ok=True) 735 | test_env_str = ",".join(map(str, test_envs)) 736 | filename = f"TE{test_env_str}.pth" 737 | if len(test_envs) > 1 and target_env is not None: 738 | train_env_str = ",".join(map(str, train_envs)) 739 | filename = f"TE{target_env}_TR{train_env_str}.pth" 740 | path = ckpt_dir / filename 741 | if not args.debug: 742 | torch.save(save_dict, path) 743 | 744 | 745 | for k, acc in ret.items(): 746 | logger.info(f"{k} = {acc:.3%}") 747 | 748 | return ret, records 749 | 750 | -------------------------------------------------------------------------------- /media/DART_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/DART_pic.png -------------------------------------------------------------------------------- /media/DG_combined_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/DG_combined_results.png -------------------------------------------------------------------------------- /media/DG_main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/DG_main_results.png -------------------------------------------------------------------------------- /media/ID_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/ID_results.png -------------------------------------------------------------------------------- /media/model_optimization_trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/model_optimization_trajectory.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown==4.2.0 2 | numpy==1.21.4 3 | Pillow==9.0.1 4 | prettytable==2.1.0 5 | sconf==0.2.3 6 | tensorboardX==2.5 7 | torch==1.7.1 8 | torchvision==0.8.2 9 | git+https://github.com/openai/CLIP.git 10 | -------------------------------------------------------------------------------- /train_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import random 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import PIL 9 | import torch 10 | import torchvision 11 | from sconf import Config 12 | from prettytable import PrettyTable 13 | 14 | from domainbed.datasets import get_dataset 15 | from domainbed import hparams_registry 16 | from domainbed.lib import misc 17 | from domainbed.lib.writers import get_writer 18 | from domainbed.lib.logger import Logger 19 | from domainbed.trainer import train 20 | from domainbed.trainer_DN import train as train_dn 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description="Domain generalization", allow_abbrev=False) 25 | parser.add_argument("name", type=str) 26 | parser.add_argument("configs", nargs="*") 27 | parser.add_argument("--data_dir", type=str, default="datadir/") 28 | parser.add_argument("--dataset", type=str, default="PACS") 29 | parser.add_argument("--algorithm", type=str, default="ERM") 30 | parser.add_argument( 31 | "--trial_seed", 32 | type=int, 33 | default=0, 34 | help="Trial number (used for seeding split_dataset and random_hparams).", 35 | ) 36 | parser.add_argument("--seed", type=int, default=0, help="Seed for everything else") 37 | parser.add_argument( 38 | "--steps", type=int, default=None, help="Number of steps. Default is dataset-dependent." 39 | ) 40 | parser.add_argument( 41 | "--checkpoint_freq", 42 | type=int, 43 | default=None, 44 | help="Checkpoint every N steps. Default is dataset-dependent.", 45 | ) 46 | parser.add_argument("--test_envs", type=int, nargs="+", default=None) 47 | parser.add_argument("--holdout_fraction", type=float, default=0.2) 48 | parser.add_argument("--model_save", default=None, type=int, help="Model save start step") 49 | # parser.add_argument("--deterministic", action="store_true") 50 | parser.add_argument("--tb_freq", default=10) 51 | parser.add_argument("--debug", action="store_true", help="Run w/ debug mode") 52 | parser.add_argument("--show", action="store_true", help="Show args and hparams w/o run") 53 | parser.add_argument( 54 | "--evalmode", 55 | default="fast", 56 | help="[fast, all]. if fast, ignore train_in datasets in evaluation time.", 57 | ) 58 | parser.add_argument("--prebuild_loader", action="store_true", help="Pre-build eval loaders") 59 | parser.add_argument("--inter_freq", type=int, default=600, help="interpolate after inter_freq steps") 60 | args, left_argv = parser.parse_known_args() 61 | args.deterministic = True 62 | 63 | # setup hparams 64 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 65 | 66 | keys = ["config.yaml"] + args.configs 67 | keys = [open(key, encoding="utf8") for key in keys] 68 | hparams = Config(*keys, default=hparams) 69 | hparams.argv_update(left_argv) 70 | 71 | # setup debug 72 | if args.debug: 73 | args.checkpoint_freq = 5 74 | args.steps = 10 75 | args.name += "_debug" 76 | 77 | timestamp = misc.timestamp() 78 | args.unique_name = f"{timestamp}_{args.name}" 79 | 80 | # path setup 81 | args.work_dir = Path(".") 82 | args.data_dir = Path(args.data_dir) 83 | 84 | args.out_root = args.work_dir / Path("train_output") / args.dataset 85 | args.out_dir = args.out_root / args.unique_name 86 | args.out_dir.mkdir(exist_ok=True, parents=True) 87 | 88 | writer = get_writer(args.out_root / "runs" / args.unique_name) 89 | logger = Logger.get(args.out_dir / "log.txt") 90 | if args.debug: 91 | logger.setLevel("DEBUG") 92 | cmd = " ".join(sys.argv) 93 | logger.info(f"Command :: {cmd}") 94 | 95 | logger.nofmt("Environment:") 96 | logger.nofmt("\tPython: {}".format(sys.version.split(" ")[0])) 97 | logger.nofmt("\tPyTorch: {}".format(torch.__version__)) 98 | logger.nofmt("\tTorchvision: {}".format(torchvision.__version__)) 99 | logger.nofmt("\tCUDA: {}".format(torch.version.cuda)) 100 | logger.nofmt("\tCUDNN: {}".format(torch.backends.cudnn.version())) 101 | logger.nofmt("\tNumPy: {}".format(np.__version__)) 102 | logger.nofmt("\tPIL: {}".format(PIL.__version__)) 103 | 104 | # Different to DomainBed, we support CUDA only. 105 | assert torch.cuda.is_available(), "CUDA is not available" 106 | 107 | logger.nofmt("Args:") 108 | for k, v in sorted(vars(args).items()): 109 | logger.nofmt("\t{}: {}".format(k, v)) 110 | 111 | logger.nofmt("HParams:") 112 | for line in hparams.dumps().split("\n"): 113 | logger.nofmt("\t" + line) 114 | 115 | if args.show: 116 | exit() 117 | 118 | # seed 119 | random.seed(args.seed) 120 | np.random.seed(args.seed) 121 | torch.manual_seed(args.seed) 122 | torch.backends.cudnn.deterministic = args.deterministic 123 | torch.backends.cudnn.benchmark = not args.deterministic 124 | 125 | # Dummy datasets for logging information. 126 | # Real dataset will be re-assigned in train function. 127 | # test_envs only decide transforms; simply set to zero. 128 | dataset, _in_splits, _out_splits = get_dataset([0], args, hparams) 129 | 130 | # print dataset information 131 | logger.nofmt("Dataset:") 132 | logger.nofmt(f"\t[{args.dataset}] #envs={len(dataset)}, #classes={dataset.num_classes}") 133 | for i, env_property in enumerate(dataset.environments): 134 | logger.nofmt(f"\tenv{i}: {env_property} (#{len(dataset[i])})") 135 | logger.nofmt("") 136 | 137 | n_steps = args.steps or dataset.N_STEPS 138 | checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ 139 | logger.info(f"n_steps = {n_steps}") 140 | logger.info(f"checkpoint_freq = {checkpoint_freq}") 141 | 142 | org_n_steps = n_steps 143 | n_steps = (n_steps // checkpoint_freq) * checkpoint_freq + 1 144 | logger.info(f"n_steps is updated to {org_n_steps} => {n_steps} for checkpointing") 145 | 146 | if not args.test_envs: 147 | args.test_envs = [[te] for te in range(len(dataset))] 148 | logger.info(f"Target test envs = {args.test_envs}") 149 | 150 | ########################################################################### 151 | # Run 152 | ########################################################################### 153 | all_records = [] 154 | results = collections.defaultdict(list) 155 | 156 | for test_env in args.test_envs: 157 | if args.dataset=="DomainNet": 158 | print("===== DN ======") 159 | res, records = train_dn( 160 | test_env, 161 | args=args, 162 | hparams=hparams, 163 | n_steps=n_steps, 164 | checkpoint_freq=checkpoint_freq, 165 | logger=logger, 166 | writer=writer, 167 | ) 168 | else: 169 | print("===== others ======") 170 | res, records = train( 171 | test_env, 172 | args=args, 173 | hparams=hparams, 174 | n_steps=n_steps, 175 | checkpoint_freq=checkpoint_freq, 176 | logger=logger, 177 | writer=writer, 178 | ) 179 | all_records.append(records) 180 | for k, v in res.items(): 181 | results[k].append(v) 182 | 183 | # log summary table 184 | logger.info("=== Summary ===") 185 | logger.info(f"Command: {' '.join(sys.argv)}") 186 | logger.info("Unique name: %s" % args.unique_name) 187 | logger.info("Out path: %s" % args.out_dir) 188 | logger.info("Algorithm: %s" % args.algorithm) 189 | logger.info("Dataset: %s" % args.dataset) 190 | 191 | table = PrettyTable(["Selection"] + dataset.environments + ["Avg."]) 192 | for key, row in results.items(): 193 | row.append(np.mean(row)) 194 | row = [f"{acc:.3%}" for acc in row] 195 | table.add_row([key] + row) 196 | logger.nofmt(table) 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | --------------------------------------------------------------------------------