├── fig ├── overview.jpg └── results_PACS.png ├── Replay ├── Finetune.py ├── alg.py ├── utils.py └── iCaRL.py ├── modelopera.py ├── network ├── Adver_network.py ├── util.py ├── common_network.py └── img_network.py ├── datautil ├── mydataloader.py ├── util.py ├── imgdata │ ├── imgdataload.py │ └── util.py └── getdataloader.py ├── opt.py ├── README.md ├── scripts ├── PACS.sh ├── subdomain_net.sh └── dg5.sh ├── RandMix.py ├── utils ├── util.py └── visual.py ├── PCA.py ├── main.py ├── train.py ├── arguments.py ├── RaTP.py └── pLabel.py /fig/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonyResearch/RaTP/HEAD/fig/overview.jpg -------------------------------------------------------------------------------- /fig/results_PACS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonyResearch/RaTP/HEAD/fig/results_PACS.png -------------------------------------------------------------------------------- /Replay/Finetune.py: -------------------------------------------------------------------------------- 1 | class Finetune: 2 | def __init__(self, args): 3 | self.comment = 'Do nothing' 4 | 5 | def update_dataloader(self, dataloader=None): 6 | return None 7 | 8 | def update(self, model, task_id, dataloader): 9 | pass -------------------------------------------------------------------------------- /Replay/alg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from Replay.iCaRL import iCaRL 3 | from Replay.Finetune import Finetune 4 | # from Replay.LDAuCID_buff import LDAuCID_buff 5 | 6 | ALGORITHMS = [ 7 | 'iCaRL', 8 | 'Finetune' 9 | ] 10 | 11 | 12 | def get_algorithm_class(algorithm_name): 13 | if algorithm_name not in globals(): 14 | raise NotImplementedError( 15 | "Algorithm not found: {}".format(algorithm_name)) 16 | return globals()[algorithm_name] 17 | -------------------------------------------------------------------------------- /modelopera.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | from network import img_network 4 | 5 | 6 | def get_fea(args): 7 | if args.net.startswith('vgg'): 8 | net = img_network.VGGBase(args) 9 | elif args.net == 'LeNet': 10 | net = img_network.LeNetBase() 11 | elif args.net == 'DTN': 12 | net = img_network.DTNBase() 13 | elif args.net.startswith('res'): 14 | net = img_network.ResBase(args) 15 | else: 16 | net = img_network.VGGBase(args) 17 | return net 18 | 19 | 20 | def accuracy(network, loader): 21 | correct = 0 22 | total = 0 23 | 24 | network.eval() 25 | with torch.no_grad(): 26 | for data in loader: 27 | x = data[0].cuda().float() 28 | y = data[1].cuda().long() 29 | p = network(x) 30 | 31 | if p.size(1) == 1: 32 | correct += (p.gt(0).eq(y).float()).sum().item() 33 | else: 34 | correct += (p.argmax(1).eq(y).float()).sum().item() 35 | total += len(x) 36 | network.train() 37 | return correct / total 38 | -------------------------------------------------------------------------------- /network/Adver_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | 5 | 6 | class ReverseLayerF(Function): 7 | @staticmethod 8 | def forward(ctx, x, alpha): 9 | ctx.alpha = alpha 10 | return x.view_as(x) 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | output = grad_output.neg() * ctx.alpha 15 | return output, None 16 | 17 | 18 | class Discriminator(nn.Module): 19 | def __init__(self, input_dim=256, hidden_dim=256, num_domains=4): 20 | super(Discriminator, self).__init__() 21 | self.input_dim = input_dim 22 | self.hidden_dim = hidden_dim 23 | layers = [ 24 | nn.Linear(input_dim, hidden_dim), 25 | nn.BatchNorm1d(hidden_dim), 26 | nn.ReLU(), 27 | nn.Linear(hidden_dim, hidden_dim), 28 | nn.BatchNorm1d(hidden_dim), 29 | nn.ReLU(), 30 | nn.Linear(hidden_dim, num_domains), 31 | ] 32 | self.layers = torch.nn.Sequential(*layers) 33 | 34 | def forward(self, x): 35 | return self.layers(x) 36 | -------------------------------------------------------------------------------- /network/util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 7 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 8 | 9 | 10 | def init_weights(m): 11 | classname = m.__class__.__name__ 12 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 13 | nn.init.kaiming_uniform_(m.weight) 14 | nn.init.zeros_(m.bias) 15 | elif classname.find('BatchNorm') != -1: 16 | nn.init.normal_(m.weight, 1.0, 0.02) 17 | nn.init.zeros_(m.bias) 18 | elif classname.find('Linear') != -1: 19 | nn.init.xavier_normal_(m.weight) 20 | nn.init.zeros_(m.bias) 21 | 22 | def freeze_proxy(model): 23 | ''' 24 | freeze the PCL proxy and classifier weight in adaptation step. 25 | ''' 26 | model.fc_proj.requires_grad = False 27 | model.classifier.requires_grad = False 28 | 29 | def freeze_classifier(model): 30 | ''' 31 | freeze the classifier of model in adaptation step. 32 | ''' 33 | for k, v in model.classifier.named_parameters(): 34 | v.requires_grad = False 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /datautil/mydataloader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | 5 | 6 | class _InfiniteSampler(torch.utils.data.Sampler): 7 | """Wraps another Sampler to yield an infinite stream.""" 8 | 9 | def __init__(self, sampler): 10 | self.sampler = sampler 11 | 12 | def __iter__(self): 13 | while True: 14 | for batch in self.sampler: 15 | yield batch 16 | 17 | 18 | class InfiniteDataLoader: 19 | def __init__(self, dataset, weights, batch_size, num_workers): 20 | super().__init__() 21 | 22 | self.dataset = dataset 23 | 24 | if weights: 25 | sampler = torch.utils.data.WeightedRandomSampler(weights, 26 | replacement=True, 27 | num_samples=batch_size) 28 | else: 29 | sampler = torch.utils.data.RandomSampler(dataset, 30 | replacement=True) 31 | 32 | if weights == None: 33 | weights = torch.ones(len(dataset)) 34 | 35 | batch_sampler = torch.utils.data.BatchSampler( 36 | sampler, 37 | batch_size=batch_size, 38 | drop_last=True) 39 | 40 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 41 | dataset, 42 | num_workers=num_workers, 43 | batch_sampler=_InfiniteSampler(batch_sampler) 44 | )) 45 | 46 | def __iter__(self): 47 | while True: 48 | yield next(self._infinite_iterator) 49 | 50 | def __len__(self): 51 | raise ValueError 52 | -------------------------------------------------------------------------------- /Replay/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | import datautil.imgdata.util as imgutil 7 | from datautil.mydataloader import InfiniteDataLoader 8 | 9 | class ExemplarDataset(Dataset): 10 | ''' 11 | Used for compute_class_mean 12 | input: imgs should be PIL image. 13 | ''' 14 | def __init__(self, imgs, transform): 15 | self.imgs = imgs 16 | self.transform = transform 17 | def __len__(self): 18 | return len(self.imgs) 19 | def __getitem__(self, index): 20 | return self.transform(self.imgs[index]) 21 | # return self.transform(Image.fromarray(self.imgs[index])) 22 | 23 | class ReplayDataset(Dataset): 24 | ''' 25 | construct replay dataset 26 | input: imgs should be PIL image. 27 | ''' 28 | def __init__(self, images, class_labels, domain_labels, transform=None, target_transform=None): 29 | self.images = images 30 | self.labels = class_labels 31 | self.dlabels = domain_labels 32 | self.transform = transform 33 | 34 | def __len__(self): 35 | return len(self.labels) 36 | 37 | def __getitem__(self, index): 38 | imgs = self.transform(self.images[index]) if self.transform is not None else self.images[index] 39 | return imgs, self.labels[index], self.dlabels[index] 40 | 41 | def get_raw_data(self): 42 | return self.images, self.labels, self.dlabels 43 | 44 | def concat_list(data_list): 45 | ''' 46 | flatten list 47 | input: list of list [[..], .., [..]] 48 | return list [..] 49 | ''' 50 | datas = [] 51 | for l in data_list: 52 | for i in l: 53 | datas.append(i) 54 | return datas 55 | -------------------------------------------------------------------------------- /datautil/util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def Nmax(test_envs, d): 7 | for i in range(len(test_envs)): 8 | if d < test_envs[i]: 9 | return i 10 | return len(test_envs) 11 | 12 | 13 | def random_pairs_of_minibatches_by_domainperm(minibatches): 14 | perm = torch.randperm(len(minibatches)).tolist() 15 | pairs = [] 16 | 17 | for i in range(len(minibatches)): 18 | j = i + 1 if i < (len(minibatches) - 1) else 0 19 | 20 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 21 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 22 | 23 | min_n = min(len(xi), len(xj)) 24 | 25 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 26 | 27 | return pairs 28 | 29 | 30 | def random_pairs_of_minibatches(args, minibatches): 31 | ld = len(minibatches) 32 | pairs = [] 33 | tdlist = np.arange(ld) 34 | txlist = np.arange(args.batch_size) 35 | for i in range(ld): 36 | for j in range(args.batch_size): 37 | (tdi, tdj), (txi, txj) = np.random.choice(tdlist, 2, 38 | replace=False), np.random.choice(txlist, 2, replace=True) 39 | if j == 0: 40 | xi, yi, di = torch.unsqueeze( 41 | minibatches[tdi][0][txi], dim=0), minibatches[tdi][1][txi], minibatches[tdi][2][txi] 42 | xj, yj, dj = torch.unsqueeze( 43 | minibatches[tdj][0][txj], dim=0), minibatches[tdj][1][txj], minibatches[tdj][2][txj] 44 | else: 45 | xi, yi, di = torch.vstack((xi, torch.unsqueeze(minibatches[tdi][0][txi], dim=0))), torch.hstack( 46 | (yi, minibatches[tdi][1][txi])), torch.hstack((di, minibatches[tdi][2][txi])) 47 | xj, yj, dj = torch.vstack((xj, torch.unsqueeze(minibatches[tdj][0][txj], dim=0))), torch.hstack( 48 | (yj, minibatches[tdj][1][txj])), torch.hstack((dj, minibatches[tdj][2][txj])) 49 | pairs.append(((xi, yi, di), (xj, yj, dj))) 50 | return pairs 51 | -------------------------------------------------------------------------------- /network/common_network.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch.nn as nn 3 | from network.util import init_weights 4 | import torch.nn.utils.weight_norm as weightNorm 5 | 6 | def feat_encoder(args, in_dim, out_dim): 7 | hidden_size = 256 if args.dataset == 'dg5' else 512 8 | encoder = nn.Sequential( 9 | nn.Linear(in_dim, hidden_size), 10 | nn.BatchNorm1d(hidden_size), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(hidden_size, out_dim) 13 | ) 14 | return encoder 15 | 16 | class feat_bottleneck(nn.Module): 17 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 18 | super(feat_bottleneck, self).__init__() 19 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.dropout = nn.Dropout(p=0.5) 22 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 23 | # self.bottleneck.apply(init_weights) 24 | self.type = type 25 | 26 | def forward(self, x): 27 | x = self.bottleneck(x) 28 | if self.type == "bn": 29 | x = self.bn(x) 30 | return x 31 | 32 | 33 | class feat_classifier(nn.Module): 34 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 35 | super(feat_classifier, self).__init__() 36 | self.type = type 37 | if type == 'wn': 38 | self.fc = weightNorm( 39 | nn.Linear(bottleneck_dim, class_num), name="weight") 40 | # self.fc.apply(init_weights) 41 | else: 42 | self.fc = nn.Linear(bottleneck_dim, class_num) 43 | # self.fc.apply(init_weights) 44 | 45 | def forward(self, x): 46 | x = self.fc(x) 47 | return x 48 | 49 | 50 | class feat_classifier_two(nn.Module): 51 | def __init__(self, class_num, input_dim, bottleneck_dim=256): 52 | super(feat_classifier_two, self).__init__() 53 | self.type = type 54 | self.fc0 = nn.Linear(input_dim, bottleneck_dim) 55 | # self.fc0.apply(init_weights) 56 | self.fc1 = nn.Linear(bottleneck_dim, class_num) 57 | # self.fc1.apply(init_weights) 58 | 59 | def forward(self, x): 60 | x = self.fc0(x) 61 | x = self.fc1(x) 62 | return x 63 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | 4 | 5 | def get_params(alg, args, alg_name, inner=False, alias=True): 6 | if args.schuse: 7 | if args.schusech == 'cos': 8 | initlr = args.lr 9 | else: 10 | initlr = 1.0 11 | else: 12 | if inner: 13 | initlr = args.inner_lr 14 | else: 15 | initlr = args.lr 16 | if inner: 17 | params = [ 18 | {'params': alg[0].parameters(), 'lr': args.lr_decay1 * 19 | initlr}, 20 | {'params': alg[1].parameters(), 'lr': args.lr_decay2 * 21 | initlr} 22 | ] 23 | elif alias: 24 | params = [ 25 | {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * initlr}, 26 | {'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * initlr} 27 | ] 28 | else: 29 | params = [ 30 | {'params': alg[0].parameters(), 'lr': args.lr_decay1 * initlr}, 31 | {'params': alg[1].parameters(), 'lr': args.lr_decay2 * initlr} 32 | ] 33 | if ('DANN' in alg_name) or ('CDANN' in alg_name): 34 | params.append({'params': alg.discriminator.parameters(), 35 | 'lr': args.lr_decay2 * initlr}) 36 | if ('CDANN' in alg_name): 37 | params.append({'params': alg.class_embeddings.parameters(), 38 | 'lr': args.lr_decay2 * initlr}) 39 | return params 40 | 41 | 42 | def get_optimizer(alg, args, inner=False, alias=True): 43 | params = get_params(alg, args, args.DGalgorithm, inner, alias) 44 | optimizer = torch.optim.SGD( 45 | params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 46 | return optimizer 47 | 48 | 49 | def get_scheduler(optimizer, args): 50 | if not args.schuse: 51 | return None 52 | if args.schusech == 'cos': 53 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 54 | optimizer, args.max_epoch * args.steps_per_epoch) 55 | else: 56 | scheduler = torch.optim.lr_scheduler.LambdaLR( 57 | optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 58 | return scheduler 59 | 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deja Vu: Continual Model Generalization for Unseen Domains 2 | Official Implementation for ICLR 2023 paper: [Deja Vu: Continual Model Generalization for Unseen Domains](https://arxiv.org/pdf/2301.10418.pdf) 3 | 4 | ![Overview](./fig/overview.jpg) 5 | RaTP first starts with a labeled source domain, applies RandMix on the full set of source data to generate augmentation data, and uses a simplified version of PCA for model optimization. Then, for continually arriving target domains, RaTP uses T2PL to generate pseudo labels for all unlabeled samples, applies RandMix on a top subset of these samples based on their softmax confidence, and optimizes the model by PCA. 6 | 7 | # Dependencies: 8 | pytorch==1.11.0 9 | torchvision==0.12.0 10 | numpy==1.20.3 11 | sklearn==0.24.2 12 | 13 | # Datasets 14 | Download **Digit-Five** and **PACS** from https://github.com/jindongwang/transferlearning/tree/master/code/DeepDG. Rename them as `dg5` and `PACS` and place them in `./Dataset`. 15 | Download the subset of **DomainNet** used in our paper from https://drive.google.com/file/d/1LDnU3el-nHoqTgnvxEZP_PxdbdBapNKP/view?usp=sharing, and place it in `./Dataset`. 16 | 17 | # Usage 18 | ## Quick Start 19 | After installing all dependency packages, you can use the following command to run the code on PACS 20 | ``` 21 | python main.py --gpu $gpu_id --order 2 0 1 3 --seed 2022 \ 22 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 23 | --output result_mainpaper 24 | ``` 25 | ## Reproduce the experiment results in the paper 26 | Please run all bash files in `scripts` with the following command. 27 | ``` 28 | cd ./scripts 29 | bash PACS.sh dg5.sh subdomain_net.sh 30 | ``` 31 | 32 | # Performance 33 | The visualization of results will be saved in `result_*` (you can customize the file name) after training. The following table reports the experiment results of running RaTP on PACS. 34 | ![Results](./fig/results_PACS.png) 35 | 36 | # Citation 37 | ``` 38 | @article{liu2023deja, 39 | title={DEJA VU: Continual Model Generalization For Unseen Domains}, 40 | author={Liu, Chenxi and Wang, Lixu and Lyu, Lingjuan and Sun, Chen and Wang, Xiao and Zhu, Qi}, 41 | journal={arXiv preprint arXiv:2301.10418}, 42 | year={2023} 43 | } 44 | ``` 45 | 46 | ## Contact 47 | 48 | If you have any questions regarding the code, please feel free to contact Lixu Wang (lixuwang2025@u.northwestern.edu) or Lingjuan Lyu (Lingjuan.Lv@sony.com). 49 | 50 | ###### Copyright 2023, Sony AI, Sony Corporation of America, All rights reserved. 51 | -------------------------------------------------------------------------------- /scripts/PACS.sh: -------------------------------------------------------------------------------- 1 | gpu_id=2 2 | 3 | # --------------------------- main paper order ----------------------------- # 4 | python main.py --gpu $gpu_id --order 2 0 1 3 --seed 2022 \ 5 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 6 | --output result_mainpaper 7 | 8 | python main.py --gpu $gpu_id --order 2 0 1 3 --seed 2023 \ 9 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 10 | --output result_mainpaper 11 | 12 | python main.py --gpu $gpu_id --order 2 0 1 3 --seed 2024 \ 13 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 14 | --output result_mainpaper 15 | 16 | 17 | python main.py --gpu $gpu_id --order 3 1 0 2 --seed 2022 \ 18 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 19 | --output result_mainpaper 20 | 21 | python main.py --gpu $gpu_id --order 3 1 0 2 --seed 2023 \ 22 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 23 | --output result_mainpaper 24 | 25 | python main.py --gpu $gpu_id --order 3 1 0 2 --seed 2024 \ 26 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 27 | --output result_mainpaper 28 | 29 | 30 | # --------------------------- additional order ----------------------------- # 31 | 32 | python main.py --gpu $gpu_id --order 0 1 2 3 --seed 2022 \ 33 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 34 | --output result_additional_order 35 | 36 | python main.py --gpu $gpu_id --order 0 1 3 2 --seed 2022 \ 37 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 38 | --output result_additional_order 39 | 40 | python main.py --gpu $gpu_id --order 0 2 1 3 --seed 2022 \ 41 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 42 | --output result_additional_order 43 | 44 | python main.py --gpu $gpu_id --order 1 0 3 2 --seed 2022 \ 45 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 46 | --output result_additional_order 47 | 48 | python main.py --gpu $gpu_id --order 1 3 2 0 --seed 2022 \ 49 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 50 | --output result_additional_order 51 | 52 | python main.py --gpu $gpu_id --order 2 3 0 1 --seed 2022 \ 53 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 54 | --output result_additional_order 55 | 56 | python main.py --gpu $gpu_id --order 2 3 1 0 --seed 2022 \ 57 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 58 | --output result_additional_order 59 | 60 | python main.py --gpu $gpu_id --order 3 2 1 0 --seed 2022 \ 61 | --aug_tau 0.5 --topk_alpha 20 --lr 0.005 --MPCL_alpha 0.5 \ 62 | --output result_additional_order -------------------------------------------------------------------------------- /scripts/subdomain_net.sh: -------------------------------------------------------------------------------- 1 | gpu_id=5 2 | 3 | # --------------------------- main paper order ----------------------------- # 4 | python main.py --gpu $gpu_id --dataset subdomain_net --order 3 5 0 1 2 4 --seed 2022 \ 5 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 6 | --output result_mainpaper 7 | 8 | python main.py --gpu $gpu_id --dataset subdomain_net --order 3 5 0 1 2 4 --seed 2023 \ 9 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 10 | --output result_mainpaper 11 | 12 | python main.py --gpu $gpu_id --dataset subdomain_net --order 3 5 0 1 2 4 --seed 2024 \ 13 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 14 | --output result_mainpaper 15 | 16 | 17 | # --------------------------- additional order ----------------------------- # 18 | python main.py --gpu $gpu_id --dataset subdomain_net --order 4 2 1 0 5 3 --seed 2022 \ 19 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 20 | --output result_additional_order 21 | 22 | python main.py --gpu $gpu_id --dataset subdomain_net --order 0 1 2 3 4 5 --seed 2022 \ 23 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 24 | --output result_additional_order 25 | 26 | python main.py --gpu $gpu_id --dataset subdomain_net --order 5 4 3 2 1 0 --seed 2022 \ 27 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 28 | --output result_additional_order 29 | 30 | python main.py --gpu $gpu_id --dataset subdomain_net --order 2 5 3 1 4 0 --seed 2022 \ 31 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 32 | --output result_additional_order 33 | 34 | python main.py --gpu $gpu_id --dataset subdomain_net --order 0 4 1 3 5 2 --seed 2022 \ 35 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 36 | --output result_additional_order 37 | 38 | python main.py --gpu $gpu_id --dataset subdomain_net --order 3 4 0 2 1 5 --seed 2022 \ 39 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 40 | --output result_additional_order 41 | 42 | python main.py --gpu $gpu_id --dataset subdomain_net --order 5 1 2 0 4 3 --seed 2022 \ 43 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 44 | --output result_additional_order 45 | 46 | python main.py --gpu $gpu_id --dataset subdomain_net --order 1 3 0 2 4 5 --seed 2022 \ 47 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 48 | --output result_additional_order 49 | 50 | python main.py --gpu $gpu_id --dataset subdomain_net --order 5 4 2 0 3 1 --seed 2022 \ 51 | --aug_tau 0.8 --topk_alpha 10 --topk_beta 1 --lr 0.005 --MPCL_alpha 1.0 \ 52 | --output result_additional_order 53 | -------------------------------------------------------------------------------- /RandMix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | 6 | class AdaIN2d(nn.Module): 7 | def __init__(self, style_dim, num_features): 8 | super().__init__() 9 | self.norm = nn.InstanceNorm2d(num_features, affine=False) 10 | self.fc = nn.Linear(style_dim, num_features*2) 11 | 12 | def forward(self, x, s): 13 | h = self.fc(s) 14 | h = h.view(h.size(0), h.size(1), 1, 1) 15 | gamma, beta = torch.chunk(h, chunks=2, dim=1) 16 | return (1 + gamma) * self.norm(x) + beta 17 | 18 | 19 | class RandMix(nn.Module): 20 | def __init__(self, noise_lv): 21 | super(RandMix, self).__init__() 22 | ############# Trainable Parameters 23 | self.zdim = zdim = 10 24 | self.noise_lv = noise_lv 25 | self.adain_1 = AdaIN2d(zdim, 3) 26 | self.adain_2 = AdaIN2d(zdim, 3) 27 | self.adain_3 = AdaIN2d(zdim, 3) 28 | self.adain_4 = AdaIN2d(zdim, 3) 29 | 30 | 31 | self.tran = transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 32 | 33 | def forward(self, x, estimation=False, ratio=0): 34 | data = x 35 | 36 | spatial1 = nn.Conv2d(3, 3, 5).cuda() 37 | spatial_up1 = nn.ConvTranspose2d(3, 3, 5).cuda() 38 | 39 | spatial2 = nn.Conv2d(3, 3, 9).cuda() 40 | spatial_up2 = nn.ConvTranspose2d(3, 3, 9).cuda() 41 | 42 | spatial3 = nn.Conv2d(3, 3, 13).cuda() 43 | spatial_up3 = nn.ConvTranspose2d(3, 3, 13).cuda() 44 | 45 | spatial4 = nn.Conv2d(3, 3, 17).cuda() 46 | spatial_up4 = nn.ConvTranspose2d(3, 3, 17).cuda() 47 | 48 | color = nn.Conv2d(3, 3, 1).cuda() 49 | weight = torch.randn(6) 50 | 51 | x = x + torch.randn_like(x) * self.noise_lv * 0.001 52 | x_c = torch.tanh(F.dropout(color(x), p=.2)) 53 | 54 | x_s1down = spatial1(x) 55 | s = torch.randn(len(x_s1down), self.zdim).cuda() 56 | x_s1down = self.adain_1(x_s1down, s) 57 | x_s1 = torch.tanh(spatial_up1(x_s1down)) 58 | 59 | x_s2down = spatial2(x) 60 | s = torch.randn(len(x_s2down), self.zdim).cuda() 61 | x_s2down = self.adain_2(x_s2down, s) 62 | x_s2 = torch.tanh(spatial_up2(x_s2down)) 63 | 64 | x_s3down = spatial3(x) 65 | s = torch.randn(len(x_s3down), self.zdim).cuda() 66 | x_s3down = self.adain_3(x_s3down, s) 67 | x_s3 = torch.tanh(spatial_up3(x_s3down)) 68 | 69 | x_s4down = spatial4(x) 70 | s = torch.randn(len(x_s4down), self.zdim).cuda() 71 | x_s4down = self.adain_4(x_s4down, s) 72 | x_s4 = torch.tanh(spatial_up4(x_s4down)) 73 | 74 | output = (weight[0] * x_c + weight[1] * x_s1 + weight[2] * x_s2 + weight[3] * x_s3 + weight[4] * x_s4 + weight[5] * data) / weight.sum() 75 | return output 76 | -------------------------------------------------------------------------------- /datautil/imgdata/imgdataload.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | from datautil.util import Nmax 5 | from datautil.imgdata.util import rgb_loader, l_loader 6 | from torchvision.datasets import ImageFolder 7 | from torchvision.datasets.folder import default_loader 8 | import torch 9 | 10 | 11 | class ImageDataset(object): 12 | def __init__(self, args, task, root_dir, domain_name, domain_label=-1, labels=None, transform=None, target_transform=None, indices=None, test_envs=[], mode='Default'): 13 | self.args = args 14 | self.imgs = ImageFolder(root_dir+domain_name).imgs 15 | self.domain_num = 0 16 | self.task = task 17 | self.dataset = args.dataset 18 | imgs = [item[0] for item in self.imgs] 19 | labels = [item[1] for item in self.imgs] 20 | self.labels = np.array(labels) # np.array 21 | self.x = imgs # list of file dir 22 | self.transform = transform 23 | self.target_transform = target_transform 24 | if indices is None: 25 | self.indices = np.arange(len(imgs)) 26 | else: 27 | self.indices = indices 28 | if mode == 'Default': 29 | self.loader = default_loader 30 | elif mode == 'RGB': 31 | self.loader = rgb_loader 32 | elif mode == 'L': 33 | self.loader = l_loader 34 | self.dlabels = np.ones(self.labels.shape) * \ 35 | (domain_label-Nmax(test_envs, domain_label)) # np.array 36 | 37 | def set_labels(self, tlabels=None, label_type='domain_label'): 38 | assert len(tlabels) == len(self.x) 39 | if label_type == 'domain_label': 40 | self.dlabels = tlabels 41 | elif label_type == 'class_label': 42 | self.labels = tlabels 43 | 44 | def target_trans(self, y): 45 | if self.target_transform is not None: 46 | return self.target_transform(y) 47 | else: 48 | return y 49 | 50 | def input_trans(self, x): 51 | if self.transform is not None: 52 | return self.transform(x) 53 | else: 54 | return x 55 | 56 | def __getitem__(self, index): 57 | index = self.indices[index] 58 | img = self.input_trans(self.loader(self.x[index])) 59 | ctarget = self.target_trans(self.labels[index]) 60 | dtarget = self.target_trans(self.dlabels[index]) 61 | return img, ctarget, dtarget 62 | 63 | def __len__(self): 64 | return len(self.indices) 65 | 66 | def get_raw_data(self): 67 | img_dict = [] 68 | clabel = [] 69 | dlabel = [] 70 | for i in self.indices: 71 | img_dict.append(self.x[i]) 72 | clabel.append(self.labels[i]) 73 | dlabel.append(self.dlabels[i]) 74 | return img_dict, np.array(clabel), np.array(dlabel) 75 | -------------------------------------------------------------------------------- /scripts/dg5.sh: -------------------------------------------------------------------------------- 1 | gpu_id=3 2 | 3 | # --------------------------- main paper order ----------------------------- # 4 | python main.py --gpu $gpu_id --dataset dg5 --order 0 1 2 3 4 --net DTN --seed 2022 \ 5 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 6 | --output result_mainpaper 7 | 8 | python main.py --gpu $gpu_id --dataset dg5 --order 0 1 2 3 4 --net DTN --seed 2023 \ 9 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 10 | --output result_mainpaper 11 | 12 | python main.py --gpu $gpu_id --dataset dg5 --order 0 1 2 3 4 --net DTN --seed 2024 \ 13 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 14 | --output result_mainpaper 15 | 16 | 17 | python main.py --gpu $gpu_id --dataset dg5 --order 4 3 2 1 0 --net DTN --seed 2022 \ 18 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 19 | --output result_mainpaper 20 | 21 | python main.py --gpu $gpu_id --dataset dg5 --order 4 3 2 1 0 --net DTN --seed 2023 \ 22 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 23 | --output result_mainpaper 24 | 25 | python main.py --gpu $gpu_id --dataset dg5 --order 4 3 2 1 0 --net DTN --seed 2024 \ 26 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 27 | --output result_mainpaper 28 | 29 | 30 | # --------------------------- additional order ----------------------------- # 31 | python main.py --gpu $gpu_id --dataset dg5 --order 0 1 4 2 3 --net DTN --seed 2022 \ 32 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 33 | --output result_additional_order 34 | 35 | python main.py --gpu $gpu_id --dataset dg5 --order 1 4 0 3 2 --net DTN --seed 2022 \ 36 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 37 | --output result_additional_order 38 | 39 | python main.py --gpu $gpu_id --dataset dg5 --order 2 0 1 3 4 --net DTN --seed 2022 \ 40 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 41 | --output result_additional_order 42 | 43 | python main.py --gpu $gpu_id --dataset dg5 --order 2 3 0 4 1 --net DTN --seed 2022 \ 44 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 45 | --output result_additional_order 46 | 47 | python main.py --gpu $gpu_id --dataset dg5 --order 3 1 2 0 4 --net DTN --seed 2022 \ 48 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 49 | --output result_additional_order 50 | 51 | python main.py --gpu $gpu_id --dataset dg5 --order 3 2 4 1 0 --net DTN --seed 2022 \ 52 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 53 | --output result_additional_order 54 | 55 | python main.py --gpu $gpu_id --dataset dg5 --order 3 4 1 2 0 --net DTN --seed 2022 \ 56 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 57 | --output result_additional_order 58 | 59 | python main.py --gpu $gpu_id --dataset dg5 --order 4 0 2 1 3 --net DTN --seed 2022 \ 60 | --aug_tau 0.8 --topk_alpha 20 --pseudo_fre 2 --lr 0.01 --MPCL_alpha 1 \ 61 | --output result_additional_order 62 | -------------------------------------------------------------------------------- /datautil/imgdata/util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from torchvision import transforms 3 | from PIL import Image, ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | import torch 6 | 7 | def image_train_source(args, resize_size=256, crop_size=224): 8 | if args.dataset == 'dg5': 9 | return transforms.Compose([ 10 | transforms.Resize((32, 32)), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 13 | ]) 14 | 15 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 16 | std=[0.229, 0.224, 0.225]) 17 | 18 | transform = transforms.Compose([ 19 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 22 | transforms.RandomGrayscale(), 23 | transforms.ToTensor(), 24 | normalize 25 | ]) 26 | 27 | return transform 28 | 29 | def image_train(args, resize_size=256, crop_size=224): 30 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | if args.dataset == 'dg5': 33 | return transforms.Compose([ 34 | transforms.Resize((32, 32)), 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 37 | ]) 38 | 39 | elif "domain_net" in args.dataset: 40 | transform = transforms.Compose([ 41 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 42 | transforms.RandomHorizontalFlip(), 43 | # transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 44 | # transforms.RandomGrayscale(), 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | else: 50 | transform = transforms.Compose([ 51 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 54 | transforms.RandomGrayscale(), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | return transform 60 | 61 | 62 | def image_test(args, resize_size=256, crop_size=224): 63 | if args.dataset == 'dg5': 64 | return transforms.Compose([ 65 | transforms.Resize((32, 32)), 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 68 | ]) 69 | 70 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]) 72 | return transforms.Compose([ 73 | transforms.Resize((224, 224)), 74 | transforms.ToTensor(), 75 | normalize 76 | ]) 77 | 78 | def rgb_loader(path): 79 | with open(path, 'rb') as f: 80 | with Image.open(f) as img: 81 | return img.convert('RGB') 82 | 83 | 84 | def l_loader(path): 85 | with open(path, 'rb') as f: 86 | with Image.open(f) as img: 87 | return img.convert('L') 88 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import random 3 | import numpy as np 4 | import torch 5 | import sys 6 | import os 7 | import torchvision 8 | import PIL 9 | 10 | 11 | def set_random_seed(seed=0): 12 | # seed setting 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | 21 | def save_checkpoint(filename, alg, args): 22 | save_dict = { 23 | "args": vars(args), 24 | "model_dict": alg.cpu().state_dict() 25 | } 26 | torch.save(save_dict, filename) 27 | 28 | 29 | def train_valid_target_eval_names(args): 30 | eval_name_dict = {'train': [], 'valid': [], 'target': []} 31 | t = 0 32 | for i in range(args.domain_num): 33 | if i not in args.test_envs: 34 | eval_name_dict['train'].append(t) 35 | t += 1 36 | for i in range(args.domain_num): 37 | if i not in args.test_envs: 38 | eval_name_dict['valid'].append(t) 39 | else: 40 | eval_name_dict['target'].append(t) 41 | t += 1 42 | return eval_name_dict 43 | 44 | 45 | def alg_loss_dict(args): 46 | loss_dict = {'ANDMask': ['total'], 47 | 'CORAL': ['class', 'coral', 'total'], 48 | 'DANN': ['class', 'dis', 'total'], 49 | 'ERM': ['class'], 50 | 'Mixup': ['class'], 51 | 'MLDG': ['total'], 52 | 'MMD': ['class', 'mmd', 'total'], 53 | 'GroupDRO': ['group'], 54 | 'RSC': ['class'], 55 | 'VREx': ['loss', 'nll', 'penalty'] 56 | } 57 | return loss_dict[args.DGalgorithm] 58 | 59 | 60 | def print_args(args, print_list): 61 | s = "==========================================\n" 62 | l = len(print_list) 63 | for arg, content in args.__dict__.items(): 64 | if l == 0 or arg in print_list: 65 | s += "{}:{}\n".format(arg, content) 66 | return s 67 | 68 | 69 | def print_environ(): 70 | print("Environment:") 71 | print("\tPython: {}".format(sys.version.split(" ")[0])) 72 | print("\tPyTorch: {}".format(torch.__version__)) 73 | print("\tTorchvision: {}".format(torchvision.__version__)) 74 | print("\tCUDA: {}".format(torch.version.cuda)) 75 | print("\tCUDNN: {}".format(torch.backends.cudnn.version())) 76 | print("\tNumPy: {}".format(np.__version__)) 77 | print("\tPIL: {}".format(PIL.__version__)) 78 | 79 | class Tee: 80 | def __init__(self, fname, mode="a"): 81 | self.stdout = sys.stdout 82 | self.file = open(fname, mode) 83 | 84 | def write(self, message): 85 | self.stdout.write(message) 86 | self.file.write(message) 87 | self.flush() 88 | 89 | def flush(self): 90 | self.stdout.flush() 91 | self.file.flush() 92 | 93 | def log_print(message, file, p=True, l=True): 94 | if p == True: 95 | print(message) 96 | 97 | if l == True: 98 | f = open(file, "a") 99 | f.write(message+'\n') 100 | f.close() 101 | 102 | -------------------------------------------------------------------------------- /PCA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class PCLoss(nn.Module): 7 | def __init__(self, num_classes, scale): 8 | super(PCLoss, self).__init__() 9 | self.soft_plus = nn.Softplus() 10 | self.label = torch.LongTensor([i for i in range(num_classes)]).cuda() 11 | self.scale = scale 12 | 13 | def forward(self, feature, target, proxy): 14 | ''' 15 | feature: (N, dim) 16 | proxy: (C, dim) 17 | ''' 18 | feature = F.normalize(feature, p=2, dim=1) 19 | pred = F.linear(feature, F.normalize(proxy, p=2, dim=1)) 20 | 21 | label = (self.label.unsqueeze(1) == target.unsqueeze(0)) 22 | pred_p = torch.masked_select(pred, label.transpose(1, 0)) # (N) positive pair 23 | pred_p = pred_p.unsqueeze(1) 24 | pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1) # (N, C-1) negative pair of anchor and proxy 25 | 26 | feature = torch.matmul(feature, feature.transpose(1, 0)) 27 | label_matrix = target.unsqueeze(1) == target.unsqueeze(0) 28 | 29 | feature = feature * ~label_matrix 30 | feature = feature.masked_fill(feature < 1e-6, -np.inf) 31 | 32 | logits = torch.cat([pred_p, pred_n, feature], dim=1) 33 | label = torch.zeros(logits.size(0), dtype=torch.long).cuda() 34 | loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label) 35 | return loss 36 | 37 | class PCALoss(nn.Module): 38 | def __init__(self, num_classes, scale): 39 | super(PCALoss, self).__init__() 40 | self.soft_plus = nn.Softplus() 41 | self.label = torch.LongTensor([i for i in range(num_classes)]).cuda() 42 | self.scale = scale 43 | 44 | def forward(self, feature, target, proxy, Mproxy, mweight=1): 45 | ''' 46 | feature: (N, dim) 47 | proxy: (C, dim) 48 | Mproxy: (C, dim) 49 | ''' 50 | feature = F.normalize(feature, p=2, dim=1) 51 | pred = F.linear(feature, F.normalize(proxy, p=2, dim=1)) # (N, C) similarity between sample and proxy 52 | Mpred = F.linear(feature, F.normalize(Mproxy, p=2, dim=1)) # (N, C) similarity between sample and old proxy 53 | 54 | label = (self.label.unsqueeze(1) == target.unsqueeze(0)) 55 | pred_p = torch.masked_select(pred, label.transpose(1, 0)) # (N) positive pair 56 | pred_p = pred_p.unsqueeze(1) 57 | pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1) # (N, C-1) negative pair of anchor and proxy 58 | Mpred_p = torch.masked_select(Mpred, label.transpose(1, 0)) 59 | Mpred_p = Mpred_p.unsqueeze(1) 60 | Mpred_n = torch.masked_select(Mpred, ~label.transpose(1, 0)).view(feature.size(0), -1) 61 | 62 | feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N) sample wise similarity 63 | label_matrix = target.unsqueeze(1) == target.unsqueeze(0) 64 | 65 | feature = feature * ~label_matrix 66 | feature = feature.masked_fill(feature < 1e-6, -np.inf) 67 | 68 | loss = -torch.log( ( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*Mpred_p.squeeze()) ) / 69 | ( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*Mpred_p.squeeze()) + 70 | torch.exp(self.scale*pred_n).sum(dim=1) + mweight*torch.exp(self.scale*Mpred_n).sum(dim=1) + torch.exp(self.scale*feature).sum(dim=1)) ).mean() 71 | return loss 72 | 73 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | 4 | from arguments import get_args 5 | from opt import * 6 | from RaTP import RaTP 7 | import modelopera 8 | import Replay.alg as ReplayAlg 9 | from datautil.getdataloader import get_img_dataloader 10 | from utils.util import set_random_seed, save_checkpoint, log_print 11 | from train import train 12 | from utils.visual import save_plot_acc_epochs, fit_tSNE, visual_tSNE 13 | 14 | if __name__ == '__main__': 15 | args = get_args() 16 | set_random_seed(args.seed) 17 | log_print('################################################', args.log_file) 18 | log_print('############### Attention: arguments steps_per_epoch should be changed with batch_size and dataset ! ####################', args.log_file) 19 | log_print('command args: {}'.format(sys.argv[1:]), args.log_file) 20 | log_print('arguments: {}\n'.format(args), args.log_file, p=False) 21 | 22 | # Get Data 23 | train_loaders, eval_loaders, eval_name_dict, task_sequence_name = get_img_dataloader(args) 24 | 25 | # Model 26 | model = RaTP(args).cuda() 27 | old_model = None # used for knwoledge distillation algorithms 28 | Replay_algorithm_class = ReplayAlg.get_algorithm_class(args.replay) 29 | Replay_algorithm = Replay_algorithm_class(args) 30 | model.train() 31 | 32 | # initial statistics metrics 33 | target_domain_acc_list = [] 34 | source_domain_acc_list = [] 35 | all_val_acc_record = {} # list of record list for each task. e.g.'task0': [initial acc, [acc along training of task0], [acc along training of task1]...] 36 | for tid in range(len(eval_name_dict['valid'])): 37 | all_val_acc_record['task{}'.format(tid)] = [[modelopera.accuracy(model, eval_loaders[eval_name_dict['valid'][tid]])]] 38 | if args.tsne: 39 | tSNE_dict = {'features':[], 'clabels':[], 'dlabels':[]} 40 | tSNE_dict = fit_tSNE(args, model, eval_loaders, tSNE_dict) 41 | 42 | 43 | # incremental train different domains 44 | for task_id, dataloader in enumerate(train_loaders): 45 | 46 | # construct replay exemplars 47 | replay_dataset = Replay_algorithm.update_dataloader() 48 | 49 | # main training 50 | model, val_acc_record, pseudo_dataloader = train(args, model, old_model, task_id, dataloader, replay_dataset, eval_loaders, eval_name_dict) 51 | for tid in range(len(eval_name_dict['valid'])): 52 | all_val_acc_record['task{}'.format(tid)].append(val_acc_record['task{}'.format(tid)]) 53 | 54 | # show inter result. 55 | for tid in range(task_id+1): 56 | log_print('after task {}: {}'.format(tid, [all_val_acc_record['task{}'.format(i)][tid+1][-1] for i in range(len(eval_name_dict['valid']))]), args.log_file) 57 | 58 | # finish task 59 | Replay_algorithm.update(model, task_id, pseudo_dataloader) 60 | 61 | if args.tsne: 62 | tSNE_dict = fit_tSNE(args, model, eval_loaders, tSNE_dict) 63 | 64 | # save model after finishing a task. It will be used for knowledge distill algorithms 65 | save_checkpoint(args.saved_model_name, model, args) 66 | old_model = copy.deepcopy(model) 67 | model.cuda() 68 | old_model.cuda().eval() 69 | 70 | save_plot_acc_epochs(args, all_val_acc_record, task_sequence_name) 71 | if args.tsne: 72 | visual_tSNE(args, tSNE_dict) 73 | 74 | log_print('\nDGaccuracy matrix: ', args.log_file) 75 | log_print('at start: {}'.format([all_val_acc_record['task{}'.format(tid)][0][0] for tid in range(len(eval_name_dict['valid']))]), args.log_file) 76 | for tid in range(len(eval_name_dict['valid'])): 77 | log_print('after task {}: {}'.format(tid, [all_val_acc_record['task{}'.format(i)][tid+1][-1] for i in range(len(eval_name_dict['valid']))]), args.log_file) 78 | 79 | log_print('', args.log_file) 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /network/img_network.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch.nn as nn 3 | from torchvision import models 4 | # from DataAug.Mixup.EFDMix import EFDMix 5 | 6 | vgg_dict = {"vgg11": models.vgg11, "vgg13": models.vgg13, "vgg16": models.vgg16, "vgg19": models.vgg19, 7 | "vgg11bn": models.vgg11_bn, "vgg13bn": models.vgg13_bn, "vgg16bn": models.vgg16_bn, "vgg19bn": models.vgg19_bn} 8 | 9 | 10 | class VGGBase(nn.Module): 11 | def __init__(self, args): 12 | super(VGGBase, self).__init__() 13 | model_vgg = vgg_dict[args.net](pretrained=True) 14 | self.features = model_vgg.features 15 | # self.classifier = nn.Sequential() 16 | # for i in range(6): # remove the final classifier layer. now classifiers sequential in_dim is 25088, out_dim is 4096 17 | # self.classifier.add_module( 18 | # "classifier"+str(i), model_vgg.classifier[i]) 19 | # self.in_features = model_vgg.classifier[6].in_features 20 | self.in_features = 512 # input image shape should be (3, 32, 32) 21 | 22 | def forward(self, x): 23 | x = self.features(x) 24 | x = x.view(x.size(0), -1) 25 | # x = self.classifier(x) 26 | return x 27 | 28 | 29 | res_dict = {"resnet18": models.resnet18, "resnet34": models.resnet34, "resnet50": models.resnet50, 30 | "resnet101": models.resnet101, "resnet152": models.resnet152, "resnext50": models.resnext50_32x4d, "resnext101": models.resnext101_32x8d} 31 | 32 | 33 | class ResBase(nn.Module): 34 | def __init__(self, args): 35 | super(ResBase, self).__init__() 36 | model_resnet = res_dict[args.net](pretrained=True) 37 | self.conv1 = model_resnet.conv1 38 | self.bn1 = model_resnet.bn1 39 | self.relu = model_resnet.relu 40 | self.maxpool = model_resnet.maxpool 41 | self.layer1 = model_resnet.layer1 42 | self.layer2 = model_resnet.layer2 43 | self.layer3 = model_resnet.layer3 44 | self.layer4 = model_resnet.layer4 45 | self.avgpool = model_resnet.avgpool 46 | self.in_features = model_resnet.fc.in_features 47 | 48 | def forward(self, x): 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | x = self.relu(x) 52 | x = self.maxpool(x) 53 | x = self.layer1(x) 54 | x = self.layer2(x) 55 | x = self.layer3(x) 56 | x = self.layer4(x) 57 | x = self.avgpool(x) 58 | x = x.view(x.size(0), -1) 59 | return x 60 | 61 | class DTNBase(nn.Module): 62 | def __init__(self): 63 | super(DTNBase, self).__init__() 64 | self.conv_params = nn.Sequential( 65 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 66 | nn.BatchNorm2d(64), 67 | nn.Dropout2d(0.1), 68 | nn.ReLU(), 69 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 70 | nn.BatchNorm2d(128), 71 | nn.Dropout2d(0.3), 72 | nn.ReLU(), 73 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 74 | nn.BatchNorm2d(256), 75 | nn.Dropout2d(0.5), 76 | nn.ReLU() 77 | ) 78 | self.in_features = 256*4*4 79 | 80 | def forward(self, x): 81 | x = self.conv_params(x) 82 | x = x.view(x.size(0), -1) 83 | return x 84 | 85 | 86 | class LeNetBase(nn.Module): 87 | def __init__(self): 88 | super(LeNetBase, self).__init__() 89 | self.conv_params = nn.Sequential( 90 | nn.Conv2d(3, 20, kernel_size=5), 91 | nn.MaxPool2d(2), 92 | nn.ReLU(), 93 | nn.Conv2d(20, 50, kernel_size=5), 94 | nn.Dropout2d(p=0.5), 95 | nn.MaxPool2d(2), 96 | nn.ReLU(), 97 | ) 98 | self.in_features = 50*4*4 99 | 100 | def forward(self, x): 101 | x = self.conv_params(x) 102 | x = x.view(x.size(0), -1) 103 | return x 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | 5 | import modelopera 6 | from utils.util import log_print 7 | from pLabel import assign_pseudo_label 8 | from datautil.mydataloader import InfiniteDataLoader 9 | 10 | def train(args, model, old_model, task_id, dataloader, replay_dataset, eval_loaders, eval_name_dict): 11 | acc_record = {} 12 | all_val_acc_record = {} 13 | for tid in range(len(eval_name_dict['valid'])): 14 | all_val_acc_record['task{}'.format(tid)] = [] 15 | best_valid_acc, target_acc = 0, 0 16 | 17 | max_epoch = args.max_epoch 18 | model.get_optimizer(lr_decay=args.lr_decay1 if task_id > 0 else 1.0) 19 | model.optimizer = op_copy(model.optimizer) 20 | 21 | with tqdm(range(max_epoch)) as tepoch: 22 | tepoch.set_description(f"Task {task_id}") 23 | for epoch in tepoch: 24 | 25 | # progressly assign pseudo label 26 | if epoch % args.pseudo_fre == 0: 27 | pseudo_dataloader, plabel_sc = assign_pseudo_label(args, dataloader, replay_dataset, task_id, model, epoch) 28 | curr_dataloader = cat_pseudo_replay(args, pseudo_dataloader, replay_dataset) 29 | replay_dataloader = None 30 | 31 | model.naug = 0 if task_id > 0 else args.batch_size*args.steps_per_epoch 32 | for iter_ in range(args.steps_per_epoch): # make sure each tasks has the same training iters. 33 | minibatches = [(data) for data in next(iter(curr_dataloader))] 34 | if minibatches[0].size(0) == 1: 35 | continue 36 | 37 | model.train() 38 | if task_id == 0: 39 | step_vals = model.train_source(minibatches, task_id, epoch) 40 | else: 41 | step_vals = model.adapt(minibatches, task_id, epoch, replay_dataloader, old_model) 42 | 43 | model.optimizer = lr_scheduler(model.optimizer, epoch, max_epoch) 44 | 45 | # only calculate accuracy of current domain 46 | for item in ['train', 'valid']: 47 | acc_record[item] = np.mean(np.array([modelopera.accuracy(model, eval_loaders[eval_name_dict[item][task_id]])])) 48 | if plabel_sc is None: 49 | tepoch.set_postfix(**step_vals, **acc_record, naug=model.naug/(args.batch_size*args.steps_per_epoch)) 50 | else: 51 | tepoch.set_postfix(**step_vals, **acc_record, naug=model.naug/(args.batch_size*args.steps_per_epoch)) 52 | 53 | # record accuracy of validation data of all tasks along epochs. 54 | for tid in range(len(eval_name_dict['valid'])): 55 | all_val_acc_record['task{}'.format(tid)].append(modelopera.accuracy(model, eval_loaders[eval_name_dict['valid'][tid]])) 56 | 57 | if acc_record['valid'] > best_valid_acc: 58 | best_valid_acc = acc_record['valid'] 59 | 60 | log_print('task{} training result on max_epoch{}: {} {}'.format(task_id, max_epoch, step_vals, acc_record), args.log_file, p=False) 61 | 62 | return model, all_val_acc_record, pseudo_dataloader 63 | 64 | 65 | def cat_pseudo_replay(args, dataloader, replay_dataset): 66 | if replay_dataset is not None: 67 | dataset = torch.utils.data.ConcatDataset([dataloader.dataset, replay_dataset]) 68 | dataloader = InfiniteDataLoader(dataset=dataset, weights=None, batch_size=args.batch_size, num_workers=args.N_WORKERS) 69 | return dataloader 70 | 71 | def op_copy(optimizer): 72 | for param_group in optimizer.param_groups: 73 | param_group['lr0'] = param_group['lr'] 74 | return optimizer 75 | 76 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 77 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 78 | for param_group in optimizer.param_groups: 79 | param_group['lr'] = param_group['lr0'] * decay 80 | # param_group['weight_decay'] = 1e-3 81 | # param_group['momentum'] = 0.9 82 | # param_group['nesterov'] = True 83 | return optimizer -------------------------------------------------------------------------------- /Replay/iCaRL.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torchvision import transforms 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from torch.nn import functional as F 7 | 8 | import Replay.utils as utils 9 | import datautil.imgdata.util as imgutil 10 | from utils.util import log_print 11 | 12 | class iCaRL: 13 | def __init__(self, args): 14 | self.args = args 15 | self.exemplar_set = [] # list of list[PIL image] : [[exemplar1 PIL image], [exemplar2 PIL image]...] 16 | self.exemplar_label_set = [] # list of np.array : [array(exemplar1 labels), array(exemplar2 labels)...] 17 | self.exemplar_dlabel_set = [] 18 | self.replay_dataset = None 19 | 20 | def update_dataloader(self, dataloader=None): 21 | exemplar_set = self.exemplar_set 22 | exemplar_label_set = self.exemplar_label_set 23 | exemplar_dlabel_set = self.exemplar_dlabel_set 24 | log_print('exemplar_set size: {}'.format(len(exemplar_set[0]) if len(exemplar_set)>0 else 0), self.args.log_file) 25 | replay_dataloader = None 26 | 27 | if len(exemplar_set) > 0: 28 | imgs = utils.concat_list(exemplar_set) 29 | labels = utils.concat_list(exemplar_label_set) 30 | dlabels = utils.concat_list(exemplar_dlabel_set) 31 | self.replay_dataset = utils.ReplayDataset(imgs, labels, dlabels, transform=imgutil.image_train(self.args)) 32 | 33 | return self.replay_dataset 34 | 35 | def update(self, model, task_id, dataloader): 36 | if self.args.replay_mode == 'class': # exemplar for each class and domain 37 | m=int(self.args.memory_size / (self.args.num_classes * (task_id+1))) 38 | elif self.args.replay_mode == 'domain': # exemplar for each domain 39 | m=int(self.args.memory_size / (task_id+1)) 40 | self._reduce_exemplar_sets(m) 41 | 42 | image_dict, class_label, domain_label = dataloader.dataset.get_raw_data() 43 | images = [dataloader.dataset.loader(dict) for dict in image_dict] # list of PIL image 44 | 45 | if self.args.replay_mode == 'class': # each exemplar contains data of one class in one specific doamin 46 | for c in range(self.args.num_classes): 47 | indices = np.where(class_label == c)[0] 48 | if len(indices) == 0: 49 | log_print('No class {} pseudo labels!!!'.format(c), self.args.log_file) 50 | continue 51 | imgs = [images[i] for i in indices] # list of PIL image 52 | clabel = class_label[class_label == c] 53 | dlabel = domain_label[class_label == c] 54 | self._construct_exemplar_set(model, imgs, clabel, dlabel, m) 55 | elif self.args.replay_mode == 'domain': # each exemplar contains data of all classes in one specific doamin 56 | self._construct_exemplar_set(model, images, class_label, domain_label, m) 57 | 58 | def _construct_exemplar_set(self, model, images, class_label, domain_label, m): 59 | ''' 60 | construct exemplar for each class in each domain 61 | input images should be one class in one specific domain 62 | ''' 63 | class_mean, feature_extractor_output = self.compute_class_mean(model, images, transform=imgutil.image_test(self.args)) 64 | exemplar = [] 65 | exemplar_index = [] 66 | 67 | now_class_mean = np.zeros((1, model.featurizer.in_features)) # feature extracter output dimension 68 | 69 | for i in range(m): 70 | 71 | #icarl code 72 | # shape:batch_size*256 73 | x = class_mean - (now_class_mean + feature_extractor_output) / (i + 1) 74 | # shape:batch_size 75 | x = np.linalg.norm(x, axis=1) 76 | index = np.argmin(x) 77 | now_class_mean += feature_extractor_output[index] 78 | 79 | # make sure selected example won't be selected again 80 | # if index in exemplar_index: 81 | # raise ValueError("Exemplars should not be repeated!!!!") 82 | exemplar.append(images[index]) 83 | exemplar_index.append(index) 84 | feature_extractor_output[index] += 10000 85 | 86 | self.exemplar_set.append(exemplar) 87 | self.exemplar_label_set.append(class_label[exemplar_index]) 88 | self.exemplar_dlabel_set.append(domain_label[exemplar_index]) 89 | 90 | 91 | def _reduce_exemplar_sets(self, m): 92 | for index in range(len(self.exemplar_set)): 93 | self.exemplar_set[index] = self.exemplar_set[index][:m] 94 | for index in range(len(self.exemplar_label_set)): 95 | self.exemplar_label_set[index] = self.exemplar_label_set[index][:m] 96 | for index in range(len(self.exemplar_dlabel_set)): 97 | self.exemplar_dlabel_set[index] = self.exemplar_dlabel_set[index][:m] 98 | 99 | 100 | def compute_class_mean(self, model, images, transform): 101 | exemplar_dataset = utils.ExemplarDataset(images, transform) 102 | exemplar_dataloader = DataLoader(dataset=exemplar_dataset, 103 | shuffle=False, 104 | batch_size=self.args.batch_size, 105 | num_workers=self.args.N_WORKERS) 106 | model.eval() # if not use this, it will affect evaluation steps after this evaluation, even they call model.eval(). 107 | feature_extractor_outputs = [] 108 | for i, x in enumerate(exemplar_dataloader): 109 | x = x.cuda() 110 | with torch.no_grad(): 111 | feature_extractor_outputs.append(model.featurizer(x)) 112 | feature_extractor_outputs = torch.cat(feature_extractor_outputs, dim=0) 113 | model.train() 114 | feature_extractor_outputs = F.normalize(feature_extractor_outputs.detach()).cpu().numpy() 115 | class_mean = np.mean(feature_extractor_outputs, axis=0) 116 | return class_mean, feature_extractor_outputs 117 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description='DG') 6 | # Data 7 | parser.add_argument('--data_dir', type=str, default='./Dataset', help='root data dir') 8 | parser.add_argument('--dataset', type=str, default='PACS', choices=['PACS', 'subdomain_net', 'dg5']) 9 | parser.add_argument('--order', type=int, nargs='+', help='training domain order') 10 | parser.add_argument('--test_envs', type=int, nargs='+', 11 | default=[], help='no fixed target domains') 12 | parser.add_argument('--split_style', type=str, default='strat',help="the style to split the train and eval datasets") 13 | 14 | #training algorithm 15 | parser.add_argument('--loss_alpha1', type=float, default=1.0, help='loss weight') 16 | parser.add_argument('--PCL_scale', default=12, type=float, help='scale of cross entropy in PCL') 17 | parser.add_argument('--pLabelAlg', type=str, default="T2PL", choices=['T2PL', 'ground'], help='pesudo label assigning algorithm in target domain. ground is ground true label') 18 | parser.add_argument('--pseudo_fre', default=1, type=int, help='assign new pseudo label each pseduo_fre epoch') 19 | parser.add_argument('--replay', type=str, default='icarl', choices=['icarl', 'Finetune'], help='data replay algorithm') 20 | parser.add_argument('--replay_mode', type=str, default='class', choices=['class', 'domain']) 21 | parser.add_argument('--memory_size', type=int, help="replay exemplar size") 22 | parser.add_argument('--aug_tau', type=float, default=0.8, help='do augmentation whose pseudo label confidence larger than this value ') 23 | parser.add_argument('--distance', type=str, default='cosine', choices=['cosine', 'euclidean']) 24 | parser.add_argument('--distill', type=str, default='KL', choices=['CE', 'KL', 'feaKL']) 25 | parser.add_argument('--distill_alpha', type=float, default=0.5) 26 | parser.add_argument('--topk_alpha', default=20, type=int, help='k nears in knn pseudo labeling.') 27 | parser.add_argument('--topk_beta', default=2, type=int, help='topk fitting samples in knn pseudo labeling.') 28 | parser.add_argument('--MPCL_alpha', type=float, default=0.5, help='MPCL weight') 29 | 30 | # Utils 31 | parser.add_argument('--seed', type=int, default=2022) 32 | parser.add_argument('--output', type=str, 33 | default="result_develop", help='result output path') 34 | parser.add_argument('--log_file', type=str, help="logging file name under output dir") 35 | parser.add_argument('--tsne', action='store_true', help='visualize embedding space using tSNE') 36 | 37 | # Model 38 | parser.add_argument('--net', type=str, default='resnet50', 39 | help="featurizer: vgg16, resnet50, resnet101,DTNBase") 40 | parser.add_argument('--classifier', type=str, 41 | default="linear", choices=["linear", "wn"]) 42 | 43 | # Training 44 | parser.add_argument('--lr', type=float, default=5e-3, help="learning rate") 45 | parser.add_argument('--lr_decay1', type=float, default=1.0, help='feature extractor lr scheduler') 46 | parser.add_argument('--max_epoch', type=int, 47 | default=30, help="max epoch") 48 | parser.add_argument('--steps_per_epoch', type=int, help='training steps in each epoch. totaly trained sampels in each epoch is steps_per_epoch*batch_size') 49 | parser.add_argument('--batch_size', type=int, 50 | default=64, help='batch_size') 51 | parser.add_argument('--gpu', type=int, default=0, help="device id to run") 52 | parser.add_argument('--N_WORKERS', type=int, default=4) 53 | parser.add_argument('--weight_decay', type=float, default=5e-4) 54 | parser.add_argument('--momentum', type=float, 55 | default=0.9, help='for optimizer') 56 | 57 | # Don't need to change 58 | parser.add_argument('--data_file', type=str, default='', 59 | help='root_dir') 60 | parser.add_argument('--task', type=str, default="img_dg", 61 | choices=["img_dg"], help='now only support image tasks') 62 | 63 | args = parser.parse_args() 64 | 65 | # I/O 66 | args.data_dir = os.path.join(args.data_dir, args.dataset, '') 67 | args.result_dir = os.path.join(args.output, args.dataset) 68 | args.tSNE_dir = os.path.join(args.result_dir, 'tSNE') 69 | os.makedirs(args.output, exist_ok=True) 70 | os.makedirs(args.result_dir, exist_ok=True) 71 | os.makedirs(args.tSNE_dir, exist_ok=True) 72 | 73 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 74 | args = img_param_init(args) 75 | args = set_default_args(args) 76 | args.num_task = len(args.domains) - len(args.test_envs) 77 | 78 | args.saved_model_name = os.path.join(args.result_dir, 'source{}.pt'.format(args.order[0])) 79 | 80 | return args 81 | 82 | def set_default_args(args): 83 | args.order = [i for i in range(len(args.domains)-len(args.test_envs))] if args.order is None else args.order 84 | args.log_file = os.path.join(args.result_dir, 'order{}.log'.format(''.join(str(i) for i in args.order))) if args.log_file is None else os.path.join(args.result_dir, args.log_file) 85 | if args.replay == 'icarl': 86 | args.replay = 'iCaRL' 87 | 88 | memory_size = {'PACS':200, 'subdomain_net':200, 'dg5':200} 89 | steps_per_epoch = {'PACS':50, 'subdomain_net':70, 'dg5':800} 90 | args.memory_size = memory_size[args.dataset] if args.memory_size is None else args.memory_size 91 | args.steps_per_epoch = steps_per_epoch[args.dataset] if args.steps_per_epoch is None else args.steps_per_epoch 92 | 93 | return args 94 | 95 | def img_param_init(args): 96 | dataset = args.dataset 97 | if dataset == 'PACS': 98 | domains = ['art_painting', 'cartoon', 'photo', 'sketch'] 99 | elif dataset == 'subdomain_net': 100 | domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'] 101 | elif dataset == 'dg5': 102 | domains = ['mnist', 'mnist_m', 'svhn', 'syn', 'usps'] 103 | else: 104 | print('No such dataset exists!') 105 | args.domains = domains 106 | args.img_dataset = { 107 | 'PACS': ['art_painting', 'cartoon', 'photo', 'sketch'], 108 | 'subdomain_net': ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'], 109 | 'dg5': ['mnist', 'mnist_m', 'svhn', 'syn', 'usps'], 110 | } 111 | if dataset == 'dg5': 112 | args.input_shape = (3, 32, 32) 113 | args.num_classes = 10 114 | else: 115 | args.input_shape = (3, 224, 224) 116 | if args.dataset == 'PACS': 117 | args.num_classes = 7 118 | elif args.dataset == 'subdomain_net': 119 | args.num_classes = 10 120 | 121 | args.proj_dim = {'dg5':128, 'PACS':256, 'subdomain_net':512} # project dim for contrastive loss. 122 | 123 | return args -------------------------------------------------------------------------------- /RaTP.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from torchvision import transforms 7 | from modelopera import get_fea 8 | from opt import * 9 | from network.common_network import feat_encoder 10 | import PCA 11 | from RandMix import RandMix 12 | 13 | def Entropy_(input_): 14 | bs = input_.size(0) 15 | epsilon = 1e-5 16 | entropy = -input_ * torch.log(input_ + epsilon) 17 | entropy = torch.sum(entropy, dim=1) 18 | return entropy 19 | 20 | class RaTP(torch.nn.Module): 21 | 22 | def __init__(self, args): 23 | super(RaTP, self).__init__() 24 | self.args = args 25 | self.task_id = 0 26 | self.naug = 0 27 | self.fea_rep = None 28 | self.featurizer = get_fea(args) 29 | 30 | # training algorithm model 31 | fea_dim = args.proj_dim[args.dataset] 32 | self.encoder = feat_encoder(args, self.featurizer.in_features, fea_dim) 33 | self._initialize_weights(self.encoder) 34 | 35 | self.classifier = nn.Parameter(torch.FloatTensor(args.num_classes, fea_dim)) 36 | nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5)) 37 | 38 | # Data augment algorithm 39 | self.data_aug = RandMix(1).cuda() 40 | if args.dataset == 'dg5': 41 | self.aug_tran = transforms.Normalize([0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 42 | else: 43 | self.aug_tran = transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 44 | 45 | def forward(self, x): 46 | x = self.featurizer(x) 47 | x = self.encoder(x) 48 | self.fea_rep = x 49 | pred = F.linear(x, self.classifier) 50 | return pred 51 | 52 | def get_optimizer(self, lr_decay=1.0): 53 | self.optimizer = torch.optim.SGD([ 54 | {'params': self.featurizer.parameters(), 'lr': lr_decay * self.args.lr}, 55 | {'params': self.encoder.parameters()}, 56 | {'params': self.classifier}, 57 | ], lr=self.args.lr, weight_decay=self.args.weight_decay) 58 | 59 | 60 | ################################################## train source and adapt ###################################################################### 61 | 62 | def train_source(self, minibatches, task_id, epoch): 63 | self.task_id = task_id 64 | all_x = minibatches[0].cuda().float() 65 | all_y = minibatches[1].cuda().long() 66 | 67 | # Data Augmentation using RandMix 68 | ratio = epoch / self.args.max_epoch 69 | data_fore = self.aug_tran(torch.sigmoid(self.data_aug(all_x, ratio=ratio))) 70 | all_x = torch.cat([all_x, data_fore]) # [original, aug] 71 | all_y = torch.cat([all_y, all_y]) 72 | 73 | loss, loss_dict = self.PCAupdate(all_x, all_y) 74 | 75 | self.optimizer.zero_grad() 76 | loss.backward() 77 | self.optimizer.step() 78 | return {'loss': loss.item()} 79 | 80 | 81 | def adapt(self, minibatches, task_id, epoch, replay_dataloader=None, old_model=None): 82 | self.task_id = task_id 83 | all_x = minibatches[0].cuda().float() 84 | all_y = minibatches[1].cuda().long() 85 | 86 | # Data Augmentation using RandMix 87 | all_x, all_y = self.select_aug(all_x, all_y, epoch) 88 | 89 | loss, loss_dict = self.PCAupdate(all_x, all_y, old_model) 90 | 91 | self.optimizer.zero_grad() 92 | loss.backward() 93 | self.optimizer.step() 94 | # self.scheduler.step() 95 | return {'loss': loss.item()} 96 | 97 | 98 | ################################################################ Algorithms #################################################### 99 | 100 | def PCAupdate(self, all_x, all_y, old_model=None): 101 | pred = self(all_x) 102 | 103 | # cross entropy loss 104 | loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y) 105 | 106 | # pca loss 107 | proxy = self.classifier 108 | features = self.fea_rep 109 | if self.task_id > 0: 110 | old_proxy = old_model.classifier 111 | loss_pcl = PCA.PCALoss(self.args.num_classes, self.args.PCL_scale)(features, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha) 112 | else: 113 | loss_pcl = PCA.PCLoss(num_classes=self.args.num_classes, scale=self.args.PCL_scale)(features, all_y, proxy) 114 | 115 | loss_dict = {'ce': loss_cls.item(), 'pcl': (self.args.loss_alpha1 * loss_pcl).item()} 116 | loss = loss_cls + self.args.loss_alpha1 * loss_pcl 117 | 118 | # distill loss 119 | if old_model is not None: 120 | distill_loss = self.args.distill_alpha * self.distill_loss(pred, all_x, old_model) 121 | loss += distill_loss 122 | loss_dict['distill'] = distill_loss.item() 123 | 124 | return loss, loss_dict 125 | 126 | def distill_loss(self, pred, all_x, old_model): 127 | old_model.cuda().eval() 128 | with torch.no_grad(): 129 | old_logist = nn.Softmax(dim=1)(old_model(all_x)) 130 | 131 | if self.args.distill == 'CE': 132 | loss = F.cross_entropy(pred, old_logist) 133 | elif self.args.distill == 'KL': 134 | loss = nn.KLDivLoss(reduction="batchmean")(nn.LogSoftmax(dim=1)(pred), old_logist) 135 | elif self.args.distill == 'feaKL': 136 | loss = nn.KLDivLoss(reduction="batchmean")(nn.LogSoftmax(dim=1)(self.fea_rep), nn.Softmax(dim=1)(old_model.fea_rep)) 137 | return loss 138 | 139 | 140 | ################################################################ Utils #################################################### 141 | def _initialize_weights(self, modules): 142 | for m in modules: 143 | if isinstance(m, nn.Conv2d): 144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 145 | m.weight.data.normal_(0, math.sqrt(2. / n)) 146 | if m.bias is not None: 147 | m.bias.data.zero_() 148 | elif isinstance(m, nn.BatchNorm2d): 149 | m.weight.data.fill_(1) 150 | m.bias.data.zero_() 151 | elif isinstance(m, nn.Linear): 152 | n = m.weight.size(1) 153 | m.weight.data.normal_(0, 0.01) 154 | m.bias.data.zero_() 155 | 156 | def select_aug(self, all_x, all_y, epoch): 157 | ratio = epoch / self.args.max_epoch 158 | if self.args.aug_tau > 0: 159 | self.eval() 160 | with torch.no_grad(): 161 | pred = nn.Softmax(dim=1)(self(all_x)) 162 | ov, idx = torch.max(pred, 1) 163 | bool_index = ov > self.args.aug_tau 164 | data_fore = all_x[bool_index] 165 | y_fore = all_y[bool_index] 166 | data_fore = self.aug_tran(torch.sigmoid(self.data_aug(data_fore, ratio=ratio))) 167 | self.train() 168 | else: 169 | data_fore = self.aug_tran(torch.sigmoid(self.data_aug(all_x, ratio=ratio))) 170 | y_fore = all_y 171 | all_x = torch.cat([all_x, data_fore]) 172 | all_y = torch.cat([all_y, y_fore]) 173 | self.naug += len(y_fore) 174 | return all_x, all_y 175 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import numpy as np 3 | import os 4 | import matplotlib.pyplot as plt 5 | import pickle 6 | from sklearn.manifold import TSNE 7 | import torch 8 | import seaborn as sns 9 | import pandas as pd 10 | from statistics import mean 11 | 12 | from Replay.utils import concat_list 13 | 14 | def save_plot_acc_epochs(args, all_val_acc_record, task_sequence_name): 15 | ''' 16 | all_val_acc_record: dict 17 | e.g. 'task0': [initial acc, [acc along training of task0], [acc along training of task1]...] 18 | task_sequence_name: list 19 | domain sequence name 20 | ''' 21 | 22 | # save result dictionary 23 | with open(os.path.join(args.result_dir, 'order{}_lr{}_seed{}.pkl'.format( 24 | ''.join(str(i) for i in args.order), args.lr, args.seed)), 'wb') as f: 25 | pickle.dump(all_val_acc_record, f) 26 | 27 | # plot result 28 | num_task = len(all_val_acc_record) 29 | plt.clf() 30 | fig, ax = plt.subplots() #figsize=(5,4) 31 | x = range(len(concat_list(all_val_acc_record['task0']))) 32 | 33 | for tid in range(num_task): 34 | acc_end = [a[-1] for a in all_val_acc_record['task{}'.format(tid)]] 35 | acc_end.pop(0) 36 | 37 | if tid == 0: 38 | label_name = '{}_{}_{}_fa{}'.format(task_sequence_name[tid], [round(100*elem,1) for elem in acc_end], round(100*mean([elem for elem in acc_end]),1), 39 | round(100*mean([elem for elem in acc_end[1:]]),1)) 40 | elif tid == (num_task-1): 41 | label_name = '{}_{}_{}_dg{}'.format(task_sequence_name[tid], [round(100*elem,1) for elem in acc_end], round(100*mean([elem for elem in acc_end]),1), 42 | round(100*mean([elem for elem in acc_end[:-1]]),1)) 43 | else: 44 | label_name = '{}_{}_{}_g{}_f{}'.format(task_sequence_name[tid], [round(100*elem,1) for elem in acc_end], round(100*mean([elem for elem in acc_end]),1), 45 | round(100*mean([elem for elem in acc_end[:tid]]),1), round(100*mean([elem for elem in acc_end[tid+1:]]),1)) 46 | 47 | ax.plot(x, concat_list(all_val_acc_record['task{}'.format(tid)]), label=label_name) 48 | ax.set_ylabel('accuracy') 49 | ax.legend() 50 | 51 | # add grid at the begining of tasks 52 | ax.set_xticks([(len(all_val_acc_record['task0'][i])*i) for i in range(1, num_task)], minor=False) 53 | ax.xaxis.grid(True, which='major') 54 | 55 | # calculate metrics 56 | da_acc, dg_acc, forget_acc = calculate_metrics(all_val_acc_record) 57 | 58 | ax.set_title('{}_{}_da{}_dg{}_fg{}'.format( 59 | args.dataset, args.seed, da_acc, dg_acc, forget_acc)) 60 | 61 | plt.savefig(os.path.join(args.result_dir, 'order{}_lr{}_seed{}.jpg'.format( 62 | ''.join(str(i) for i in args.order), args.lr, args.seed))) 63 | 64 | 65 | def calculate_metrics(all_val_acc_record): # for ablation study. First average in each domain, then average all domain. 66 | num_task = len(all_val_acc_record) 67 | da, dg_av, fa_av = [], [], [] 68 | for tid in range(num_task): 69 | dg, forget = [], [] 70 | acc_end = [a[-1] for a in all_val_acc_record['task{}'.format(tid)]] 71 | acc_end.pop(0) # (num_task, num_task) 72 | da.append(acc_end[tid]) 73 | for i in range(0, tid): 74 | dg.append(acc_end[i]) 75 | if len(dg) > 0: 76 | dg_av.append(mean(dg)) 77 | for i in range(tid+1, num_task): 78 | forget.append(acc_end[i]) 79 | if len(forget) > 0: 80 | fa_av.append(mean(forget)) 81 | return round(100*mean(da),1), round(100*mean(dg_av),1), round(100*mean(fa_av),1) 82 | 83 | def fit_tSNE(args, net, eval_loaders, tSNE_dict): 84 | ''' 85 | fit a tSNE using eval data from all domain 86 | netF: feature extractor 87 | return: 88 | tsne_results: 2-D array 89 | clabels and dlabels: 1-D array 90 | ''' 91 | # get embedding features using model feature extractor 92 | features = [] 93 | clabels, dlabels = [], [] 94 | net.eval() 95 | with torch.no_grad(): 96 | for i in range(args.num_task): 97 | loader = eval_loaders[args.eval_name_dict['valid'][i]] 98 | for data in loader: # this line will change the performance!! ?? 99 | x = data[0].cuda().float() 100 | clabel = data[1] 101 | dlabel = data[2] 102 | feature = net.featurizer(x) 103 | features.append(feature.tolist()) 104 | clabels.append(clabel.tolist()) 105 | dlabels.append(dlabel.tolist()) 106 | features = concat_list(features) 107 | clabels = concat_list(clabels) 108 | dlabels = concat_list(dlabels) 109 | 110 | tsne = TSNE(n_components=2) #, perplexity=40, n_iter=300) 111 | tsne_results = tsne.fit_transform(features) 112 | tSNE_dict['features'].append(tsne_results) 113 | tSNE_dict['clabels'].append(clabels) 114 | tSNE_dict['dlabels'].append(dlabels) 115 | net.train() 116 | # visual_tSNE(args, tsne_results, clabels, dlabels, task_id) 117 | return tSNE_dict 118 | 119 | def visual_tSNE(args, tSNE_dict): 120 | plt.clf() 121 | fig, axes = plt.subplots(args.num_task+1, 2+len(args.domains), figsize=(5*(2+len(args.domains)), 5*(args.num_task+1))) 122 | for i in range(args.num_task+1): 123 | df = pd.DataFrame() 124 | df['x'] = tSNE_dict['features'][i][:,0] 125 | df['y'] = tSNE_dict['features'][i][:,1] 126 | df['class'] = tSNE_dict['clabels'][i] 127 | df['domain'] = tSNE_dict['dlabels'][i] 128 | 129 | sns.scatterplot(ax = axes[i][0], 130 | x="x", y="y", 131 | hue=df.domain, 132 | palette=sns.color_palette("hls", len(args.domains)), 133 | data=df, 134 | legend="full", 135 | alpha=0.3 136 | ) 137 | sns.scatterplot(ax = axes[i][1], 138 | x="x", y="y", 139 | hue=df['class'], 140 | palette=sns.color_palette("hls", args.num_classes), 141 | data=df, 142 | legend="full", 143 | alpha=0.3 144 | ) 145 | 146 | for j in range(len(args.domains)): 147 | df['xd'] = df['x'][df['domain']==j] 148 | df['yd'] = df['y'][df['domain']==j] 149 | df['classd'] = df['class'][df['domain']==j] 150 | 151 | sns.scatterplot(ax = axes[i][2+j], 152 | x="xd", y="yd", 153 | hue=df['classd'], 154 | palette=sns.color_palette("hls", args.num_classes), 155 | data=df, 156 | legend="full", 157 | alpha=0.3 158 | ) 159 | axes[i][2+j].set_title('domain{}'.format(j)) 160 | 161 | plt.savefig(os.path.join(args.tSNE_dir, 'order{}.jpg'.format(''.join(str(i) for i in args.order)))) 162 | # save result dictionary 163 | with open(os.path.join(args.tSNE_dir, 'order{}_lr{}_seed{}.pkl'.format( 164 | ''.join(str(i) for i in args.order), args.lr, args.seed)), 'wb') as f: 165 | pickle.dump(tSNE_dict, f) 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | pass -------------------------------------------------------------------------------- /pLabel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader, Dataset 6 | import torch.nn.functional as F 7 | from scipy.spatial.distance import cdist 8 | from sklearn.neighbors import KNeighborsClassifier 9 | import datautil.imgdata.util as imgutil 10 | from datautil.mydataloader import InfiniteDataLoader 11 | from utils.util import log_print 12 | import Replay.utils as RPutils 13 | 14 | def assign_pseudo_label(args, dataloader, replay_dataset, taskid, model, epoch, cur=False): 15 | pseudo_tau = 0 16 | if taskid == 0 or args.pLabelAlg == 'ground': 17 | return dataloader, None 18 | 19 | else: 20 | image_dict, clabel, dlabel = dataloader.dataset.get_raw_data() 21 | images = [dataloader.dataset.loader(dict) for dict in image_dict] # list of PIL image 22 | 23 | pseudo_image_dict = [] 24 | pseudo_clabel = [] 25 | pseudo_dlabel = [] 26 | 27 | curr_dataset = RPutils.ReplayDataset(images, clabel, dlabel, transform=imgutil.image_test(args)) 28 | curr_dataloader = DataLoader(dataset=curr_dataset, 29 | shuffle=False, 30 | batch_size=args.batch_size, 31 | num_workers=args.N_WORKERS) 32 | model.eval().cuda() 33 | pseudo_clabel, pacc_dict, bool_index = T2PL(args, curr_dataloader, model, pseudo_tau) 34 | for i, v in enumerate(bool_index): 35 | if v: 36 | pseudo_image_dict.append(image_dict[i]) 37 | pseudo_dlabel.append(dlabel[i]) 38 | model.train() 39 | pseudo_dataset = PseudoDataset(pseudo_image_dict, np.array(pseudo_clabel), np.array(pseudo_dlabel), loader=dataloader.dataset.loader, transform=imgutil.image_train(args)) 40 | pseudo_dataloader = InfiniteDataLoader(dataset=pseudo_dataset, weights=None, batch_size=args.batch_size, num_workers=args.N_WORKERS) 41 | 42 | return pseudo_dataloader, pacc_dict #{'ps':len(pseudo_image_dict), 'pc':correct} 43 | 44 | def T2PL(args, loader, model, pseudo_tau): 45 | start_test = True 46 | with torch.no_grad(): 47 | for i, data in enumerate(loader): 48 | inputs = data[0].cuda() 49 | labels = data[1] 50 | 51 | feas = model.encoder(model.featurizer(inputs)) 52 | outputs = F.linear(feas, model.classifier) 53 | 54 | if start_test: 55 | all_fea = [feas.float().cpu()] 56 | all_output = [outputs.float().cpu()] 57 | all_label = [labels.float()] 58 | start_test = False 59 | else: 60 | all_fea.append(feas.float().cpu()) 61 | all_output.append(outputs.float().cpu()) 62 | all_label.append(labels.float()) 63 | all_fea = torch.cat(all_fea, dim=0) 64 | all_output = torch.cat(all_output, dim=0) 65 | all_label = torch.cat(all_label, dim=0) 66 | 67 | all_output = nn.Softmax(dim=1)(all_output) 68 | ov, idx = torch.max(all_output, 1) 69 | bool_index = ov > pseudo_tau 70 | all_output = all_output[bool_index] 71 | all_fea = all_fea[bool_index] 72 | all_label = all_label[bool_index] 73 | 74 | acc_list = [] 75 | 76 | # softmax predict 77 | _, predict = torch.max(all_output, 1) 78 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 79 | acc_list.append(accuracy) 80 | 81 | all_fea = all_fea / torch.norm(all_fea, p=2, dim=1, keepdim=True) 82 | 83 | all_fea = all_fea.float().cpu() # (N, dim) 84 | K = all_output.size(1) 85 | aff = all_output.float().cpu() # (N, C) 86 | 87 | # top k features for SHOT 88 | topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_beta), 1) 89 | top_aff, top_fea = [], [] 90 | 91 | for cls_idx in range(args.num_classes): 92 | feat_samp_idx = torch.topk(aff[:, cls_idx], topk_num)[1] 93 | top_fea.append(all_fea[feat_samp_idx, :]) 94 | top_aff.append(aff[feat_samp_idx, :]) 95 | 96 | top_aff = torch.cat(top_aff, dim=0).numpy() 97 | top_fea = torch.cat(top_fea, dim=0).numpy() 98 | _, top_predict = torch.max(torch.from_numpy(top_aff), 1) 99 | 100 | # SHOT 101 | for _ in range(2): 102 | initc = top_aff.transpose().dot(top_fea) 103 | initc = initc / (1e-8 + top_aff.sum(axis=0)[:,None]) 104 | 105 | cls_count = np.eye(K)[predict].sum(axis=0) 106 | labelset = np.where(cls_count>0) 107 | labelset = labelset[0] 108 | 109 | dd = cdist(all_fea, initc[labelset], args.distance) 110 | pred_label = dd.argmin(axis=1) 111 | predict = labelset[pred_label] 112 | 113 | top_cls_count = np.eye(K)[top_predict].sum(axis=0) 114 | top_labelset = np.where(top_cls_count>0) 115 | top_labelset = top_labelset[0] 116 | 117 | top_dd = cdist(top_fea, initc[top_labelset], args.distance) 118 | top_pred_label = top_dd.argmin(axis=1) 119 | top_predict = top_labelset[top_pred_label] 120 | 121 | top_aff = np.eye(K)[top_predict] 122 | acc_list.append(np.sum(predict == all_label.float().numpy()) / len(all_fea)) 123 | 124 | # knn on distance of each features and cluster center 125 | top_sample = [] 126 | top_label = [] 127 | topk_fit_num = max(all_fea.shape[0] // (args.num_classes * args.topk_beta), 1) 128 | topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha), 1) 129 | 130 | for cls_idx in range(len(labelset)): 131 | feat_samp_idx = torch.topk(torch.from_numpy(dd)[:, cls_idx], topk_fit_num, largest=False )[1] 132 | 133 | feat_cls_sample = all_fea[feat_samp_idx, :] 134 | feat_cls_label = torch.zeros([len(feat_samp_idx)]).fill_(cls_idx) 135 | 136 | top_sample.append(feat_cls_sample) 137 | top_label.append(feat_cls_label) 138 | top_sample = torch.cat(top_sample, dim=0).cpu().numpy() 139 | top_label = torch.cat(top_label, dim=0).cpu().numpy() 140 | 141 | knn = KNeighborsClassifier(n_neighbors=topk_num) 142 | knn.fit(top_sample, top_label) 143 | 144 | knn_predict = knn.predict(all_fea.cpu().numpy()).tolist() 145 | knn_predict = [int(i) for i in knn_predict] 146 | 147 | predict = labelset[knn_predict] 148 | acc_list.append(np.sum(predict == all_label.float().numpy()) / len(all_fea)) 149 | 150 | # log_print("acc:" + " --> ".join("{:.3f}".format(acc) for acc in acc_list), args.log_file, p=False) 151 | acc_dict = {} 152 | for i in range(len(acc_list)): 153 | acc_dict['pa{}'.format(i)] = round(acc_list[i],3) 154 | 155 | return predict.astype('int'), acc_dict, bool_index 156 | 157 | 158 | class PseudoDataset(Dataset): 159 | ''' 160 | construct pseudo dataset 161 | input: images path. 162 | ''' 163 | def __init__(self, images_dict, class_labels, domain_labels, loader, transform=None, target_transform=None): 164 | self.x = images_dict # list of [PIL image path] 165 | self.labels = class_labels # numpy array 166 | self.dlabels = domain_labels # numpy array 167 | self.loader = loader 168 | self.transform = transform 169 | 170 | def __len__(self): 171 | return len(self.labels) 172 | 173 | def __getitem__(self, index): 174 | imgs = self.transform(self.loader(self.x[index])) if self.transform is not None else self.loader(self.x[index]) 175 | return imgs, self.labels[index], self.dlabels[index] 176 | 177 | def get_raw_data(self): 178 | return self.x, self.labels, self.dlabels -------------------------------------------------------------------------------- /datautil/getdataloader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import sklearn.model_selection as ms 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | import datautil.imgdata.util as imgutil 9 | from datautil.imgdata.imgdataload import ImageDataset 10 | from datautil.mydataloader import InfiniteDataLoader 11 | from utils.util import log_print, train_valid_target_eval_names 12 | 13 | 14 | def get_img_dataloader(args): 15 | ''' 16 | Outputs: 17 | train_loaders: list. Each element is a dataloader for a source domain's training data. 18 | val_loaders: list. [source domain train dataloaders + target domain dataloaders + source domain test dataloaders] 19 | eval_name_dict: dictinonaty. keys: ['train', 'valid', 'target'], store the index of corresponding data in val_loaders 20 | 21 | e.g. PACS data. test_envs = [] 22 | train_loaders: [training dataloader of 'Art', training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch'] 23 | val_loaders: [training dataloader of 'Art', training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch', 24 | test dataloader of 'Art', test dataloader of 'cartoon', test dataloader of 'photo', test dataloader of 'sketch'] 25 | eval_name_dict: ['train': [0,1,2,3], 'valid':[4,5,6,7], 'target':[]] 26 | task_sequence_name: ['Art', 'cartoon', 'photo', 'sketch'] 27 | 28 | e.g. PACS data. test_envs = [0] 29 | train_loaders: [training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch'] 30 | val_loaders: [training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch', dataloader of 'art painting', test dataloader of 'cartoon', test dataloader of 'photo', test dataloader of 'sketch'] 31 | 32 | dataloader return: images, class_label, domain_label (datautil.imgdata.imgdataload.ImageDataset) 33 | images: torch tensor (batch, 3, 224, 224) 34 | class_label: torch tensor (batch,) 35 | domain_label: torch tensor (batch,) 36 | Note that when alg is consup and forAug is None(the case of using original supervised contrastive loss, images is return images is [batch_size*2, C, H, W], batch_size*2 is concatenate of two imgutil.image_train transform of the same original image. 37 | ''' 38 | rate = 0.2 # test data rate 39 | trdatalist, tedatalist = [], [] 40 | train_name_list, target_name_list = [], [] 41 | 42 | names = args.img_dataset[args.dataset] 43 | args.domain_num = len(names) 44 | 45 | eval_name_dict = train_valid_target_eval_names(args) # keys: train, valid, target 46 | args.eval_name_dict = eval_name_dict 47 | args.test_envs = args.order[1:] 48 | 49 | for i, domian_id in enumerate(args.order): 50 | if i == 0: 51 | tmpdatay = ImageDataset(args, args.task, args.data_dir, 52 | names[domian_id], domian_id, transform=imgutil.image_train(args), test_envs=args.test_envs).labels 53 | l = len(tmpdatay) 54 | if args.split_style == 'strat': 55 | indexall = np.arange(l) 56 | stsplit = ms.StratifiedShuffleSplit( 57 | 2, test_size=rate, train_size=1-rate, random_state=args.seed) 58 | stsplit.get_n_splits(indexall, tmpdatay) 59 | indextr, indexte = next(stsplit.split(indexall, tmpdatay)) 60 | else: 61 | indexall = np.arange(l) 62 | np.random.seed(args.seed) 63 | np.random.shuffle(indexall) 64 | ted = int(l*rate) 65 | indextr, indexte = indexall[:-ted], indexall[-ted:] 66 | 67 | trdatalist.append(ImageDataset(args, args.task, args.data_dir, 68 | names[domian_id], domian_id, transform=imgutil.image_train_source(args), indices=indextr, test_envs=args.test_envs)) 69 | tedatalist.append(ImageDataset(args, args.task, args.data_dir, 70 | names[domian_id], domian_id, transform=imgutil.image_test(args), indices=indexte, test_envs=args.test_envs)) 71 | 72 | else: 73 | trdatalist.append(ImageDataset(args, args.task, args.data_dir, 74 | names[domian_id], domian_id, transform=imgutil.image_train(args), test_envs=args.test_envs)) 75 | tedatalist.append(ImageDataset(args, args.task, args.data_dir, 76 | names[domian_id], domian_id, transform=imgutil.image_test(args), test_envs=args.test_envs)) 77 | train_name_list.append(names[domian_id]) 78 | 79 | # If use for InfiniteDataloader, it will fetch data recurrently. 80 | train_loaders = [InfiniteDataLoader( 81 | dataset=env, 82 | weights=None, 83 | batch_size=args.batch_size, 84 | num_workers=args.N_WORKERS) 85 | for env in trdatalist] 86 | 87 | eval_loaders = [DataLoader( 88 | dataset=env, 89 | batch_size=args.batch_size*2, 90 | num_workers=args.N_WORKERS, 91 | drop_last=False, 92 | shuffle=False) 93 | for env in trdatalist+tedatalist] 94 | 95 | log_print('domain training tasks sequence: {}, corresponding data size: {}'.format(train_name_list, [len(d.dataset) for d in train_loaders]), args.log_file) 96 | log_print('domain validation data size: {}\n'.format([len(eval_loaders[i].dataset) for i in eval_name_dict['valid']]), args.log_file) 97 | 98 | return train_loaders, eval_loaders, eval_name_dict, train_name_list 99 | 100 | # def get_img_dataloader(args): 101 | # ''' 102 | # Outputs: 103 | # train_loaders: list. Each element is a dataloader for a source domain's training data. 104 | # val_loaders: list. [source domain train dataloaders + target domain dataloaders + source domain test dataloaders] 105 | # eval_name_dict: dictinonaty. keys: ['train', 'valid', 'target'], store the index of corresponding data in val_loaders 106 | 107 | # e.g. PACS data. test_envs = [] 108 | # train_loaders: [training dataloader of 'Art', training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch'] 109 | # val_loaders: [training dataloader of 'Art', training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch', 110 | # test dataloader of 'Art', test dataloader of 'cartoon', test dataloader of 'photo', test dataloader of 'sketch'] 111 | # eval_name_dict: ['train': [0,1,2,3], 'valid':[4,5,6,7], 'target':[]] 112 | # task_sequence_name: ['Art', 'cartoon', 'photo', 'sketch'] 113 | 114 | # e.g. PACS data. test_envs = [0] 115 | # train_loaders: [training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch'] 116 | # val_loaders: [training dataloader of 'cartoon', training dataloader of 'photo', training dataloader of 'sketch', dataloader of 'art painting', test dataloader of 'cartoon', test dataloader of 'photo', test dataloader of 'sketch'] 117 | 118 | # dataloader return: images, class_label, domain_label (datautil.imgdata.imgdataload.ImageDataset) 119 | # images: torch tensor (batch, 3, 224, 224) 120 | # class_label: torch tensor (batch,) 121 | # domain_label: torch tensor (batch,) 122 | # Note that when alg is consup and forAug is None(the case of using original supervised contrastive loss, images is return images is [batch_size*2, C, H, W], batch_size*2 is concatenate of two imgutil.image_train transform of the same original image. 123 | # ''' 124 | # rate = 0.2 # test data rate 125 | # trdatalist, tedatalist = [], [] 126 | # train_name_list, target_name_list = [], [] 127 | 128 | # names = args.img_dataset[args.dataset] 129 | # args.domain_num = len(names) 130 | # for i in range(len(names)): 131 | # if i in args.test_envs: 132 | # tedatalist.append(ImageDataset(args, args.task, args.data_dir, 133 | # names[i], i, transform=imgutil.image_test(args.dataset), test_envs=args.test_envs)) 134 | # target_name_list.append(names[i]) 135 | # else: 136 | # tmpdatay = ImageDataset(args, args.task, args.data_dir, 137 | # names[i], i, transform=imgutil.image_train(args), test_envs=args.test_envs).labels 138 | # l = len(tmpdatay) 139 | # if args.split_style == 'strat': 140 | # indexall = np.arange(l) 141 | # stsplit = ms.StratifiedShuffleSplit( 142 | # 2, test_size=rate, train_size=1-rate, random_state=args.seed) 143 | # stsplit.get_n_splits(indexall, tmpdatay) 144 | # indextr, indexte = next(stsplit.split(indexall, tmpdatay)) 145 | # else: 146 | # indexall = np.arange(l) 147 | # np.random.seed(args.seed) 148 | # np.random.shuffle(indexall) 149 | # ted = int(l*rate) 150 | # indextr, indexte = indexall[:-ted], indexall[-ted:] 151 | 152 | # if i != args.order[0]: # use all target domain data for training and testing 153 | # all_index = np.append(indextr, indexte) 154 | # trdatalist.append(ImageDataset(args, args.task, args.data_dir, 155 | # names[i], i, transform=imgutil.image_train(args), indices=all_index, test_envs=args.test_envs)) 156 | # tedatalist.append(ImageDataset(args, args.task, args.data_dir, 157 | # names[i], i, transform=imgutil.image_test(args), indices=all_index, test_envs=args.test_envs)) 158 | # else: 159 | # trdatalist.append(ImageDataset(args, args.task, args.data_dir, 160 | # names[i], i, transform=imgutil.image_train(args), indices=indextr, test_envs=args.test_envs)) 161 | # tedatalist.append(ImageDataset(args, args.task, args.data_dir, 162 | # names[i], i, transform=imgutil.image_test(args), indices=indexte, test_envs=args.test_envs)) 163 | # train_name_list.append(names[i]) 164 | # # test_name_list.append(names[i]) 165 | 166 | # # If use for InfiniteDataloader, it will fetch data recurrently. 167 | # train_loaders = [InfiniteDataLoader( 168 | # dataset=env, 169 | # weights=None, 170 | # batch_size=args.batch_size, 171 | # num_workers=args.N_WORKERS) 172 | # for env in trdatalist] 173 | 174 | # # if use DataLoader instead of InfiniteDataLoader, accuracy will decrease and training time will largely increase. 175 | # # train_loaders = [DataLoader( 176 | # # dataset=env, 177 | # # batch_size=args.batch_size, 178 | # # shuffle=True, 179 | # # num_workers=args.N_WORKERS) 180 | # # for env in trdatalist] 181 | 182 | # eval_loaders = [DataLoader( 183 | # dataset=env, 184 | # batch_size=args.batch_size*2, 185 | # num_workers=args.N_WORKERS, 186 | # drop_last=False, 187 | # shuffle=False) 188 | # for env in trdatalist+tedatalist] 189 | 190 | # eval_name_dict = train_valid_target_eval_names(args) # keys: train, valid, target 191 | # train_loaders = change_order(args.order, train_loaders) 192 | # train_name_list = change_order(args.order, train_name_list) 193 | # eval_name_dict = change_eval_order(args.order, eval_name_dict) 194 | # args.eval_name_dict = eval_name_dict 195 | 196 | # log_print('domain training tasks sequence: {}, corresponding data size: {}'.format(train_name_list, [len(d.dataset) for d in train_loaders]), args.log_file) 197 | # log_print('domain validation data size: {}\n'.format([len(eval_loaders[i].dataset) for i in eval_name_dict['valid']]), args.log_file) 198 | # # log_print('target domain data: {}, corresponding data size: {}'.format(target_name_list, [len(eval_loaders[i].dataset) for i in eval_name_dict['target']]), args.log_file) 199 | 200 | # return train_loaders, eval_loaders, eval_name_dict, train_name_list 201 | 202 | # def change_order(order, original_list): 203 | # ''' 204 | # change training domain order based on args.order 205 | # e.g. 206 | # original_list = [a,b,c], order = [2,1,0] 207 | # new_original_list = [c,b,a] 208 | # ''' 209 | # new_list = [] 210 | # for i in order: 211 | # new_list.append(original_list[i]) 212 | # return new_list 213 | 214 | # def change_eval_order(order, eval_name_dict): 215 | # eval_name_dict['train'] = change_order(order, eval_name_dict['train']) 216 | # eval_name_dict['valid'] = change_order(order, eval_name_dict['valid']) 217 | # return eval_name_dict 218 | 219 | 220 | # class utilDataset(Dataset): 221 | # ''' 222 | # construct pseudo dataset 223 | # input: images_dict. 224 | # ''' 225 | # def __init__(self, images_dict, class_labels, domain_labels, loader, transform=None, target_transform=None): 226 | # self.x = images_dict # list of [PIL image] 227 | # self.labels = class_labels # numpy array 228 | # self.dlabels = domain_labels # numpy array 229 | # self.loader = loader 230 | # self.transform = transform 231 | 232 | # def __len__(self): 233 | # return len(self.labels) 234 | 235 | # def __getitem__(self, index): 236 | # imgs = self.transform(self.loader(self.x[index])) if self.transform is not None else self.loader(self.x[index]) 237 | # return imgs, self.labels[index], self.dlabels[index] 238 | 239 | # def get_raw_data(self): 240 | # return self.x, self.labels, self.dlabels --------------------------------------------------------------------------------