├── LICENSE
├── README.md
├── config.yaml
├── domainbed
├── algorithms
│ ├── __init__.py
│ ├── algorithms.py
│ └── miro.py
├── datasets
│ ├── __init__.py
│ ├── datasets.py
│ └── transforms.py
├── evaluator.py
├── hparams_registry.py
├── lib
│ ├── fast_data_loader.py
│ ├── logger.py
│ ├── misc.py
│ ├── query.py
│ ├── swa_utils.py
│ ├── wide_resnet.py
│ └── writers.py
├── misc
│ └── domain_net_duplicates.txt
├── models
│ ├── mixstyle.py
│ ├── resnet_mixstyle.py
│ └── resnet_mixstyle2.py
├── networks
│ ├── __init__.py
│ ├── backbones.py
│ ├── networks.py
│ └── ur_networks.py
├── optimizers.py
├── scripts
│ └── download.py
├── swad.py
├── trainer.py
└── trainer_DN.py
├── media
├── DART_pic.png
├── DG_combined_results.png
├── DG_main_results.png
├── ID_results.png
└── model_optimization_trajectory.gif
├── requirements.txt
└── train_all.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Video Analytics Lab -- IISc
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DART: Diversify-Aggregate-Repeat Training
2 | This repository contains codes for the training and evaluation of our CVPR-23 paper DART:Diversify-Aggregate-Repeat Training Improves Generalization of Neural Networks [main](https://openaccess.thecvf.com/content/CVPR2023/papers/Jain_DART_Diversify-Aggregate-Repeat_Training_Improves_Generalization_of_Neural_Networks_CVPR_2023_paper.pdf) and [supplementary](https://openaccess.thecvf.com/content/CVPR2023/supplemental/Jain_DART_Diversify-Aggregate-Repeat_Training_CVPR_2023_supplemental.pdf). The arxiv link for the paper is also [available](https://arxiv.org/pdf/2302.14685.pdf).
3 |
4 |
5 |
6 |
7 |
8 |
9 | # Environment Settings
10 | * Python 3.6.9
11 | * PyTorch 1.8
12 | * Torchvision 0.8.0
13 | * Numpy 1.19.2
14 |
15 |
16 |
17 | # Training
18 | For training DART on Domain Generalization task:
19 | ```
20 | python train_all.py [name_of_exp] --data_dir ./path/to/data --algorithm ERM --dataset PACS --inter_freq 1000 --steps 10001
21 | ```
22 | ### Combine with SWAD
23 | set `swad: True` in config.yaml file or pass `--swad True` in the python command.
24 | ### Changing Model & Hyperparams
25 | Similarly, to change the model (eg- VIT), swad hyperparameters or MIRO hyperparams, you can update ```config.yaml``` file or pass it as argument in the python command.
26 | ```
27 | python train_all.py [name_of_exp] --data_dir ./path/to/data \
28 | --lr 3e-5 \
29 | --inter_freq 600 \
30 | --steps 8001 \
31 | --dataset OfficeHome \
32 | --algorithm MIRO \
33 | --ld 0.1 \
34 | --weight_decay 1e-6 \
35 | --swad True \
36 | --model clip_vit-b16
37 | ```
38 |
39 | # Results
40 | ## In-Domain Generalization of DART:
41 |
42 |
43 |
44 |
45 | ## Domain Generalization of DART:
46 |
47 |
48 |
49 |
50 | ## Combining DART with other DG methods on Office-Home:
51 |
52 |
53 |
54 |
55 |
56 |
57 | # Citing this work
58 | ```
59 | @inproceedings{jain2023dart,
60 | title={DART: Diversify-Aggregate-Repeat Training Improves Generalization of Neural Networks},
61 | author={Jain, Samyak and Addepalli, Sravanti and Sahu, Pawan Kumar and Dey, Priyam and Babu, R Venkatesh},
62 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
63 | pages={16048--16059},
64 | year={2023}
65 | }
66 | ```
67 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # Default update config
2 | # Config order: hparams_registry -> config.yaml -> CLI
3 | swad: True # True / False
4 | swad_kwargs:
5 | n_converge: 3
6 | n_tolerance: 6
7 | tolerance_ratio: 0.3
8 | test_batchsize: 128
9 |
10 | # resnet50, resnet50_barlowtwins, resnet50_moco, clip_resnet50, clip_vit-b16, swag_regnety_16gf
11 | model: resnet50
12 | feat_layers: stem_block
13 |
14 | # MIRO params
15 | ld: 0.1 # lambda
16 | lr_mult: 10.
17 |
--------------------------------------------------------------------------------
/domainbed/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | from .algorithms import *
2 | from .miro import MIRO
3 |
4 |
5 | def get_algorithm_class(algorithm_name):
6 | """Return the algorithm class with the given name."""
7 | if algorithm_name not in globals():
8 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
9 | return globals()[algorithm_name]
10 |
--------------------------------------------------------------------------------
/domainbed/algorithms/algorithms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import copy
4 | from typing import List
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.autograd as autograd
10 | import numpy as np
11 |
12 | # import higher
13 |
14 | from domainbed import networks
15 | from domainbed.lib.misc import random_pairs_of_minibatches
16 | from domainbed.optimizers import get_optimizer
17 |
18 | from domainbed.models.resnet_mixstyle import (
19 | resnet18_mixstyle_L234_p0d5_a0d1,
20 | resnet50_mixstyle_L234_p0d5_a0d1,
21 | )
22 | from domainbed.models.resnet_mixstyle2 import (
23 | resnet18_mixstyle2_L234_p0d5_a0d1,
24 | resnet50_mixstyle2_L234_p0d5_a0d1,
25 | )
26 |
27 |
28 | def to_minibatch(x, y):
29 | minibatches = list(zip(x, y))
30 | return minibatches
31 |
32 |
33 | class Algorithm(torch.nn.Module):
34 | """
35 | A subclass of Algorithm implements a domain generalization algorithm.
36 | Subclasses should implement the following:
37 | - update()
38 | - predict()
39 | """
40 |
41 | transforms = {}
42 |
43 | def __init__(self, input_shape, num_classes, num_domains, hparams):
44 | super(Algorithm, self).__init__()
45 | self.input_shape = input_shape
46 | self.num_classes = num_classes
47 | self.num_domains = num_domains
48 | self.hparams = hparams
49 |
50 | def update(self, x, y, **kwargs):
51 | """
52 | Perform one update step, given a list of (x, y) tuples for all
53 | environments.
54 | """
55 | raise NotImplementedError
56 |
57 | def predict(self, x):
58 | raise NotImplementedError
59 |
60 | def forward(self, x):
61 | return self.predict(x)
62 |
63 | def new_optimizer(self, parameters):
64 | optimizer = get_optimizer(
65 | self.hparams["optimizer"],
66 | parameters,
67 | lr=self.hparams["lr"],
68 | weight_decay=self.hparams["weight_decay"],
69 | )
70 | return optimizer
71 |
72 | def clone(self):
73 | clone = copy.deepcopy(self)
74 | clone.optimizer = self.new_optimizer(clone.network.parameters())
75 | clone.optimizer.load_state_dict(self.optimizer.state_dict())
76 |
77 | return clone
78 |
79 |
80 | class ERM(Algorithm):
81 | """
82 | Empirical Risk Minimization (ERM)
83 | """
84 |
85 | def __init__(self, input_shape, num_classes, num_domains, hparams):
86 | super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams)
87 | self.featurizer = networks.Featurizer(input_shape, self.hparams)
88 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
89 | self.network = nn.Sequential(self.featurizer, self.classifier)
90 | self.optimizer = get_optimizer(
91 | hparams["optimizer"],
92 | self.network.parameters(),
93 | lr=self.hparams["lr"],
94 | weight_decay=self.hparams["weight_decay"],
95 | )
96 |
97 | def update(self, x, y, **kwargs):
98 | all_x = torch.cat(x)
99 | all_y = torch.cat(y)
100 | loss = F.cross_entropy(self.predict(all_x), all_y)
101 |
102 | self.optimizer.zero_grad()
103 | loss.backward()
104 | self.optimizer.step()
105 |
106 | return {"loss": loss.item()}
107 |
108 | def predict(self, x):
109 | return self.network(x)
110 |
111 |
112 | class Mixstyle(Algorithm):
113 | """MixStyle w/o domain label (random shuffle)"""
114 |
115 | def __init__(self, input_shape, num_classes, num_domains, hparams):
116 | assert input_shape[1:3] == (224, 224), "Mixstyle support R18 and R50 only"
117 | super().__init__(input_shape, num_classes, num_domains, hparams)
118 | if hparams["resnet18"]:
119 | network = resnet18_mixstyle_L234_p0d5_a0d1()
120 | else:
121 | network = resnet50_mixstyle_L234_p0d5_a0d1()
122 | self.featurizer = networks.ResNet(input_shape, self.hparams, network)
123 |
124 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
125 | self.network = nn.Sequential(self.featurizer, self.classifier)
126 | self.optimizer = self.new_optimizer(self.network.parameters())
127 |
128 | def update(self, x, y, **kwargs):
129 | all_x = torch.cat(x)
130 | all_y = torch.cat(y)
131 | loss = F.cross_entropy(self.predict(all_x), all_y)
132 |
133 | self.optimizer.zero_grad()
134 | loss.backward()
135 | self.optimizer.step()
136 |
137 | return {"loss": loss.item()}
138 |
139 | def predict(self, x):
140 | return self.network(x)
141 |
142 |
143 | class ARM(ERM):
144 | """Adaptive Risk Minimization (ARM)"""
145 |
146 | def __init__(self, input_shape, num_classes, num_domains, hparams):
147 | original_input_shape = input_shape
148 | input_shape = (1 + original_input_shape[0],) + original_input_shape[1:]
149 | super(ARM, self).__init__(input_shape, num_classes, num_domains, hparams)
150 | self.context_net = networks.ContextNet(original_input_shape)
151 | self.support_size = hparams["batch_size"]
152 |
153 | def predict(self, x):
154 | batch_size, c, h, w = x.shape
155 | if batch_size % self.support_size == 0:
156 | meta_batch_size = batch_size // self.support_size
157 | support_size = self.support_size
158 | else:
159 | meta_batch_size, support_size = 1, batch_size
160 | context = self.context_net(x)
161 | context = context.reshape((meta_batch_size, support_size, 1, h, w))
162 | context = context.mean(dim=1)
163 | context = torch.repeat_interleave(context, repeats=support_size, dim=0)
164 | x = torch.cat([x, context], dim=1)
165 | return self.network(x)
166 |
167 |
168 | class SAM(ERM):
169 | """Sharpness-Aware Minimization
170 | """
171 | @staticmethod
172 | def norm(tensor_list: List[torch.tensor], p=2):
173 | """Compute p-norm for tensor list"""
174 | return torch.cat([x.flatten() for x in tensor_list]).norm(p)
175 |
176 | def update(self, x, y, **kwargs):
177 | all_x = torch.cat([xi for xi in x])
178 | all_y = torch.cat([yi for yi in y])
179 | loss = F.cross_entropy(self.predict(all_x), all_y)
180 |
181 | # 1. eps(w) = rho * g(w) / g(w).norm(2)
182 | # = (rho / g(w).norm(2)) * g(w)
183 | grad_w = autograd.grad(loss, self.network.parameters())
184 | scale = self.hparams["rho"] / self.norm(grad_w)
185 | eps = [g * scale for g in grad_w]
186 |
187 | # 2. w' = w + eps(w)
188 | with torch.no_grad():
189 | for p, v in zip(self.network.parameters(), eps):
190 | p.add_(v)
191 |
192 | # 3. w = w - lr * g(w')
193 | loss = F.cross_entropy(self.predict(all_x), all_y)
194 |
195 | self.optimizer.zero_grad()
196 | loss.backward()
197 | # restore original network params
198 | with torch.no_grad():
199 | for p, v in zip(self.network.parameters(), eps):
200 | p.sub_(v)
201 | self.optimizer.step()
202 |
203 | return {"loss": loss.item()}
204 |
205 |
206 | class AbstractDANN(Algorithm):
207 | """Domain-Adversarial Neural Networks (abstract class)"""
208 |
209 | def __init__(self, input_shape, num_classes, num_domains, hparams, conditional, class_balance):
210 |
211 | super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains, hparams)
212 |
213 | self.register_buffer("update_count", torch.tensor([0]))
214 | self.conditional = conditional
215 | self.class_balance = class_balance
216 |
217 | # Algorithms
218 | self.featurizer = networks.Featurizer(input_shape, self.hparams)
219 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
220 | self.discriminator = networks.MLP(self.featurizer.n_outputs, num_domains, self.hparams)
221 | self.class_embeddings = nn.Embedding(num_classes, self.featurizer.n_outputs)
222 |
223 | # Optimizers
224 | self.disc_opt = get_optimizer(
225 | hparams["optimizer"],
226 | (list(self.discriminator.parameters()) + list(self.class_embeddings.parameters())),
227 | lr=self.hparams["lr_d"],
228 | weight_decay=self.hparams["weight_decay_d"],
229 | betas=(self.hparams["beta1"], 0.9),
230 | )
231 |
232 | self.gen_opt = get_optimizer(
233 | hparams["optimizer"],
234 | (list(self.featurizer.parameters()) + list(self.classifier.parameters())),
235 | lr=self.hparams["lr_g"],
236 | weight_decay=self.hparams["weight_decay_g"],
237 | betas=(self.hparams["beta1"], 0.9),
238 | )
239 |
240 | def update(self, x, y, **kwargs):
241 | self.update_count += 1
242 | all_x = torch.cat([xi for xi in x])
243 | all_y = torch.cat([yi for yi in y])
244 | minibatches = to_minibatch(x, y)
245 | all_z = self.featurizer(all_x)
246 | if self.conditional:
247 | disc_input = all_z + self.class_embeddings(all_y)
248 | else:
249 | disc_input = all_z
250 | disc_out = self.discriminator(disc_input)
251 | disc_labels = torch.cat(
252 | [
253 | torch.full((x.shape[0],), i, dtype=torch.int64, device="cuda")
254 | for i, (x, y) in enumerate(minibatches)
255 | ]
256 | )
257 |
258 | if self.class_balance:
259 | y_counts = F.one_hot(all_y).sum(dim=0)
260 | weights = 1.0 / (y_counts[all_y] * y_counts.shape[0]).float()
261 | disc_loss = F.cross_entropy(disc_out, disc_labels, reduction="none")
262 | disc_loss = (weights * disc_loss).sum()
263 | else:
264 | disc_loss = F.cross_entropy(disc_out, disc_labels)
265 |
266 | disc_softmax = F.softmax(disc_out, dim=1)
267 | input_grad = autograd.grad(
268 | disc_softmax[:, disc_labels].sum(), [disc_input], create_graph=True
269 | )[0]
270 | grad_penalty = (input_grad ** 2).sum(dim=1).mean(dim=0)
271 | disc_loss += self.hparams["grad_penalty"] * grad_penalty
272 |
273 | d_steps_per_g = self.hparams["d_steps_per_g_step"]
274 | if self.update_count.item() % (1 + d_steps_per_g) < d_steps_per_g:
275 |
276 | self.disc_opt.zero_grad()
277 | disc_loss.backward()
278 | self.disc_opt.step()
279 | return {"disc_loss": disc_loss.item()}
280 | else:
281 | all_preds = self.classifier(all_z)
282 | classifier_loss = F.cross_entropy(all_preds, all_y)
283 | gen_loss = classifier_loss + (self.hparams["lambda"] * -disc_loss)
284 | self.disc_opt.zero_grad()
285 | self.gen_opt.zero_grad()
286 | gen_loss.backward()
287 | self.gen_opt.step()
288 | return {"gen_loss": gen_loss.item()}
289 |
290 | def predict(self, x):
291 | return self.classifier(self.featurizer(x))
292 |
293 |
294 | class DANN(AbstractDANN):
295 | """Unconditional DANN"""
296 |
297 | def __init__(self, input_shape, num_classes, num_domains, hparams):
298 | super(DANN, self).__init__(
299 | input_shape,
300 | num_classes,
301 | num_domains,
302 | hparams,
303 | conditional=False,
304 | class_balance=False,
305 | )
306 |
307 |
308 | class CDANN(AbstractDANN):
309 | """Conditional DANN"""
310 |
311 | def __init__(self, input_shape, num_classes, num_domains, hparams):
312 | super(CDANN, self).__init__(
313 | input_shape,
314 | num_classes,
315 | num_domains,
316 | hparams,
317 | conditional=True,
318 | class_balance=True,
319 | )
320 |
321 |
322 | class OrgMixup(ERM):
323 | """
324 | Original Mixup independent with domains
325 | """
326 |
327 | def update(self, x, y, **kwargs):
328 | x = torch.cat(x)
329 | y = torch.cat(y)
330 |
331 | indices = torch.randperm(x.size(0))
332 | x2 = x[indices]
333 | y2 = y[indices]
334 |
335 | lam = np.random.beta(self.hparams["mixup_alpha"], self.hparams["mixup_alpha"])
336 |
337 | x = lam * x + (1 - lam) * x2
338 | predictions = self.predict(x)
339 |
340 | objective = lam * F.cross_entropy(predictions, y)
341 | objective += (1 - lam) * F.cross_entropy(predictions, y2)
342 |
343 | self.optimizer.zero_grad()
344 | objective.backward()
345 | self.optimizer.step()
346 |
347 | return {"loss": objective.item()}
348 |
349 |
350 | class CutMix(ERM):
351 | @staticmethod
352 | def rand_bbox(size, lam):
353 | W = size[2]
354 | H = size[3]
355 | cut_rat = np.sqrt(1.0 - lam)
356 | cut_w = np.int(W * cut_rat)
357 | cut_h = np.int(H * cut_rat)
358 |
359 | # uniform
360 | cx = np.random.randint(W)
361 | cy = np.random.randint(H)
362 |
363 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
364 | bby1 = np.clip(cy - cut_h // 2, 0, H)
365 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
366 | bby2 = np.clip(cy + cut_h // 2, 0, H)
367 |
368 | return bbx1, bby1, bbx2, bby2
369 |
370 | def update(self, x, y, **kwargs):
371 | # cutmix_prob is set to 1.0 for ImageNet and 0.5 for CIFAR100 in the original paper.
372 | x = torch.cat(x)
373 | y = torch.cat(y)
374 |
375 | r = np.random.rand(1)
376 | if self.hparams["beta"] > 0 and r < self.hparams["cutmix_prob"]:
377 | # generate mixed sample
378 | beta = self.hparams["beta"]
379 | lam = np.random.beta(beta, beta)
380 | rand_index = torch.randperm(x.size()[0]).cuda()
381 | target_a = y
382 | target_b = y[rand_index]
383 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.size(), lam)
384 | x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2]
385 | # adjust lambda to exactly match pixel ratio
386 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
387 | # compute output
388 | output = self.predict(x)
389 | objective = F.cross_entropy(output, target_a) * lam + F.cross_entropy(
390 | output, target_b
391 | ) * (1.0 - lam)
392 | else:
393 | output = self.predict(x)
394 | objective = F.cross_entropy(output, y)
395 |
396 | self.optimizer.zero_grad()
397 | objective.backward()
398 | self.optimizer.step()
399 |
400 | return {"loss": objective.item()}
401 |
402 |
403 | class SagNet(Algorithm):
404 | """
405 | Style Agnostic Network
406 | Algorithm 1 from: https://arxiv.org/abs/1910.11645
407 | """
408 |
409 | def __init__(self, input_shape, num_classes, num_domains, hparams):
410 | super(SagNet, self).__init__(input_shape, num_classes, num_domains, hparams)
411 | # featurizer network
412 | self.network_f = networks.Featurizer(input_shape, self.hparams)
413 | # content network
414 | self.network_c = nn.Linear(self.network_f.n_outputs, num_classes)
415 | # style network
416 | self.network_s = nn.Linear(self.network_f.n_outputs, num_classes)
417 |
418 | # # This commented block of code implements something closer to the
419 | # # original paper, but is specific to ResNet and puts in disadvantage
420 | # # the other algorithms.
421 | # resnet_c = networks.Featurizer(input_shape, self.hparams)
422 | # resnet_s = networks.Featurizer(input_shape, self.hparams)
423 | # # featurizer network
424 | # self.network_f = torch.nn.Sequential(
425 | # resnet_c.network.conv1,
426 | # resnet_c.network.bn1,
427 | # resnet_c.network.relu,
428 | # resnet_c.network.maxpool,
429 | # resnet_c.network.layer1,
430 | # resnet_c.network.layer2,
431 | # resnet_c.network.layer3)
432 | # # content network
433 | # self.network_c = torch.nn.Sequential(
434 | # resnet_c.network.layer4,
435 | # resnet_c.network.avgpool,
436 | # networks.Flatten(),
437 | # resnet_c.network.fc)
438 | # # style network
439 | # self.network_s = torch.nn.Sequential(
440 | # resnet_s.network.layer4,
441 | # resnet_s.network.avgpool,
442 | # networks.Flatten(),
443 | # resnet_s.network.fc)
444 |
445 | def opt(p):
446 | return get_optimizer(
447 | hparams["optimizer"], p, lr=hparams["lr"], weight_decay=hparams["weight_decay"]
448 | )
449 |
450 | self.optimizer_f = opt(self.network_f.parameters())
451 | self.optimizer_c = opt(self.network_c.parameters())
452 | self.optimizer_s = opt(self.network_s.parameters())
453 | self.weight_adv = hparams["sag_w_adv"]
454 |
455 | def forward_c(self, x):
456 | # learning content network on randomized style
457 | return self.network_c(self.randomize(self.network_f(x), "style"))
458 |
459 | def forward_s(self, x):
460 | # learning style network on randomized content
461 | return self.network_s(self.randomize(self.network_f(x), "content"))
462 |
463 | def randomize(self, x, what="style", eps=1e-5):
464 | sizes = x.size()
465 | alpha = torch.rand(sizes[0], 1).cuda()
466 |
467 | if len(sizes) == 4:
468 | x = x.view(sizes[0], sizes[1], -1)
469 | alpha = alpha.unsqueeze(-1)
470 |
471 | mean = x.mean(-1, keepdim=True)
472 | var = x.var(-1, keepdim=True)
473 |
474 | x = (x - mean) / (var + eps).sqrt()
475 |
476 | idx_swap = torch.randperm(sizes[0])
477 | if what == "style":
478 | mean = alpha * mean + (1 - alpha) * mean[idx_swap]
479 | var = alpha * var + (1 - alpha) * var[idx_swap]
480 | else:
481 | x = x[idx_swap].detach()
482 |
483 | x = x * (var + eps).sqrt() + mean
484 | return x.view(*sizes)
485 |
486 | def update(self, x, y, **kwargs):
487 | all_x = torch.cat([xi for xi in x])
488 | all_y = torch.cat([yi for yi in y])
489 |
490 | # learn content
491 | self.optimizer_f.zero_grad()
492 | self.optimizer_c.zero_grad()
493 | loss_c = F.cross_entropy(self.forward_c(all_x), all_y)
494 | loss_c.backward()
495 | self.optimizer_f.step()
496 | self.optimizer_c.step()
497 |
498 | # learn style
499 | self.optimizer_s.zero_grad()
500 | loss_s = F.cross_entropy(self.forward_s(all_x), all_y)
501 | loss_s.backward()
502 | self.optimizer_s.step()
503 |
504 | # learn adversary
505 | self.optimizer_f.zero_grad()
506 | loss_adv = -F.log_softmax(self.forward_s(all_x), dim=1).mean(1).mean()
507 | loss_adv = loss_adv * self.weight_adv
508 | loss_adv.backward()
509 | self.optimizer_f.step()
510 |
511 | return {
512 | "loss_c": loss_c.item(),
513 | "loss_s": loss_s.item(),
514 | "loss_adv": loss_adv.item(),
515 | }
516 |
517 | def predict(self, x):
518 | return self.network_c(self.network_f(x))
519 |
520 |
521 | class RSC(ERM):
522 | def __init__(self, input_shape, num_classes, num_domains, hparams):
523 | super(RSC, self).__init__(input_shape, num_classes, num_domains, hparams)
524 | self.drop_f = (1 - hparams["rsc_f_drop_factor"]) * 100
525 | self.drop_b = (1 - hparams["rsc_b_drop_factor"]) * 100
526 | self.num_classes = num_classes
527 |
528 | def update(self, x, y, **kwargs):
529 | # inputs
530 | all_x = torch.cat([xi for xi in x])
531 | # labels
532 | all_y = torch.cat([yi for yi in y])
533 | # one-hot labels
534 | all_o = torch.nn.functional.one_hot(all_y, self.num_classes)
535 | # features
536 | all_f = self.featurizer(all_x)
537 | # predictions
538 | all_p = self.classifier(all_f)
539 |
540 | # Equation (1): compute gradients with respect to representation
541 | all_g = autograd.grad((all_p * all_o).sum(), all_f)[0]
542 |
543 | # Equation (2): compute top-gradient-percentile mask
544 | percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1)
545 | percentiles = torch.Tensor(percentiles)
546 | percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1))
547 | mask_f = all_g.lt(percentiles.cuda()).float()
548 |
549 | # Equation (3): mute top-gradient-percentile activations
550 | all_f_muted = all_f * mask_f
551 |
552 | # Equation (4): compute muted predictions
553 | all_p_muted = self.classifier(all_f_muted)
554 |
555 | # Section 3.3: Batch Percentage
556 | all_s = F.softmax(all_p, dim=1)
557 | all_s_muted = F.softmax(all_p_muted, dim=1)
558 | changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1)
559 | percentile = np.percentile(changes.detach().cpu(), self.drop_b)
560 | mask_b = changes.lt(percentile).float().view(-1, 1)
561 | mask = torch.logical_or(mask_f, mask_b).float()
562 |
563 | # Equations (3) and (4) again, this time mutting over examples
564 | all_p_muted_again = self.classifier(all_f * mask)
565 |
566 | # Equation (5): update
567 | loss = F.cross_entropy(all_p_muted_again, all_y)
568 | self.optimizer.zero_grad()
569 | loss.backward()
570 | self.optimizer.step()
571 |
572 | return {"loss": loss.item()}
573 |
--------------------------------------------------------------------------------
/domainbed/algorithms/miro.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kakao Brain. All Rights Reserved.
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from domainbed.optimizers import get_optimizer
8 | from domainbed.networks.ur_networks import URFeaturizer
9 | from domainbed.lib import misc
10 | from domainbed.algorithms import Algorithm
11 |
12 |
13 | class ForwardModel(nn.Module):
14 | """Forward model is used to reduce gpu memory usage of SWAD.
15 | """
16 | def __init__(self, network):
17 | super().__init__()
18 | self.network = network
19 |
20 | def forward(self, x):
21 | return self.predict(x)
22 |
23 | def predict(self, x):
24 | return self.network(x)
25 |
26 |
27 | class MeanEncoder(nn.Module):
28 | """Identity function"""
29 | def __init__(self, shape):
30 | super().__init__()
31 | self.shape = shape
32 |
33 | def forward(self, x):
34 | return x
35 |
36 |
37 | class VarianceEncoder(nn.Module):
38 | """Bias-only model with diagonal covariance"""
39 | def __init__(self, shape, init=0.1, channelwise=True, eps=1e-5):
40 | super().__init__()
41 | self.shape = shape
42 | self.eps = eps
43 |
44 | init = (torch.as_tensor(init - eps).exp() - 1.0).log()
45 | b_shape = shape
46 | if channelwise:
47 | if len(shape) == 4:
48 | # [B, C, H, W]
49 | b_shape = (1, shape[1], 1, 1)
50 | elif len(shape ) == 3:
51 | # CLIP-ViT: [H*W+1, B, C]
52 | b_shape = (1, 1, shape[2])
53 | else:
54 | raise ValueError()
55 |
56 | self.b = nn.Parameter(torch.full(b_shape, init))
57 |
58 | def forward(self, x):
59 | return F.softplus(self.b) + self.eps
60 |
61 |
62 | def get_shapes(model, input_shape):
63 | # get shape of intermediate features
64 | with torch.no_grad():
65 | dummy = torch.rand(1, *input_shape).to(next(model.parameters()).device)
66 | _, feats = model(dummy, ret_feats=True)
67 | shapes = [f.shape for f in feats]
68 |
69 | return shapes
70 |
71 |
72 | class MIRO(Algorithm):
73 | """Mutual-Information Regularization with Oracle"""
74 | def __init__(self, input_shape, num_classes, num_domains, hparams, **kwargs):
75 | super().__init__(input_shape, num_classes, num_domains, hparams)
76 | self.pre_featurizer = URFeaturizer(
77 | input_shape, self.hparams, freeze="all", feat_layers=hparams.feat_layers
78 | )
79 | self.featurizer = URFeaturizer(
80 | input_shape, self.hparams, feat_layers=hparams.feat_layers
81 | )
82 | self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
83 | self.network = nn.Sequential(self.featurizer, self.classifier)
84 | self.ld = hparams.ld
85 |
86 | # build mean/var encoders
87 | shapes = get_shapes(self.pre_featurizer, self.input_shape)
88 | self.mean_encoders = nn.ModuleList([
89 | MeanEncoder(shape) for shape in shapes
90 | ])
91 | self.var_encoders = nn.ModuleList([
92 | VarianceEncoder(shape) for shape in shapes
93 | ])
94 |
95 | # optimizer
96 | parameters = [
97 | {"params": self.network.parameters()},
98 | {"params": self.mean_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
99 | {"params": self.var_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
100 | ]
101 | self.optimizer = get_optimizer(
102 | hparams["optimizer"],
103 | parameters,
104 | lr=self.hparams["lr"],
105 | weight_decay=self.hparams["weight_decay"],
106 | )
107 |
108 | def update(self, x, y, **kwargs):
109 | all_x = torch.cat(x)
110 | all_y = torch.cat(y)
111 | feat, inter_feats = self.featurizer(all_x, ret_feats=True)
112 | logit = self.classifier(feat)
113 | loss = F.cross_entropy(logit, all_y)
114 |
115 | # MIRO
116 | with torch.no_grad():
117 | _, pre_feats = self.pre_featurizer(all_x, ret_feats=True)
118 |
119 | reg_loss = 0.
120 | for f, pre_f, mean_enc, var_enc in misc.zip_strict(
121 | inter_feats, pre_feats, self.mean_encoders, self.var_encoders
122 | ):
123 | # mutual information regularization
124 | mean = mean_enc(f)
125 | var = var_enc(f)
126 | vlb = (mean - pre_f).pow(2).div(var) + var.log()
127 | reg_loss += vlb.mean() / 2.
128 |
129 | loss += reg_loss * self.ld
130 |
131 | self.optimizer.zero_grad()
132 | loss.backward()
133 | self.optimizer.step()
134 |
135 | return {"loss": loss.item(), "reg_loss": reg_loss.item()}
136 |
137 | def predict(self, x):
138 | return self.network(x)
139 |
140 | def get_forward_model(self):
141 | forward_model = ForwardModel(self.network)
142 | return forward_model
143 |
--------------------------------------------------------------------------------
/domainbed/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from domainbed.datasets import datasets
5 | from domainbed.lib import misc
6 | from domainbed.datasets import transforms as DBT
7 |
8 |
9 | def set_transfroms(dset, data_type, hparams, algorithm_class=None):
10 | """
11 | Args:
12 | data_type: ['train', 'valid', 'test', 'mnist']
13 | """
14 | assert hparams["data_augmentation"]
15 |
16 | additional_data = False
17 | if data_type == "train":
18 | dset.transforms = {"x": DBT.aug}
19 | additional_data = True
20 | elif data_type == "valid":
21 | if hparams["val_augment"] is False:
22 | dset.transforms = {"x": DBT.basic}
23 | else:
24 | # Originally, DomainBed use same training augmentation policy to validation.
25 | # We turn off the augmentation for validation as default,
26 | # but left the option to reproducibility.
27 | dset.transforms = {"x": DBT.aug}
28 | elif data_type == "test":
29 | dset.transforms = {"x": DBT.basic}
30 | elif data_type == "mnist":
31 | # No augmentation for mnist
32 | dset.transforms = {"x": lambda x: x}
33 | else:
34 | raise ValueError(data_type)
35 |
36 | if additional_data and algorithm_class is not None:
37 | for key, transform in algorithm_class.transforms.items():
38 | dset.transforms[key] = transform
39 |
40 |
41 | def get_dataset(test_envs, args, hparams, algorithm_class=None):
42 | """Get dataset and split."""
43 | is_mnist = "MNIST" in args.dataset
44 | dataset = vars(datasets)[args.dataset](args.data_dir)
45 | # if not isinstance(dataset, MultipleEnvironmentImageFolder):
46 | # raise ValueError("SMALL image datasets are not implemented (corrupted), for transform.")
47 |
48 | in_splits = []
49 | out_splits = []
50 | for env_i, env in enumerate(dataset):
51 | # The split only depends on seed_hash (= trial_seed).
52 | # It means that the split is always identical only if use same trial_seed,
53 | # independent to run the code where, when, or how many times.
54 | out, in_ = split_dataset(
55 | env,
56 | int(len(env) * args.holdout_fraction),
57 | misc.seed_hash(args.trial_seed, env_i),
58 | )
59 | if env_i in test_envs:
60 | in_type = "test"
61 | out_type = "test"
62 | else:
63 | in_type = "train"
64 | out_type = "valid"
65 |
66 | if is_mnist:
67 | in_type = "mnist"
68 | out_type = "mnist"
69 |
70 | set_transfroms(in_, in_type, hparams, algorithm_class)
71 | set_transfroms(out, out_type, hparams, algorithm_class)
72 |
73 | if hparams["class_balanced"]:
74 | in_weights = misc.make_weights_for_balanced_classes(in_)
75 | out_weights = misc.make_weights_for_balanced_classes(out)
76 | else:
77 | in_weights, out_weights = None, None
78 | in_splits.append((in_, in_weights))
79 | out_splits.append((out, out_weights))
80 |
81 | return dataset, in_splits, out_splits
82 |
83 |
84 | class _SplitDataset(torch.utils.data.Dataset):
85 | """Used by split_dataset"""
86 |
87 | def __init__(self, underlying_dataset, keys):
88 | super(_SplitDataset, self).__init__()
89 | self.underlying_dataset = underlying_dataset
90 | self.keys = keys
91 | self.transforms = {}
92 |
93 | self.direct_return = isinstance(underlying_dataset, _SplitDataset)
94 |
95 | def __getitem__(self, key):
96 | if self.direct_return:
97 | return self.underlying_dataset[self.keys[key]]
98 |
99 | x, y = self.underlying_dataset[self.keys[key]]
100 | ret = {"y": y}
101 |
102 | for key, transform in self.transforms.items():
103 | ret[key] = transform(x)
104 |
105 | return ret
106 |
107 | def __len__(self):
108 | return len(self.keys)
109 |
110 |
111 | def split_dataset(dataset, n, seed=0):
112 | """
113 | Return a pair of datasets corresponding to a random split of the given
114 | dataset, with n datapoints in the first dataset and the rest in the last,
115 | using the given random seed
116 | """
117 | assert n <= len(dataset)
118 | keys = list(range(len(dataset)))
119 | np.random.RandomState(seed).shuffle(keys)
120 | keys_1 = keys[:n]
121 | keys_2 = keys[n:]
122 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2)
123 |
--------------------------------------------------------------------------------
/domainbed/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import torch
5 | from PIL import Image, ImageFile
6 | from torchvision import transforms as T
7 | from torch.utils.data import TensorDataset
8 | from torchvision.datasets import MNIST, ImageFolder
9 | from torchvision.transforms.functional import rotate
10 |
11 | ImageFile.LOAD_TRUNCATED_IMAGES = True
12 |
13 | DATASETS = [
14 | # Debug
15 | "Debug28",
16 | "Debug224",
17 | # Small images
18 | "ColoredMNIST",
19 | "RotatedMNIST",
20 | # Big images
21 | "VLCS",
22 | "PACS",
23 | "OfficeHome",
24 | "TerraIncognita",
25 | "DomainNet",
26 | ]
27 |
28 |
29 | def get_dataset_class(dataset_name):
30 | """Return the dataset class with the given name."""
31 | if dataset_name not in globals():
32 | raise NotImplementedError("Dataset not found: {}".format(dataset_name))
33 | return globals()[dataset_name]
34 |
35 |
36 | def num_environments(dataset_name):
37 | return len(get_dataset_class(dataset_name).ENVIRONMENTS)
38 |
39 |
40 | class MultipleDomainDataset:
41 | N_STEPS = 5001 # Default, subclasses may override
42 | CHECKPOINT_FREQ = 100 # Default, subclasses may override
43 | N_WORKERS = 4 # Default, subclasses may override
44 | ENVIRONMENTS = None # Subclasses should override
45 | INPUT_SHAPE = None # Subclasses should override
46 |
47 | def __getitem__(self, index):
48 | """
49 | Return: sub-dataset for specific domain
50 | """
51 | return self.datasets[index]
52 |
53 | def __len__(self):
54 | """
55 | Return: # of sub-datasets
56 | """
57 | return len(self.datasets)
58 |
59 |
60 | class Debug(MultipleDomainDataset):
61 | def __init__(self, root):
62 | super().__init__()
63 | self.input_shape = self.INPUT_SHAPE
64 | self.num_classes = 2
65 | self.datasets = []
66 | for _ in [0, 1, 2]:
67 | self.datasets.append(
68 | TensorDataset(
69 | torch.randn(16, *self.INPUT_SHAPE),
70 | torch.randint(0, self.num_classes, (16,)),
71 | )
72 | )
73 |
74 |
75 | class Debug28(Debug):
76 | INPUT_SHAPE = (3, 28, 28)
77 | ENVIRONMENTS = ["0", "1", "2"]
78 |
79 |
80 | class Debug224(Debug):
81 | INPUT_SHAPE = (3, 224, 224)
82 | ENVIRONMENTS = ["0", "1", "2"]
83 |
84 |
85 | class MultipleEnvironmentMNIST(MultipleDomainDataset):
86 | def __init__(self, root, environments, dataset_transform, input_shape, num_classes):
87 | """
88 | Args:
89 | root: root dir for saving MNIST dataset
90 | environments: env properties for each dataset
91 | dataset_transform: dataset generator function
92 | """
93 | super().__init__()
94 | if root is None:
95 | raise ValueError("Data directory not specified!")
96 |
97 | original_dataset_tr = MNIST(root, train=True, download=True)
98 | original_dataset_te = MNIST(root, train=False, download=True)
99 |
100 | original_images = torch.cat((original_dataset_tr.data, original_dataset_te.data))
101 |
102 | original_labels = torch.cat((original_dataset_tr.targets, original_dataset_te.targets))
103 |
104 | shuffle = torch.randperm(len(original_images))
105 |
106 | original_images = original_images[shuffle]
107 | original_labels = original_labels[shuffle]
108 |
109 | self.datasets = []
110 | self.environments = environments
111 |
112 | for i in range(len(environments)):
113 | images = original_images[i :: len(environments)]
114 | labels = original_labels[i :: len(environments)]
115 | self.datasets.append(dataset_transform(images, labels, environments[i]))
116 |
117 | self.input_shape = input_shape
118 | self.num_classes = num_classes
119 |
120 |
121 | class ColoredMNIST(MultipleEnvironmentMNIST):
122 | ENVIRONMENTS = ["+90%", "+80%", "-90%"]
123 |
124 | def __init__(self, root):
125 | super(ColoredMNIST, self).__init__(
126 | root,
127 | [0.1, 0.2, 0.9],
128 | self.color_dataset,
129 | (2, 28, 28),
130 | 2,
131 | )
132 |
133 | def color_dataset(self, images, labels, environment):
134 | # # Subsample 2x for computational convenience
135 | # images = images.reshape((-1, 28, 28))[:, ::2, ::2]
136 | # Assign a binary label based on the digit
137 | labels = (labels < 5).float()
138 | # Flip label with probability 0.25
139 | labels = self.torch_xor_(labels, self.torch_bernoulli_(0.25, len(labels)))
140 |
141 | # Assign a color based on the label; flip the color with probability e
142 | colors = self.torch_xor_(labels, self.torch_bernoulli_(environment, len(labels)))
143 | images = torch.stack([images, images], dim=1)
144 | # Apply the color to the image by zeroing out the other color channel
145 | images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0
146 |
147 | x = images.float().div_(255.0)
148 | y = labels.view(-1).long()
149 |
150 | return TensorDataset(x, y)
151 |
152 | def torch_bernoulli_(self, p, size):
153 | return (torch.rand(size) < p).float()
154 |
155 | def torch_xor_(self, a, b):
156 | return (a - b).abs()
157 |
158 |
159 | class RotatedMNIST(MultipleEnvironmentMNIST):
160 | ENVIRONMENTS = ["0", "15", "30", "45", "60", "75"]
161 |
162 | def __init__(self, root):
163 | super(RotatedMNIST, self).__init__(
164 | root,
165 | [0, 15, 30, 45, 60, 75],
166 | self.rotate_dataset,
167 | (1, 28, 28),
168 | 10,
169 | )
170 |
171 | def rotate_dataset(self, images, labels, angle):
172 | rotation = T.Compose(
173 | [
174 | T.ToPILImage(),
175 | T.Lambda(lambda x: rotate(x, angle, fill=(0,), resample=Image.BICUBIC)),
176 | T.ToTensor(),
177 | ]
178 | )
179 |
180 | x = torch.zeros(len(images), 1, 28, 28)
181 | for i in range(len(images)):
182 | x[i] = rotation(images[i])
183 |
184 | y = labels.view(-1)
185 |
186 | return TensorDataset(x, y)
187 |
188 |
189 | class MultipleEnvironmentImageFolder(MultipleDomainDataset):
190 | def __init__(self, root):
191 | super().__init__()
192 | environments = [f.name for f in os.scandir(root) if f.is_dir()]
193 | environments = sorted(environments)
194 | self.environments = environments
195 |
196 | self.datasets = []
197 | for environment in environments:
198 | path = os.path.join(root, environment)
199 | env_dataset = ImageFolder(path)
200 |
201 | self.datasets.append(env_dataset)
202 |
203 | self.input_shape = (3, 224, 224)
204 | self.num_classes = len(self.datasets[-1].classes)
205 |
206 |
207 | class VLCS(MultipleEnvironmentImageFolder):
208 | CHECKPOINT_FREQ = 200
209 | ENVIRONMENTS = ["C", "L", "S", "V"]
210 |
211 | def __init__(self, root):
212 | self.dir = os.path.join(root, "VLCS/")
213 | super().__init__(self.dir)
214 |
215 |
216 | class PACS(MultipleEnvironmentImageFolder):
217 | CHECKPOINT_FREQ = 200
218 | ENVIRONMENTS = ["A", "C", "P", "S"]
219 |
220 | def __init__(self, root):
221 | self.dir = os.path.join(root, "PACS/")
222 | super().__init__(self.dir)
223 |
224 |
225 | class DomainNet(MultipleEnvironmentImageFolder):
226 | CHECKPOINT_FREQ = 1000
227 | N_STEPS = 15001
228 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"]
229 |
230 | def __init__(self, root):
231 | self.dir = os.path.join(root, "domain_net/")
232 | super().__init__(self.dir)
233 |
234 |
235 | class OfficeHome(MultipleEnvironmentImageFolder):
236 | CHECKPOINT_FREQ = 200
237 | ENVIRONMENTS = ["A", "C", "P", "R"]
238 |
239 | def __init__(self, root):
240 | self.dir = os.path.join(root, "office_home/")
241 | super().__init__(self.dir)
242 |
243 |
244 | class TerraIncognita(MultipleEnvironmentImageFolder):
245 | CHECKPOINT_FREQ = 200
246 | ENVIRONMENTS = ["L100", "L38", "L43", "L46"]
247 |
248 | def __init__(self, root):
249 | self.dir = os.path.join(root, "terra_incognita/")
250 | super().__init__(self.dir)
251 |
--------------------------------------------------------------------------------
/domainbed/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms as T
2 |
3 |
4 | basic = T.Compose(
5 | [
6 | T.Resize((224, 224)),
7 | T.ToTensor(),
8 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
9 | ]
10 | )
11 | aug = T.Compose(
12 | [
13 | T.RandomResizedCrop(224, scale=(0.7, 1.0)),
14 | T.RandomHorizontalFlip(),
15 | T.ColorJitter(0.3, 0.3, 0.3, 0.3),
16 | T.RandomGrayscale(p=0.1),
17 | T.ToTensor(),
18 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
19 | ]
20 | )
21 |
--------------------------------------------------------------------------------
/domainbed/evaluator.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from domainbed.lib.fast_data_loader import FastDataLoader
6 |
7 | if torch.cuda.is_available():
8 | device = "cuda"
9 | else:
10 | device = "cpu"
11 |
12 |
13 | def accuracy_from_loader(algorithm, loader, weights, debug=False):
14 | correct = 0
15 | total = 0
16 | losssum = 0.0
17 | weights_offset = 0
18 |
19 | algorithm.eval()
20 |
21 | for i, batch in enumerate(loader):
22 | x = batch["x"].to(device)
23 | y = batch["y"].to(device)
24 |
25 | with torch.no_grad():
26 | logits = algorithm.predict(x)
27 | loss = F.cross_entropy(logits, y).item()
28 |
29 | B = len(x)
30 | losssum += loss * B
31 |
32 | if weights is None:
33 | batch_weights = torch.ones(len(x))
34 | else:
35 | batch_weights = weights[weights_offset : weights_offset + len(x)]
36 | weights_offset += len(x)
37 | batch_weights = batch_weights.to(device)
38 | if logits.size(1) == 1:
39 | correct += (logits.gt(0).eq(y).float() * batch_weights).sum().item()
40 | else:
41 | correct += (logits.argmax(1).eq(y).float() * batch_weights).sum().item()
42 | total += batch_weights.sum().item()
43 |
44 | if debug:
45 | break
46 |
47 | algorithm.train()
48 |
49 | acc = correct / total
50 | loss = losssum / total
51 | return acc, loss
52 |
53 |
54 | def accuracy(algorithm, loader_kwargs, weights, **kwargs):
55 | if isinstance(loader_kwargs, dict):
56 | loader = FastDataLoader(**loader_kwargs)
57 | elif isinstance(loader_kwargs, FastDataLoader):
58 | loader = loader_kwargs
59 | else:
60 | raise ValueError(loader_kwargs)
61 | return accuracy_from_loader(algorithm, loader, weights, **kwargs)
62 |
63 |
64 | class Evaluator:
65 | def __init__(
66 | self, test_envs, eval_meta, n_envs, logger, evalmode="fast", debug=False, target_env=None
67 | ):
68 | all_envs = list(range(n_envs))
69 | train_envs = sorted(set(all_envs) - set(test_envs))
70 | self.test_envs = test_envs
71 | self.train_envs = train_envs
72 | self.eval_meta = eval_meta
73 | self.n_envs = n_envs
74 | self.logger = logger
75 | self.evalmode = evalmode
76 | self.debug = debug
77 |
78 | if target_env is not None:
79 | self.set_target_env(target_env)
80 |
81 | def set_target_env(self, target_env):
82 | """When len(test_envs) == 2, you can specify target env for computing exact test acc."""
83 | self.test_envs = [target_env]
84 |
85 |
86 | def evaluate(self, algorithm, suffix):
87 | n_train_envs = len(self.train_envs)
88 | n_test_envs = len(self.test_envs)
89 | assert n_test_envs == 1
90 | summaries = collections.defaultdict(float)
91 | # for key order
92 | summaries["test_in"+suffix] = 0.0
93 | summaries["test_out"+suffix] = 0.0
94 | summaries["comb_val"+suffix] = 0.0
95 | # order: in_splits + out_splits.
96 | for name, loader_kwargs, weights in self.eval_meta:
97 | # env\d_[in|out]
98 | env_name, inout = name.split("_")
99 | env_num = int(env_name[3:])
100 |
101 | skip_eval = self.evalmode == "fast" and inout == "in" and env_num not in self.test_envs
102 | if skip_eval:
103 | continue ## removing env_in of train envs
104 |
105 | is_test = env_num in self.test_envs
106 | acc, loss = accuracy(algorithm, loader_kwargs, weights, debug=self.debug)
107 |
108 | if env_num in self.train_envs:
109 | summaries["comb_val" + suffix] += acc / n_train_envs
110 | if inout == "out":
111 | summaries["comb_val_loss"+suffix] += loss / n_train_envs
112 | elif is_test:
113 | summaries["test_" + inout + suffix] += acc / n_test_envs
114 |
115 | return summaries
--------------------------------------------------------------------------------
/domainbed/hparams_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import numpy as np
4 |
5 |
6 | def _hparams(algorithm, dataset, random_state):
7 | """
8 | Global registry of hyperparams. Each entry is a (default, random) tuple.
9 | New algorithms / networks / etc. should add entries here.
10 | """
11 | SMALL_IMAGES = ["Debug28", "RotatedMNIST", "ColoredMNIST"]
12 |
13 | hparams = {}
14 |
15 | hparams["data_augmentation"] = (True, True)
16 | hparams["val_augment"] = (False, False) # augmentation for in-domain validation set
17 | hparams["resnet18"] = (False, False)
18 | hparams["resnet_dropout"] = (0.0, random_state.choice([0.0, 0.1, 0.5]))
19 | hparams["class_balanced"] = (False, False)
20 | hparams["optimizer"] = ("adam", "adam")
21 |
22 | hparams["freeze_bn"] = (True, True)
23 | hparams["pretrained"] = (True, True) # only for ResNet
24 |
25 | if dataset not in SMALL_IMAGES:
26 | hparams["lr"] = (5e-5, 10 ** random_state.uniform(-5, -3.5))
27 | if dataset == "DomainNet":
28 | hparams["batch_size"] = (32, int(2 ** random_state.uniform(3, 5)))
29 | else:
30 | hparams["batch_size"] = (32, int(2 ** random_state.uniform(3, 5.5)))
31 | if algorithm == "ARM":
32 | hparams["batch_size"] = (8, 8)
33 | else:
34 | hparams["lr"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5))
35 | hparams["batch_size"] = (64, int(2 ** random_state.uniform(3, 9)))
36 |
37 | if dataset in SMALL_IMAGES:
38 | hparams["weight_decay"] = (0.0, 0.0)
39 | else:
40 | hparams["weight_decay"] = (0.0, 10 ** random_state.uniform(-6, -2))
41 |
42 | if algorithm in ["DANN", "CDANN"]:
43 | if dataset not in SMALL_IMAGES:
44 | hparams["lr_g"] = (5e-5, 10 ** random_state.uniform(-5, -3.5))
45 | hparams["lr_d"] = (5e-5, 10 ** random_state.uniform(-5, -3.5))
46 | else:
47 | hparams["lr_g"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5))
48 | hparams["lr_d"] = (1e-3, 10 ** random_state.uniform(-4.5, -2.5))
49 |
50 | if dataset in SMALL_IMAGES:
51 | hparams["weight_decay_g"] = (0.0, 0.0)
52 | else:
53 | hparams["weight_decay_g"] = (0.0, 10 ** random_state.uniform(-6, -2))
54 |
55 | hparams["lambda"] = (1.0, 10 ** random_state.uniform(-2, 2))
56 | hparams["weight_decay_d"] = (0.0, 10 ** random_state.uniform(-6, -2))
57 | hparams["d_steps_per_g_step"] = (1, int(2 ** random_state.uniform(0, 3)))
58 | hparams["grad_penalty"] = (0.0, 10 ** random_state.uniform(-2, 1))
59 | hparams["beta1"] = (0.5, random_state.choice([0.0, 0.5]))
60 | hparams["mlp_width"] = (256, int(2 ** random_state.uniform(6, 10)))
61 | hparams["mlp_depth"] = (3, int(random_state.choice([3, 4, 5])))
62 | hparams["mlp_dropout"] = (0.0, random_state.choice([0.0, 0.1, 0.5]))
63 | elif algorithm == "RSC":
64 | hparams["rsc_f_drop_factor"] = (1 / 3, random_state.uniform(0, 0.5))
65 | hparams["rsc_b_drop_factor"] = (1 / 3, random_state.uniform(0, 0.5))
66 | elif algorithm == "SagNet":
67 | hparams["sag_w_adv"] = (0.1, 10 ** random_state.uniform(-2, 1))
68 | elif algorithm == "IRM":
69 | hparams["irm_lambda"] = (1e2, 10 ** random_state.uniform(-1, 5))
70 | hparams["irm_penalty_anneal_iters"] = (
71 | 500,
72 | int(10 ** random_state.uniform(0, 4)),
73 | )
74 | elif algorithm in ["Mixup", "OrgMixup"]:
75 | hparams["mixup_alpha"] = (0.2, 10 ** random_state.uniform(-1, -1))
76 | elif algorithm == "GroupDRO":
77 | hparams["groupdro_eta"] = (1e-2, 10 ** random_state.uniform(-3, -1))
78 | elif algorithm in ("MMD", "CORAL"):
79 | hparams["mmd_gamma"] = (1.0, 10 ** random_state.uniform(-1, 1))
80 | elif algorithm in ("MLDG", "SOMLDG"):
81 | hparams["mldg_beta"] = (1.0, 10 ** random_state.uniform(-1, 1))
82 | elif algorithm == "MTL":
83 | hparams["mtl_ema"] = (0.99, random_state.choice([0.5, 0.9, 0.99, 1.0]))
84 | elif algorithm == "VREx":
85 | hparams["vrex_lambda"] = (1e1, 10 ** random_state.uniform(-1, 5))
86 | hparams["vrex_penalty_anneal_iters"] = (
87 | 500,
88 | int(10 ** random_state.uniform(0, 4)),
89 | )
90 | elif algorithm == "SAM":
91 | hparams["rho"] = (0.05, random_state.choice([0.01, 0.02, 0.05, 0.1]))
92 | elif algorithm == "CutMix":
93 | hparams["beta"] = (1.0, 1.0)
94 | # cutmix_prob is set to 1.0 for ImageNet and 0.5 for CIFAR100 in the original paper.
95 | hparams["cutmix_prob"] = (1.0, 1.0)
96 |
97 | return hparams
98 |
99 |
100 | def default_hparams(algorithm, dataset):
101 | dummy_random_state = np.random.RandomState(0)
102 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, dummy_random_state).items()}
103 |
104 |
105 | def random_hparams(algorithm, dataset, seed):
106 | random_state = np.random.RandomState(seed)
107 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, random_state).items()}
108 |
--------------------------------------------------------------------------------
/domainbed/lib/fast_data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 |
5 |
6 | class _InfiniteSampler(torch.utils.data.Sampler):
7 | """Wraps another Sampler to yield an infinite stream."""
8 |
9 | def __init__(self, sampler):
10 | self.sampler = sampler
11 |
12 | def __iter__(self):
13 | while True:
14 | for batch in self.sampler:
15 | yield batch
16 |
17 |
18 | class InfiniteDataLoader:
19 | def __init__(self, dataset, weights, batch_size, num_workers):
20 | super().__init__()
21 |
22 | if weights:
23 | sampler = torch.utils.data.WeightedRandomSampler(
24 | weights, replacement=True, num_samples=batch_size
25 | )
26 | else:
27 | sampler = torch.utils.data.RandomSampler(dataset, replacement=True)
28 |
29 | batch_sampler = torch.utils.data.BatchSampler(
30 | sampler, batch_size=batch_size, drop_last=True
31 | )
32 |
33 | self._infinite_iterator = iter(
34 | torch.utils.data.DataLoader(
35 | dataset,
36 | num_workers=num_workers,
37 | batch_sampler=_InfiniteSampler(batch_sampler),
38 | )
39 | )
40 |
41 | def __iter__(self):
42 | while True:
43 | yield next(self._infinite_iterator)
44 |
45 | def __len__(self):
46 | raise ValueError
47 |
48 |
49 | class FastDataLoader:
50 | """
51 | DataLoader wrapper with slightly improved speed by not respawning worker
52 | processes at every epoch.
53 | """
54 |
55 | def __init__(self, dataset, batch_size, num_workers, shuffle=False):
56 | super().__init__()
57 |
58 | if shuffle:
59 | sampler = torch.utils.data.RandomSampler(dataset, replacement=False)
60 | else:
61 | sampler = torch.utils.data.SequentialSampler(dataset)
62 |
63 | batch_sampler = torch.utils.data.BatchSampler(
64 | sampler,
65 | batch_size=batch_size,
66 | drop_last=False,
67 | )
68 |
69 | self._infinite_iterator = iter(
70 | torch.utils.data.DataLoader(
71 | dataset,
72 | num_workers=num_workers,
73 | batch_sampler=_InfiniteSampler(batch_sampler),
74 | )
75 | )
76 |
77 | self._length = len(batch_sampler)
78 |
79 | def __iter__(self):
80 | for _ in range(len(self)):
81 | yield next(self._infinite_iterator)
82 |
83 | def __len__(self):
84 | return self._length
85 |
--------------------------------------------------------------------------------
/domainbed/lib/logger.py:
--------------------------------------------------------------------------------
1 | """ Singleton Logger """
2 | import sys
3 | import logging
4 |
5 |
6 | def levelize(levelname):
7 | """Convert levelname to level only if it is levelname"""
8 | if isinstance(levelname, str):
9 | return logging.getLevelName(levelname)
10 | else:
11 | return levelname # already level
12 |
13 |
14 | class ColorFormatter(logging.Formatter):
15 | color_dic = {
16 | "DEBUG": 37, # white
17 | "INFO": 36, # cyan
18 | "WARNING": 33, # yellow
19 | "ERROR": 31, # red
20 | "CRITICAL": 41, # white on red bg
21 | }
22 |
23 | def format(self, record):
24 | color = self.color_dic.get(record.levelname, 37) # default white
25 | record.levelname = "\033[{}m{}\033[0m".format(color, record.levelname)
26 | return logging.Formatter.format(self, record)
27 |
28 |
29 | class Logger(logging.Logger):
30 | NAME = "SingletonLogger"
31 |
32 | @classmethod
33 | def get(cls, file_path=None, level="INFO", colorize=True, track_code=False):
34 | logging.setLoggerClass(cls)
35 | logger = logging.getLogger(cls.NAME)
36 | logging.setLoggerClass(logging.Logger) # restore
37 | logger.setLevel(level)
38 |
39 | if logger.hasHandlers():
40 | # If logger already got all handlers (# handlers == 2), use the logger.
41 | # else, re-set handlers.
42 | if len(logger.handlers) == 2:
43 | return logger
44 |
45 | logger.handlers.clear()
46 |
47 | log_format = "%(levelname)s %(asctime)s | %(message)s"
48 | # log_format = '%(asctime)s | %(message)s'
49 | if track_code:
50 | log_format = (
51 | "%(levelname)s::%(asctime)s | [%(filename)s] [%(funcName)s:%(lineno)d] "
52 | "%(message)s"
53 | )
54 | date_format = "%m/%d %H:%M:%S"
55 | if colorize:
56 | formatter = ColorFormatter(log_format, date_format)
57 | else:
58 | formatter = logging.Formatter(log_format, date_format)
59 |
60 | # standard output handler
61 | # NOTE as default, StreamHandler use stderr stream instead of stdout stream.
62 | # Use StreamHandler(sys.stdout) for stdout stream.
63 | stream_handler = logging.StreamHandler(sys.stdout)
64 | stream_handler.setFormatter(formatter)
65 | logger.addHandler(stream_handler)
66 |
67 | if file_path:
68 | # file output handler
69 | file_handler = logging.FileHandler(file_path)
70 | file_handler.setFormatter(formatter)
71 | logger.addHandler(file_handler)
72 |
73 | logger.propagate = False
74 |
75 | return logger
76 |
77 | def nofmt(self, msg, *args, level="INFO", **kwargs):
78 | level = levelize(level)
79 | formatters = self.remove_formats()
80 | super().log(level, msg, *args, **kwargs)
81 | self.set_formats(formatters)
82 |
83 | def remove_formats(self):
84 | """Remove all formats from logger"""
85 | formatters = []
86 | for handler in self.handlers:
87 | formatters.append(handler.formatter)
88 | handler.setFormatter(logging.Formatter("%(message)s"))
89 |
90 | return formatters
91 |
92 | def set_formats(self, formatters):
93 | """Set formats to every handler of logger"""
94 | for handler, formatter in zip(self.handlers, formatters):
95 | handler.setFormatter(formatter)
96 |
97 | def set_file_handler(self, file_path):
98 | file_handler = logging.FileHandler(file_path)
99 | formatter = self.handlers[0].formatter
100 | file_handler.setFormatter(formatter)
101 | self.addHandler(file_handler)
102 |
--------------------------------------------------------------------------------
/domainbed/lib/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | Things that don't belong anywhere else
5 | """
6 |
7 | import hashlib
8 | import sys
9 | import random
10 | import os
11 | import shutil
12 | import errno
13 | from itertools import chain
14 | from datetime import datetime
15 | from collections import Counter
16 | from typing import List
17 | from contextlib import contextmanager
18 | from subprocess import call
19 |
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 |
25 |
26 | def make_weights_for_balanced_classes(dataset):
27 | counts = Counter()
28 | classes = []
29 | for _, y in dataset:
30 | y = int(y)
31 | counts[y] += 1
32 | classes.append(y)
33 |
34 | n_classes = len(counts)
35 |
36 | weight_per_class = {}
37 | for y in counts:
38 | weight_per_class[y] = 1 / (counts[y] * n_classes)
39 |
40 | weights = torch.zeros(len(dataset))
41 | for i, y in enumerate(classes):
42 | weights[i] = weight_per_class[int(y)]
43 |
44 | return weights
45 |
46 |
47 | def seed_hash(*args):
48 | """
49 | Derive an integer hash from all args, for use as a random seed.
50 | """
51 | args_str = str(args)
52 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2 ** 31)
53 |
54 |
55 | def to_row(row, colwidth=10, latex=False):
56 | """Convert value list to row string"""
57 | if latex:
58 | sep = " & "
59 | end_ = "\\\\"
60 | else:
61 | sep = " "
62 | end_ = ""
63 |
64 | def format_val(x):
65 | if np.issubdtype(type(x), np.floating):
66 | x = "{:.6f}".format(x)
67 | return str(x).ljust(colwidth)[:colwidth]
68 |
69 | return sep.join([format_val(x) for x in row]) + " " + end_
70 |
71 |
72 | def random_pairs_of_minibatches(minibatches):
73 | # n_tr_envs = len(minibatches)
74 | perm = torch.randperm(len(minibatches)).tolist()
75 | pairs = []
76 |
77 | for i in range(len(minibatches)):
78 | # j = cyclic(i + 1)
79 | j = i + 1 if i < (len(minibatches) - 1) else 0
80 |
81 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1]
82 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1]
83 |
84 | min_n = min(len(xi), len(xj))
85 |
86 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n])))
87 |
88 | return pairs
89 |
90 |
91 | ###########################################################
92 | # Custom utils
93 | ###########################################################
94 |
95 |
96 | def index_conditional_iterate(skip_condition, iterable, index):
97 | for i, x in enumerate(iterable):
98 | if skip_condition(i):
99 | continue
100 |
101 | if index:
102 | yield i, x
103 | else:
104 | yield x
105 |
106 |
107 | class SplitIterator:
108 | def __init__(self, test_envs):
109 | self.test_envs = test_envs
110 |
111 | def train(self, iterable, index=False):
112 | return index_conditional_iterate(lambda idx: idx in self.test_envs, iterable, index)
113 |
114 | def test(self, iterable, index=False):
115 | return index_conditional_iterate(lambda idx: idx not in self.test_envs, iterable, index)
116 |
117 |
118 | class AverageMeter():
119 | """ Computes and stores the average and current value """
120 | def __init__(self):
121 | self.reset()
122 |
123 | def reset(self):
124 | """ Reset all statistics """
125 | self.val = 0
126 | self.avg = 0
127 | self.sum = 0
128 | self.count = 0
129 |
130 | def update(self, val, n=1):
131 | """ Update statistics """
132 | self.val = val
133 | self.sum += val * n
134 | self.count += n
135 | self.avg = self.sum / self.count
136 |
137 | def __repr__(self):
138 | return "{:.3f} (val={:.3f}, count={})".format(self.avg, self.val, self.count)
139 |
140 |
141 | class AverageMeters():
142 | def __init__(self, *keys):
143 | self.keys = keys
144 | for k in keys:
145 | setattr(self, k, AverageMeter())
146 |
147 | def resets(self):
148 | for k in self.keys:
149 | getattr(self, k).reset()
150 |
151 | def updates(self, dic, n=1):
152 | for k, v in dic.items():
153 | getattr(self, k).update(v, n)
154 |
155 | def __repr__(self):
156 | return " ".join(["{}: {}".format(k, str(getattr(self, k))) for k in self.keys])
157 |
158 | def get_averages(self):
159 | dic = {k: getattr(self, k).avg for k in self.keys}
160 | return dic
161 |
162 |
163 | def timestamp(fmt="%y%m%d_%H-%M-%S"):
164 | return datetime.now().strftime(fmt)
165 |
166 |
167 | def makedirs(path):
168 | if not os.path.exists(path):
169 | try:
170 | os.makedirs(path)
171 | except OSError as exc:
172 | if exc.errno != errno.EEXIST:
173 | raise
174 |
175 |
176 | def rm(path):
177 | """ remove dir recursively """
178 | if os.path.isdir(path):
179 | shutil.rmtree(path, ignore_errors=True)
180 | elif os.path.exists(path):
181 | os.remove(path)
182 |
183 |
184 | def cp(src, dst):
185 | shutil.copy2(src, dst)
186 |
187 |
188 | def set_seed(seed):
189 | random.seed(seed)
190 | # os.environ['PYTHONHASHSEED'] = str(seed)
191 | np.random.seed(seed)
192 | torch.manual_seed(seed)
193 | # torch.backends.cudnn.deterministic = True
194 | torch.backends.cudnn.benchmark = True
195 |
196 |
197 | def get_lr(optimizer):
198 | """Assume that the optimizer has single lr"""
199 | lr = optimizer.param_groups[0]['lr']
200 |
201 | return lr
202 |
203 |
204 | def entropy(logits):
205 | ent = F.softmax(logits, -1) * F.log_softmax(logits, -1)
206 | ent = -ent.sum(1) # batch-wise
207 | return ent.mean()
208 |
209 |
210 | @torch.no_grad()
211 | def hash_bn(module):
212 | summary = []
213 | for m in module.modules():
214 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
215 | w = m.weight.detach().mean().item()
216 | b = m.bias.detach().mean().item()
217 | rm = m.running_mean.detach().mean().item()
218 | rv = m.running_var.detach().mean().item()
219 | summary.append((w, b, rm, rv))
220 |
221 | if not summary:
222 | return 0., 0.
223 |
224 | w, b, rm, rv = [np.mean(col) for col in zip(*summary)]
225 | p = np.mean([w, b])
226 | s = np.mean([rm, rv])
227 |
228 | return p, s
229 |
230 |
231 | @torch.no_grad()
232 | def hash_params(module):
233 | return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item()
234 |
235 |
236 | @torch.no_grad()
237 | def hash_module(module):
238 | p = hash_params(module)
239 | _, s = hash_bn(module)
240 |
241 | return p, s
242 |
243 |
244 | def merge_dictlist(dictlist):
245 | """Merge list of dicts into dict of lists, by grouping same key.
246 | """
247 | ret = {
248 | k: []
249 | for k in dictlist[0].keys()
250 | }
251 | for dic in dictlist:
252 | for data_key, v in dic.items():
253 | ret[data_key].append(v)
254 | return ret
255 |
256 |
257 | def zip_strict(*iterables):
258 | """strict version of zip. The length of iterables should be same.
259 |
260 | NOTE yield looks non-reachable, but they are required.
261 | """
262 | # For trivial cases, use pure zip.
263 | if len(iterables) < 2:
264 | return zip(*iterables)
265 |
266 | # Tail for the first iterable
267 | first_stopped = False
268 | def first_tail():
269 | nonlocal first_stopped
270 | first_stopped = True
271 | return
272 | yield
273 |
274 | # Tail for the zip
275 | def zip_tail():
276 | if not first_stopped:
277 | raise ValueError('zip_equal: first iterable is longer')
278 | for _ in chain.from_iterable(rest):
279 | raise ValueError('zip_equal: first iterable is shorter')
280 | yield
281 |
282 | # Put the pieces together
283 | iterables = iter(iterables)
284 | first = chain(next(iterables), first_tail())
285 | rest = list(map(iter, iterables))
286 | return chain(zip(first, *rest), zip_tail())
287 |
288 |
289 | def freeze_(module):
290 | for p in module.parameters():
291 | p.requires_grad_(False)
292 | module.eval()
293 |
294 |
295 | def unfreeze_(module):
296 | for p in module.parameters():
297 | p.requires_grad_(True)
298 | module.train()
299 |
--------------------------------------------------------------------------------
/domainbed/lib/query.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """Small query library."""
4 |
5 | import inspect
6 | import json
7 | import types
8 | import warnings
9 |
10 | import numpy as np
11 |
12 |
13 | def make_selector_fn(selector):
14 | """
15 | If selector is a function, return selector.
16 | Otherwise, return a function corresponding to the selector string. Examples
17 | of valid selector strings and the corresponding functions:
18 | x lambda obj: obj['x']
19 | x.y lambda obj: obj['x']['y']
20 | x,y lambda obj: (obj['x'], obj['y'])
21 | """
22 | if isinstance(selector, str):
23 | if "," in selector:
24 | parts = selector.split(",")
25 | part_selectors = [make_selector_fn(part) for part in parts]
26 | return lambda obj: tuple(sel(obj) for sel in part_selectors)
27 | elif "." in selector:
28 | parts = selector.split(".")
29 | part_selectors = [make_selector_fn(part) for part in parts]
30 |
31 | def f(obj):
32 | for sel in part_selectors:
33 | obj = sel(obj)
34 | return obj
35 |
36 | return f
37 | else:
38 | key = selector.strip()
39 | return lambda obj: obj[key]
40 | elif isinstance(selector, types.FunctionType):
41 | return selector
42 | else:
43 | raise TypeError
44 |
45 |
46 | def hashable(obj):
47 | try:
48 | hash(obj)
49 | return obj
50 | except TypeError:
51 | return json.dumps({"_": obj}, sort_keys=True)
52 |
53 |
54 | class Q(object):
55 | def __init__(self, list_):
56 | super(Q, self).__init__()
57 | self._list = list_
58 |
59 | def __len__(self):
60 | return len(self._list)
61 |
62 | def __getitem__(self, key):
63 | return self._list[key]
64 |
65 | def __eq__(self, other):
66 | if isinstance(other, self.__class__):
67 | return self._list == other._list
68 | else:
69 | return self._list == other
70 |
71 | def __str__(self):
72 | return str(self._list)
73 |
74 | def __repr__(self):
75 | return repr(self._list)
76 |
77 | def _append(self, item):
78 | """Unsafe, be careful you know what you're doing."""
79 | self._list.append(item)
80 |
81 | def group(self, selector):
82 | """
83 | Group elements by selector and return a list of (group, group_records)
84 | tuples.
85 | """
86 | selector = make_selector_fn(selector)
87 | groups = {}
88 | for x in self._list:
89 | group = selector(x)
90 | group_key = hashable(group)
91 | if group_key not in groups:
92 | groups[group_key] = (group, Q([]))
93 | groups[group_key][1]._append(x)
94 | results = [groups[key] for key in sorted(groups.keys())]
95 | return Q(results)
96 |
97 | def group_map(self, selector, fn):
98 | """
99 | Group elements by selector, apply fn to each group, and return a list
100 | of the results.
101 | """
102 | return self.group(selector).map(fn)
103 |
104 | def map(self, fn):
105 | """
106 | map self onto fn. If fn takes multiple args, tuple-unpacking
107 | is applied.
108 | """
109 | if len(inspect.signature(fn).parameters) > 1:
110 | return Q([fn(*x) for x in self._list])
111 | else:
112 | return Q([fn(x) for x in self._list])
113 |
114 | def select(self, selector):
115 | selector = make_selector_fn(selector)
116 | return Q([selector(x) for x in self._list])
117 |
118 | def min(self):
119 | return min(self._list)
120 |
121 | def max(self):
122 | return max(self._list)
123 |
124 | def sum(self):
125 | return sum(self._list)
126 |
127 | def len(self):
128 | return len(self._list)
129 |
130 | def mean(self):
131 | with warnings.catch_warnings():
132 | warnings.simplefilter("ignore")
133 | return float(np.mean(self._list))
134 |
135 | def std(self):
136 | with warnings.catch_warnings():
137 | warnings.simplefilter("ignore")
138 | return float(np.std(self._list))
139 |
140 | def mean_std(self):
141 | return (self.mean(), self.std())
142 |
143 | def argmax(self, selector):
144 | selector = make_selector_fn(selector)
145 | return max(self._list, key=selector)
146 |
147 | def filter(self, fn):
148 | return Q([x for x in self._list if fn(x)])
149 |
150 | def filter_equals(self, selector, value):
151 | """like [x for x in y if x.selector == value]"""
152 | selector = make_selector_fn(selector)
153 | return self.filter(lambda r: selector(r) == value)
154 |
155 | def filter_not_none(self):
156 | return self.filter(lambda r: r is not None)
157 |
158 | def filter_not_nan(self):
159 | return self.filter(lambda r: not np.isnan(r))
160 |
161 | def flatten(self):
162 | return Q([y for x in self._list for y in x])
163 |
164 | def unique(self):
165 | result = []
166 | result_set = set()
167 | for x in self._list:
168 | hashable_x = hashable(x)
169 | if hashable_x not in result_set:
170 | result_set.add(hashable_x)
171 | result.append(x)
172 | return Q(result)
173 |
174 | def sorted(self, key=None, reverse=False):
175 | if key is None:
176 | key = lambda x: x
177 |
178 | def key2(x):
179 | x = key(x)
180 | if isinstance(x, (np.floating, float)) and np.isnan(x):
181 | return float("-inf")
182 | else:
183 | return x
184 |
185 | return Q(sorted(self._list, key=key2, reverse=reverse))
186 |
--------------------------------------------------------------------------------
/domainbed/lib/swa_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py
2 | import copy
3 | import warnings
4 | import math
5 | from copy import deepcopy
6 |
7 | import torch
8 | from torch.nn import Module
9 | from torch.optim.lr_scheduler import _LRScheduler
10 |
11 | from domainbed.networks.ur_networks import URResNet
12 |
13 |
14 | class AveragedModel(Module):
15 | def filter(self, model):
16 | if isinstance(model, AveragedModel):
17 | # prevent nested averagedmodel
18 | model = model.module
19 |
20 | if hasattr(model, "get_forward_model"):
21 | model = model.get_forward_model()
22 | # URERM models use URNetwork, which manages features internally.
23 | for m in model.modules():
24 | if isinstance(m, URResNet):
25 | m.clear_features()
26 |
27 | return model
28 |
29 | def __init__(self, model, device=None, avg_fn=None, rm_optimizer=False):
30 | super(AveragedModel, self).__init__()
31 | self.start_step = -1
32 | self.end_step = -1
33 | model = self.filter(model)
34 | self.module = deepcopy(model)
35 | self.module.zero_grad()
36 | if rm_optimizer:
37 | for k, v in vars(self.module).items():
38 | if isinstance(v, torch.optim.Optimizer):
39 | setattr(self.module, k, None)
40 | # print(f"{k} -> {getattr(self.module, k)}")
41 | if device is not None:
42 | self.module = self.module.to(device)
43 | self.register_buffer('n_averaged', torch.tensor(0, dtype=torch.long, device=device))
44 | if avg_fn is None:
45 | def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
46 | return averaged_model_parameter + \
47 | (model_parameter - averaged_model_parameter) / (num_averaged + 1)
48 | self.avg_fn = avg_fn
49 |
50 | def forward(self, *args, **kwargs):
51 | # return self.predict(*args, **kwargs)
52 | return self.module(*args, **kwargs)
53 |
54 | def predict(self, *args, **kwargs):
55 | return self.module.predict(*args, **kwargs)
56 |
57 | @property
58 | def network(self):
59 | return self.module.network
60 |
61 | def update_parameters(self, model, step=None, start_step=None, end_step=None):
62 | model = self.filter(model)
63 | for p_swa, p_model in zip(self.parameters(), model.parameters()):
64 | device = p_swa.device
65 | p_model_ = p_model.detach().to(device)
66 | if self.n_averaged == 0:
67 | p_swa.detach().copy_(p_model_)
68 | else:
69 | p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
70 | self.n_averaged.to(device)))
71 | self.n_averaged += 1
72 |
73 | if step is not None:
74 | if start_step is None:
75 | start_step = step
76 | if end_step is None:
77 | end_step = step
78 |
79 | if start_step is not None:
80 | if self.n_averaged == 1:
81 | self.start_step = start_step
82 |
83 | if end_step is not None:
84 | self.end_step = end_step
85 |
86 | def clone(self):
87 | clone = copy.deepcopy(self.module)
88 | clone.optimizer = clone.new_optimizer(clone.network.parameters())
89 | return clone
90 |
91 |
92 | @torch.no_grad()
93 | def update_bn(iterator, model, n_steps, device='cuda'):
94 | momenta = {}
95 | for module in model.modules():
96 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
97 | module.running_mean = torch.zeros_like(module.running_mean)
98 | module.running_var = torch.ones_like(module.running_var)
99 | momenta[module] = module.momentum
100 |
101 | if not momenta:
102 | return
103 |
104 | was_training = model.training
105 | model.train()
106 | for module in momenta.keys():
107 | module.momentum = None
108 | module.num_batches_tracked *= 0
109 |
110 | # for input in loader:
111 | for i in range(n_steps):
112 | # batches_dictlist: [{env0_data_key: tensor, env0_...}, env1_..., ...]
113 | batches_dictlist = next(iterator)
114 | x = torch.cat([
115 | dic["x"] for dic in batches_dictlist
116 | ])
117 | x = x.to(device)
118 |
119 | model(x)
120 |
121 | for bn_module in momenta.keys():
122 | bn_module.momentum = momenta[bn_module]
123 | model.train(was_training)
124 |
125 |
126 | class SWALR(_LRScheduler):
127 | r"""Anneals the learning rate in each parameter group to a fixed value.
128 | This learning rate scheduler is meant to be used with Stochastic Weight
129 | Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
130 | Arguments:
131 | optimizer (torch.optim.Optimizer): wrapped optimizer
132 | swa_lrs (float or list): the learning rate value for all param groups
133 | together or separately for each group.
134 | annealing_epochs (int): number of epochs in the annealing phase
135 | (default: 10)
136 | annealing_strategy (str): "cos" or "linear"; specifies the annealing
137 | strategy: "cos" for cosine annealing, "linear" for linear annealing
138 | (default: "cos")
139 | last_epoch (int): the index of the last epoch (default: 'cos')
140 | The :class:`SWALR` scheduler is can be used together with other
141 | schedulers to switch to a constant learning rate late in the training
142 | as in the example below.
143 | Example:
144 | >>> loader, optimizer, model = ...
145 | >>> lr_lambda = lambda epoch: 0.9
146 | >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
147 | >>> lr_lambda=lr_lambda)
148 | >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
149 | >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
150 | >>> swa_start = 160
151 | >>> for i in range(300):
152 | >>> for input, target in loader:
153 | >>> optimizer.zero_grad()
154 | >>> loss_fn(model(input), target).backward()
155 | >>> optimizer.step()
156 | >>> if i > swa_start:
157 | >>> swa_scheduler.step()
158 | >>> else:
159 | >>> scheduler.step()
160 | .. _Averaging Weights Leads to Wider Optima and Better Generalization:
161 | https://arxiv.org/abs/1803.05407
162 | """
163 | def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):
164 | swa_lrs = self._format_param(optimizer, swa_lr)
165 | for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
166 | group['swa_lr'] = swa_lr
167 | if anneal_strategy not in ['cos', 'linear']:
168 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', "
169 | "instead got {}".format(anneal_strategy))
170 | elif anneal_strategy == 'cos':
171 | self.anneal_func = self._cosine_anneal
172 | elif anneal_strategy == 'linear':
173 | self.anneal_func = self._linear_anneal
174 | if not isinstance(anneal_epochs, int) or anneal_epochs < 1:
175 | raise ValueError("anneal_epochs must be a positive integer, got {}".format(
176 | anneal_epochs))
177 | self.anneal_epochs = anneal_epochs
178 |
179 | super(SWALR, self).__init__(optimizer, last_epoch)
180 |
181 | @staticmethod
182 | def _format_param(optimizer, swa_lrs):
183 | if isinstance(swa_lrs, (list, tuple)):
184 | if len(swa_lrs) != len(optimizer.param_groups):
185 | raise ValueError("swa_lr must have the same length as "
186 | "optimizer.param_groups: swa_lr has {}, "
187 | "optimizer.param_groups has {}".format(
188 | len(swa_lrs), len(optimizer.param_groups)))
189 | return swa_lrs
190 | else:
191 | return [swa_lrs] * len(optimizer.param_groups)
192 |
193 | @staticmethod
194 | def _linear_anneal(t):
195 | return t
196 |
197 | @staticmethod
198 | def _cosine_anneal(t):
199 | return (1 - math.cos(math.pi * t)) / 2
200 |
201 | @staticmethod
202 | def _get_initial_lr(lr, swa_lr, alpha):
203 | if alpha == 1:
204 | return swa_lr
205 | return (lr - alpha * swa_lr) / (1 - alpha)
206 |
207 | def get_lr(self):
208 | if not self._get_lr_called_within_step:
209 | warnings.warn("To get the last learning rate computed by the scheduler, "
210 | "please use `get_last_lr()`.", UserWarning)
211 | step = self._step_count - 1
212 | prev_t = max(0, min(1, (step - 1) / self.anneal_epochs))
213 | prev_alpha = self.anneal_func(prev_t)
214 | prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha)
215 | for group in self.optimizer.param_groups]
216 | t = max(0, min(1, step / self.anneal_epochs))
217 | alpha = self.anneal_func(t)
218 | return [group['swa_lr'] * alpha + lr * (1 - alpha)
219 | for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
220 |
--------------------------------------------------------------------------------
/domainbed/lib/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | From https://github.com/meliketoy/wide-resnet.pytorch
5 | """
6 |
7 | import sys
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
19 |
20 |
21 | def conv_init(m):
22 | classname = m.__class__.__name__
23 | if classname.find("Conv") != -1:
24 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
25 | init.constant_(m.bias, 0)
26 | elif classname.find("BatchNorm") != -1:
27 | init.constant_(m.weight, 1)
28 | init.constant_(m.bias, 0)
29 |
30 |
31 | class wide_basic(nn.Module):
32 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
33 | super(wide_basic, self).__init__()
34 | self.bn1 = nn.BatchNorm2d(in_planes)
35 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
36 | self.dropout = nn.Dropout(p=dropout_rate)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
39 |
40 | self.shortcut = nn.Sequential()
41 | if stride != 1 or in_planes != planes:
42 | self.shortcut = nn.Sequential(
43 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
44 | )
45 |
46 | def forward(self, x):
47 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
48 | out = self.conv2(F.relu(self.bn2(out)))
49 | out += self.shortcut(x)
50 |
51 | return out
52 |
53 |
54 | class Wide_ResNet(nn.Module):
55 | """Wide Resnet with the softmax layer chopped off"""
56 |
57 | def __init__(self, input_shape, depth, widen_factor, dropout_rate):
58 | super(Wide_ResNet, self).__init__()
59 | self.in_planes = 16
60 |
61 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4"
62 | n = (depth - 4) / 6
63 | k = widen_factor
64 |
65 | # print('| Wide-Resnet %dx%d' % (depth, k))
66 | nStages = [16, 16 * k, 32 * k, 64 * k]
67 |
68 | self.conv1 = conv3x3(input_shape[0], nStages[0])
69 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
70 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
71 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
72 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
73 |
74 | self.n_outputs = nStages[3]
75 |
76 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
77 | strides = [stride] + [1] * (int(num_blocks) - 1)
78 | layers = []
79 |
80 | for stride in strides:
81 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
82 | self.in_planes = planes
83 |
84 | return nn.Sequential(*layers)
85 |
86 | def forward(self, x):
87 | out = self.conv1(x)
88 | out = self.layer1(out)
89 | out = self.layer2(out)
90 | out = self.layer3(out)
91 | out = F.relu(self.bn1(out))
92 | out = F.avg_pool2d(out, 8)
93 | return out[:, :, 0, 0]
94 |
--------------------------------------------------------------------------------
/domainbed/lib/writers.py:
--------------------------------------------------------------------------------
1 | class Writer:
2 | def add_scalars(self, tag_scalar_dic, global_step):
3 | raise NotImplementedError()
4 |
5 | def add_scalars_with_prefix(self, tag_scalar_dic, global_step, prefix):
6 | tag_scalar_dic = {prefix + k: v for k, v in tag_scalar_dic.items()}
7 | self.add_scalars(tag_scalar_dic, global_step)
8 |
9 |
10 | class TBWriter(Writer):
11 | def __init__(self, dir_path):
12 | from tensorboardX import SummaryWriter
13 |
14 | self.writer = SummaryWriter(dir_path, flush_secs=30)
15 |
16 | def add_scalars(self, tag_scalar_dic, global_step):
17 | for tag, scalar in tag_scalar_dic.items():
18 | self.writer.add_scalar(tag, scalar, global_step)
19 |
20 |
21 | def get_writer(dir_path):
22 | """
23 | Args:
24 | dir_path: tb dir
25 | """
26 | writer = TBWriter(dir_path)
27 |
28 | return writer
29 |
--------------------------------------------------------------------------------
/domainbed/models/mixstyle.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/mixstyle.py
3 | """
4 | import random
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class MixStyle(nn.Module):
10 | """MixStyle.
11 | Reference:
12 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
13 | """
14 |
15 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6):
16 | """
17 | Args:
18 | p (float): probability of using MixStyle.
19 | alpha (float): parameter of the Beta distribution.
20 | eps (float): scaling parameter to avoid numerical issues.
21 | """
22 | super().__init__()
23 | self.p = p
24 | self.beta = torch.distributions.Beta(alpha, alpha)
25 | self.eps = eps
26 | self.alpha = alpha
27 |
28 | print("* MixStyle params")
29 | print(f"- p: {p}")
30 | print(f"- alpha: {alpha}")
31 |
32 | def __repr__(self):
33 | return f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})"
34 |
35 | def forward(self, x):
36 | if not self.training:
37 | return x
38 |
39 | if random.random() > self.p:
40 | return x
41 |
42 | B = x.size(0)
43 |
44 | mu = x.mean(dim=[2, 3], keepdim=True)
45 | var = x.var(dim=[2, 3], keepdim=True)
46 | sig = (var + self.eps).sqrt()
47 | mu, sig = mu.detach(), sig.detach()
48 | x_normed = (x - mu) / sig
49 |
50 | lmda = self.beta.sample((B, 1, 1, 1))
51 | lmda = lmda.to(x.device)
52 |
53 | perm = torch.randperm(B)
54 | mu2, sig2 = mu[perm], sig[perm]
55 | mu_mix = mu * lmda + mu2 * (1 - lmda)
56 | sig_mix = sig * lmda + sig2 * (1 - lmda)
57 |
58 | return x_normed * sig_mix + mu_mix
59 |
60 |
61 | class MixStyle2(nn.Module):
62 | """MixStyle (w/ domain prior).
63 | The input should contain two equal-sized mini-batches from two distinct domains.
64 | Reference:
65 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
66 | """
67 |
68 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6):
69 | """
70 | Args:
71 | p (float): probability of using MixStyle.
72 | alpha (float): parameter of the Beta distribution.
73 | eps (float): scaling parameter to avoid numerical issues.
74 | """
75 | super().__init__()
76 | self.p = p
77 | self.beta = torch.distributions.Beta(alpha, alpha)
78 | self.eps = eps
79 | self.alpha = alpha
80 |
81 | print("* MixStyle params")
82 | print(f"- p: {p}")
83 | print(f"- alpha: {alpha}")
84 |
85 | def __repr__(self):
86 | return f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})"
87 |
88 | def forward(self, x):
89 | """
90 | For the input x, the first half comes from one domain,
91 | while the second half comes from the other domain.
92 | """
93 | if not self.training:
94 | return x
95 |
96 | if random.random() > self.p:
97 | return x
98 |
99 | B = x.size(0)
100 |
101 | mu = x.mean(dim=[2, 3], keepdim=True)
102 | var = x.var(dim=[2, 3], keepdim=True)
103 | sig = (var + self.eps).sqrt()
104 | mu, sig = mu.detach(), sig.detach()
105 | x_normed = (x - mu) / sig
106 |
107 | lmda = self.beta.sample((B, 1, 1, 1))
108 | lmda = lmda.to(x.device)
109 |
110 | perm = torch.arange(B - 1, -1, -1) # inverse index
111 | perm_b, perm_a = perm.chunk(2)
112 | perm_b = perm_b[torch.randperm(B // 2)]
113 | perm_a = perm_a[torch.randperm(B // 2)]
114 | perm = torch.cat([perm_b, perm_a], 0)
115 |
116 | mu2, sig2 = mu[perm], sig[perm]
117 | mu_mix = mu * lmda + mu2 * (1 - lmda)
118 | sig_mix = sig * lmda + sig2 * (1 - lmda)
119 |
120 | return x_normed * sig_mix + mu_mix
121 |
--------------------------------------------------------------------------------
/domainbed/models/resnet_mixstyle.py:
--------------------------------------------------------------------------------
1 | """MixStyle w/ random shuffle
2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/resnet_mixstyle.py
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | from .mixstyle import MixStyle
9 |
10 | model_urls = {
11 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
12 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
13 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
14 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
15 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None):
28 | super().__init__()
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = nn.BatchNorm2d(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = nn.BatchNorm2d(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | residual = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | residual = self.downsample(x)
49 |
50 | out += residual
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None):
60 | super().__init__()
61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
64 | self.bn2 = nn.BatchNorm2d(planes)
65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
66 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.downsample = downsample
69 | self.stride = stride
70 |
71 | def forward(self, x):
72 | residual = x
73 |
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv3(out)
83 | out = self.bn3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 | out = self.relu(out)
90 |
91 | return out
92 |
93 |
94 | class ResNet(nn.Module):
95 | def __init__(
96 | self, block, layers, mixstyle_layers=[], mixstyle_p=0.5, mixstyle_alpha=0.3, **kwargs
97 | ):
98 | self.inplanes = 64
99 | super().__init__()
100 |
101 | # backbone network
102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
110 | self.global_avgpool = nn.AdaptiveAvgPool2d(1)
111 |
112 | self.mixstyle = None
113 | if mixstyle_layers:
114 | self.mixstyle = MixStyle(p=mixstyle_p, alpha=mixstyle_alpha)
115 | for layer_name in mixstyle_layers:
116 | assert layer_name in ["conv1", "conv2_x", "conv3_x", "conv4_x", "conv5_x"]
117 | print("Insert MixStyle after the following layers: {}".format(mixstyle_layers))
118 | self.mixstyle_layers = mixstyle_layers
119 |
120 | self._out_features = 512 * block.expansion
121 | self.fc = nn.Identity() # for DomainBed compatibility
122 |
123 | self._init_params()
124 |
125 | def _make_layer(self, block, planes, blocks, stride=1):
126 | downsample = None
127 | if stride != 1 or self.inplanes != planes * block.expansion:
128 | downsample = nn.Sequential(
129 | nn.Conv2d(
130 | self.inplanes,
131 | planes * block.expansion,
132 | kernel_size=1,
133 | stride=stride,
134 | bias=False,
135 | ),
136 | nn.BatchNorm2d(planes * block.expansion),
137 | )
138 |
139 | layers = []
140 | layers.append(block(self.inplanes, planes, stride, downsample))
141 | self.inplanes = planes * block.expansion
142 | for i in range(1, blocks):
143 | layers.append(block(self.inplanes, planes))
144 |
145 | return nn.Sequential(*layers)
146 |
147 | def _init_params(self):
148 | for m in self.modules():
149 | if isinstance(m, nn.Conv2d):
150 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
151 | if m.bias is not None:
152 | nn.init.constant_(m.bias, 0)
153 | elif isinstance(m, nn.BatchNorm2d):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 | elif isinstance(m, nn.BatchNorm1d):
157 | nn.init.constant_(m.weight, 1)
158 | nn.init.constant_(m.bias, 0)
159 | elif isinstance(m, nn.Linear):
160 | nn.init.normal_(m.weight, 0, 0.01)
161 | if m.bias is not None:
162 | nn.init.constant_(m.bias, 0)
163 |
164 | def compute_style(self, x):
165 | mu = x.mean(dim=[2, 3])
166 | sig = x.std(dim=[2, 3])
167 | return torch.cat([mu, sig], 1)
168 |
169 | def featuremaps(self, x):
170 | x = self.conv1(x)
171 | x = self.bn1(x)
172 | x = self.relu(x)
173 | x = self.maxpool(x)
174 |
175 | x = self.layer1(x)
176 | if "conv2_x" in self.mixstyle_layers:
177 | x = self.mixstyle(x)
178 |
179 | x = self.layer2(x)
180 | if "conv3_x" in self.mixstyle_layers:
181 | x = self.mixstyle(x)
182 |
183 | x = self.layer3(x)
184 | if "conv4_x" in self.mixstyle_layers:
185 | x = self.mixstyle(x)
186 |
187 | x = self.layer4(x)
188 | if "conv5_x" in self.mixstyle_layers:
189 | x = self.mixstyle(x)
190 |
191 | return x
192 |
193 | def forward(self, x):
194 | f = self.featuremaps(x)
195 | v = self.global_avgpool(f)
196 | return v.view(v.size(0), -1)
197 |
198 |
199 | def init_pretrained_weights(model, model_url):
200 | pretrain_dict = model_zoo.load_url(model_url)
201 | model.load_state_dict(pretrain_dict, strict=False)
202 |
203 |
204 | def resnet18_mixstyle_L234_p0d5_a0d1(pretrained=True, **kwargs):
205 | model = ResNet(
206 | block=BasicBlock,
207 | layers=[2, 2, 2, 2],
208 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"],
209 | mixstyle_p=0.5,
210 | mixstyle_alpha=0.1,
211 | )
212 |
213 | if pretrained:
214 | init_pretrained_weights(model, model_urls["resnet18"])
215 |
216 | return model
217 |
218 |
219 | def resnet50_mixstyle_L234_p0d5_a0d1(pretrained=True, **kwargs):
220 | model = ResNet(
221 | block=Bottleneck,
222 | layers=[3, 4, 6, 3],
223 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"],
224 | mixstyle_p=0.5,
225 | mixstyle_alpha=0.1,
226 | )
227 |
228 | if pretrained:
229 | init_pretrained_weights(model, model_urls["resnet50"])
230 |
231 | return model
232 |
--------------------------------------------------------------------------------
/domainbed/models/resnet_mixstyle2.py:
--------------------------------------------------------------------------------
1 | """MixStyle w/ domain label
2 | https://github.com/KaiyangZhou/mixstyle-release/blob/master/imcls/models/resnet_mixstyle2.py
3 | """
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.utils.model_zoo as model_zoo
8 |
9 | from .mixstyle import MixStyle2 as MixStyle
10 |
11 | model_urls = {
12 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
13 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
14 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
15 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
16 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | """3x3 convolution with padding"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super().__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super().__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x):
73 | residual = x
74 |
75 | out = self.conv1(x)
76 | out = self.bn1(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class ResNet(nn.Module):
96 | def __init__(
97 | self, block, layers, mixstyle_layers=[], mixstyle_p=0.5, mixstyle_alpha=0.3, **kwargs
98 | ):
99 | self.inplanes = 64
100 | super().__init__()
101 |
102 | # backbone network
103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
104 | self.bn1 = nn.BatchNorm2d(64)
105 | self.relu = nn.ReLU(inplace=True)
106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
107 | self.layer1 = self._make_layer(block, 64, layers[0])
108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
111 | self.global_avgpool = nn.AdaptiveAvgPool2d(1)
112 |
113 | self.mixstyle = None
114 | if mixstyle_layers:
115 | self.mixstyle = MixStyle(p=mixstyle_p, alpha=mixstyle_alpha)
116 | for layer_name in mixstyle_layers:
117 | assert layer_name in ["conv1", "conv2_x", "conv3_x", "conv4_x", "conv5_x"]
118 | print("Insert MixStyle after the following layers: {}".format(mixstyle_layers))
119 | self.mixstyle_layers = mixstyle_layers
120 |
121 | self._out_features = 512 * block.expansion
122 | self.fc = nn.Identity() # for DomainBed compatibility
123 |
124 | self._init_params()
125 |
126 | def _make_layer(self, block, planes, blocks, stride=1):
127 | downsample = None
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = nn.Sequential(
130 | nn.Conv2d(
131 | self.inplanes,
132 | planes * block.expansion,
133 | kernel_size=1,
134 | stride=stride,
135 | bias=False,
136 | ),
137 | nn.BatchNorm2d(planes * block.expansion),
138 | )
139 |
140 | layers = []
141 | layers.append(block(self.inplanes, planes, stride, downsample))
142 | self.inplanes = planes * block.expansion
143 | for i in range(1, blocks):
144 | layers.append(block(self.inplanes, planes))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def _init_params(self):
149 | for m in self.modules():
150 | if isinstance(m, nn.Conv2d):
151 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
152 | if m.bias is not None:
153 | nn.init.constant_(m.bias, 0)
154 | elif isinstance(m, nn.BatchNorm2d):
155 | nn.init.constant_(m.weight, 1)
156 | nn.init.constant_(m.bias, 0)
157 | elif isinstance(m, nn.BatchNorm1d):
158 | nn.init.constant_(m.weight, 1)
159 | nn.init.constant_(m.bias, 0)
160 | elif isinstance(m, nn.Linear):
161 | nn.init.normal_(m.weight, 0, 0.01)
162 | if m.bias is not None:
163 | nn.init.constant_(m.bias, 0)
164 |
165 | def compute_style(self, x):
166 | mu = x.mean(dim=[2, 3])
167 | sig = x.std(dim=[2, 3])
168 | return torch.cat([mu, sig], 1)
169 |
170 | def featuremaps(self, x):
171 | x = self.conv1(x)
172 | x = self.bn1(x)
173 | x = self.relu(x)
174 | x = self.maxpool(x)
175 |
176 | x = self.layer1(x)
177 | if "conv2_x" in self.mixstyle_layers:
178 | x = self.mixstyle(x)
179 |
180 | x = self.layer2(x)
181 | if "conv3_x" in self.mixstyle_layers:
182 | x = self.mixstyle(x)
183 |
184 | x = self.layer3(x)
185 | if "conv4_x" in self.mixstyle_layers:
186 | x = self.mixstyle(x)
187 |
188 | x = self.layer4(x)
189 | if "conv5_x" in self.mixstyle_layers:
190 | x = self.mixstyle(x)
191 |
192 | return x
193 |
194 | def forward(self, x):
195 | f = self.featuremaps(x)
196 | v = self.global_avgpool(f)
197 | return v.view(v.size(0), -1)
198 |
199 |
200 | def init_pretrained_weights(model, model_url):
201 | pretrain_dict = model_zoo.load_url(model_url)
202 | model.load_state_dict(pretrain_dict, strict=False)
203 |
204 |
205 | """
206 | Residual network configurations:
207 | --
208 | resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
209 | resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
210 | resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
211 | resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
212 | resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
213 | """
214 |
215 |
216 | def resnet18_mixstyle2_L234_p0d5_a0d1(pretrained=True, **kwargs):
217 | model = ResNet(
218 | block=BasicBlock,
219 | layers=[2, 2, 2, 2],
220 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"],
221 | mixstyle_p=0.5,
222 | mixstyle_alpha=0.1,
223 | )
224 |
225 | if pretrained:
226 | init_pretrained_weights(model, model_urls["resnet18"])
227 |
228 | return model
229 |
230 |
231 | def resnet50_mixstyle2_L234_p0d5_a0d1(pretrained=True, **kwargs):
232 | model = ResNet(
233 | block=Bottleneck,
234 | layers=[3, 4, 6, 3],
235 | mixstyle_layers=["conv2_x", "conv3_x", "conv4_x"],
236 | mixstyle_p=0.5,
237 | mixstyle_alpha=0.1,
238 | )
239 |
240 | if pretrained:
241 | init_pretrained_weights(model, model_urls["resnet50"])
242 |
243 | return model
244 |
--------------------------------------------------------------------------------
/domainbed/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .networks import *
2 | from .ur_networks import URFeaturizer, URResNet
3 |
--------------------------------------------------------------------------------
/domainbed/networks/backbones.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kakao Brain. All Rights Reserved.
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models
6 | import clip
7 |
8 |
9 | def clip_imageencoder(name):
10 | model, _preprocess = clip.load(name, device="cpu")
11 | imageencoder = model.visual
12 |
13 | return imageencoder
14 |
15 |
16 | class Identity(nn.Module):
17 | """An identity layer"""
18 |
19 | def __init__(self):
20 | super(Identity, self).__init__()
21 |
22 | def forward(self, x):
23 | return x
24 |
25 |
26 | def torchhub_load(repo, model, **kwargs):
27 | try:
28 | # torch >= 1.10
29 | network = torch.hub.load(repo, model=model, skip_validation=True, **kwargs)
30 | except TypeError:
31 | # torch 1.7.1
32 | network = torch.hub.load(repo, model=model, **kwargs)
33 |
34 | return network
35 |
36 |
37 | def get_backbone(name, preserve_readout, pretrained):
38 | if not pretrained:
39 | assert name in ["resnet50", "swag_regnety_16gf"], "Only RN50/RegNet supports non-pretrained network"
40 |
41 | if name == "resnet18":
42 | network = torchvision.models.resnet18(pretrained=True)
43 | n_outputs = 512
44 | elif name == "resnet50":
45 | network = torchvision.models.resnet50(pretrained=pretrained)
46 | n_outputs = 2048
47 | elif name == "resnet50_barlowtwins":
48 | network = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
49 | n_outputs = 2048
50 | elif name == "resnet50_moco":
51 | network = torchvision.models.resnet50()
52 |
53 | # download pretrained model of MoCo v3: https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/r-50-1000ep.pth.tar
54 | ckpt_path = "./r-50-1000ep.pth.tar"
55 |
56 | # https://github.com/facebookresearch/moco-v3/blob/main/main_lincls.py#L172
57 | print("=> loading checkpoint '{}'".format(ckpt_path))
58 | checkpoint = torch.load(ckpt_path, map_location="cpu")
59 |
60 | # rename moco pre-trained keys
61 | state_dict = checkpoint['state_dict']
62 | linear_keyword = "fc" # resnet linear keyword
63 | for k in list(state_dict.keys()):
64 | # retain only base_encoder up to before the embedding layer
65 | if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword):
66 | # remove prefix
67 | state_dict[k[len("module.base_encoder."):]] = state_dict[k]
68 | # delete renamed or unused k
69 | del state_dict[k]
70 |
71 | msg = network.load_state_dict(state_dict, strict=False)
72 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword}
73 |
74 | print("=> loaded pre-trained model '{}'".format(ckpt_path))
75 |
76 | n_outputs = 2048
77 | elif name.startswith("clip_resnet"):
78 | name = "RN" + name[11:]
79 | network = clip_imageencoder(name)
80 | n_outputs = network.output_dim
81 | elif name == "clip_vit-b16":
82 | network = clip_imageencoder("ViT-B/16")
83 | n_outputs = network.output_dim
84 | elif name == "swag_regnety_16gf":
85 | # No readout layer as default
86 | network = torchhub_load("facebookresearch/swag", model="regnety_16gf", pretrained=pretrained)
87 |
88 | network.head = nn.Sequential(
89 | nn.AdaptiveAvgPool2d(1),
90 | nn.Flatten(1),
91 | )
92 | n_outputs = 3024
93 | else:
94 | raise ValueError(name)
95 |
96 | if not preserve_readout:
97 | # remove readout layer (but left GAP and flatten)
98 | # final output shape: [B, n_outputs]
99 | if name.startswith("resnet"):
100 | del network.fc
101 | network.fc = Identity()
102 |
103 | return network, n_outputs
104 |
--------------------------------------------------------------------------------
/domainbed/networks/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from domainbed.lib import wide_resnet
8 | from domainbed.networks.backbones import get_backbone
9 |
10 |
11 | class SqueezeLastTwo(nn.Module):
12 | """
13 | A module which squeezes the last two dimensions,
14 | ordinary squeeze can be a problem for batch size 1
15 | """
16 |
17 | def __init__(self):
18 | super(SqueezeLastTwo, self).__init__()
19 |
20 | def forward(self, x):
21 | return x.view(x.shape[0], x.shape[1])
22 |
23 |
24 | class MLP(nn.Module):
25 | """Just an MLP"""
26 |
27 | def __init__(self, n_inputs, n_outputs, hparams):
28 | super(MLP, self).__init__()
29 | self.input = nn.Linear(n_inputs, hparams["mlp_width"])
30 | self.dropout = nn.Dropout(hparams["mlp_dropout"])
31 | self.hiddens = nn.ModuleList(
32 | [
33 | nn.Linear(hparams["mlp_width"], hparams["mlp_width"])
34 | for _ in range(hparams["mlp_depth"] - 2)
35 | ]
36 | )
37 | self.output = nn.Linear(hparams["mlp_width"], n_outputs)
38 | self.n_outputs = n_outputs
39 |
40 | def forward(self, x):
41 | x = self.input(x)
42 | x = self.dropout(x)
43 | x = F.relu(x)
44 | for hidden in self.hiddens:
45 | x = hidden(x)
46 | x = self.dropout(x)
47 | x = F.relu(x)
48 | x = self.output(x)
49 | return x
50 |
51 |
52 | class ResNet(torch.nn.Module):
53 | """ResNet with the softmax chopped off and the batchnorm frozen"""
54 |
55 | def __init__(self, input_shape, hparams):
56 | super(ResNet, self).__init__()
57 | self.network, self.n_outputs = get_backbone(
58 | hparams.model,
59 | preserve_readout=False,
60 | pretrained=hparams.pretrained
61 | )
62 |
63 | # adapt number of channels
64 | nc = input_shape[0]
65 | if nc != 3:
66 | tmp = self.network.conv1.weight.data.clone()
67 |
68 | self.network.conv1 = nn.Conv2d(
69 | nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
70 | )
71 |
72 | for i in range(nc):
73 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :]
74 |
75 | self.hparams = hparams
76 | self.dropout = nn.Dropout(hparams["resnet_dropout"])
77 | self.freeze_bn()
78 |
79 | def forward(self, x):
80 | """Encode x into a feature vector of size n_outputs."""
81 | return self.dropout(self.network(x))
82 |
83 | def train(self, mode=True):
84 | """
85 | Override the default train() to freeze the BN parameters
86 | """
87 | super().train(mode)
88 | self.freeze_bn()
89 |
90 | def freeze_bn(self):
91 | for m in self.network.modules():
92 | if isinstance(m, nn.BatchNorm2d):
93 | m.eval()
94 |
95 |
96 | class MNIST_CNN(nn.Module):
97 | """
98 | Hand-tuned architecture for MNIST.
99 | Weirdness I've noticed so far with this architecture:
100 | - adding a linear layer after the mean-pool in features hurts
101 | RotatedMNIST-100 generalization severely.
102 | """
103 |
104 | n_outputs = 128
105 |
106 | def __init__(self, input_shape):
107 | super(MNIST_CNN, self).__init__()
108 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1)
109 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
110 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
111 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)
112 |
113 | self.bn0 = nn.GroupNorm(8, 64)
114 | self.bn1 = nn.GroupNorm(8, 128)
115 | self.bn2 = nn.GroupNorm(8, 128)
116 | self.bn3 = nn.GroupNorm(8, 128)
117 |
118 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
119 | self.squeezeLastTwo = SqueezeLastTwo()
120 |
121 | def forward(self, x):
122 | x = self.conv1(x)
123 | x = F.relu(x)
124 | x = self.bn0(x)
125 |
126 | x = self.conv2(x)
127 | x = F.relu(x)
128 | x = self.bn1(x)
129 |
130 | x = self.conv3(x)
131 | x = F.relu(x)
132 | x = self.bn2(x)
133 |
134 | x = self.conv4(x)
135 | x = F.relu(x)
136 | x = self.bn3(x)
137 |
138 | x = self.avgpool(x)
139 | x = self.squeezeLastTwo(x)
140 | return x
141 |
142 |
143 | class ContextNet(nn.Module):
144 | def __init__(self, input_shape):
145 | super(ContextNet, self).__init__()
146 |
147 | # Keep same dimensions
148 | padding = (5 - 1) // 2
149 | self.context_net = nn.Sequential(
150 | nn.Conv2d(input_shape[0], 64, 5, padding=padding),
151 | nn.BatchNorm2d(64),
152 | nn.ReLU(),
153 | nn.Conv2d(64, 64, 5, padding=padding),
154 | nn.BatchNorm2d(64),
155 | nn.ReLU(),
156 | nn.Conv2d(64, 1, 5, padding=padding),
157 | )
158 |
159 | def forward(self, x):
160 | return self.context_net(x)
161 |
162 |
163 | def Featurizer(input_shape, hparams):
164 | """Auto-select an appropriate featurizer for the given input shape."""
165 | if len(input_shape) == 1:
166 | return MLP(input_shape[0], 128, hparams)
167 | elif input_shape[1:3] == (28, 28):
168 | return MNIST_CNN(input_shape)
169 | elif input_shape[1:3] == (32, 32):
170 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.0)
171 | elif input_shape[1:3] == (224, 224):
172 | return ResNet(input_shape, hparams)
173 | else:
174 | raise NotImplementedError(f"Input shape {input_shape} is not supported")
175 |
--------------------------------------------------------------------------------
/domainbed/networks/ur_networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kakao Brain. All Rights Reserved.
2 |
3 | import torch
4 | import torch.nn as nn
5 | from .backbones import get_backbone
6 |
7 |
8 | BLOCKNAMES = {
9 | "resnet": {
10 | "stem": ["conv1", "bn1", "relu", "maxpool"],
11 | "block1": ["layer1"],
12 | "block2": ["layer2"],
13 | "block3": ["layer3"],
14 | "block4": ["layer4"],
15 | },
16 | "clipresnet": {
17 | "stem": ["conv1", "bn1", "conv2", "bn2", "conv3", "bn3", "relu", "avgpool"],
18 | "block1": ["layer1"],
19 | "block2": ["layer2"],
20 | "block3": ["layer3"],
21 | "block4": ["layer4"],
22 | },
23 | "clipvit": { # vit-base
24 | "stem": ["conv1"],
25 | "block1": ["transformer.resblocks.0", "transformer.resblocks.1", "transformer.resblocks.2"],
26 | "block2": ["transformer.resblocks.3", "transformer.resblocks.4", "transformer.resblocks.5"],
27 | "block3": ["transformer.resblocks.6", "transformer.resblocks.7", "transformer.resblocks.8"],
28 | "block4": ["transformer.resblocks.9", "transformer.resblocks.10", "transformer.resblocks.11"],
29 | },
30 | "regnety": {
31 | "stem": ["stem"],
32 | "block1": ["trunk_output.block1"],
33 | "block2": ["trunk_output.block2"],
34 | "block3": ["trunk_output.block3"],
35 | "block4": ["trunk_output.block4"]
36 | },
37 | }
38 |
39 |
40 | def get_module(module, name):
41 | for n, m in module.named_modules():
42 | if n == name:
43 | return m
44 |
45 |
46 | def build_blocks(model, block_name_dict):
47 | # blocks = nn.ModuleList()
48 | blocks = [] # saved model can be broken...
49 | for _key, name_list in block_name_dict.items():
50 | block = nn.ModuleList()
51 | for module_name in name_list:
52 | module = get_module(model, module_name)
53 | block.append(module)
54 | blocks.append(block)
55 |
56 | return blocks
57 |
58 |
59 | def freeze_(model):
60 | """Freeze model
61 | Note that this function does not control BN
62 | """
63 | for p in model.parameters():
64 | p.requires_grad_(False)
65 |
66 |
67 | class URResNet(torch.nn.Module):
68 | """ResNet + FrozenBN + IntermediateFeatures
69 | """
70 |
71 | def __init__(self, input_shape, hparams, preserve_readout=False, freeze=None, feat_layers=None):
72 | assert input_shape == (3, 224, 224), input_shape
73 | super().__init__()
74 |
75 | self.network, self.n_outputs = get_backbone(hparams.model, preserve_readout, hparams.pretrained)
76 |
77 | if hparams.model == "resnet18":
78 | block_names = BLOCKNAMES["resnet"]
79 | elif hparams.model.startswith("resnet50"):
80 | block_names = BLOCKNAMES["resnet"]
81 | elif hparams.model.startswith("clip_resnet"):
82 | block_names = BLOCKNAMES["clipresnet"]
83 | elif hparams.model.startswith("clip_vit"):
84 | block_names = BLOCKNAMES["clipvit"]
85 | elif hparams.model == "swag_regnety_16gf":
86 | block_names = BLOCKNAMES["regnety"]
87 | elif hparams.model.startswith("vit"):
88 | block_names = BLOCKNAMES["vit"]
89 | else:
90 | raise ValueError(hparams.model)
91 |
92 | self._features = []
93 | self.feat_layers = self.build_feature_hooks(feat_layers, block_names)
94 | self.blocks = build_blocks(self.network, block_names)
95 |
96 | self.freeze(freeze)
97 |
98 | if not preserve_readout:
99 | self.dropout = nn.Dropout(hparams["resnet_dropout"])
100 | else:
101 | self.dropout = nn.Identity()
102 | assert hparams["resnet_dropout"] == 0.0
103 |
104 | self.hparams = hparams
105 | self.freeze_bn()
106 |
107 | def freeze(self, freeze):
108 | if freeze is not None:
109 | if freeze == "all":
110 | freeze_(self.network)
111 | else:
112 | for block in self.blocks[:freeze+1]:
113 | freeze_(block)
114 |
115 | def hook(self, module, input, output):
116 | self._features.append(output)
117 |
118 | def build_feature_hooks(self, feats, block_names):
119 | assert feats in ["stem_block", "block"]
120 |
121 | if feats is None:
122 | return []
123 |
124 | # build feat layers
125 | if feats.startswith("stem"):
126 | last_stem_name = block_names["stem"][-1]
127 | feat_layers = [last_stem_name]
128 | else:
129 | feat_layers = []
130 |
131 | for name, module_names in block_names.items():
132 | if name == "stem":
133 | continue
134 |
135 | module_name = module_names[-1]
136 | feat_layers.append(module_name)
137 |
138 | # print(f"feat layers = {feat_layers}")
139 |
140 | for n, m in self.network.named_modules():
141 | if n in feat_layers:
142 | m.register_forward_hook(self.hook)
143 |
144 | return feat_layers
145 |
146 | def forward(self, x, ret_feats=False):
147 | """Encode x into a feature vector of size n_outputs."""
148 | self.clear_features()
149 | out = self.dropout(self.network(x))
150 | if ret_feats:
151 | return out, self._features
152 | else:
153 | return out
154 |
155 | def clear_features(self):
156 | self._features.clear()
157 |
158 | def train(self, mode=True):
159 | """
160 | Override the default train() to freeze the BN parameters
161 | """
162 | super().train(mode)
163 | self.freeze_bn()
164 |
165 | def freeze_bn(self):
166 | for m in self.network.modules():
167 | if isinstance(m, nn.BatchNorm2d):
168 | m.eval()
169 |
170 |
171 | def URFeaturizer(input_shape, hparams, **kwargs):
172 | """Auto-select an appropriate featurizer for the given input shape."""
173 | if input_shape[1:3] == (224, 224):
174 | return URResNet(input_shape, hparams, **kwargs)
175 | else:
176 | raise NotImplementedError(f"Input shape {input_shape} is not supported")
177 |
--------------------------------------------------------------------------------
/domainbed/optimizers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_optimizer(name, params, **kwargs):
5 | name = name.lower()
6 | optimizers = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD, "adamw": torch.optim.AdamW}
7 | optim_cls = optimizers[name]
8 |
9 | return optim_cls(params, **kwargs)
10 |
--------------------------------------------------------------------------------
/domainbed/scripts/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | from torchvision.datasets import MNIST
4 | import xml.etree.ElementTree as ET
5 | from zipfile import ZipFile
6 | import argparse
7 | import tarfile
8 | import shutil
9 | import gdown
10 | import uuid
11 | import json
12 | import os
13 |
14 |
15 | # utils #######################################################################
16 |
17 |
18 | def stage_path(data_dir, name):
19 | full_path = os.path.join(data_dir, name)
20 |
21 | if not os.path.exists(full_path):
22 | os.makedirs(full_path)
23 |
24 | return full_path
25 |
26 |
27 | def download_and_extract(url, dst, remove=True):
28 | gdown.download(url, dst, quiet=False)
29 |
30 | if dst.endswith(".tar.gz"):
31 | tar = tarfile.open(dst, "r:gz")
32 | tar.extractall(os.path.dirname(dst))
33 | tar.close()
34 |
35 | if dst.endswith(".tar"):
36 | tar = tarfile.open(dst, "r:")
37 | tar.extractall(os.path.dirname(dst))
38 | tar.close()
39 |
40 | if dst.endswith(".zip"):
41 | zf = ZipFile(dst, "r")
42 | zf.extractall(os.path.dirname(dst))
43 | zf.close()
44 |
45 | if remove:
46 | os.remove(dst)
47 |
48 |
49 | # VLCS ########################################################################
50 |
51 | # Slower, but builds dataset from the original sources
52 | #
53 | # def download_vlcs(data_dir):
54 | # full_path = stage_path(data_dir, "VLCS")
55 | #
56 | # tmp_path = os.path.join(full_path, "tmp/")
57 | # if not os.path.exists(tmp_path):
58 | # os.makedirs(tmp_path)
59 | #
60 | # with open("domainbed/misc/vlcs_files.txt", "r") as f:
61 | # lines = f.readlines()
62 | # files = [line.strip().split() for line in lines]
63 | #
64 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar",
65 | # os.path.join(tmp_path, "voc2007_trainval.tar"))
66 | #
67 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz",
68 | # os.path.join(tmp_path, "caltech101.tar.gz"))
69 | #
70 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar",
71 | # os.path.join(tmp_path, "sun09_hcontext.tar"))
72 | #
73 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:")
74 | # tar.extractall(tmp_path)
75 | # tar.close()
76 | #
77 | # for src, dst in files:
78 | # class_folder = os.path.join(data_dir, dst)
79 | #
80 | # if not os.path.exists(class_folder):
81 | # os.makedirs(class_folder)
82 | #
83 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg")
84 | #
85 | # if "labelme" in src:
86 | # # download labelme from the web
87 | # gdown.download(src, dst, quiet=False)
88 | # else:
89 | # src = os.path.join(tmp_path, src)
90 | # shutil.copyfile(src, dst)
91 | #
92 | # shutil.rmtree(tmp_path)
93 |
94 |
95 | def download_vlcs(data_dir):
96 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017
97 | full_path = stage_path(data_dir, "VLCS")
98 |
99 | download_and_extract(
100 | "https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8",
101 | os.path.join(data_dir, "VLCS.tar.gz"),
102 | )
103 |
104 |
105 | # MNIST #######################################################################
106 |
107 |
108 | def download_mnist(data_dir):
109 | # Original URL: http://yann.lecun.com/exdb/mnist/
110 | full_path = stage_path(data_dir, "MNIST")
111 | MNIST(full_path, download=True)
112 |
113 |
114 | # PACS ########################################################################
115 |
116 |
117 | def download_pacs(data_dir):
118 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017
119 | full_path = stage_path(data_dir, "PACS")
120 |
121 | download_and_extract(
122 | "https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd",
123 | os.path.join(data_dir, "PACS.zip"),
124 | )
125 |
126 | os.rename(os.path.join(data_dir, "kfold"), full_path)
127 |
128 |
129 | # Office-Home #################################################################
130 |
131 |
132 | def download_office_home(data_dir):
133 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/
134 | full_path = stage_path(data_dir, "office_home")
135 |
136 | download_and_extract(
137 | "https://drive.google.com/uc?id=1uY0pj7oFsjMxRwaD3Sxy0jgel0fsYXLC",
138 | os.path.join(data_dir, "office_home.zip"),
139 | )
140 |
141 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), full_path)
142 |
143 |
144 | # DomainNET ###################################################################
145 |
146 |
147 | def download_domain_net(data_dir):
148 | # Original URL: http://ai.bu.edu/M3SDA/
149 | full_path = stage_path(data_dir, "domain_net")
150 |
151 | urls = [
152 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip",
153 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip",
154 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip",
155 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip",
156 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip",
157 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip",
158 | ]
159 |
160 | for url in urls:
161 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1]))
162 |
163 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f:
164 | for line in f.readlines():
165 | try:
166 | os.remove(os.path.join(full_path, line.strip()))
167 | except OSError:
168 | pass
169 |
170 |
171 | # TerraIncognita ##############################################################
172 |
173 |
174 | def download_terra_incognita(data_dir):
175 | # Original URL: https://beerys.github.io/CaltechCameraTraps/
176 | # New URL: http://lila.science/datasets/caltech-camera-traps
177 |
178 | full_path = stage_path(data_dir, "terra_incognita")
179 |
180 | download_and_extract(
181 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz",
182 | os.path.join(full_path, "terra_incognita_images.tar.gz"),
183 | )
184 |
185 | download_and_extract(
186 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip",
187 | os.path.join(full_path, "caltech_camera_traps.json.zip"),
188 | )
189 |
190 | include_locations = ["38", "46", "100", "43"]
191 |
192 | include_categories = [
193 | "bird",
194 | "bobcat",
195 | "cat",
196 | "coyote",
197 | "dog",
198 | "empty",
199 | "opossum",
200 | "rabbit",
201 | "raccoon",
202 | "squirrel",
203 | ]
204 |
205 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/")
206 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json")
207 | destination_folder = full_path
208 |
209 | stats = {}
210 |
211 | if not os.path.exists(destination_folder):
212 | os.mkdir(destination_folder)
213 |
214 | with open(annotations_file, "r") as f:
215 | data = json.load(f)
216 |
217 | category_dict = {}
218 | for item in data["categories"]:
219 | category_dict[item["id"]] = item["name"]
220 |
221 | for image in data["images"]:
222 | image_location = image["location"]
223 |
224 | if image_location not in include_locations:
225 | continue
226 |
227 | loc_folder = os.path.join(destination_folder, "location_" + str(image_location) + "/")
228 |
229 | if not os.path.exists(loc_folder):
230 | os.mkdir(loc_folder)
231 |
232 | image_id = image["id"]
233 | image_fname = image["file_name"]
234 |
235 | for annotation in data["annotations"]:
236 | if annotation["image_id"] == image_id:
237 | if image_location not in stats:
238 | stats[image_location] = {}
239 |
240 | category = category_dict[annotation["category_id"]]
241 |
242 | if category not in include_categories:
243 | continue
244 |
245 | if category not in stats[image_location]:
246 | stats[image_location][category] = 0
247 | else:
248 | stats[image_location][category] += 1
249 |
250 | loc_cat_folder = os.path.join(loc_folder, category + "/")
251 |
252 | if not os.path.exists(loc_cat_folder):
253 | os.mkdir(loc_cat_folder)
254 |
255 | dst_path = os.path.join(loc_cat_folder, image_fname)
256 | src_path = os.path.join(images_folder, image_fname)
257 |
258 | shutil.copyfile(src_path, dst_path)
259 |
260 | shutil.rmtree(images_folder)
261 | os.remove(annotations_file)
262 |
263 |
264 | if __name__ == "__main__":
265 | parser = argparse.ArgumentParser(description="Download datasets")
266 | parser.add_argument("--data_dir", type=str, required=True)
267 | args = parser.parse_args()
268 |
269 | download_mnist(args.data_dir)
270 | download_pacs(args.data_dir)
271 | download_vlcs(args.data_dir)
272 | download_domain_net(args.data_dir)
273 | download_office_home(args.data_dir)
274 | download_terra_incognita(args.data_dir)
275 |
--------------------------------------------------------------------------------
/domainbed/swad.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import deque
3 | import numpy as np
4 | from domainbed.lib import swa_utils
5 |
6 |
7 | class SWADBase:
8 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn):
9 | raise NotImplementedError()
10 |
11 | def get_final_model(self):
12 | raise NotImplementedError()
13 |
14 |
15 | class IIDMax(SWADBase):
16 | """SWAD start from iid max acc and select last by iid max swa acc"""
17 |
18 | def __init__(self, evaluator, **kwargs):
19 | self.iid_max_acc = 0.0
20 | self.swa_max_acc = 0.0
21 | self.avgmodel = None
22 | self.final_model = None
23 | self.evaluator = evaluator
24 |
25 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn):
26 | if self.iid_max_acc < val_acc:
27 | self.iid_max_acc = val_acc
28 | self.avgmodel = swa_utils.AveragedModel(segment_swa.module, rm_optimizer=True)
29 | self.avgmodel.start_step = segment_swa.start_step
30 |
31 | self.avgmodel.update_parameters(segment_swa.module)
32 | self.avgmodel.end_step = segment_swa.end_step
33 |
34 | # evaluate
35 | accuracies, summaries = self.evaluator.evaluate(self.avgmodel)
36 | results = {**summaries, **accuracies}
37 | prt_fn(results, self.avgmodel)
38 |
39 | swa_val_acc = results["train_out"]
40 | if swa_val_acc > self.swa_max_acc:
41 | self.swa_max_acc = swa_val_acc
42 | self.final_model = copy.deepcopy(self.avgmodel)
43 |
44 | def get_final_model(self):
45 | return self.final_model
46 |
47 |
48 | class LossValley(SWADBase):
49 | """IIDMax has a potential problem that bias to validation dataset.
50 | LossValley choose SWAD range by detecting loss valley.
51 | """
52 |
53 | def __init__(self, evaluator, n_converge, n_tolerance, tolerance_ratio, **kwargs):
54 | """
55 | Args:
56 | evaluator
57 | n_converge: converge detector window size.
58 | n_tolerance: loss min smoothing window size
59 | tolerance_ratio: decision ratio for dead loss valley
60 | """
61 | self.evaluator = evaluator
62 | self.n_converge = n_converge
63 | self.n_tolerance = n_tolerance
64 | self.tolerance_ratio = tolerance_ratio
65 |
66 | self.converge_Q = deque(maxlen=n_converge)
67 | self.smooth_Q = deque(maxlen=n_tolerance)
68 |
69 | self.final_model = None
70 |
71 | self.converge_step = None
72 | self.dead_valley = False
73 | self.threshold = None
74 |
75 | def get_smooth_loss(self, idx):
76 | smooth_loss = min([model.end_loss for model in list(self.smooth_Q)[idx:]])
77 | return smooth_loss
78 |
79 | @property
80 | def is_converged(self):
81 | return self.converge_step is not None
82 |
83 | def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn):
84 | if self.dead_valley:
85 | return
86 |
87 | frozen = copy.deepcopy(segment_swa.cpu())
88 | frozen.end_loss = val_loss
89 | self.converge_Q.append(frozen)
90 | self.smooth_Q.append(frozen)
91 |
92 | if not self.is_converged:
93 | if len(self.converge_Q) < self.n_converge:
94 | return
95 |
96 | min_idx = np.argmin([model.end_loss for model in self.converge_Q])
97 | untilmin_segment_swa = self.converge_Q[min_idx] # until-min segment swa.
98 | if min_idx == 0:
99 | self.converge_step = self.converge_Q[0].end_step
100 | self.final_model = swa_utils.AveragedModel(untilmin_segment_swa)
101 |
102 | th_base = np.mean([model.end_loss for model in self.converge_Q])
103 | self.threshold = th_base * (1.0 + self.tolerance_ratio)
104 |
105 | if self.n_tolerance < self.n_converge:
106 | for i in range(self.n_converge - self.n_tolerance):
107 | model = self.converge_Q[1 + i]
108 | self.final_model.update_parameters(
109 | model, start_step=model.start_step, end_step=model.end_step
110 | )
111 | elif self.n_tolerance > self.n_converge:
112 | converge_idx = self.n_tolerance - self.n_converge
113 | Q = list(self.smooth_Q)[: converge_idx + 1]
114 | start_idx = 0
115 | for i in reversed(range(len(Q))):
116 | model = Q[i]
117 | if model.end_loss > self.threshold:
118 | start_idx = i + 1
119 | break
120 | for model in Q[start_idx + 1 :]:
121 | self.final_model.update_parameters(
122 | model, start_step=model.start_step, end_step=model.end_step
123 | )
124 | print(
125 | f"Model converged at step {self.converge_step}, "
126 | f"Start step = {self.final_model.start_step}; "
127 | f"Threshold = {self.threshold:.6f}, "
128 | )
129 | return
130 |
131 | if self.smooth_Q[0].end_step < self.converge_step:
132 | return
133 |
134 | # converged -> loss valley
135 | min_vloss = self.get_smooth_loss(0)
136 | if min_vloss > self.threshold:
137 | self.dead_valley = True
138 | print(f"Valley is dead at step {self.final_model.end_step}")
139 | return
140 |
141 | model = self.smooth_Q[0]
142 | self.final_model.update_parameters(
143 | model, start_step=model.start_step, end_step=model.end_step
144 | )
145 |
146 | def get_final_model(self):
147 | if not self.is_converged:
148 | self.evaluator.logger.error(
149 | "Requested final model, but model is not yet converged; return last model instead"
150 | )
151 | return self.converge_Q[-1].cuda()
152 |
153 | if not self.dead_valley:
154 | self.smooth_Q.popleft()
155 | while self.smooth_Q:
156 | smooth_loss = self.get_smooth_loss(0)
157 | if smooth_loss > self.threshold:
158 | break
159 | segment_swa = self.smooth_Q.popleft()
160 | self.final_model.update_parameters(segment_swa, step=segment_swa.end_step)
161 |
162 | return self.final_model.cuda()
163 |
--------------------------------------------------------------------------------
/domainbed/trainer.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | import time
4 | import copy
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | import torch.utils.data
10 |
11 | from domainbed.datasets import get_dataset, split_dataset
12 | from domainbed import algorithms
13 | from domainbed.evaluator import Evaluator
14 | from domainbed.lib import misc
15 | from domainbed.lib import swa_utils
16 | from domainbed.lib.query import Q
17 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
18 | from domainbed import swad as swad_module
19 |
20 | # if torch.cuda.is_available():
21 | # device = "cuda"
22 | # else:
23 | # device = "cpu"
24 |
25 |
26 | def json_handler(v):
27 | if isinstance(v, (Path, range)):
28 | return str(v)
29 | raise TypeError(f"`{type(v)}` is not JSON Serializable")
30 |
31 | def interpolate_algos(sd1, sd2, sd3, sd4):
32 | return {key: (sd1[key] + sd2[key] + sd3[key] +sd4[key])/4 for key in sd1.keys()}
33 |
34 | def train(test_envs, args, hparams, n_steps, checkpoint_freq, logger, writer, target_env=None):
35 | logger.info("")
36 | # n_steps = 1
37 | #######################################################
38 | # setup dataset & loader
39 | #######################################################
40 | args.real_test_envs = test_envs # for log
41 | algorithm_class = algorithms.get_algorithm_class(args.algorithm)
42 | dataset, in_splits, out_splits = get_dataset(test_envs, args, hparams, algorithm_class)
43 | test_splits = []
44 | # if hparams.indomain_test > 0.0:
45 | # logger.info("!!! In-domain test mode On !!!")
46 | # assert hparams["val_augment"] is False, (
47 | # "indomain_test split the val set into val/test sets. "
48 | # "Therefore, the val set should be not augmented."
49 | # )
50 | # val_splits = []
51 | # for env_i, (out_split, _weights) in enumerate(out_splits):
52 | # n = len(out_split) // 2
53 | # seed = misc.seed_hash(args.trial_seed, env_i)
54 | # val_split, test_split = split_dataset(out_split, n, seed=seed)
55 | # val_splits.append((val_split, None))
56 | # test_splits.append((test_split, None))
57 | # logger.info(
58 | # "env %d: out (#%d) -> val (#%d) / test (#%d)"
59 | # % (env_i, len(out_split), len(val_split), len(test_split))
60 | # )
61 | # out_splits = val_splits
62 |
63 | if target_env is not None:
64 | testenv_name = f"te_{dataset.environments[target_env]}"
65 | logger.info(f"Target env = {target_env}")
66 | else:
67 | testenv_properties = [str(dataset.environments[i]) for i in test_envs]
68 | testenv_name = "te_" + "_".join(testenv_properties)
69 |
70 | logger.info(
71 | "Testenv name escaping {} -> {}".format(testenv_name, testenv_name.replace(".", ""))
72 | )
73 | testenv_name = testenv_name.replace(".", "")
74 | logger.info(f"Test envs = {test_envs}, name = {testenv_name}")
75 |
76 | n_envs = len(dataset)
77 | train_envs = sorted(set(range(n_envs)) - set(test_envs))
78 | iterator = misc.SplitIterator(test_envs)
79 | batch_sizes = np.full([n_envs], hparams["batch_size"], dtype=np.int)
80 | batch_sizes50 = np.full([n_envs], int(hparams["batch_size"]*3*0.5), dtype=np.int)
81 | batch_sizes25 = np.full([n_envs], int(hparams["batch_size"]*3*0.25), dtype=np.int)
82 |
83 |
84 | batch_sizes[test_envs] = 0
85 | batch_sizes = batch_sizes.tolist()
86 | batch_sizes50[test_envs] = 0
87 | batch_sizes50 = batch_sizes50.tolist()
88 | batch_sizes25[test_envs] = 0
89 | batch_sizes25 = batch_sizes25.tolist()
90 |
91 | logger.info(f"Batch sizes for CombERM branch: {batch_sizes} (total={sum(batch_sizes)})")
92 | logger.info(f"Own domain Batch sizes for each domain: {batch_sizes50} (total={sum(batch_sizes50)})")
93 | logger.info(f"Other domain Batch sizes for each domain: {batch_sizes25} (total={sum(batch_sizes25)})")
94 |
95 | # calculate steps per epoch
96 | steps_per_epochs = [
97 | len(env) / batch_size for (env, _), batch_size in iterator.train(zip(in_splits, batch_sizes))
98 | ]
99 | steps_per_epochs50 = [
100 | len(env) / batch_size50 for (env, _), batch_size50 in iterator.train(zip(in_splits, batch_sizes50))
101 | ]
102 | steps_per_epochs25 = [
103 | len(env) / batch_size25 for (env, _), batch_size25 in iterator.train(zip(in_splits, batch_sizes25))
104 | ]
105 | steps_per_epoch = min(steps_per_epochs)
106 | steps_per_epoch50 = min(steps_per_epochs50)
107 | steps_per_epoch25 = min(steps_per_epochs25)
108 | # epoch is computed by steps_per_epoch
109 | prt_steps = ", ".join([f"{step:.2f}" for step in steps_per_epochs])
110 | prt_steps50 = ", ".join([f"{step:.2f}" for step in steps_per_epochs50])
111 | prt_steps25 = ", ".join([f"{step:.2f}" for step in steps_per_epochs25])
112 | logger.info(f"steps-per-epoch for CombERM : {prt_steps} -> min = {steps_per_epoch:.2f}")
113 | logger.info(f"steps-per-epoch for own domain: {prt_steps50} -> min = {steps_per_epoch50:.2f}")
114 | logger.info(f"steps-per-epoch for other domain: {prt_steps25} -> min = {steps_per_epoch25:.2f}")
115 |
116 | # setup loaders
117 | train_loaders = [
118 | InfiniteDataLoader(
119 | dataset=env,
120 | weights=env_weights,
121 | batch_size=batch_size,
122 | num_workers=dataset.N_WORKERS
123 | )
124 | for (env, env_weights), batch_size in iterator.train(zip(in_splits, batch_sizes))
125 | ]
126 | train_loaders50 = [
127 | InfiniteDataLoader(
128 | dataset=env,
129 | weights=env_weights,
130 | batch_size=batch_size50,
131 | num_workers=dataset.N_WORKERS
132 | )
133 | for (env, env_weights), batch_size50 in iterator.train(zip(in_splits, batch_sizes50))
134 | ]
135 | train_loaders25a = [
136 | InfiniteDataLoader(
137 | dataset=env,
138 | weights=env_weights,
139 | batch_size=batch_size25,
140 | num_workers=dataset.N_WORKERS
141 | )
142 | for (env, env_weights), batch_size25 in iterator.train(zip(in_splits, batch_sizes25))
143 | ]
144 | train_loaders25b = [
145 | InfiniteDataLoader(
146 | dataset=env,
147 | weights=env_weights,
148 | batch_size=batch_size25,
149 | num_workers=dataset.N_WORKERS
150 | )
151 | for (env, env_weights), batch_size25 in iterator.train(zip(in_splits, batch_sizes25))
152 | ]
153 |
154 | # setup eval loaders
155 | eval_loaders_kwargs = []
156 | for i, (env, _) in enumerate(in_splits + out_splits + test_splits):
157 | batchsize = hparams["test_batchsize"]
158 | loader_kwargs = {"dataset": env, "batch_size": batchsize, "num_workers": dataset.N_WORKERS}
159 | if args.prebuild_loader:
160 | loader_kwargs = FastDataLoader(**loader_kwargs)
161 | eval_loaders_kwargs.append(loader_kwargs)
162 |
163 | eval_weights = [None for _, weights in (in_splits + out_splits + test_splits)]
164 | eval_loader_names = ["env{}_in".format(i) for i in range(len(in_splits))]
165 | eval_loader_names += ["env{}_out".format(i) for i in range(len(out_splits))]
166 | eval_loader_names += ["env{}_inTE".format(i) for i in range(len(test_splits))]
167 | eval_meta = list(zip(eval_loader_names, eval_loaders_kwargs, eval_weights))
168 |
169 | #######################################################
170 | # setup algorithm (model)
171 | #######################################################
172 | algorithmCE1 = algorithm_class(
173 | dataset.input_shape,
174 | dataset.num_classes,
175 | len(dataset) - len(test_envs),
176 | hparams,
177 | )
178 | algorithmCE2 = algorithm_class(
179 | dataset.input_shape,
180 | dataset.num_classes,
181 | len(dataset) - len(test_envs),
182 | hparams,
183 | )
184 | algorithmCE3 = algorithm_class(
185 | dataset.input_shape,
186 | dataset.num_classes,
187 | len(dataset) - len(test_envs),
188 | hparams,
189 | )
190 | algorithmCE4 = algorithm_class(
191 | dataset.input_shape,
192 | dataset.num_classes,
193 | len(dataset) - len(test_envs),
194 | hparams,
195 | )
196 |
197 | # algorithmCE1.to(device)
198 | # algorithmCE2.to(device)
199 | # algorithmCE3.to(device)
200 | # algorithmCE4.to(device)
201 | algorithmCE1.cuda()
202 | algorithmCE2.cuda()
203 | algorithmCE3.cuda()
204 | algorithmCE4.cuda()
205 |
206 | n_params = sum([p.numel() for p in algorithmCE1.parameters()])
207 | logger.info("# of params = %d" % n_params)
208 |
209 | train_minibatches_iterator = zip(*train_loaders)
210 | train_minibatches_iterator50 = zip(*train_loaders50)
211 | train_minibatches_iterator25a = zip(*train_loaders25a)
212 | train_minibatches_iterator25b = zip(*train_loaders25b)
213 |
214 | checkpoint_vals = collections.defaultdict(lambda: [])
215 |
216 | #######################################################
217 | # start training loop
218 | #######################################################
219 | evaluator = Evaluator(
220 | test_envs,
221 | eval_meta,
222 | n_envs,
223 | logger,
224 | evalmode=args.evalmode,
225 | debug=args.debug,
226 | target_env=target_env,
227 | )
228 |
229 | # swad = None
230 | # if hparams["swad"]:
231 | # swad_algorithm = swa_utils.AveragedModel(algorithm)
232 | # swad_cls = getattr(swad_module, hparams["swad"])
233 | # swad = swad_cls(evaluator, **hparams.swad_kwargs)
234 |
235 | swad1 = None
236 | if hparams["swad"]:
237 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1)
238 | swad_cls1 = getattr(swad_module, "LossValley")
239 | swad1 = swad_cls1(evaluator, **hparams.swad_kwargs)
240 | swad2 = None
241 | if hparams["swad"]:
242 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2)
243 | swad_cls2 = getattr(swad_module, "LossValley")
244 | swad2 = swad_cls2(evaluator, **hparams.swad_kwargs)
245 | swad3 = None
246 | if hparams["swad"]:
247 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3)
248 | swad_cls3 = getattr(swad_module, "LossValley")
249 | swad3 = swad_cls3(evaluator, **hparams.swad_kwargs)
250 | swad4 = None
251 | if hparams["swad"]:
252 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4)
253 | swad_cls4 = getattr(swad_module, "LossValley")
254 | swad4 = swad_cls4(evaluator, **hparams.swad_kwargs)
255 |
256 | last_results_keys = None
257 | records = []
258 | records_inter = []
259 | epochs_path = args.out_dir / "results.jsonl"
260 |
261 | for step in range(n_steps):
262 | step_start_time = time.time()
263 |
264 | # batches_dictlist: [ {x: ,y: }, {x: ,y: }, {x: ,y: } ]
265 | batches_dictlist = next(train_minibatches_iterator)
266 | batches_dictlist50 = next(train_minibatches_iterator50)
267 | batches_dictlist25a = next(train_minibatches_iterator25a)
268 | batches_dictlist25b = next(train_minibatches_iterator25b)
269 |
270 | # batches: {x: [ ,, ] ,y: [ ,, ] }
271 | batchesCE = misc.merge_dictlist(batches_dictlist)
272 | batches1 = {'x': [ batches_dictlist50[0]['x'],batches_dictlist25a[1]['x'],batches_dictlist25b[2]['x'] ],
273 | 'y': [ batches_dictlist50[0]['y'],batches_dictlist25a[1]['y'],batches_dictlist25b[2]['y'] ] }
274 | batches2 = {'x': [ batches_dictlist25b[0]['x'],batches_dictlist50[1]['x'],batches_dictlist25a[2]['x'] ],
275 | 'y': [ batches_dictlist25b[0]['y'],batches_dictlist50[1]['y'],batches_dictlist25a[2]['y'] ] }
276 | batches3 = {'x': [ batches_dictlist25a[0]['x'],batches_dictlist25b[1]['x'],batches_dictlist50[2]['x'] ],
277 | 'y': [ batches_dictlist25a[0]['y'],batches_dictlist25b[1]['y'],batches_dictlist50[2]['y'] ] }
278 |
279 | # to device
280 | batchesCE = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batchesCE.items()}
281 | batches1 = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batches1.items()}
282 | batches2 = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batches2.items()}
283 | batches3 = {key: [tensor.cuda() for tensor in tensorlist] for key, tensorlist in batches3.items()}
284 |
285 | inputsCE = {**batchesCE, "step":step}
286 | inputs1 = {**batches1, "step": step}
287 | inputs2 = {**batches2, "step": step}
288 | inputs3 = {**batches3, "step": step}
289 |
290 | step_valsCE1 = algorithmCE1.update(**inputs1)
291 | step_valsCE2 = algorithmCE2.update(**inputs2)
292 | step_valsCE3 = algorithmCE3.update(**inputs3)
293 | step_valsCE4 = algorithmCE4.update(**inputsCE)
294 |
295 |
296 | for key, val in step_valsCE1.items():
297 | checkpoint_vals['1_'+key].append(val)
298 | for key, val in step_valsCE2.items():
299 | checkpoint_vals['2_'+key].append(val)
300 | for key, val in step_valsCE3.items():
301 | checkpoint_vals['3_'+key].append(val)
302 | for key, val in step_valsCE4.items():
303 | checkpoint_vals['4_'+key].append(val)
304 | checkpoint_vals["step_time"].append(time.time() - step_start_time)
305 |
306 | if swad1:
307 | # swad_algorithm is segment_swa for swad
308 | swad_algorithm1.update_parameters(algorithmCE1, step=step)
309 | swad_algorithm2.update_parameters(algorithmCE2, step=step)
310 | swad_algorithm3.update_parameters(algorithmCE3, step=step)
311 | swad_algorithm4.update_parameters(algorithmCE4, step=step)
312 |
313 | if step % checkpoint_freq == 0:
314 | results = {
315 | "step": step,
316 | "epoch": step / steps_per_epoch,
317 | }
318 |
319 | for key, val in checkpoint_vals.items():
320 | results[key] = np.mean(val)
321 |
322 | eval_start_time = time.time()
323 | summaries1 = evaluator.evaluate(algorithmCE1, suffix='_1')
324 | summaries2 = evaluator.evaluate(algorithmCE2, suffix='_2')
325 | summaries3 = evaluator.evaluate(algorithmCE3, suffix='_3')
326 | summaries4 = evaluator.evaluate(algorithmCE4, suffix='_4')
327 | results["eval_time"] = time.time() - eval_start_time
328 |
329 | # results = (epochs, loss, step, step_time)
330 | results_keys = list(summaries1.keys()) + list(summaries2.keys()) + list(summaries3.keys()) + list(summaries4.keys()) + list(results.keys())
331 | # merge results
332 | results.update(summaries1)
333 | results.update(summaries2)
334 | results.update(summaries3)
335 | results.update(summaries4)
336 |
337 | # print
338 | if results_keys != last_results_keys:
339 | logger.info(misc.to_row(results_keys))
340 | last_results_keys = results_keys
341 | logger.info(misc.to_row([results[key] for key in results_keys]))
342 | records.append(copy.deepcopy(results))
343 |
344 | # update results to record
345 | results.update({"hparams": dict(hparams), "args": vars(args)})
346 |
347 | with open(epochs_path, "a") as f:
348 | f.write(json.dumps(results, sort_keys=True, default=json_handler) + "\n")
349 |
350 | checkpoint_vals = collections.defaultdict(lambda: [])
351 |
352 | writer.add_scalars_with_prefix(summaries1, step, f"{testenv_name}/summary1/")
353 | writer.add_scalars_with_prefix(summaries2, step, f"{testenv_name}/summary2/")
354 | writer.add_scalars_with_prefix(summaries3, step, f"{testenv_name}/summary3/")
355 | writer.add_scalars_with_prefix(summaries4, step, f"{testenv_name}/summary4/")
356 | # writer.add_scalars_with_prefix(accuracies, step, f"{testenv_name}/all/")
357 |
358 | if args.model_save and step >= args.model_save:
359 | ckpt_dir = args.out_dir / "checkpoints"
360 | ckpt_dir.mkdir(exist_ok=True)
361 |
362 | test_env_str = ",".join(map(str, test_envs))
363 | filename = "TE{}_{}.pth".format(test_env_str, step)
364 | if len(test_envs) > 1 and target_env is not None:
365 | train_env_str = ",".join(map(str, train_envs))
366 | filename = f"TE{target_env}_TR{train_env_str}_{step}.pth"
367 | path = ckpt_dir / filename
368 |
369 | save_dict = {
370 | "args": vars(args),
371 | "model_hparams": dict(hparams),
372 | "test_envs": test_envs,
373 | "model_dict1": algorithmCE1.cpu().state_dict(),
374 | "model_dict2": algorithmCE2.cpu().state_dict(),
375 | "model_dict3": algorithmCE3.cpu().state_dict(),
376 | "model_dict4": algorithmCE4.cpu().state_dict(),
377 | }
378 | algorithmCE1.cuda()
379 | algorithmCE2.cuda()
380 | algorithmCE3.cuda()
381 | algorithmCE4.cuda()
382 | if not args.debug:
383 | torch.save(save_dict, path)
384 | else:
385 | logger.debug("DEBUG Mode -> no save (org path: %s)" % path)
386 |
387 | # swad
388 | if swad1:
389 | def prt_results_fn(results, avgmodel):
390 | step_str = f" [{avgmodel.start_step}-{avgmodel.end_step}]"
391 | row = misc.to_row([results[key] for key in results_keys if key in results])
392 | logger.info(row + step_str)
393 |
394 | swad1.update_and_evaluate(
395 | swad_algorithm1, results["comb_val_1"], results["comb_val_loss_1"], prt_results_fn
396 | )
397 | swad2.update_and_evaluate(
398 | swad_algorithm2, results["comb_val_2"], results["comb_val_loss_2"], prt_results_fn
399 | )
400 | swad3.update_and_evaluate(
401 | swad_algorithm3, results["comb_val_3"], results["comb_val_loss_3"], prt_results_fn
402 | )
403 | swad4.update_and_evaluate(
404 | swad_algorithm4, results["comb_val_4"], results["comb_val_loss_4"], prt_results_fn
405 | )
406 |
407 | # if hasattr(swad, "dead_valley") and swad.dead_valley:
408 | # logger.info("SWAD valley is dead -> early stop !")
409 | # break
410 | if hasattr(swad1, "dead_valley") and swad1.dead_valley:
411 | logger.info("SWAD valley is dead for 1 -> early stop !")
412 | if hasattr(swad2, "dead_valley") and swad2.dead_valley:
413 | logger.info("SWAD valley is dead for 2 -> early stop !")
414 | if hasattr(swad3, "dead_valley") and swad3.dead_valley:
415 | logger.info("SWAD valley is dead for 3 -> early stop !")
416 | if hasattr(swad4, "dead_valley") and swad4.dead_valley:
417 | logger.info("SWAD valley is dead for 4 -> early stop !")
418 |
419 | if (hparams["model"]=='clip_vit-b16') and (step % 1500 == 0):
420 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1) # reset
421 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2)
422 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3)
423 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4)
424 |
425 | if step % args.tb_freq == 0:
426 | # add step values only for tb log
427 | writer.add_scalars_with_prefix(step_valsCE1, step, f"{testenv_name}/summary1/")
428 | writer.add_scalars_with_prefix(step_valsCE2, step, f"{testenv_name}/summary2/")
429 | writer.add_scalars_with_prefix(step_valsCE3, step, f"{testenv_name}/summary3/")
430 | writer.add_scalars_with_prefix(step_valsCE4, step, f"{testenv_name}/summary4/")
431 |
432 | if step%args.inter_freq==0 and step!=0:
433 | if args.algorithm in ['DANN', 'CDANN']:
434 | inter_state_dict = interpolate_algos(algorithmCE1.featurizer.state_dict(), algorithmCE2.featurizer.state_dict(), algorithmCE3.featurizer.state_dict(), algorithmCE4.featurizer.state_dict())
435 | algorithmCE1.featurizer.load_state_dict(inter_state_dict)
436 | algorithmCE2.featurizer.load_state_dict(inter_state_dict)
437 | algorithmCE3.featurizer.load_state_dict(inter_state_dict)
438 | algorithmCE4.featurizer.load_state_dict(inter_state_dict)
439 | inter_state_dict2 = interpolate_algos(algorithmCE1.classifier.state_dict(), algorithmCE2.classifier.state_dict(), algorithmCE3.classifier.state_dict(), algorithmCE4.classifier.state_dict())
440 | algorithmCE1.classifier.load_state_dict(inter_state_dict2)
441 | algorithmCE2.classifier.load_state_dict(inter_state_dict2)
442 | algorithmCE3.classifier.load_state_dict(inter_state_dict2)
443 | algorithmCE4.classifier.load_state_dict(inter_state_dict2)
444 | inter_state_dict3 = interpolate_algos(algorithmCE1.discriminator.state_dict(), algorithmCE2.discriminator.state_dict(), algorithmCE3.discriminator.state_dict(), algorithmCE4.discriminator.state_dict())
445 | algorithmCE1.discriminator.load_state_dict(inter_state_dict3)
446 | algorithmCE2.discriminator.load_state_dict(inter_state_dict3)
447 | algorithmCE3.discriminator.load_state_dict(inter_state_dict3)
448 | algorithmCE4.discriminator.load_state_dict(inter_state_dict3)
449 |
450 | elif args.algorithm in ['SagNet']:
451 | inter_state_dict = interpolate_algos(algorithmCE1.network_f.state_dict(), algorithmCE2.network_f.state_dict(), algorithmCE3.network_f.state_dict(), algorithmCE4.network_f.state_dict())
452 | algorithmCE1.network_f.load_state_dict(inter_state_dict)
453 | algorithmCE2.network_f.load_state_dict(inter_state_dict)
454 | algorithmCE3.network_f.load_state_dict(inter_state_dict)
455 | algorithmCE4.network_f.load_state_dict(inter_state_dict)
456 | inter_state_dict2 = interpolate_algos(algorithmCE1.network_c.state_dict(), algorithmCE2.network_c.state_dict(), algorithmCE3.network_c.state_dict(), algorithmCE4.network_c.state_dict())
457 | algorithmCE1.network_c.load_state_dict(inter_state_dict2)
458 | algorithmCE2.network_c.load_state_dict(inter_state_dict2)
459 | algorithmCE3.network_c.load_state_dict(inter_state_dict2)
460 | algorithmCE4.network_c.load_state_dict(inter_state_dict2)
461 | inter_state_dict3 = interpolate_algos(algorithmCE1.network_s.state_dict(), algorithmCE2.network_s.state_dict(), algorithmCE3.network_s.state_dict(), algorithmCE4.network_s.state_dict())
462 | algorithmCE1.network_s.load_state_dict(inter_state_dict3)
463 | algorithmCE2.network_s.load_state_dict(inter_state_dict3)
464 | algorithmCE3.network_s.load_state_dict(inter_state_dict3)
465 | algorithmCE4.network_s.load_state_dict(inter_state_dict3)
466 |
467 | else:
468 | inter_state_dict = interpolate_algos(algorithmCE1.network.state_dict(), algorithmCE2.network.state_dict(), algorithmCE3.network.state_dict(), algorithmCE4.network.state_dict())
469 | algorithmCE1.network.load_state_dict(inter_state_dict)
470 | algorithmCE2.network.load_state_dict(inter_state_dict)
471 | algorithmCE3.network.load_state_dict(inter_state_dict)
472 | algorithmCE4.network.load_state_dict(inter_state_dict)
473 |
474 | logger.info(f"Evaluating interpolated model at {step} step")
475 | summaries_inter = evaluator.evaluate(algorithmCE1, suffix='_from_inter')
476 | inter_results = {"inter_step": step, "inter_epoch": step / steps_per_epoch}
477 | inter_results_keys = list(summaries_inter.keys()) + list(inter_results.keys())
478 | inter_results.update(summaries_inter)
479 | logger.info(misc.to_row([inter_results[key] for key in inter_results_keys]))
480 | records_inter.append(copy.deepcopy(inter_results))
481 | writer.add_scalars_with_prefix(summaries_inter, step, f"{testenv_name}/summary_inter/")
482 |
483 | # find best
484 | logger.info("---")
485 | # print(records)
486 | records = Q(records)
487 | records_inter = Q(records_inter)
488 |
489 | # print(len(records))
490 | # print(records)
491 |
492 | # 1
493 | oracle_best1 = records.argmax("test_out_1")["test_in_1"]
494 | iid_best1 = records.argmax("comb_val_1")["test_in_1"]
495 | inDom1 = records.argmax("comb_val_1")["comb_val_1"]
496 | # own_best1 = records.argmax("own_val_from_first")["test_in_from_first"]
497 | last1 = records[-1]["test_in_1"]
498 | # 2
499 | oracle_best2 = records.argmax("test_out_2")["test_in_2"]
500 | iid_best2 = records.argmax("comb_val_2")["test_in_2"]
501 | inDom2 = records.argmax("comb_val_2")["comb_val_2"]
502 | # own_best2 = records.argmax("own_val_from_second")["test_in_from_second"]
503 | last2 = records[-1]["test_in_2"]
504 | # 3
505 | oracle_best3 = records.argmax("test_out_3")["test_in_3"]
506 | iid_best3 = records.argmax("comb_val_3")["test_in_3"]
507 | inDom3 = records.argmax("comb_val_3")["comb_val_3"]
508 | # own_best3 = records.argmax("own_val_from_third")["test_in_from_third"]
509 | last3 = records[-1]["test_in_3"]
510 | # CE
511 | oracle_best4 = records.argmax("test_out_4")["test_in_4"]
512 | iid_best4 = records.argmax("comb_val_4")["test_in_4"]
513 | inDom4 = records.argmax("comb_val_4")["comb_val_4"]
514 | last4 = records[-1]["test_in_4"]
515 | # inter
516 | oracle_best_inter = records_inter.argmax("test_out_from_inter")["test_in_from_inter"]
517 | iid_best_inter = records_inter.argmax("comb_val_from_inter")["test_in_from_inter"]
518 | inDom_inter = records_inter.argmax("comb_val_from_inter")["comb_val_from_inter"]
519 |
520 | # if hparams.indomain_test:
521 | # # if test set exist, use test set for indomain results
522 | # in_key = "train_inTE"
523 | # else:
524 | # in_key = "train_out"
525 |
526 | # iid_best_indomain = records.argmax("train_out")[in_key]
527 | # last_indomain = records[-1][in_key]
528 |
529 | ret = {
530 | "oracle_1": oracle_best1,
531 | "iid_1": iid_best1,
532 | # "own_1": own_best1,
533 | "inDom1": inDom1,
534 | "last_1": last1,
535 | "oracle_2": oracle_best2,
536 | "iid_2": iid_best2,
537 | # "own_2": own_best2,
538 | "inDom2":inDom2,
539 | "last_2": last2,
540 | "oracle_3": oracle_best3,
541 | "iid_3": iid_best3,
542 | # "own_3": own_best3,
543 | "inDom3":inDom3,
544 | "last_3": last3,
545 | "oracle_4": oracle_best4,
546 | "iid_4": iid_best4,
547 | "inDom4": inDom4,
548 | "last_4": last4,
549 | "oracle_inter": oracle_best_inter,
550 | "iid_inter": iid_best_inter,
551 | "inDom_inter":inDom_inter,
552 | }
553 |
554 | # Evaluate SWAD
555 | if swad1:
556 | swad_algorithm1 = swad1.get_final_model()
557 | swad_algorithm2 = swad2.get_final_model()
558 | swad_algorithm3 = swad3.get_final_model()
559 | swad_algorithm4 = swad4.get_final_model()
560 | if hparams["freeze_bn"] is False:
561 | n_steps = 500 if not args.debug else 10
562 | logger.warning(f"Update SWAD BN statistics for {n_steps} steps ...")
563 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm1, n_steps)
564 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm2, n_steps)
565 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm3, n_steps)
566 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm4, n_steps)
567 |
568 | logger.warning("Evaluate SWAD ...")
569 | summaries_swad1 = evaluator.evaluate(swad_algorithm1, suffix='_s1')
570 | summaries_swad2 = evaluator.evaluate(swad_algorithm2, suffix='_s2')
571 | summaries_swad3 = evaluator.evaluate(swad_algorithm3, suffix='_s3')
572 | summaries_swad4 = evaluator.evaluate(swad_algorithm4, suffix='_s4')
573 |
574 | swad_results = {**summaries_swad1, **summaries_swad2, **summaries_swad3, **summaries_swad4}
575 | step_str = f" [{swad_algorithm1.start_step}-{swad_algorithm1.end_step}] (N={swad_algorithm1.n_averaged}) || [{swad_algorithm2.start_step}-{swad_algorithm2.end_step}] (N={swad_algorithm2.n_averaged}) || [{swad_algorithm3.start_step}-{swad_algorithm3.end_step}] (N={swad_algorithm3.n_averaged}) || [{swad_algorithm4.start_step}-{swad_algorithm4.end_step}] (N={swad_algorithm4.n_averaged})"
576 | row = misc.to_row([swad_results[key] for key in list(swad_results.keys())]) + step_str
577 | logger.info(row)
578 |
579 | ret["SWAD 1"] = swad_results["test_in_s1"]
580 | ret["SWAD 1 (inDom)"] = swad_results["comb_val_s1"]
581 | ret["SWAD 2"] = swad_results["test_in_s2"]
582 | ret["SWAD 2 (inDom)"] = swad_results["comb_val_s2"]
583 | ret["SWAD 3"] = swad_results["test_in_s3"]
584 | ret["SWAD 3 (inDom)"] = swad_results["comb_val_s3"]
585 | ret["SWAD 4"] = swad_results["test_in_s4"]
586 | ret["SWAD 4 (inDom)"] = swad_results["comb_val_s4"]
587 |
588 | save_dict = {
589 | "args": vars(args),
590 | "model_hparams": dict(hparams),
591 | "test_envs": test_envs,
592 | "SWAD_1": swad_algorithm1.state_dict(),
593 | "SWAD_2": swad_algorithm2.state_dict(),
594 | "SWAD_3": swad_algorithm3.state_dict(),
595 | "SWAD_4": swad_algorithm4.state_dict(),
596 | }
597 |
598 | if args.algorithm in ['DANN', 'CDANN']:
599 | inter_state_dict = interpolate_algos(swad_algorithm1.module.featurizer.state_dict(), swad_algorithm2.module.featurizer.state_dict(), swad_algorithm3.module.featurizer.state_dict(), swad_algorithm4.module.featurizer.state_dict())
600 | swad_algorithm1.module.featurizer.load_state_dict(inter_state_dict)
601 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.classifier.state_dict(), swad_algorithm2.module.classifier.state_dict(), swad_algorithm3.module.classifier.state_dict(), swad_algorithm4.module.classifier.state_dict())
602 | swad_algorithm1.module.classifier.load_state_dict(inter_state_dict2)
603 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.discriminator.state_dict(), swad_algorithm2.module.discriminator.state_dict(), swad_algorithm3.module.discriminator.state_dict(), swad_algorithm4.module.discriminator.state_dict())
604 | swad_algorithm1.module.discriminator.load_state_dict(inter_state_dict3)
605 |
606 | elif args.algorithm in ['SagNet']:
607 | inter_state_dict = interpolate_algos(swad_algorithm1.module.network_f.state_dict(), swad_algorithm2.module.network_f.state_dict(), swad_algorithm3.module.network_f.state_dict(), swad_algorithm4.module.network_f.state_dict())
608 | swad_algorithm1.module.network_f.load_state_dict(inter_state_dict)
609 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.network_c.state_dict(), swad_algorithm2.module.network_c.state_dict(), swad_algorithm3.module.network_c.state_dict(), swad_algorithm4.module.network_c.state_dict())
610 | swad_algorithm1.module.network_c.load_state_dict(inter_state_dict2)
611 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.network_s.state_dict(), swad_algorithm2.module.network_s.state_dict(), swad_algorithm3.module.network_s.state_dict(), swad_algorithm4.module.network_s.state_dict())
612 | swad_algorithm1.module.network_s.load_state_dict(inter_state_dict3)
613 |
614 | else:
615 | inter_state_dict = interpolate_algos(swad_algorithm1.network.state_dict(), swad_algorithm2.network.state_dict(), swad_algorithm3.network.state_dict(), swad_algorithm4.network.state_dict())
616 | swad_algorithm1.network.load_state_dict(inter_state_dict)
617 |
618 | logger.info(f"Evaluating interpolated model of SWAD models")
619 | summaries_swadinter = evaluator.evaluate(swad_algorithm1, suffix='_from_swadinter')
620 | swadinter_results = {**summaries_swadinter}
621 | logger.info(misc.to_row([swadinter_results[key] for key in list(swadinter_results.keys())]))
622 | ret["SWAD INTER"] = swadinter_results["test_in_from_swadinter"]
623 | ret["SWAD INTER (inDom)"] = swadinter_results["comb_val_from_swadinter"]
624 | save_dict["SWAD_INTER"] = inter_state_dict
625 |
626 |
627 | ckpt_dir = args.out_dir / "checkpoints"
628 | ckpt_dir.mkdir(exist_ok=True)
629 | test_env_str = ",".join(map(str, test_envs))
630 | filename = f"TE{test_env_str}.pth"
631 | if len(test_envs) > 1 and target_env is not None:
632 | train_env_str = ",".join(map(str, train_envs))
633 | filename = f"TE{target_env}_TR{train_env_str}.pth"
634 | path = ckpt_dir / filename
635 | if swad1:
636 | torch.save(save_dict, path)
637 |
638 |
639 | for k, acc in ret.items():
640 | logger.info(f"{k} = {acc:.3%}")
641 |
642 | return ret, records
643 |
644 |
--------------------------------------------------------------------------------
/domainbed/trainer_DN.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | import time
4 | import copy
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | import torch.utils.data
10 |
11 | from domainbed.datasets import get_dataset, split_dataset
12 | from domainbed import algorithms
13 | from domainbed.evaluator import Evaluator
14 | from domainbed.lib import misc
15 | from domainbed.lib import swa_utils
16 | from domainbed.lib.query import Q
17 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
18 | from domainbed import swad as swad_module
19 |
20 | # if torch.cuda.is_available():
21 | # device = "cuda"
22 | # else:
23 | # device = "cpu"
24 |
25 |
26 | def json_handler(v):
27 | if isinstance(v, (Path, range)):
28 | return str(v)
29 | raise TypeError(f"`{type(v)}` is not JSON Serializable")
30 |
31 |
32 | def interpolate_algos(sd1, sd2, sd3, sd4, sd5, sd6):
33 | return {key: (sd1[key] + sd2[key] + sd3[key] +sd4[key]+ sd5[key] +sd6[key])/6 for key in sd1.keys()}
34 |
35 | def train(test_envs, args, hparams, n_steps, checkpoint_freq, logger, writer, target_env=None):
36 | logger.info("")
37 | # n_steps = 1
38 | #######################################################
39 | # setup dataset & loader
40 | #######################################################
41 | args.real_test_envs = test_envs # for log
42 | algorithm_class = algorithms.get_algorithm_class(args.algorithm)
43 | dataset, in_splits, out_splits = get_dataset(test_envs, args, hparams, algorithm_class)
44 | test_splits = []
45 | # if hparams.indomain_test > 0.0:
46 | # logger.info("!!! In-domain test mode On !!!")
47 | # assert hparams["val_augment"] is False, (
48 | # "indomain_test split the val set into val/test sets. "
49 | # "Therefore, the val set should be not augmented."
50 | # )
51 | # val_splits = []
52 | # for env_i, (out_split, _weights) in enumerate(out_splits):
53 | # n = len(out_split) // 2
54 | # seed = misc.seed_hash(args.trial_seed, env_i)
55 | # val_split, test_split = split_dataset(out_split, n, seed=seed)
56 | # val_splits.append((val_split, None))
57 | # test_splits.append((test_split, None))
58 | # logger.info(
59 | # "env %d: out (#%d) -> val (#%d) / test (#%d)"
60 | # % (env_i, len(out_split), len(val_split), len(test_split))
61 | # )
62 | # out_splits = val_splits
63 |
64 | if target_env is not None:
65 | testenv_name = f"te_{dataset.environments[target_env]}"
66 | logger.info(f"Target env = {target_env}")
67 | else:
68 | testenv_properties = [str(dataset.environments[i]) for i in test_envs]
69 | testenv_name = "te_" + "_".join(testenv_properties)
70 |
71 | logger.info(
72 | "Testenv name escaping {} -> {}".format(testenv_name, testenv_name.replace(".", ""))
73 | )
74 | testenv_name = testenv_name.replace(".", "")
75 | logger.info(f"Test envs = {test_envs}, name = {testenv_name}")
76 |
77 | n_envs = len(dataset)
78 | train_envs = sorted(set(range(n_envs)) - set(test_envs))
79 | iterator = misc.SplitIterator(test_envs)
80 | batch_sizes = np.full([n_envs], hparams["batch_size"], dtype=np.int)
81 | batch_sizes50 = np.full([n_envs], int(hparams["batch_size"]*5*0.4), dtype=np.int)
82 | batch_sizes25 = np.full([n_envs], int(hparams["batch_size"]*5*0.15), dtype=np.int)
83 |
84 |
85 | batch_sizes[test_envs] = 0
86 | batch_sizes = batch_sizes.tolist()
87 | batch_sizes50[test_envs] = 0
88 | batch_sizes50 = batch_sizes50.tolist()
89 | batch_sizes25[test_envs] = 0
90 | batch_sizes25 = batch_sizes25.tolist()
91 |
92 | logger.info(f"Batch sizes for CombERM branch: {batch_sizes} (total={sum(batch_sizes)})")
93 | logger.info(f"Own domain Batch sizes for each domain: {batch_sizes50} (total={sum(batch_sizes50)})")
94 | logger.info(f"Other domain Batch sizes for each domain: {batch_sizes25} (total={sum(batch_sizes25)})")
95 |
96 | # calculate steps per epoch
97 | steps_per_epochs = [
98 | len(env) / batch_size for (env, _), batch_size in iterator.train(zip(in_splits, batch_sizes))
99 | ]
100 | steps_per_epochs50 = [
101 | len(env) / batch_size50 for (env, _), batch_size50 in iterator.train(zip(in_splits, batch_sizes50))
102 | ]
103 | steps_per_epochs25 = [
104 | len(env) / batch_size25 for (env, _), batch_size25 in iterator.train(zip(in_splits, batch_sizes25))
105 | ]
106 | steps_per_epoch = min(steps_per_epochs)
107 | steps_per_epoch50 = min(steps_per_epochs50)
108 | steps_per_epoch25 = min(steps_per_epochs25)
109 |
110 | # epoch is computed by steps_per_epoch
111 | prt_steps = ", ".join([f"{step:.2f}" for step in steps_per_epochs])
112 | prt_steps50 = ", ".join([f"{step:.2f}" for step in steps_per_epochs50])
113 | prt_steps25 = ", ".join([f"{step:.2f}" for step in steps_per_epochs25])
114 | logger.info(f"steps-per-epoch for CombERM : {prt_steps} -> min = {steps_per_epoch:.2f}")
115 | logger.info(f"steps-per-epoch for own domain: {prt_steps50} -> min = {steps_per_epoch50:.2f}")
116 | logger.info(f"steps-per-epoch for other domain: {prt_steps25} -> min = {steps_per_epoch25:.2f}")
117 |
118 | # setup loaders
119 | train_loaders = [
120 | InfiniteDataLoader(
121 | dataset=env,
122 | weights=env_weights,
123 | batch_size=batch_size,
124 | num_workers=dataset.N_WORKERS
125 | )
126 | for (env, env_weights), batch_size in iterator.train(zip(in_splits, batch_sizes))
127 | ]
128 | train_loaders50 = [
129 | InfiniteDataLoader(
130 | dataset=env,
131 | weights=env_weights,
132 | batch_size=batch_size50,
133 | num_workers=dataset.N_WORKERS
134 | )
135 | for (env, env_weights), batch_size50 in iterator.train(zip(in_splits, batch_sizes50))
136 | ]
137 | train_loaders25 = [
138 | InfiniteDataLoader(
139 | dataset=env,
140 | weights=env_weights,
141 | batch_size=batch_size25,
142 | num_workers=dataset.N_WORKERS
143 | )
144 | for (env, env_weights), batch_size25 in iterator.train(zip(in_splits, batch_sizes25))
145 | ]
146 |
147 | # setup eval loaders
148 | eval_loaders_kwargs = []
149 | for i, (env, _) in enumerate(in_splits + out_splits + test_splits):
150 | batchsize = hparams["test_batchsize"]
151 | loader_kwargs = {"dataset": env, "batch_size": batchsize, "num_workers": dataset.N_WORKERS}
152 | if args.prebuild_loader:
153 | loader_kwargs = FastDataLoader(**loader_kwargs)
154 | eval_loaders_kwargs.append(loader_kwargs)
155 |
156 | eval_weights = [None for _, weights in (in_splits + out_splits + test_splits)]
157 | eval_loader_names = ["env{}_in".format(i) for i in range(len(in_splits))]
158 | eval_loader_names += ["env{}_out".format(i) for i in range(len(out_splits))]
159 | eval_loader_names += ["env{}_inTE".format(i) for i in range(len(test_splits))]
160 | eval_meta = list(zip(eval_loader_names, eval_loaders_kwargs, eval_weights))
161 |
162 | #######################################################
163 | # setup algorithm (model)
164 | #######################################################
165 | algorithmCE1 = algorithm_class(
166 | dataset.input_shape,
167 | dataset.num_classes,
168 | len(dataset) - len(test_envs),
169 | hparams,
170 | )
171 | algorithmCE2 = algorithm_class(
172 | dataset.input_shape,
173 | dataset.num_classes,
174 | len(dataset) - len(test_envs),
175 | hparams,
176 | )
177 | algorithmCE3 = algorithm_class(
178 | dataset.input_shape,
179 | dataset.num_classes,
180 | len(dataset) - len(test_envs),
181 | hparams,
182 | )
183 | algorithmCE4 = algorithm_class(
184 | dataset.input_shape,
185 | dataset.num_classes,
186 | len(dataset) - len(test_envs),
187 | hparams,
188 | )
189 | algorithmCE5 = algorithm_class(
190 | dataset.input_shape,
191 | dataset.num_classes,
192 | len(dataset) - len(test_envs),
193 | hparams,
194 | )
195 | algorithmCE6 = algorithm_class(
196 | dataset.input_shape,
197 | dataset.num_classes,
198 | len(dataset) - len(test_envs),
199 | hparams,
200 | )
201 |
202 | algorithmCE1.cuda()
203 | algorithmCE2.cuda()
204 | algorithmCE3.cuda()
205 | algorithmCE4.cuda()
206 | algorithmCE5.cuda()
207 | algorithmCE6.cuda()
208 |
209 | n_params = sum([p.numel() for p in algorithmCE1.parameters()])
210 | logger.info("# of params = %d" % n_params)
211 |
212 | train_minibatches_iterator = zip(*train_loaders)
213 | train_minibatches_iterator50 = zip(*train_loaders50)
214 | train_minibatches_iterator25 = zip(*train_loaders25)
215 | # train_minibatches_iterator25b = zip(*train_loaders25b)
216 |
217 | checkpoint_vals = collections.defaultdict(lambda: [])
218 |
219 | #######################################################
220 | # start training loop
221 | #######################################################
222 | evaluator = Evaluator(
223 | test_envs,
224 | eval_meta,
225 | n_envs,
226 | logger,
227 | evalmode=args.evalmode,
228 | debug=args.debug,
229 | target_env=target_env,
230 | )
231 |
232 | # swad = None
233 | # if hparams["swad"]:
234 | # swad_algorithm = swa_utils.AveragedModel(algorithm)
235 | # swad_cls = getattr(swad_module, hparams["swad"])
236 | # swad = swad_cls(evaluator, **hparams.swad_kwargs)
237 |
238 | swad1 = None
239 | if hparams["swad"]:
240 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1)
241 | swad_cls1 = getattr(swad_module, "LossValley")
242 | swad1 = swad_cls1(evaluator, **hparams.swad_kwargs)
243 | swad2 = None
244 | if hparams["swad"]:
245 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2)
246 | swad_cls2 = getattr(swad_module, "LossValley")
247 | swad2 = swad_cls2(evaluator, **hparams.swad_kwargs)
248 | swad3 = None
249 | if hparams["swad"]:
250 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3)
251 | swad_cls3 = getattr(swad_module, "LossValley")
252 | swad3 = swad_cls3(evaluator, **hparams.swad_kwargs)
253 | swad4 = None
254 | if hparams["swad"]:
255 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4)
256 | swad_cls4 = getattr(swad_module, "LossValley")
257 | swad4 = swad_cls4(evaluator, **hparams.swad_kwargs)
258 | swad5 = None
259 | if hparams["swad"]:
260 | swad_algorithm5 = swa_utils.AveragedModel(algorithmCE5)
261 | swad_cls5 = getattr(swad_module, "LossValley")
262 | swad5 = swad_cls5(evaluator, **hparams.swad_kwargs)
263 | swad6 = None
264 | if hparams["swad"]:
265 | swad_algorithm6 = swa_utils.AveragedModel(algorithmCE6)
266 | swad_cls6 = getattr(swad_module,"LossValley")
267 | swad6 = swad_cls6(evaluator, **hparams.swad_kwargs)
268 |
269 | last_results_keys = None
270 | records = []
271 | records_inter = []
272 | epochs_path = args.out_dir / "results.jsonl"
273 |
274 | for step in range(n_steps):
275 | step_start_time = time.time()
276 |
277 | # batches_dictlist: [ {x: ,y: }, {x: ,y: }, {x: ,y: } ]
278 | batches_dictlist = next(train_minibatches_iterator)
279 | batches_dictlist50 = next(train_minibatches_iterator50)
280 | batches_dictlist25 = next(train_minibatches_iterator25)
281 |
282 | # batches: {x: [ ,, ] ,y: [ ,, ] }
283 | batchesCE = misc.merge_dictlist(batches_dictlist)
284 | batches1 = {'x': [ batches_dictlist50[0]['x'],batches_dictlist25[1]['x'],batches_dictlist25[2]['x'],batches_dictlist25[3]['x'],batches_dictlist25[4]['x'] ],
285 | 'y': [ batches_dictlist50[0]['y'],batches_dictlist25[1]['y'],batches_dictlist25[2]['y'],batches_dictlist25[3]['y'],batches_dictlist25[4]['y'] ] }
286 | batches2 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist50[1]['x'],batches_dictlist25[2]['x'],batches_dictlist25[3]['x'],batches_dictlist25[4]['x'] ],
287 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist50[1]['y'],batches_dictlist25[2]['y'],batches_dictlist25[3]['y'],batches_dictlist25[4]['y'] ] }
288 | batches3 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist25[1]['x'],batches_dictlist50[2]['x'],batches_dictlist25[3]['x'],batches_dictlist25[4]['x'] ],
289 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist25[1]['y'],batches_dictlist50[2]['y'],batches_dictlist25[3]['y'],batches_dictlist25[4]['y'] ] }
290 | batches4 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist25[1]['x'],batches_dictlist25[2]['x'],batches_dictlist50[3]['x'],batches_dictlist25[4]['x'] ],
291 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist25[1]['y'],batches_dictlist25[2]['y'],batches_dictlist50[3]['y'],batches_dictlist25[4]['y'] ] }
292 | batches5 = {'x': [ batches_dictlist25[0]['x'],batches_dictlist25[1]['x'],batches_dictlist25[2]['x'],batches_dictlist25[3]['x'],batches_dictlist50[4]['x'] ],
293 | 'y': [ batches_dictlist25[0]['y'],batches_dictlist25[1]['y'],batches_dictlist25[2]['y'],batches_dictlist25[3]['y'],batches_dictlist50[4]['y'] ] }
294 |
295 | # to device
296 | batchesCE = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batchesCE.items()}
297 | batches1 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches1.items()}
298 | batches2 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches2.items()}
299 | batches3 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches3.items()}
300 | batches4 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches4.items()}
301 | batches5 = {key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches5.items()}
302 |
303 | inputsCE = {**batchesCE, "step":step}
304 | inputs1 = {**batches1, "step": step}
305 | inputs2 = {**batches2, "step": step}
306 | inputs3 = {**batches3, "step": step}
307 | inputs4 = {**batches4, "step": step}
308 | inputs5 = {**batches5, "step": step}
309 |
310 | step_valsCE1 = algorithmCE1.update(**inputs1)
311 | step_valsCE2 = algorithmCE2.update(**inputs2)
312 | step_valsCE3 = algorithmCE3.update(**inputs3)
313 | step_valsCE4 = algorithmCE4.update(**inputs4)
314 | step_valsCE5 = algorithmCE5.update(**inputs5)
315 | step_valsCE6 = algorithmCE6.update(**inputsCE)
316 |
317 |
318 | for key, val in step_valsCE1.items():
319 | checkpoint_vals['1_'+key].append(val)
320 | for key, val in step_valsCE2.items():
321 | checkpoint_vals['2_'+key].append(val)
322 | for key, val in step_valsCE3.items():
323 | checkpoint_vals['3_'+key].append(val)
324 | for key, val in step_valsCE4.items():
325 | checkpoint_vals['4_'+key].append(val)
326 | for key, val in step_valsCE5.items():
327 | checkpoint_vals['5_'+key].append(val)
328 | for key, val in step_valsCE6.items():
329 | checkpoint_vals['6_'+key].append(val)
330 | checkpoint_vals["step_time"].append(time.time() - step_start_time)
331 |
332 | if swad1:
333 | # swad_algorithm is segment_swa for swad
334 | swad_algorithm1.update_parameters(algorithmCE1, step=step)
335 | swad_algorithm2.update_parameters(algorithmCE2, step=step)
336 | swad_algorithm3.update_parameters(algorithmCE3, step=step)
337 | swad_algorithm4.update_parameters(algorithmCE4, step=step)
338 | swad_algorithm5.update_parameters(algorithmCE5, step=step)
339 | swad_algorithm6.update_parameters(algorithmCE6, step=step)
340 |
341 | if step % checkpoint_freq == 0:
342 | results = {
343 | "step": step,
344 | "epoch": step / steps_per_epoch,
345 | }
346 |
347 | for key, val in checkpoint_vals.items():
348 | results[key] = np.mean(val)
349 |
350 | eval_start_time = time.time()
351 | summaries1 = evaluator.evaluate(algorithmCE1, suffix='_1')
352 | summaries2 = evaluator.evaluate(algorithmCE2, suffix='_2')
353 | summaries3 = evaluator.evaluate(algorithmCE3, suffix='_3')
354 | summaries4 = evaluator.evaluate(algorithmCE4, suffix='_4')
355 | summaries5 = evaluator.evaluate(algorithmCE5, suffix='_5')
356 | summaries6 = evaluator.evaluate(algorithmCE6, suffix='_6')
357 | results["eval_time"] = time.time() - eval_start_time
358 |
359 | # results = (epochs, loss, step, step_time)
360 | results_keys = list(summaries1.keys()) + list(summaries2.keys()) + list(summaries3.keys()) + list(summaries4.keys()) + list(summaries5.keys()) + list(summaries6.keys()) + list(results.keys())
361 | # merge results
362 | results.update(summaries1)
363 | results.update(summaries2)
364 | results.update(summaries3)
365 | results.update(summaries4)
366 | results.update(summaries5)
367 | results.update(summaries6)
368 |
369 | # print
370 | if results_keys != last_results_keys:
371 | logger.info(misc.to_row(results_keys))
372 | last_results_keys = results_keys
373 | logger.info(misc.to_row([results[key] for key in results_keys]))
374 | records.append(copy.deepcopy(results))
375 |
376 | # update results to record
377 | results.update({"hparams": dict(hparams), "args": vars(args)})
378 |
379 | with open(epochs_path, "a") as f:
380 | f.write(json.dumps(results, sort_keys=True, default=json_handler) + "\n")
381 |
382 | checkpoint_vals = collections.defaultdict(lambda: [])
383 |
384 | writer.add_scalars_with_prefix(summaries1, step, f"{testenv_name}/summary1/")
385 | writer.add_scalars_with_prefix(summaries2, step, f"{testenv_name}/summary2/")
386 | writer.add_scalars_with_prefix(summaries3, step, f"{testenv_name}/summary3/")
387 | writer.add_scalars_with_prefix(summaries4, step, f"{testenv_name}/summary4/")
388 | writer.add_scalars_with_prefix(summaries5, step, f"{testenv_name}/summary5/")
389 | writer.add_scalars_with_prefix(summaries6, step, f"{testenv_name}/summary6/")
390 | # writer.add_scalars_with_prefix(accuracies, step, f"{testenv_name}/all/")
391 |
392 | if args.model_save and step >= args.model_save:
393 | ckpt_dir = args.out_dir / "checkpoints"
394 | ckpt_dir.mkdir(exist_ok=True)
395 |
396 | test_env_str = ",".join(map(str, test_envs))
397 | filename = "TE{}_{}.pth".format(test_env_str, step)
398 | if len(test_envs) > 1 and target_env is not None:
399 | train_env_str = ",".join(map(str, train_envs))
400 | filename = f"TE{target_env}_TR{train_env_str}_{step}.pth"
401 | path = ckpt_dir / filename
402 |
403 | save_dict = {
404 | "args": vars(args),
405 | "model_hparams": dict(hparams),
406 | "test_envs": test_envs,
407 | "model_dict1": algorithmCE1.cpu().state_dict(),
408 | "model_dict2": algorithmCE2.cpu().state_dict(),
409 | "model_dict3": algorithmCE3.cpu().state_dict(),
410 | "model_dict4": algorithmCE4.cpu().state_dict(),
411 | "model_dict5": algorithmCE5.cpu().state_dict(),
412 | "model_dict6": algorithmCE6.cpu().state_dict(),
413 | }
414 | algorithmCE1.cuda()
415 | algorithmCE2.cuda()
416 | algorithmCE3.cuda()
417 | algorithmCE4.cuda()
418 | algorithmCE5.cuda()
419 | algorithmCE6.cuda()
420 | if not args.debug:
421 | torch.save(save_dict, path)
422 | else:
423 | logger.debug("DEBUG Mode -> no save (org path: %s)" % path)
424 |
425 | # swad
426 | if swad1:
427 | def prt_results_fn(results, avgmodel):
428 | step_str = f" [{avgmodel.start_step}-{avgmodel.end_step}]"
429 | row = misc.to_row([results[key] for key in results_keys if key in results])
430 | logger.info(row + step_str)
431 |
432 | swad1.update_and_evaluate(
433 | swad_algorithm1, results["comb_val_1"], results["comb_val_loss_1"], prt_results_fn
434 | )
435 | swad2.update_and_evaluate(
436 | swad_algorithm2, results["comb_val_2"], results["comb_val_loss_2"], prt_results_fn
437 | )
438 | swad3.update_and_evaluate(
439 | swad_algorithm3, results["comb_val_3"], results["comb_val_loss_3"], prt_results_fn
440 | )
441 | swad4.update_and_evaluate(
442 | swad_algorithm4, results["comb_val_4"], results["comb_val_loss_4"], prt_results_fn
443 | )
444 | swad5.update_and_evaluate(
445 | swad_algorithm5, results["comb_val_5"], results["comb_val_loss_5"], prt_results_fn
446 | )
447 | swad6.update_and_evaluate(
448 | swad_algorithm6, results["comb_val_6"], results["comb_val_loss_6"], prt_results_fn
449 | )
450 |
451 | # if hasattr(swad, "dead_valley") and swad.dead_valley:
452 | # logger.info("SWAD valley is dead -> early stop !")
453 | # break
454 | if hasattr(swad1, "dead_valley") and swad1.dead_valley:
455 | logger.info("SWAD valley is dead for 1 -> early stop !")
456 | if hasattr(swad2, "dead_valley") and swad2.dead_valley:
457 | logger.info("SWAD valley is dead for 2 -> early stop !")
458 | if hasattr(swad3, "dead_valley") and swad3.dead_valley:
459 | logger.info("SWAD valley is dead for 3 -> early stop !")
460 | if hasattr(swad4, "dead_valley") and swad4.dead_valley:
461 | logger.info("SWAD valley is dead for 4 -> early stop !")
462 | if hasattr(swad5, "dead_valley") and swad5.dead_valley:
463 | logger.info("SWAD valley is dead for 5 -> early stop !")
464 | if hasattr(swad6, "dead_valley") and swad6.dead_valley:
465 | logger.info("SWAD valley is dead for 6 -> early stop !")
466 |
467 |
468 | if (hparams["model"]=='clip_vit-b16') and (step % 2000 == 0):
469 | swad_algorithm1 = swa_utils.AveragedModel(algorithmCE1) # reset
470 | swad_algorithm2 = swa_utils.AveragedModel(algorithmCE2)
471 | swad_algorithm3 = swa_utils.AveragedModel(algorithmCE3)
472 | swad_algorithm4 = swa_utils.AveragedModel(algorithmCE4)
473 | swad_algorithm5 = swa_utils.AveragedModel(algorithmCE5)
474 | swad_algorithm6 = swa_utils.AveragedModel(algorithmCE6)
475 |
476 | if step % args.tb_freq == 0:
477 | # add step values only for tb log
478 | writer.add_scalars_with_prefix(step_valsCE1, step, f"{testenv_name}/summary1/")
479 | writer.add_scalars_with_prefix(step_valsCE2, step, f"{testenv_name}/summary2/")
480 | writer.add_scalars_with_prefix(step_valsCE3, step, f"{testenv_name}/summary3/")
481 | writer.add_scalars_with_prefix(step_valsCE4, step, f"{testenv_name}/summary4/")
482 | writer.add_scalars_with_prefix(step_valsCE5, step, f"{testenv_name}/summary5/")
483 | writer.add_scalars_with_prefix(step_valsCE6, step, f"{testenv_name}/summary6/")
484 |
485 | if step%args.inter_freq==0 and step!=0:
486 | if args.algorithm in ['DANN', 'CDANN']:
487 | inter_state_dict = interpolate_algos(algorithmCE1.featurizer.state_dict(), algorithCE2.featurizer.state_dict(), algorithmCE3.featurizer.state_dict(), algorithmCE4.featurizer.state_dict(), algorithmCE5.featurizer.state_dict(), algorithCE6.featurizer.state_dict())
488 | algorithmCE1.featurizer.load_state_dict(inter_state_dict)
489 | algorithmCE2.featurizer.load_state_dict(inter_state_dict)
490 | algorithmCE3.featurizer.load_state_dict(inter_state_dict)
491 | algorithmCE4.featurizer.load_state_dict(inter_state_dict)
492 | algorithmCE5.featurizer.load_state_dict(inter_state_dict)
493 | algorithmCE6.featurizer.load_state_dict(inter_state_dict)
494 | inter_state_dict2 = interpolate_algos(algorithmCE1.classifier.state_dict(), algorithCE2.classifier.state_dict(), algorithmCE3.classifier.state_dict(), algorithmCE4.classifier.state_dict(), algorithmCE5.classifier.state_dict(), algorithCE6.classifier.state_dict())
495 | algorithmCE1.classifier.load_state_dict(inter_state_dict)
496 | algorithmCE2.classifier.load_state_dict(inter_state_dict)
497 | algorithmCE3.classifier.load_state_dict(inter_state_dict)
498 | algorithmCE4.classifier.load_state_dict(inter_state_dict)
499 | algorithmCE5.classifier.load_state_dict(inter_state_dict)
500 | algorithmCE6.classifier.load_state_dict(inter_state_dict)
501 | inter_state_dict3 = interpolate_algos(algorithmCE1.discriminator.state_dict(), algorithCE2.discriminator.state_dict(), algorithmCE3.discriminator.state_dict(), algorithmCE4.discriminator.state_dict(), algorithmCE5.discriminator.state_dict(), algorithCE6.discriminator.state_dict())
502 | algorithmCE1.discriminator.load_state_dict(inter_state_dict)
503 | algorithmCE2.discriminator.load_state_dict(inter_state_dict)
504 | algorithmCE3.discriminator.load_state_dict(inter_state_dict)
505 | algorithmCE4.discriminator.load_state_dict(inter_state_dict)
506 | algorithmCE5.discriminator.load_state_dict(inter_state_dict)
507 | algorithmCE6.discriminator.load_state_dict(inter_state_dict)
508 | elif args.algorithm in ['SagNet']:
509 | inter_state_dict = interpolate_algos(algorithmCE1.network_f.state_dict(), algorithmCE2.network_f.state_dict(), algorithmCE3.network_f.state_dict(), algorithmCE4.network_f.state_dict(), algorithmCE5.network_f.state_dict(), algorithmCE6.network_f.state_dict())
510 | algorithmCE1.network_f.load_state_dict(inter_state_dict)
511 | algorithmCE2.network_f.load_state_dict(inter_state_dict)
512 | algorithmCE3.network_f.load_state_dict(inter_state_dict)
513 | algorithmCE4.network_f.load_state_dict(inter_state_dict)
514 | algorithmCE5.network_f.load_state_dict(inter_state_dict)
515 | algorithmCE6.network_f.load_state_dict(inter_state_dict)
516 | inter_state_dict2 = interpolate_algos(algorithmCE1.network_c.state_dict(), algorithmCE2.network_c.state_dict(), algorithmCE3.network_c.state_dict(), algorithmCE4.network_c.state_dict(), algorithmCE5.network_c.state_dict(), algorithmCE6.network_c.state_dict())
517 | algorithmCE1.network_c.load_state_dict(inter_state_dict)
518 | algorithmCE2.network_c.load_state_dict(inter_state_dict)
519 | algorithmCE3.network_c.load_state_dict(inter_state_dict)
520 | algorithmCE4.network_c.load_state_dict(inter_state_dict)
521 | algorithmCE5.network_c.load_state_dict(inter_state_dict)
522 | algorithmCE6.network_c.load_state_dict(inter_state_dict)
523 | inter_state_dict3 = interpolate_algos(algorithmCE1.network_s.state_dict(), algorithmCE2.network_s.state_dict(), algorithmCE3.network_s.state_dict(), algorithmCE4.network_s.state_dict(), algorithmCE5.network_s.state_dict(), algorithmCE6.network_s.state_dict())
524 | algorithmCE1.network_s.load_state_dict(inter_state_dict)
525 | algorithmCE2.network_s.load_state_dict(inter_state_dict)
526 | algorithmCE3.network_s.load_state_dict(inter_state_dict)
527 | algorithmCE4.network_s.load_state_dict(inter_state_dict)
528 | algorithmCE5.network_s.load_state_dict(inter_state_dict)
529 | algorithmCE6.network_s.load_state_dict(inter_state_dict)
530 | else:
531 | inter_state_dict = interpolate_algos(algorithmCE1.network.state_dict(), algorithmCE2.network.state_dict(), algorithmCE3.network.state_dict(), algorithmCE4.network.state_dict(), algorithmCE5.network.state_dict(), algorithmCE6.network.state_dict())
532 | algorithmCE1.network.load_state_dict(inter_state_dict)
533 | algorithmCE2.network.load_state_dict(inter_state_dict)
534 | algorithmCE3.network.load_state_dict(inter_state_dict)
535 | algorithmCE4.network.load_state_dict(inter_state_dict)
536 | algorithmCE5.network.load_state_dict(inter_state_dict)
537 | algorithmCE6.network.load_state_dict(inter_state_dict)
538 |
539 | logger.info(f"Evaluating interpolated model at {step} step")
540 | summaries_inter = evaluator.evaluate(algorithmCE1, suffix='_from_inter')
541 | inter_results = {"inter_step": step, "inter_epoch": step / steps_per_epoch}
542 | inter_results_keys = list(summaries_inter.keys()) + list(inter_results.keys())
543 | inter_results.update(summaries_inter)
544 | logger.info(misc.to_row([inter_results[key] for key in inter_results_keys]))
545 | records_inter.append(copy.deepcopy(inter_results))
546 | writer.add_scalars_with_prefix(summaries_inter, step, f"{testenv_name}/summary_inter/")
547 |
548 | # find best
549 | logger.info("---")
550 | # print(records)
551 | records = Q(records)
552 | records_inter = Q(records_inter)
553 |
554 | # print(len(records))
555 | # print(records)
556 |
557 | # 1
558 | oracle_best1 = records.argmax("test_out_1")["test_in_1"]
559 | iid_best1 = records.argmax("comb_val_1")["test_in_1"]
560 | inDom1 = records.argmax("comb_val_1")["comb_val_1"]
561 | # own_best1 = records.argmax("own_val_from_first")["test_in_from_first"]
562 | last1 = records[-1]["test_in_1"]
563 | # 2
564 | oracle_best2 = records.argmax("test_out_2")["test_in_2"]
565 | iid_best2 = records.argmax("comb_val_2")["test_in_2"]
566 | inDom2 = records.argmax("comb_val_2")["comb_val_2"]
567 | # own_best2 = records.argmax("own_val_from_second")["test_in_from_second"]
568 | last2 = records[-1]["test_in_2"]
569 | # 3
570 | oracle_best3 = records.argmax("test_out_3")["test_in_3"]
571 | iid_best3 = records.argmax("comb_val_3")["test_in_3"]
572 | inDom3 = records.argmax("comb_val_3")["comb_val_3"]
573 | # own_best3 = records.argmax("own_val_from_third")["test_in_from_third"]
574 | last3 = records[-1]["test_in_3"]
575 | # CE
576 | oracle_best4 = records.argmax("test_out_4")["test_in_4"]
577 | iid_best4 = records.argmax("comb_val_4")["test_in_4"]
578 | inDom4 = records.argmax("comb_val_4")["comb_val_4"]
579 | last4 = records[-1]["test_in_4"]
580 |
581 | oracle_best5 = records.argmax("test_out_5")["test_in_5"]
582 | iid_best5 = records.argmax("comb_val_5")["test_in_5"]
583 | inDom5 = records.argmax("comb_val_5")["comb_val_5"]
584 | last5 = records[-1]["test_in_5"]
585 |
586 | oracle_best6 = records.argmax("test_out_6")["test_in_6"]
587 | iid_best6 = records.argmax("comb_val_6")["test_in_6"]
588 | inDom6 = records.argmax("comb_val_6")["comb_val_6"]
589 | last6 = records[-1]["test_in_6"]
590 | # inter
591 | oracle_best_inter = records_inter.argmax("test_out_from_inter")["test_in_from_inter"]
592 | iid_best_inter = records_inter.argmax("comb_val_from_inter")["test_in_from_inter"]
593 | inDom_inter = records_inter.argmax("comb_val_from_inter")["comb_val_from_inter"]
594 |
595 | # if hparams.indomain_test:
596 | # # if test set exist, use test set for indomain results
597 | # in_key = "train_inTE"
598 | # else:
599 | # in_key = "train_out"
600 |
601 | # iid_best_indomain = records.argmax("train_out")[in_key]
602 | # last_indomain = records[-1][in_key]
603 |
604 | ret = {
605 | "oracle_1": oracle_best1,
606 | "iid_1": iid_best1,
607 | # "own_1": own_best1,
608 | "inDom1": inDom1,
609 | "last_1": last1,
610 | "oracle_2": oracle_best2,
611 | "iid_2": iid_best2,
612 | # "own_2": own_best2,
613 | "inDom2":inDom2,
614 | "last_2": last2,
615 | "oracle_3": oracle_best3,
616 | "iid_3": iid_best3,
617 | # "own_3": own_best3,
618 | "inDom3":inDom3,
619 | "last_3": last3,
620 | "oracle_4": oracle_best4,
621 | "iid_4": iid_best4,
622 | "inDom4": inDom4,
623 | "last_4": last4,
624 |
625 | "oracle_5": oracle_best5,
626 | "iid_5": iid_best5,
627 | "inDom5": inDom5,
628 | "last_5": last5,
629 |
630 | "oracle_6": oracle_best6,
631 | "iid_6": iid_best6,
632 | "inDom6": inDom6,
633 | "last_6": last6,
634 | # "last (inD)": last_indomain,
635 | # "iid (inD)": iid_best_indomain,
636 | "oracle_inter": oracle_best_inter,
637 | "iid_inter": iid_best_inter,
638 | "inDom_inter":inDom_inter,
639 | }
640 |
641 | # Evaluate SWAD
642 | if swad1:
643 | swad_algorithm1 = swad1.get_final_model()
644 | swad_algorithm2 = swad2.get_final_model()
645 | swad_algorithm3 = swad3.get_final_model()
646 | swad_algorithm4 = swad4.get_final_model()
647 | swad_algorithm5 = swad5.get_final_model()
648 | swad_algorithm6 = swad6.get_final_model()
649 | if hparams["freeze_bn"] is False:
650 | n_steps = 500 if not args.debug else 10
651 | logger.warning(f"Update SWAD BN statistics for {n_steps} steps ...")
652 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm1, n_steps)
653 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm2, n_steps)
654 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm3, n_steps)
655 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm4, n_steps)
656 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm5, n_steps)
657 | swa_utils.update_bn(train_minibatches_iterator, swad_algorithm6, n_steps)
658 |
659 | logger.warning("Evaluate SWAD ...")
660 | summaries_swad1 = evaluator.evaluate(swad_algorithm1, suffix='_s1')
661 | summaries_swad2 = evaluator.evaluate(swad_algorithm2, suffix='_s2')
662 | summaries_swad3 = evaluator.evaluate(swad_algorithm3, suffix='_s3')
663 | summaries_swad4 = evaluator.evaluate(swad_algorithm4, suffix='_s4')
664 | summaries_swad5 = evaluator.evaluate(swad_algorithm5, suffix='_s5')
665 | summaries_swad6 = evaluator.evaluate(swad_algorithm6, suffix='_s6')
666 | # accuracies, summaries = evaluator.evaluate(swad_algorithm)
667 |
668 | # results = {**summaries, **accuracies}
669 | # start = swad_algorithm.start_step
670 | # end = swad_algorithm.end_step
671 | # step_str = f" [{start}-{end}] (N={swad_algorithm.n_averaged})"
672 | # row = misc.to_row([results[key] for key in results_keys if key in results]) + step_str
673 | # logger.info(row)
674 |
675 | swad_results = {**summaries_swad1, **summaries_swad2, **summaries_swad3, **summaries_swad4, **summaries_swad5, **summaries_swad6}
676 | step_str = f" [{swad_algorithm1.start_step}-{swad_algorithm1.end_step}] (N={swad_algorithm1.n_averaged}) || [{swad_algorithm2.start_step}-{swad_algorithm2.end_step}] (N={swad_algorithm2.n_averaged}) || [{swad_algorithm3.start_step}-{swad_algorithm3.end_step}] (N={swad_algorithm3.n_averaged}) || [{swad_algorithm4.start_step}-{swad_algorithm4.end_step}] (N={swad_algorithm4.n_averaged}) || [{swad_algorithm5.start_step}-{swad_algorithm5.end_step}] (N={swad_algorithm5.n_averaged}) || [{swad_algorithm6.start_step}-{swad_algorithm6.end_step}] (N={swad_algorithm6.n_averaged})"
677 | row = misc.to_row([swad_results[key] for key in list(swad_results.keys())]) + step_str
678 | logger.info(row)
679 |
680 | ret["SWAD 1"] = swad_results["test_in_s1"]
681 | ret["SWAD 1 (inDom)"] = swad_results["comb_val_s1"]
682 | ret["SWAD 2"] = swad_results["test_in_s2"]
683 | ret["SWAD 2 (inDom)"] = swad_results["comb_val_s2"]
684 | ret["SWAD 3"] = swad_results["test_in_s3"]
685 | ret["SWAD 3 (inDom)"] = swad_results["comb_val_s3"]
686 | ret["SWAD 4"] = swad_results["test_in_s4"]
687 | ret["SWAD 4 (inDom)"] = swad_results["comb_val_s4"]
688 | ret["SWAD 5"] = swad_results["test_in_s5"]
689 | ret["SWAD 5 (inDom)"] = swad_results["comb_val_s5"]
690 | ret["SWAD 6"] = swad_results["test_in_s6"]
691 | ret["SWAD 6 (inDom)"] = swad_results["comb_val_s6"]
692 |
693 | save_dict = {
694 | "args": vars(args),
695 | "model_hparams": dict(hparams),
696 | "test_envs": test_envs,
697 | "SWAD_1": swad_algorithm1.network.state_dict(),
698 | "SWAD_2": swad_algorithm2.network.state_dict(),
699 | "SWAD_3": swad_algorithm3.network.state_dict(),
700 | "SWAD_4": swad_algorithm4.network.state_dict(),
701 | "SWAD_5": swad_algorithm5.network.state_dict(),
702 | "SWAD_6": swad_algorithm6.network.state_dict(),
703 | }
704 |
705 | if args.algorithm in ['DANN', 'CDANN']:
706 | inter_state_dict = interpolate_algos(swad_algorithm1.module.featurizer.state_dict(), swad_algorithm2.module.featurizer.state_dict(), swad_algorithm3.module.featurizer.state_dict(), swad_algorithm4.module.featurizer.state_dict(), swad_algorithm5.module.featurizer.state_dict(), swad_algorithm6.module.featurizer.state_dict())
707 | swad_algorithm1.module.featurizer.load_state_dict(inter_state_dict)
708 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.classifier.state_dict(), swad_algorithm2.module.classifier.state_dict(), swad_algorithm3.module.classifier.state_dict(), swad_algorithm4.module.classifier.state_dict(), swad_algorithm5.module.classifier.state_dict(), swad_algorithm6.module.classifier.state_dict())
709 | swad_algorithm1.module.classifier.load_state_dict(inter_state_dict2)
710 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.discriminator.state_dict(), swad_algorithm2.module.discriminator.state_dict(), swad_algorithm3.module.discriminator.state_dict(), swad_algorithm4.module.discriminator.state_dict(), swad_algorithm5.module.discriminator.state_dict(), swad_algorithm6.module.discriminator.state_dict())
711 | swad_algorithm1.module.discriminator.load_state_dict(inter_state_dict3)
712 |
713 | elif args.algorithm in ['SagNet']:
714 | inter_state_dict = interpolate_algos(swad_algorithm1.module.network_f.state_dict(), swad_algorithm2.module.network_f.state_dict(), swad_algorithm3.module.network_f.state_dict(), swad_algorithm4.module.network_f.state_dict(), swad_algorithm5.module.network_f.state_dict(), swad_algorithm6.module.network_f.state_dict())
715 | swad_algorithm1.module.network_f.load_state_dict(inter_state_dict)
716 | inter_state_dict2 = interpolate_algos(swad_algorithm1.module.network_c.state_dict(), swad_algorithm2.module.network_c.state_dict(), swad_algorithm3.module.network_c.state_dict(), swad_algorithm4.module.network_c.state_dict(), swad_algorithm5.module.network_c.state_dict(), swad_algorithm6.module.network_c.state_dict())
717 | swad_algorithm1.module.network_c.load_state_dict(inter_state_dict2)
718 | inter_state_dict3 = interpolate_algos(swad_algorithm1.module.network_s.state_dict(), swad_algorithm2.module.network_s.state_dict(), swad_algorithm3.module.network_s.state_dict(), swad_algorithm4.module.network_s.state_dict(), swad_algorithm5.module.network_s.state_dict(), swad_algorithm6.module.network_s.state_dict())
719 | swad_algorithm1.module.network_s.load_state_dict(inter_state_dict3)
720 | else:
721 | inter_state_dict = interpolate_algos(swad_algorithm1.network.state_dict(), swad_algorithm2.network.state_dict(), swad_algorithm3.network.state_dict(), swad_algorithm4.network.state_dict(), swad_algorithm5.network.state_dict(), swad_algorithm6.network.state_dict())
722 | swad_algorithm1.network.load_state_dict(inter_state_dict)
723 |
724 | logger.info(f"Evaluating interpolated model of SWAD models")
725 | summaries_swadinter = evaluator.evaluate(swad_algorithm1, suffix='_from_swadinter')
726 | swadinter_results = {**summaries_swadinter}
727 | logger.info(misc.to_row([swadinter_results[key] for key in list(swadinter_results.keys())]))
728 | ret["SWAD INTER"] = swadinter_results["test_in_from_swadinter"]
729 | ret["SWAD INTER (inDom)"] = swadinter_results["comb_val_from_swadinter"]
730 | save_dict["SWAD_INTER"] = inter_state_dict
731 |
732 |
733 | ckpt_dir = args.out_dir / "checkpoints"
734 | ckpt_dir.mkdir(exist_ok=True)
735 | test_env_str = ",".join(map(str, test_envs))
736 | filename = f"TE{test_env_str}.pth"
737 | if len(test_envs) > 1 and target_env is not None:
738 | train_env_str = ",".join(map(str, train_envs))
739 | filename = f"TE{target_env}_TR{train_env_str}.pth"
740 | path = ckpt_dir / filename
741 | if not args.debug:
742 | torch.save(save_dict, path)
743 |
744 |
745 | for k, acc in ret.items():
746 | logger.info(f"{k} = {acc:.3%}")
747 |
748 | return ret, records
749 |
750 |
--------------------------------------------------------------------------------
/media/DART_pic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/DART_pic.png
--------------------------------------------------------------------------------
/media/DG_combined_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/DG_combined_results.png
--------------------------------------------------------------------------------
/media/DG_main_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/DG_main_results.png
--------------------------------------------------------------------------------
/media/ID_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/ID_results.png
--------------------------------------------------------------------------------
/media/model_optimization_trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/DART/62274e83d2f08eb416db61d0957476c53fde9361/media/model_optimization_trajectory.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gdown==4.2.0
2 | numpy==1.21.4
3 | Pillow==9.0.1
4 | prettytable==2.1.0
5 | sconf==0.2.3
6 | tensorboardX==2.5
7 | torch==1.7.1
8 | torchvision==0.8.2
9 | git+https://github.com/openai/CLIP.git
10 |
--------------------------------------------------------------------------------
/train_all.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import random
4 | import sys
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import PIL
9 | import torch
10 | import torchvision
11 | from sconf import Config
12 | from prettytable import PrettyTable
13 |
14 | from domainbed.datasets import get_dataset
15 | from domainbed import hparams_registry
16 | from domainbed.lib import misc
17 | from domainbed.lib.writers import get_writer
18 | from domainbed.lib.logger import Logger
19 | from domainbed.trainer import train
20 | from domainbed.trainer_DN import train as train_dn
21 |
22 |
23 | def main():
24 | parser = argparse.ArgumentParser(description="Domain generalization", allow_abbrev=False)
25 | parser.add_argument("name", type=str)
26 | parser.add_argument("configs", nargs="*")
27 | parser.add_argument("--data_dir", type=str, default="datadir/")
28 | parser.add_argument("--dataset", type=str, default="PACS")
29 | parser.add_argument("--algorithm", type=str, default="ERM")
30 | parser.add_argument(
31 | "--trial_seed",
32 | type=int,
33 | default=0,
34 | help="Trial number (used for seeding split_dataset and random_hparams).",
35 | )
36 | parser.add_argument("--seed", type=int, default=0, help="Seed for everything else")
37 | parser.add_argument(
38 | "--steps", type=int, default=None, help="Number of steps. Default is dataset-dependent."
39 | )
40 | parser.add_argument(
41 | "--checkpoint_freq",
42 | type=int,
43 | default=None,
44 | help="Checkpoint every N steps. Default is dataset-dependent.",
45 | )
46 | parser.add_argument("--test_envs", type=int, nargs="+", default=None)
47 | parser.add_argument("--holdout_fraction", type=float, default=0.2)
48 | parser.add_argument("--model_save", default=None, type=int, help="Model save start step")
49 | # parser.add_argument("--deterministic", action="store_true")
50 | parser.add_argument("--tb_freq", default=10)
51 | parser.add_argument("--debug", action="store_true", help="Run w/ debug mode")
52 | parser.add_argument("--show", action="store_true", help="Show args and hparams w/o run")
53 | parser.add_argument(
54 | "--evalmode",
55 | default="fast",
56 | help="[fast, all]. if fast, ignore train_in datasets in evaluation time.",
57 | )
58 | parser.add_argument("--prebuild_loader", action="store_true", help="Pre-build eval loaders")
59 | parser.add_argument("--inter_freq", type=int, default=600, help="interpolate after inter_freq steps")
60 | args, left_argv = parser.parse_known_args()
61 | args.deterministic = True
62 |
63 | # setup hparams
64 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset)
65 |
66 | keys = ["config.yaml"] + args.configs
67 | keys = [open(key, encoding="utf8") for key in keys]
68 | hparams = Config(*keys, default=hparams)
69 | hparams.argv_update(left_argv)
70 |
71 | # setup debug
72 | if args.debug:
73 | args.checkpoint_freq = 5
74 | args.steps = 10
75 | args.name += "_debug"
76 |
77 | timestamp = misc.timestamp()
78 | args.unique_name = f"{timestamp}_{args.name}"
79 |
80 | # path setup
81 | args.work_dir = Path(".")
82 | args.data_dir = Path(args.data_dir)
83 |
84 | args.out_root = args.work_dir / Path("train_output") / args.dataset
85 | args.out_dir = args.out_root / args.unique_name
86 | args.out_dir.mkdir(exist_ok=True, parents=True)
87 |
88 | writer = get_writer(args.out_root / "runs" / args.unique_name)
89 | logger = Logger.get(args.out_dir / "log.txt")
90 | if args.debug:
91 | logger.setLevel("DEBUG")
92 | cmd = " ".join(sys.argv)
93 | logger.info(f"Command :: {cmd}")
94 |
95 | logger.nofmt("Environment:")
96 | logger.nofmt("\tPython: {}".format(sys.version.split(" ")[0]))
97 | logger.nofmt("\tPyTorch: {}".format(torch.__version__))
98 | logger.nofmt("\tTorchvision: {}".format(torchvision.__version__))
99 | logger.nofmt("\tCUDA: {}".format(torch.version.cuda))
100 | logger.nofmt("\tCUDNN: {}".format(torch.backends.cudnn.version()))
101 | logger.nofmt("\tNumPy: {}".format(np.__version__))
102 | logger.nofmt("\tPIL: {}".format(PIL.__version__))
103 |
104 | # Different to DomainBed, we support CUDA only.
105 | assert torch.cuda.is_available(), "CUDA is not available"
106 |
107 | logger.nofmt("Args:")
108 | for k, v in sorted(vars(args).items()):
109 | logger.nofmt("\t{}: {}".format(k, v))
110 |
111 | logger.nofmt("HParams:")
112 | for line in hparams.dumps().split("\n"):
113 | logger.nofmt("\t" + line)
114 |
115 | if args.show:
116 | exit()
117 |
118 | # seed
119 | random.seed(args.seed)
120 | np.random.seed(args.seed)
121 | torch.manual_seed(args.seed)
122 | torch.backends.cudnn.deterministic = args.deterministic
123 | torch.backends.cudnn.benchmark = not args.deterministic
124 |
125 | # Dummy datasets for logging information.
126 | # Real dataset will be re-assigned in train function.
127 | # test_envs only decide transforms; simply set to zero.
128 | dataset, _in_splits, _out_splits = get_dataset([0], args, hparams)
129 |
130 | # print dataset information
131 | logger.nofmt("Dataset:")
132 | logger.nofmt(f"\t[{args.dataset}] #envs={len(dataset)}, #classes={dataset.num_classes}")
133 | for i, env_property in enumerate(dataset.environments):
134 | logger.nofmt(f"\tenv{i}: {env_property} (#{len(dataset[i])})")
135 | logger.nofmt("")
136 |
137 | n_steps = args.steps or dataset.N_STEPS
138 | checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ
139 | logger.info(f"n_steps = {n_steps}")
140 | logger.info(f"checkpoint_freq = {checkpoint_freq}")
141 |
142 | org_n_steps = n_steps
143 | n_steps = (n_steps // checkpoint_freq) * checkpoint_freq + 1
144 | logger.info(f"n_steps is updated to {org_n_steps} => {n_steps} for checkpointing")
145 |
146 | if not args.test_envs:
147 | args.test_envs = [[te] for te in range(len(dataset))]
148 | logger.info(f"Target test envs = {args.test_envs}")
149 |
150 | ###########################################################################
151 | # Run
152 | ###########################################################################
153 | all_records = []
154 | results = collections.defaultdict(list)
155 |
156 | for test_env in args.test_envs:
157 | if args.dataset=="DomainNet":
158 | print("===== DN ======")
159 | res, records = train_dn(
160 | test_env,
161 | args=args,
162 | hparams=hparams,
163 | n_steps=n_steps,
164 | checkpoint_freq=checkpoint_freq,
165 | logger=logger,
166 | writer=writer,
167 | )
168 | else:
169 | print("===== others ======")
170 | res, records = train(
171 | test_env,
172 | args=args,
173 | hparams=hparams,
174 | n_steps=n_steps,
175 | checkpoint_freq=checkpoint_freq,
176 | logger=logger,
177 | writer=writer,
178 | )
179 | all_records.append(records)
180 | for k, v in res.items():
181 | results[k].append(v)
182 |
183 | # log summary table
184 | logger.info("=== Summary ===")
185 | logger.info(f"Command: {' '.join(sys.argv)}")
186 | logger.info("Unique name: %s" % args.unique_name)
187 | logger.info("Out path: %s" % args.out_dir)
188 | logger.info("Algorithm: %s" % args.algorithm)
189 | logger.info("Dataset: %s" % args.dataset)
190 |
191 | table = PrettyTable(["Selection"] + dataset.environments + ["Avg."])
192 | for key, row in results.items():
193 | row.append(np.mean(row))
194 | row = [f"{acc:.3%}" for acc in row]
195 | table.add_row([key] + row)
196 | logger.nofmt(table)
197 |
198 |
199 | if __name__ == "__main__":
200 | main()
201 |
--------------------------------------------------------------------------------