├── LICENSE ├── README.md ├── data └── office-home │ ├── Art_list.txt │ ├── Clipart_list.txt │ ├── Product_list.txt │ └── RealWorld_list.txt ├── demo_mixmatch.py ├── demo_ssda.py ├── demo_ssda_mixmatch.py ├── demo_uda.py ├── logs └── uda │ └── run1 │ └── mixmatch │ ├── office-home │ ├── AC │ │ └── mixmatch_atdoc_naatdoc_na5.txt │ ├── AP │ │ └── mixmatch_atdoc_naatdoc_na5.txt │ └── AR │ │ └── mixmatch_atdoc_naatdoc_na5.txt │ └── office │ └── AD │ └── mixmatch_atdoc_naatdoc_na5.txt ├── loss.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 tim-learn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official implementation for **ATDOC** 2 | 3 | [**[CVPR-2021] Domain Adaptation with Auxiliary Target Domain-Oriented Classifier**](https://arxiv.org/pdf/2007.04171.pdf) 4 | 5 | [Update @ Nov 23 2021] 6 | 7 | 1. **[For Office, please change the *max-epoch* to 100; for VISDA-C, change the *max-epoch* to 1 and change the *net* to resnet101]** 8 | 2. **Add the code associated with SSDA, change the *max-epoch* to 20 for DomainNet-126** 9 | 3. **Thank @lyxok1 for pointing out the typo in Eq.(6), we have corrected it in the new verison of this paper.** 10 | 11 | 12 | 13 | Below is the demo for **ATDOC** on a UDA task of Office-Home [*max_epoch* to 50]: 14 | 15 | 16 | 1. installing packages 17 | 18 | `python == 3.6.8` 19 | `pytorch ==1.1.0` 20 | `torchvision == 0.3.0` 21 | `numpy, scipy, sklearn, PIL, argparse, tqdm` 22 | 23 | 2. download the Office-Home dataset 24 | 25 | `mkdir dataset` 26 | 27 | `cd dataset` 28 | 29 | `pip install gdown` 30 | 31 | `gdown https://drive.google.com/u/0/uc?id=0B81rNlvomiwed0V1YUxQdC1uOTg&export=download` 32 | 33 | `unzip OfficeHomeDataset_10072016.zip` 34 | 35 | `mv ./OfficeHomeDataset_10072016/Real\ World ./OfficeHomeDataset_10072016/RealWorld` 36 | 37 | `cd ../` 38 | 39 | 3. run the main file with '**Source-model-only**' 40 | 41 | `python demo_uda.py --pl none --dset office-home --max_epoch 50 --s 0 --t 1 --gpu_id 0 --method srconly --output logs/uda/run1/` 42 | 43 | 4. run the main file with '**ATDOC-NC**' 44 | 45 | `python demo_uda.py --pl atdoc_nc --tar_par 0.1 --dset office-home --max_epoch 50 --s 0 --t 1 --gpu_id 0 --method srconly --output logs/uda/run1/` 46 | 47 | 5. run the main file with '**ATDOC-NA**' 48 | 49 | `python demo_uda.py --pl atdoc_na --tar_par 0.2 --dset office-home --max_epoch 50 --s 0 --t 1 --gpu_id 0 --method srconly --output logs/uda/run1/` 50 | 51 | 6. run the main file with '**ATDOC-NA**' combined with '**CDAN+E**' 52 | 53 | `python demo_uda.py --pl atdoc_na --tar_par 0.2 --dset office-home --max_epoch 50 --s 0 --t 1 --gpu_id 0 --method CDANE --output logs/uda/run1/` 54 | 55 | 7. run the main file with '**ATDOC-NA**' combined with '**MixMatch**' 56 | 57 | `python demo_mixmatch.py --pl none --dset office-home --max_epoch 50 --s 0 --t 1 --gpu_id 0 --output logs/uda/run1/` 58 | 59 | 8. run the main file with '**ATDOC-NA**' combined with '**MixMatch**' 60 | 61 | `python demo_mixmatch.py --pl atdoc_na --dset office-home --max_epoch 50 --s 0 --t 1 --gpu_id 0 --output logs/uda/run1/` 62 | 63 | 64 | 65 | 66 | ### Citation 67 | 68 | If you find this code useful for your research, please cite our paper 69 | 70 | > @inproceedings{liang2021domain, 71 | >     title={Domain Adaptation with Auxiliary Target Domain-Oriented Classifier}, 72 | >     author={Liang, Jian and Hu, Dapeng and Feng, Jiashi}, 73 | >     booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 74 | >     year={2021} 75 | > } 76 | > 77 | ### Contact 78 | 79 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com) 80 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu) 81 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg) -------------------------------------------------------------------------------- /demo_mixmatch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import loss 10 | import random, pdb, math 11 | import sys, copy 12 | from tqdm import tqdm 13 | import utils, pickle 14 | import scipy.io as sio 15 | 16 | def data_load(args): 17 | train_transform = torchvision.transforms.Compose([ 18 | torchvision.transforms.Resize((256, 256)), 19 | torchvision.transforms.RandomCrop((224, 224)), 20 | torchvision.transforms.RandomHorizontalFlip(), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]) 24 | test_transform = torchvision.transforms.Compose([ 25 | torchvision.transforms.Resize((256, 256)), 26 | torchvision.transforms.CenterCrop((224, 224)), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | 31 | source_set = utils.ObjectImage('', args.s_dset_path, train_transform) 32 | target_set = utils.ObjectImage_mul('', args.t_dset_path, [train_transform, train_transform]) 33 | test_set = utils.ObjectImage('', args.test_dset_path, test_transform) 34 | 35 | dset_loaders = {} 36 | dset_loaders["source"] = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size, 37 | shuffle=True, num_workers=args.worker, drop_last=True) 38 | dset_loaders["target"] = torch.utils.data.DataLoader(target_set, batch_size=args.batch_size, 39 | shuffle=True, num_workers=args.worker, drop_last=True) 40 | dset_loaders["test"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3, 41 | shuffle=False, num_workers=args.worker, drop_last=False) 42 | 43 | return dset_loaders 44 | 45 | def lr_scheduler(optimizer, init_lr, iter_num, max_iter, gamma=10, power=0.75): 46 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 47 | for param_group in optimizer.param_groups: 48 | param_group['lr'] = init_lr * decay 49 | param_group['weight_decay'] = 1e-3 50 | param_group['momentum'] = 0.9 51 | param_group['nesterov'] = True 52 | return optimizer 53 | 54 | def train(args): 55 | ## set pre-process 56 | dset_loaders = data_load(args) 57 | 58 | max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) 59 | args.max_iter = args.max_epoch*max_len 60 | 61 | ## set base network 62 | if args.net == 'resnet101': 63 | netG = utils.ResBase101().cuda() 64 | elif args.net == 'resnet50': 65 | netG = utils.ResBase50().cuda() 66 | 67 | netF = utils.ResClassifier(class_num=args.class_num, feature_dim=netG.in_features, 68 | bottleneck_dim=args.bottleneck_dim).cuda() 69 | 70 | if len(args.gpu_id.split(',')) > 1: 71 | netG = nn.DataParallel(netG) 72 | 73 | optimizer_g = optim.SGD(netG.parameters(), lr = args.lr * 0.1) 74 | optimizer_f = optim.SGD(netF.parameters(), lr = args.lr) 75 | 76 | base_network = nn.Sequential(netG, netF) 77 | source_loader_iter = iter(dset_loaders["source"]) 78 | target_loader_iter = iter(dset_loaders["target"]) 79 | 80 | if args.pl == 'atdoc_na': 81 | mem_fea = torch.rand(1*len(dset_loaders["target"].dataset), args.bottleneck_dim).cuda() 82 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 83 | mem_cls = torch.ones(1*len(dset_loaders["target"].dataset), args.class_num).cuda() / args.class_num 84 | 85 | if args.pl == 'atdoc_nc': 86 | mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda() 87 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 88 | 89 | list_acc = [] 90 | best_ent = 100 91 | 92 | for iter_num in range(1, args.max_iter + 1): 93 | base_network.train() 94 | lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=args.max_iter) 95 | lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) 96 | 97 | try: 98 | inputs_source, labels_source = source_loader_iter.next() 99 | except: 100 | source_loader_iter = iter(dset_loaders["source"]) 101 | inputs_source, labels_source = source_loader_iter.next() 102 | try: 103 | inputs_target, _, target_idx = target_loader_iter.next() 104 | except: 105 | target_loader_iter = iter(dset_loaders["target"]) 106 | inputs_target, _, target_idx = target_loader_iter.next() 107 | 108 | targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(1, labels_source.view(-1,1), 1) 109 | inputs_s = inputs_source.cuda() 110 | targets_s = targets_s.cuda() 111 | inputs_t = inputs_target[0].cuda() 112 | inputs_t2 = inputs_target[1].cuda() 113 | 114 | if args.pl == 'atdoc_na': 115 | 116 | targets_u = 0 117 | for inp in [inputs_t, inputs_t2]: 118 | with torch.no_grad(): 119 | features_target, outputs_u = base_network(inp) 120 | 121 | dis = -torch.mm(features_target.detach(), mem_fea.t()) 122 | for di in range(dis.size(0)): 123 | dis[di, target_idx[di]] = torch.max(dis) 124 | # dis[di, target_idx[di]+len(dset_loaders["target"].dataset)] = torch.max(dis) 125 | 126 | _, p1 = torch.sort(dis, dim=1) 127 | w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda() 128 | for wi in range(w.size(0)): 129 | for wj in range(args.K): 130 | w[wi][p1[wi, wj]] = 1/ args.K 131 | 132 | _, pred = torch.max(w.mm(mem_cls), 1) 133 | 134 | targets_u += 0.5*torch.eye(outputs_u.size(1))[pred].cuda() 135 | 136 | elif args.pl == 'atdoc_nc': 137 | 138 | targets_u = 0 139 | mem_fea_norm = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 140 | for inp in [inputs_t, inputs_t2]: 141 | with torch.no_grad(): 142 | features_target, outputs_u = base_network(inp) 143 | dis = torch.mm(features_target.detach(), mem_fea_norm.t()) 144 | _, pred = torch.max(dis, dim=1) 145 | targets_u += 0.5*torch.eye(outputs_u.size(1))[pred].cuda() 146 | 147 | elif args.pl == 'npl': 148 | 149 | targets_u = 0 150 | for inp in [inputs_t, inputs_t2]: 151 | with torch.no_grad(): 152 | _, outputs_u = base_network(inp) 153 | _, pred = torch.max(outputs_u.detach(), 1) 154 | targets_u += 0.5*torch.eye(outputs_u.size(1))[pred].cuda() 155 | 156 | else: 157 | with torch.no_grad(): 158 | # compute guessed labels of unlabel samples 159 | _, outputs_u = base_network(inputs_t) 160 | _, outputs_u2 = base_network(inputs_t2) 161 | p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 162 | pt = p**(1/args.T) 163 | targets_u = pt / pt.sum(dim=1, keepdim=True) 164 | targets_u = targets_u.detach() 165 | 166 | #################################################################### 167 | all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0) 168 | all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0) 169 | if args.alpha > 0: 170 | l = np.random.beta(args.alpha, args.alpha) 171 | l = max(l, 1-l) 172 | else: 173 | l = 1 174 | idx = torch.randperm(all_inputs.size(0)) 175 | 176 | input_a, input_b = all_inputs, all_inputs[idx] 177 | target_a, target_b = all_targets, all_targets[idx] 178 | mixed_input = l * input_a + (1 - l) * input_b 179 | mixed_target = l * target_a + (1 - l) * target_b 180 | 181 | # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 182 | mixed_input = list(torch.split(mixed_input, args.batch_size)) 183 | mixed_input = utils.interleave(mixed_input, args.batch_size) 184 | # s = [sa, sb, sc] 185 | # t1 = [t1a, t1b, t1c] 186 | # t2 = [t2a, t2b, t2c] 187 | # => s' = [sa, t1b, t2c] t1' = [t1a, sb, t1c] t2' = [t2a, t2b, sc] 188 | 189 | # _, logits = base_network(mixed_input[0]) 190 | features, logits = base_network(mixed_input[0]) 191 | logits = [logits] 192 | for input in mixed_input[1:]: 193 | _, temp = base_network(input) 194 | logits.append(temp) 195 | 196 | # put interleaved samples back 197 | # [i[:,0] for i in aa] 198 | logits = utils.interleave(logits, args.batch_size) 199 | logits_x = logits[0] 200 | logits_u = torch.cat(logits[1:], dim=0) 201 | 202 | train_criterion = utils.SemiLoss() 203 | 204 | Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], 205 | iter_num, args.max_iter, args.lambda_u) 206 | total_loss = Lx + w * Lu 207 | 208 | optimizer_g.zero_grad() 209 | optimizer_f.zero_grad() 210 | total_loss.backward() 211 | optimizer_g.step() 212 | optimizer_f.step() 213 | 214 | if args.pl == 'atdoc_na': 215 | base_network.eval() 216 | with torch.no_grad(): 217 | fea1, outputs1 = base_network(inputs_t) 218 | fea2, outputs2 = base_network(inputs_t2) 219 | feat = 0.5 * (fea1 + fea2) 220 | feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True) 221 | softmax_out = 0.5*(nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) 222 | softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 223 | 224 | mem_fea[target_idx] = (1.0 - args.momentum)*mem_fea[target_idx] + args.momentum*feat 225 | mem_cls[target_idx] = (1.0 - args.momentum)*mem_cls[target_idx] + args.momentum*softmax_out 226 | 227 | if args.pl == 'atdoc_nc': 228 | base_network.eval() 229 | with torch.no_grad(): 230 | fea1, outputs1 = base_network(inputs_t) 231 | fea2, outputs2 = base_network(inputs_t2) 232 | feat = 0.5*(fea1 + fea2) 233 | softmax_t = 0.5*(nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) 234 | _, pred_t = torch.max(softmax_t, 1) 235 | onehot_t = torch.eye(args.class_num)[pred_t].cuda() 236 | center_t = torch.mm(feat.t(), onehot_t) / (onehot_t.sum(dim=0) + 1e-8) 237 | 238 | mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * center_t.t().clone() 239 | 240 | if iter_num % int(args.eval_epoch * max_len) == 0: 241 | base_network.eval() 242 | if args.dset == 'VISDA-C': 243 | acc, py, score, y, tacc = utils.cal_acc_visda(dset_loaders["test"], base_network) 244 | args.out_file.write(tacc + '\n') 245 | args.out_file.flush() 246 | _ent = loss.Entropy(score) 247 | mean_ent = 0 248 | for ci in range(args.class_num): 249 | mean_ent += _ent[py==ci].mean() 250 | mean_ent /= args.class_num 251 | 252 | else: 253 | acc, py, score, y = utils.cal_acc(dset_loaders["test"], base_network) 254 | mean_ent = torch.mean(loss.Entropy(score)) 255 | 256 | list_acc.append(acc * 100) 257 | 258 | if best_ent > mean_ent: 259 | val_acc = acc * 100 260 | best_ent = mean_ent 261 | best_y = y 262 | best_py = py 263 | best_score = score 264 | 265 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.name, iter_num, args.max_iter, acc*100, mean_ent) 266 | args.out_file.write(log_str + '\n') 267 | args.out_file.flush() 268 | print(log_str+'\n') 269 | 270 | idx = np.argmax(np.array(list_acc)) 271 | max_acc = list_acc[idx] 272 | final_acc = list_acc[-1] 273 | 274 | log_str = '\n==========================================\n' 275 | log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(val_acc, max_acc, final_acc) 276 | args.out_file.write(log_str + '\n') 277 | args.out_file.flush() 278 | 279 | # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt")) 280 | # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), 281 | # 'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()}) 282 | 283 | return base_network, py 284 | 285 | if __name__ == "__main__": 286 | parser = argparse.ArgumentParser(description='Mixmatch for Domain Adaptation') 287 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 288 | parser.add_argument('--s', type=int, default=0, help="source") 289 | parser.add_argument('--t', type=int, default=1, help="target") 290 | parser.add_argument('--output', type=str, default='san') 291 | parser.add_argument('--seed', type=int, default=0, help="random seed") 292 | parser.add_argument('--max_epoch', type=int, default=50) 293 | parser.add_argument('--batch_size', type=int, default=36, help="batch_size") 294 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 295 | parser.add_argument('--bottleneck_dim', type=int, default=256) 296 | 297 | parser.add_argument('--net', type=str, default='resnet50', choices=["resnet50", "resnet101"]) 298 | parser.add_argument('--dset', type=str, default='office-home', choices=['DomainNet126', 'VISDA-C', 'office', 'office-home'], help="The dataset or source dataset used") 299 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 300 | parser.add_argument('--pl', type=str, default='none', choices=['none', 'npl', 'atdoc_na', 'atdoc_nc']) 301 | parser.add_argument('--K', type=int, default=5) 302 | parser.add_argument('--momentum', type=float, default=1.0) 303 | 304 | parser.add_argument('--alpha', default=0.75, type=float) 305 | parser.add_argument('--lambda_u', default=100, type=float) 306 | parser.add_argument('--T', default=0.5, type=float) 307 | parser.add_argument('--ema_decay', default=0.999, type=float) 308 | 309 | args = parser.parse_args() 310 | if args.pl == 'atdoc_na': 311 | args.pl += args.pl + str(args.K) 312 | args.momentum = 1.0 313 | if args.pl == 'atdoc_nc': 314 | args.momentum = 0.1 315 | 316 | args.eval_epoch = args.max_epoch / 10 317 | 318 | if args.dset == 'office-home': 319 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 320 | args.class_num = 65 321 | if args.dset == 'office': 322 | names = ['amazon', 'dslr', 'webcam'] 323 | args.class_num = 31 324 | if args.dset == 'DomainNet126': 325 | names = ['clipart', 'painting', 'real', 'sketch'] 326 | args.class_num = 126 327 | if args.dset == 'VISDA-C': 328 | names = ['train', 'validation'] 329 | args.class_num = 12 330 | 331 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 332 | SEED = args.seed 333 | torch.manual_seed(SEED) 334 | torch.cuda.manual_seed(SEED) 335 | np.random.seed(SEED) 336 | random.seed(SEED) 337 | # torch.backends.cudnn.deterministic = True 338 | 339 | args.s_dset_path = './data/' + args.dset + '/' + names[args.s] + '_list.txt' 340 | args.t_dset_path = './data/' + args.dset + '/' + names[args.t] + '_list.txt' 341 | args.test_dset_path = args.t_dset_path 342 | 343 | args.output_dir = osp.join(args.output, 'mixmatch', args.dset, 344 | names[args.s][0].upper() + names[args.t][0].upper()) 345 | 346 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 347 | if not osp.exists(args.output_dir): 348 | os.system('mkdir -p ' + args.output_dir) 349 | if not osp.exists(args.output_dir): 350 | os.mkdir(args.output_dir) 351 | 352 | args.log = 'mixmatch_' + args.pl 353 | args.out_file = open(osp.join(args.output_dir, "{:}.txt".format(args.log)), "w") 354 | 355 | utils.print_args(args) 356 | 357 | train(args) -------------------------------------------------------------------------------- /demo_ssda.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import loss 10 | import random, pdb, math 11 | import sys, copy 12 | from tqdm import tqdm 13 | import utils 14 | import scipy.io as sio 15 | 16 | def data_load(args): 17 | train_transform = torchvision.transforms.Compose([ 18 | torchvision.transforms.Resize((256, 256)), 19 | torchvision.transforms.RandomCrop((224, 224)), 20 | torchvision.transforms.RandomHorizontalFlip(), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]) 24 | test_transform = torchvision.transforms.Compose([ 25 | torchvision.transforms.Resize((256, 256)), 26 | torchvision.transforms.CenterCrop((224, 224)), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | 31 | source_set = utils.ObjectImage('', args.s_dset_path, train_transform) 32 | ltarget_set = utils.ObjectImage_mul('', args.lt_dset_path, train_transform) 33 | target_set = utils.ObjectImage_mul('', args.t_dset_path, train_transform) 34 | val_set = utils.ObjectImage('', args.vt_dset_path, test_transform) 35 | test_set = utils.ObjectImage('', args.test_dset_path, test_transform) 36 | 37 | dset_loaders = {} 38 | dset_loaders["source"] = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size, 39 | shuffle=True, num_workers=args.worker, drop_last=True) 40 | dset_loaders["ltarget"] = torch.utils.data.DataLoader(ltarget_set, batch_size=args.batch_size//3, 41 | shuffle=True, num_workers=args.worker, drop_last=True) 42 | dset_loaders["target"] = torch.utils.data.DataLoader(target_set, batch_size=2*args.batch_size//3, 43 | shuffle=True, num_workers=args.worker, drop_last=True) 44 | dset_loaders["val"] = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size*3, 45 | shuffle=False, num_workers=args.worker, drop_last=False) 46 | dset_loaders["test"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3, 47 | shuffle=False, num_workers=args.worker, drop_last=False) 48 | return dset_loaders 49 | 50 | def lr_scheduler(optimizer, init_lr, iter_num, max_iter, gamma=10, power=0.75): 51 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = init_lr * decay 54 | param_group['weight_decay'] = 1e-3 55 | param_group['momentum'] = 0.9 56 | param_group['nesterov'] = True 57 | return optimizer 58 | 59 | def bsp_loss(feature): 60 | train_bs = feature.size(0) // 2 61 | feature_s = feature.narrow(0, 0, train_bs) 62 | feature_t = feature.narrow(0, train_bs, train_bs) 63 | _, s_s, _ = torch.svd(feature_s) 64 | _, s_t, _ = torch.svd(feature_t) 65 | sigma = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2) 66 | sigma *= 0.0001 67 | return sigma 68 | 69 | def train(args): 70 | ## set pre-process 71 | dset_loaders = data_load(args) 72 | class_num = args.class_num 73 | class_weight_src = torch.ones(class_num, ).cuda() 74 | ################################################################################################## 75 | 76 | ## set base network 77 | if args.net == 'resnet34': 78 | netG = utils.ResBase34().cuda() 79 | elif args.net == 'vgg16': 80 | netG = utils.VGG16Base().cuda() 81 | 82 | netF = utils.ResClassifier(class_num=class_num, feature_dim=netG.in_features, 83 | bottleneck_dim=args.bottleneck_dim).cuda() 84 | 85 | max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) 86 | args.max_iter = args.max_epoch * max_len 87 | 88 | ad_flag = False 89 | if args.method == 'DANN': 90 | ad_net = utils.AdversarialNetwork(args.bottleneck_dim, 1024, max_iter=args.max_iter).cuda() 91 | ad_flag = True 92 | if args.method == 'CDANE': 93 | ad_net = utils.AdversarialNetwork(args.bottleneck_dim*class_num, 1024, max_iter=args.max_iter).cuda() 94 | random_layer = None 95 | ad_flag = True 96 | 97 | optimizer_g = optim.SGD(netG.parameters(), lr = args.lr * 0.1) 98 | optimizer_f = optim.SGD(netF.parameters(), lr = args.lr) 99 | if ad_flag: 100 | optimizer_d = optim.SGD(ad_net.parameters(), lr = args.lr) 101 | 102 | base_network = nn.Sequential(netG, netF) 103 | 104 | if args.pl.startswith('atdoc_na'): 105 | mem_fea = torch.rand(len(dset_loaders["target"].dataset) + len(dset_loaders["ltarget"].dataset), args.bottleneck_dim).cuda() 106 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 107 | mem_cls = torch.ones(len(dset_loaders["target"].dataset) + len(dset_loaders["ltarget"].dataset), class_num).cuda() / class_num 108 | 109 | if args.pl == 'atdoc_nc': 110 | mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda() 111 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 112 | 113 | source_loader_iter = iter(dset_loaders["source"]) 114 | target_loader_iter = iter(dset_loaders["target"]) 115 | ltarget_loader_iter = iter(dset_loaders["ltarget"]) 116 | 117 | # ### 118 | list_acc = [] 119 | best_val_acc = 0 120 | 121 | for iter_num in range(1, args.max_iter + 1): 122 | # print(iter_num) 123 | base_network.train() 124 | lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=args.max_iter) 125 | lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) 126 | if ad_flag: 127 | lr_scheduler(optimizer_d, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) 128 | 129 | try: 130 | inputs_source, labels_source = source_loader_iter.next() 131 | except: 132 | source_loader_iter = iter(dset_loaders["source"]) 133 | inputs_source, labels_source = source_loader_iter.next() 134 | try: 135 | inputs_target, _, idx = target_loader_iter.next() 136 | except: 137 | target_loader_iter = iter(dset_loaders["target"]) 138 | inputs_target, _, idx = target_loader_iter.next() 139 | 140 | try: 141 | inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next() 142 | except: 143 | ltarget_loader_iter = iter(dset_loaders["ltarget"]) 144 | inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next() 145 | 146 | inputs_ltarget, labels_ltarget = inputs_ltarget.cuda(), labels_ltarget.cuda() 147 | 148 | inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda() 149 | 150 | if args.method == 'srconly' and args.pl == 'none': 151 | features_source, outputs_source = base_network(inputs_source) 152 | features_ltarget, outputs_ltarget = base_network(inputs_ltarget) 153 | else: 154 | features_ltarget, outputs_ltarget = base_network(inputs_ltarget) 155 | features_source, outputs_source = base_network(inputs_source) 156 | features_target, outputs_target = base_network(inputs_target) 157 | 158 | features_target = torch.cat((features_ltarget, features_target), dim=0) 159 | outputs_target = torch.cat((outputs_ltarget, outputs_target), dim=0) 160 | 161 | features = torch.cat((features_source, features_target), dim=0) 162 | outputs = torch.cat((outputs_source, outputs_target), dim=0) 163 | softmax_out = nn.Softmax(dim=1)(outputs) 164 | 165 | eff = utils.calc_coeff(iter_num, max_iter=args.max_iter) 166 | 167 | if args.method[-1] == 'E': 168 | entropy = loss.Entropy(softmax_out) 169 | else: 170 | entropy = None 171 | 172 | if args.method == 'CDANE': 173 | transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, eff, random_layer) 174 | 175 | elif args.method == 'DANN': 176 | transfer_loss = loss.DANN(features, ad_net, entropy, eff) 177 | 178 | elif args.method == 'srconly': 179 | transfer_loss = torch.tensor(0.0).cuda() 180 | else: 181 | raise ValueError('Method cannot be recognized.') 182 | 183 | src_ = loss.CrossEntropyLabelSmooth(reduction='none',num_classes=class_num, epsilon=args.smooth)(outputs_source, labels_source) 184 | weight_src = class_weight_src[labels_source].unsqueeze(0) 185 | classifier_loss = torch.sum(weight_src * src_) / (torch.sum(weight_src).item()) 186 | total_loss = transfer_loss + classifier_loss 187 | 188 | ltar_ = loss.CrossEntropyLabelSmooth(reduction='none',num_classes=class_num, epsilon=args.smooth)(outputs_ltarget, labels_ltarget) 189 | weight_src = class_weight_src[labels_ltarget].unsqueeze(0) 190 | ltar_classifier_loss = torch.sum(weight_src * ltar_) / (torch.sum(weight_src).item()) 191 | total_loss += ltar_classifier_loss 192 | 193 | eff = iter_num / args.max_iter 194 | 195 | if not args.pl == 'none': 196 | outputs_target = outputs_target[-args.batch_size//3:,:] 197 | features_target = features_target[-args.batch_size//3:,:] 198 | 199 | if args.pl == 'none': 200 | pass 201 | 202 | elif args.pl == 'square': 203 | softmax_out = nn.Softmax(dim=1)(outputs_target) 204 | square_loss = - torch.sqrt((softmax_out**2).sum(dim=1)).mean() 205 | total_loss += args.tar_par * eff * square_loss 206 | 207 | elif args.pl == 'bsp': 208 | sigma_loss = bsp_loss(features) 209 | total_loss += args.tar_par * sigma_loss 210 | 211 | elif args.pl == 'ent': 212 | softmax_out = nn.Softmax(dim=1)(outputs_target) 213 | ent_loss = torch.mean(loss.Entropy(softmax_out)) 214 | ent_loss /= torch.log(torch.tensor(class_num+0.0)) 215 | total_loss += args.tar_par * eff * ent_loss 216 | 217 | elif args.pl == 'bnm': 218 | softmax_out = nn.Softmax(dim=1)(outputs_target) 219 | bnm_loss = -torch.norm(softmax_out, 'nuc') 220 | cof = torch.tensor(np.sqrt(np.min(softmax_out.size())) / softmax_out.size(0)) 221 | bnm_loss *= cof 222 | total_loss += args.tar_par * eff * bnm_loss 223 | 224 | elif args.pl == 'mcc': 225 | softmax_out = nn.Softmax(dim=1)(outputs_target) 226 | ent_weight = 1 + torch.exp(-loss.Entropy(softmax_out)).detach() 227 | ent_weight /= ent_weight.sum() 228 | cov_tar = softmax_out.t().mm(torch.diag(softmax_out.size(0)*ent_weight)).mm(softmax_out) 229 | mcc_loss = (torch.diag(cov_tar)/ cov_tar.sum(dim=1)).mean() 230 | total_loss -= args.tar_par * eff * mcc_loss 231 | 232 | elif args.pl == 'npl': 233 | softmax_out = nn.Softmax(dim=1)(outputs_target) 234 | softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 235 | 236 | weight_, pred = torch.max(softmax_out, 1) 237 | loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target, pred) 238 | classifier_loss = torch.sum(weight_ * loss_) / (torch.sum(weight_).item()) 239 | total_loss += args.tar_par * eff * classifier_loss 240 | 241 | elif args.pl == 'atdoc_nc': 242 | mem_fea_norm = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 243 | dis = torch.mm(features_target.detach(), mem_fea_norm.t()) 244 | _, pred = torch.max(dis, dim=1) 245 | classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred) 246 | total_loss += args.tar_par * eff * classifier_loss 247 | 248 | elif args.pl.startswith('atdoc_na'): 249 | 250 | dis = -torch.mm(features_target.detach(), mem_fea.t()) 251 | for di in range(dis.size(0)): 252 | dis[di, idx[di]] = torch.max(dis) 253 | _, p1 = torch.sort(dis, dim=1) 254 | 255 | w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda() 256 | for wi in range(w.size(0)): 257 | for wj in range(args.K): 258 | w[wi][p1[wi, wj]] = 1/ args.K 259 | 260 | weight_, pred = torch.max(w.mm(mem_cls), 1) 261 | 262 | if args.pl.startswith('atdoc_na_now'): 263 | classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred) 264 | else: 265 | loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target, pred) 266 | classifier_loss = torch.sum(weight_ * loss_) / (torch.sum(weight_).item()) 267 | total_loss += args.tar_par * eff * classifier_loss 268 | 269 | optimizer_g.zero_grad() 270 | optimizer_f.zero_grad() 271 | if ad_flag: 272 | optimizer_d.zero_grad() 273 | total_loss.backward() 274 | optimizer_g.step() 275 | optimizer_f.step() 276 | if ad_flag: 277 | optimizer_d.step() 278 | 279 | if args.pl.startswith('atdoc_na'): 280 | base_network.eval() 281 | with torch.no_grad(): 282 | features_target, outputs_target = base_network(inputs_target) 283 | features_target = features_target / torch.norm(features_target, p=2, dim=1, keepdim=True) 284 | softmax_out = nn.Softmax(dim=1)(outputs_target) 285 | if args.pl.startswith('atdoc_na_nos'): 286 | outputs_target = softmax_out 287 | else: 288 | outputs_target = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 289 | 290 | mem_fea[idx] = (1.0 - args.momentum) * mem_fea[idx] + args.momentum * features_target.clone() 291 | mem_cls[idx] = (1.0 - args.momentum) * mem_cls[idx] + args.momentum * outputs_target.clone() 292 | 293 | with torch.no_grad(): 294 | features_ltarget, outputs_ltarget = base_network(inputs_ltarget) 295 | features_ltarget = features_ltarget / torch.norm(features_ltarget, p=2, dim=1, keepdim=True) 296 | softmax_out = nn.Softmax(dim=1)(outputs_ltarget) 297 | if args.pl.startswith('atdoc_na_nos'): 298 | outputs_ltarget = softmax_out 299 | else: 300 | outputs_ltarget = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 301 | 302 | mem_fea[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \ 303 | mem_fea[lidx + len(dset_loaders["target"].dataset)] + args.momentum * features_ltarget.clone() 304 | mem_cls[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \ 305 | mem_cls[lidx + len(dset_loaders["target"].dataset)] + args.momentum * outputs_ltarget.clone() 306 | 307 | if args.pl == 'atdoc_nc': 308 | base_network.eval() 309 | with torch.no_grad(): 310 | feat_u, outputs_target = base_network(inputs_target) 311 | softmax_t = nn.Softmax(dim=1)(outputs_target) 312 | _, pred_t = torch.max(softmax_t, 1) 313 | onehot_tu = torch.eye(args.class_num)[pred_t].cuda() 314 | 315 | feat_l, outputs_target = base_network(inputs_ltarget) 316 | softmax_t = nn.Softmax(dim=1)(outputs_target) 317 | _, pred_t = torch.max(softmax_t, 1) 318 | onehot_tl = torch.eye(args.class_num)[pred_t].cuda() 319 | 320 | center_t = ((torch.mm(feat_u.t(), onehot_tu) + torch.mm(feat_l.t(), onehot_tl))) / (onehot_tu.sum(dim=0) + onehot_tl.sum(dim=0) + 1e-8) 321 | mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * center_t.t().clone() 322 | 323 | if iter_num % int(args.eval_epoch*max_len) == 0: 324 | base_network.eval() 325 | acc, py, score, y = utils.cal_acc(dset_loaders["test"], base_network) 326 | val_acc, _, _, _ = utils.cal_acc(dset_loaders["val"], base_network) 327 | 328 | list_acc.append(acc*100) 329 | if best_val_acc <= val_acc: 330 | best_val_acc = val_acc 331 | best_acc = acc 332 | best_y = y 333 | best_py = py 334 | best_score = score 335 | 336 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Val Acc = {:.2f}%'.format(args.name, iter_num, args.max_iter, acc*100, val_acc*100) 337 | args.out_file.write(log_str + '\n') 338 | args.out_file.flush() 339 | print(log_str+'\n') 340 | 341 | val_acc = best_acc * 100 342 | idx = np.argmax(np.array(list_acc)) 343 | max_acc = list_acc[idx] 344 | final_acc = list_acc[-1] 345 | 346 | log_str = '\n==========================================\n' 347 | log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(val_acc, max_acc, final_acc) 348 | args.out_file.write(log_str + '\n') 349 | args.out_file.flush() 350 | 351 | if __name__ == "__main__": 352 | parser = argparse.ArgumentParser(description='Semi-supervised Domain Adaptation') 353 | parser.add_argument('--method', type=str, default='srconly', choices=['srconly', 'DANN', 'CDANE']) 354 | parser.add_argument('--pl', type=str, default='none', choices=['none', 'npl', 'ent', 'bsp', 'bnm', 'mcc', 355 | 'atdoc_na', 'atdoc_nc', 'atdoc_na_nos', 'atdoc_na_now']) 356 | # atdoc_na_nos: atdoc_na without predictions sharpening 357 | 358 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 359 | parser.add_argument('--s', type=int, default=0, help="source") 360 | parser.add_argument('--t', type=int, default=1, help="target") 361 | parser.add_argument('--output', type=str, default='san') 362 | parser.add_argument('--seed', type=int, default=0, help="random seed") 363 | parser.add_argument('--batch_size', type=int, default=36, help="batch_size") 364 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 365 | parser.add_argument('--bottleneck_dim', type=int, default=256) 366 | 367 | parser.add_argument('--max_epoch', type=int, default=30) 368 | parser.add_argument('--momentum', type=float, default=1.0) 369 | parser.add_argument('--K', type=int, default=5) 370 | parser.add_argument('--smooth', type=float, default=0.1) 371 | parser.add_argument('--tar_par', type=float, default=1.0) 372 | 373 | parser.add_argument('--net', type=str, default='resnet34', choices=["resnet50", "resnet101", "resnet34", "vgg16"]) 374 | parser.add_argument('--dset', type=str, default='multi', choices=['multi', 'office-home', 'office'], help="The dataset or source dataset used") 375 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 376 | parser.add_argument('--shot', type=int, default=3, choices=[1, 3]) 377 | 378 | args = parser.parse_args() 379 | args.output = args.output.strip() 380 | 381 | if args.pl.startswith('atdoc_na'): 382 | args.pl += str(args.K) 383 | if args.pl == 'atdoc_nc': 384 | args.momentum = 0.1 385 | 386 | args.eval_epoch = args.max_epoch / 10 387 | 388 | if args.dset == 'office-home': 389 | names = ['Art', 'Clipart', 'Product', 'Real'] 390 | args.class_num = 65 391 | if args.dset == 'multi': 392 | names = ['clipart', 'painting', 'real', 'sketch'] 393 | args.class_num = 126 394 | if args.dset == 'office': 395 | names = ['amazon', 'dslr', 'webcam'] 396 | args.class_num = 31 397 | 398 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 399 | SEED = args.seed 400 | torch.manual_seed(SEED) 401 | torch.cuda.manual_seed(SEED) 402 | np.random.seed(SEED) 403 | random.seed(SEED) 404 | # torch.backends.cudnn.deterministic = True 405 | 406 | args.s_dset_path = './data/ssda/' + args.dset + '/labeled_source_images_' \ 407 | + names[args.s] + '.txt' 408 | args.lt_dset_path = './data/ssda/' + args.dset + '/labeled_target_images_' \ 409 | + names[args.t] + '_' + str(args.shot) + '.txt' 410 | args.t_dset_path = './data/ssda/' + args.dset + '/unlabeled_target_images_' \ 411 | + names[args.t] + '_' + str(args.shot) + '.txt' 412 | args.vt_dset_path = './data/ssda/' + args.dset + '/validation_target_images_' \ 413 | + names[args.t] + '_3.txt' 414 | 415 | args.test_dset_path = args.t_dset_path 416 | if args.pl == 'none': 417 | args.output_dir = osp.join(args.output, args.pl, args.dset, 418 | names[args.s][0].upper() + names[args.t][0].upper()) 419 | else: 420 | args.output_dir = osp.join(args.output, args.pl + '_' + str(args.tar_par), args.dset, 421 | names[args.s][0].upper() + names[args.t][0].upper()) 422 | 423 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 424 | if not osp.exists(args.output_dir): 425 | os.system('mkdir -p ' + args.output_dir) 426 | if not osp.exists(args.output_dir): 427 | os.mkdir(args.output_dir) 428 | 429 | args.log = args.method + '_' + str(args.shot) 430 | args.out_file = open(osp.join(args.output_dir, "{:}.txt".format(args.log)), "w") 431 | 432 | utils.print_args(args) 433 | train(args) 434 | -------------------------------------------------------------------------------- /demo_ssda_mixmatch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import loss 10 | import random, pdb, math 11 | import sys, copy 12 | from tqdm import tqdm 13 | import utils, pickle 14 | import scipy.io as sio 15 | 16 | def data_load(args): 17 | train_transform = torchvision.transforms.Compose([ 18 | torchvision.transforms.Resize((256, 256)), 19 | torchvision.transforms.RandomCrop((224, 224)), 20 | torchvision.transforms.RandomHorizontalFlip(), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]) 24 | test_transform = torchvision.transforms.Compose([ 25 | torchvision.transforms.Resize((256, 256)), 26 | torchvision.transforms.CenterCrop((224, 224)), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | 31 | source_set = utils.ObjectImage('', args.s_dset_path, train_transform) 32 | ltarget_set = utils.ObjectImage_mul('', args.lt_dset_path, [train_transform, train_transform]) 33 | target_set = utils.ObjectImage_mul('', args.t_dset_path, [train_transform, train_transform]) 34 | val_set = utils.ObjectImage('', args.vt_dset_path, test_transform) 35 | test_set = utils.ObjectImage('', args.test_dset_path, test_transform) 36 | 37 | dset_loaders = {} 38 | dset_loaders["source"] = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size, 39 | shuffle=True, num_workers=args.worker, drop_last=True) 40 | dset_loaders["ltarget"] = torch.utils.data.DataLoader(ltarget_set, batch_size=args.batch_size//3, 41 | shuffle=True, num_workers=args.worker, drop_last=True) 42 | dset_loaders["target"] = torch.utils.data.DataLoader(target_set, batch_size=2*args.batch_size//3, 43 | shuffle=True, num_workers=args.worker, drop_last=True) 44 | dset_loaders["val"] = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size*3, 45 | shuffle=False, num_workers=args.worker, drop_last=False) 46 | dset_loaders["test"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3, 47 | shuffle=False, num_workers=args.worker, drop_last=False) 48 | return dset_loaders 49 | 50 | def lr_scheduler(optimizer, init_lr, iter_num, max_iter, gamma=10, power=0.75): 51 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = init_lr * decay 54 | param_group['weight_decay'] = 1e-3 55 | param_group['momentum'] = 0.9 56 | param_group['nesterov'] = True 57 | return optimizer 58 | 59 | def train(args): 60 | ## set pre-process 61 | dset_loaders = data_load(args) 62 | 63 | max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) 64 | args.max_iter = args.max_epoch*max_len 65 | 66 | ## set base network 67 | if args.net == 'resnet34': 68 | netG = utils.ResBase34().cuda() 69 | elif args.net == 'vgg16': 70 | netG = utils.VGG16Base().cuda() 71 | 72 | netF = utils.ResClassifier(class_num=args.class_num, feature_dim=netG.in_features, 73 | bottleneck_dim=args.bottleneck_dim).cuda() 74 | 75 | if len(args.gpu_id.split(',')) > 1: 76 | netG = nn.DataParallel(netG) 77 | 78 | optimizer_g = optim.SGD(netG.parameters(), lr = args.lr * 0.1) 79 | optimizer_f = optim.SGD(netF.parameters(), lr = args.lr) 80 | 81 | base_network = nn.Sequential(netG, netF) 82 | source_loader_iter = iter(dset_loaders["source"]) 83 | target_loader_iter = iter(dset_loaders["target"]) 84 | ltarget_loader_iter = iter(dset_loaders["ltarget"]) 85 | 86 | if args.pl.startswith('atdoc_na'): 87 | mem_fea = torch.rand(len(dset_loaders["target"].dataset) + len(dset_loaders["ltarget"].dataset), args.bottleneck_dim).cuda() 88 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 89 | mem_cls = torch.ones(len(dset_loaders["target"].dataset) + len(dset_loaders["ltarget"].dataset), args.class_num).cuda() / args.class_num 90 | 91 | if args.pl == 'atdoc_nc': 92 | mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda() 93 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 94 | 95 | list_acc = [] 96 | best_val_acc = 0 97 | 98 | for iter_num in range(1, args.max_iter + 1): 99 | base_network.train() 100 | lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=args.max_iter) 101 | lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) 102 | 103 | try: 104 | inputs_source, labels_source = source_loader_iter.next() 105 | except: 106 | source_loader_iter = iter(dset_loaders["source"]) 107 | inputs_source, labels_source = source_loader_iter.next() 108 | try: 109 | inputs_target, _, target_idx = target_loader_iter.next() 110 | except: 111 | target_loader_iter = iter(dset_loaders["target"]) 112 | inputs_target, _, target_idx = target_loader_iter.next() 113 | 114 | try: 115 | inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next() 116 | except: 117 | ltarget_loader_iter = iter(dset_loaders["ltarget"]) 118 | inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next() 119 | 120 | inputs_lt = inputs_ltarget[0].cuda() 121 | inputs_lt2 = inputs_ltarget[1].cuda() 122 | targets_lt = torch.zeros(args.batch_size//3, args.class_num).scatter_(1, labels_ltarget.view(-1,1), 1) 123 | targets_lt = targets_lt.cuda() 124 | 125 | targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(1, labels_source.view(-1,1), 1) 126 | inputs_s = inputs_source.cuda() 127 | targets_s = targets_s.cuda() 128 | inputs_t = inputs_target[0].cuda() 129 | inputs_t2 = inputs_target[1].cuda() 130 | 131 | if args.pl.startswith('atdoc_na'): 132 | 133 | targets_u = 0 134 | for inp in [inputs_t, inputs_t2]: 135 | with torch.no_grad(): 136 | features_target, outputs_u = base_network(inp) 137 | 138 | dis = -torch.mm(features_target.detach(), mem_fea.t()) 139 | for di in range(dis.size(0)): 140 | dis[di, target_idx[di]] = torch.max(dis) 141 | # dis[di, target_idx[di]+len(dset_loaders["target"].dataset)] = torch.max(dis) 142 | 143 | _, p1 = torch.sort(dis, dim=1) 144 | w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda() 145 | for wi in range(w.size(0)): 146 | for wj in range(args.K): 147 | w[wi][p1[wi, wj]] = 1/ args.K 148 | 149 | _, pred = torch.max(w.mm(mem_cls), 1) 150 | 151 | targets_u += 0.5*torch.eye(outputs_u.size(1))[pred].cuda() 152 | 153 | elif args.pl == 'atdoc_nc': 154 | 155 | targets_u = 0 156 | mem_fea_norm = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 157 | for inp in [inputs_t, inputs_t2]: 158 | with torch.no_grad(): 159 | features_target, outputs_u = base_network(inp) 160 | dis = torch.mm(features_target.detach(), mem_fea_norm.t()) 161 | _, pred = torch.max(dis, dim=1) 162 | targets_u += 0.5*torch.eye(outputs_u.size(1))[pred].cuda() 163 | 164 | elif args.pl == 'npl': 165 | 166 | targets_u = 0 167 | for inp in [inputs_t, inputs_t2]: 168 | with torch.no_grad(): 169 | _, outputs_u = base_network(inp) 170 | _, pred = torch.max(outputs_u.detach(), 1) 171 | targets_u += 0.5*torch.eye(outputs_u.size(1))[pred].cuda() 172 | 173 | else: 174 | with torch.no_grad(): 175 | # compute guessed labels of unlabel samples 176 | _, outputs_u = base_network(inputs_t) 177 | _, outputs_u2 = base_network(inputs_t2) 178 | p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 179 | pt = p**(1/args.T) 180 | targets_u = pt / pt.sum(dim=1, keepdim=True) 181 | targets_u = targets_u.detach() 182 | 183 | #################################################################### 184 | all_inputs = torch.cat([inputs_s, inputs_lt, inputs_t, inputs_lt2, inputs_t2], dim=0) 185 | all_targets = torch.cat([targets_s, targets_lt, targets_u, targets_lt, targets_u], dim=0) 186 | if args.alpha > 0: 187 | l = np.random.beta(args.alpha, args.alpha) 188 | l = max(l, 1-l) 189 | else: 190 | l = 1 191 | idx = torch.randperm(all_inputs.size(0)) 192 | 193 | input_a, input_b = all_inputs, all_inputs[idx] 194 | target_a, target_b = all_targets, all_targets[idx] 195 | mixed_input = l * input_a + (1 - l) * input_b 196 | mixed_target = l * target_a + (1 - l) * target_b 197 | 198 | # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 199 | mixed_input = list(torch.split(mixed_input, args.batch_size)) 200 | mixed_input = utils.interleave(mixed_input, args.batch_size) 201 | # s = [sa, sb, sc] 202 | # t1 = [t1a, t1b, t1c] 203 | # t2 = [t2a, t2b, t2c] 204 | # => s' = [sa, t1b, t2c] t1' = [t1a, sb, t1c] t2' = [t2a, t2b, sc] 205 | 206 | # _, logits = base_network(mixed_input[0]) 207 | features, logits = base_network(mixed_input[0]) 208 | logits = [logits] 209 | for input in mixed_input[1:]: 210 | _, temp = base_network(input) 211 | logits.append(temp) 212 | 213 | # put interleaved samples back 214 | # [i[:,0] for i in aa] 215 | logits = utils.interleave(logits, args.batch_size) 216 | logits_x = logits[0] 217 | logits_u = torch.cat(logits[1:], dim=0) 218 | 219 | train_criterion = utils.SemiLoss() 220 | 221 | Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], 222 | iter_num, args.max_iter, args.lambda_u) 223 | loss = Lx + w * Lu 224 | 225 | optimizer_g.zero_grad() 226 | optimizer_f.zero_grad() 227 | loss.backward() 228 | optimizer_g.step() 229 | optimizer_f.step() 230 | 231 | if args.pl.startswith('atdoc_na'): 232 | base_network.eval() 233 | with torch.no_grad(): 234 | fea1, outputs1 = base_network(inputs_t) 235 | fea2, outputs2 = base_network(inputs_t2) 236 | feat = 0.5 * (fea1 + fea2) 237 | feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True) 238 | softmax_out = 0.5*(nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) 239 | softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 240 | 241 | mem_fea[target_idx] = (1.0 - args.momentum)*mem_fea[target_idx] + args.momentum*feat 242 | mem_cls[target_idx] = (1.0 - args.momentum)*mem_cls[target_idx] + args.momentum*softmax_out 243 | 244 | with torch.no_grad(): 245 | fea1, outputs1 = base_network(inputs_lt) 246 | fea2, outputs2 = base_network(inputs_lt2) 247 | feat = 0.5 * (fea1 + fea2) 248 | feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True) 249 | softmax_out = 0.5*(nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) 250 | softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 251 | 252 | mem_fea[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \ 253 | mem_fea[lidx + len(dset_loaders["target"].dataset)] + args.momentum*feat 254 | mem_cls[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \ 255 | mem_cls[lidx + len(dset_loaders["target"].dataset)] + args.momentum*softmax_out 256 | 257 | if args.pl == 'atdoc_nc': 258 | base_network.eval() 259 | with torch.no_grad(): 260 | fea1, outputs1 = base_network(inputs_t) 261 | fea2, outputs2 = base_network(inputs_t2) 262 | feat_u = 0.5 * (fea1 + fea2) 263 | softmax_t = 0.5*(nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) 264 | _, pred_t = torch.max(softmax_t, 1) 265 | onehot_tu = torch.eye(args.class_num)[pred_t].cuda() 266 | 267 | with torch.no_grad(): 268 | fea1, outputs1 = base_network(inputs_lt) 269 | fea2, outputs2 = base_network(inputs_lt2) 270 | feat_l = 0.5 * (fea1 + fea2) 271 | softmax_t = 0.5*(nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) 272 | _, pred_t = torch.max(softmax_t, 1) 273 | onehot_tl = torch.eye(args.class_num)[pred_t].cuda() 274 | # onehot_tl = torch.eye(args.class_num)[labels_ltarget].cuda() 275 | 276 | center_t = ((torch.mm(feat_u.t(), onehot_tu) + torch.mm(feat_l.t(), onehot_tl))) / (onehot_tu.sum(dim=0) + onehot_tl.sum(dim=0) + 1e-8) 277 | mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * center_t.t().clone() 278 | 279 | if iter_num % int(args.eval_epoch * max_len) == 0: 280 | base_network.eval() 281 | if args.dset == 'VISDA-C': 282 | acc, py, score, y, tacc = utils.cal_acc_visda(dset_loaders["test"], base_network) 283 | args.out_file.write(tacc + '\n') 284 | args.out_file.flush() 285 | else: 286 | acc, py, score, y = utils.cal_acc(dset_loaders["test"], base_network) 287 | val_acc, _, _, _ = utils.cal_acc(dset_loaders["val"], base_network) 288 | 289 | list_acc.append(acc*100) 290 | if best_val_acc <= val_acc: 291 | best_val_acc = val_acc 292 | best_acc = acc 293 | best_y = y 294 | best_py = py 295 | best_score = score 296 | 297 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Val Acc = {:.2f}%'.format(args.name, iter_num, args.max_iter, acc*100, val_acc*100) 298 | args.out_file.write(log_str + '\n') 299 | args.out_file.flush() 300 | print(log_str+'\n') 301 | 302 | val_acc = best_acc * 100 303 | idx = np.argmax(np.array(list_acc)) 304 | max_acc = list_acc[idx] 305 | final_acc = list_acc[-1] 306 | 307 | log_str = '\n==========================================\n' 308 | log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(val_acc, max_acc, final_acc) 309 | args.out_file.write(log_str + '\n') 310 | args.out_file.flush() 311 | 312 | # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt")) 313 | # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), 314 | # 'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()}) 315 | 316 | return base_network, py 317 | 318 | if __name__ == "__main__": 319 | parser = argparse.ArgumentParser(description='MixMatch for Domain Adaptation') 320 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 321 | parser.add_argument('--s', type=int, default=0, help="source") 322 | parser.add_argument('--t', type=int, default=1, help="target") 323 | parser.add_argument('--output', type=str, default='san') 324 | parser.add_argument('--seed', type=int, default=0, help="random seed") 325 | parser.add_argument('--max_epoch', type=int, default=50) 326 | parser.add_argument('--batch_size', type=int, default=36, help="batch_size") 327 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 328 | parser.add_argument('--bottleneck_dim', type=int, default=256) 329 | 330 | parser.add_argument('--net', type=str, default='resnet34', choices=["resnet50", "resnet101", "resnet34", "vgg16"]) 331 | parser.add_argument('--dset', type=str, default='multi', choices=['multi', 'office-home', 'office'], help="The dataset or source dataset used") 332 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 333 | parser.add_argument('--pl', type=str, default='none', choices=['none', 'npl', 'atdoc_na', 'atdoc_nc']) 334 | parser.add_argument('--K', type=int, default=5) 335 | parser.add_argument('--momentum', type=float, default=1.0) 336 | 337 | parser.add_argument('--alpha', default=0.75, type=float) 338 | parser.add_argument('--lambda_u', default=100, type=float) 339 | parser.add_argument('--T', default=0.5, type=float) 340 | parser.add_argument('--ema_decay', default=0.999, type=float) 341 | parser.add_argument('--shot', type=int, default=3, choices=[1, 3]) 342 | 343 | args = parser.parse_args() 344 | args.output = args.output.strip() 345 | 346 | if args.pl == 'atdoc_na': 347 | args.pl += str(args.K) 348 | if args.pl == 'atdoc_nc': 349 | args.momentum = 0.1 350 | 351 | args.eval_epoch = args.max_epoch / 10 352 | 353 | if args.dset == 'office-home': 354 | names = ['Art', 'Clipart', 'Product', 'Real'] 355 | args.class_num = 65 356 | if args.dset == 'multi': 357 | names = ['clipart', 'painting', 'real', 'sketch'] 358 | args.class_num = 126 359 | if args.dset == 'office': 360 | names = ['amazon', 'dslr', 'webcam'] 361 | args.class_num = 31 362 | 363 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 364 | SEED = args.seed 365 | torch.manual_seed(SEED) 366 | torch.cuda.manual_seed(SEED) 367 | np.random.seed(SEED) 368 | random.seed(SEED) 369 | # torch.backends.cudnn.deterministic = True 370 | 371 | args.s_dset_path = './data/ssda/' + args.dset + '/labeled_source_images_' \ 372 | + names[args.s] + '.txt' 373 | args.lt_dset_path = './data/ssda/' + args.dset + '/labeled_target_images_' \ 374 | + names[args.t] + '_' + str(args.shot) + '.txt' 375 | args.t_dset_path = './data/ssda/' + args.dset + '/unlabeled_target_images_' \ 376 | + names[args.t] + '_' + str(args.shot) + '.txt' 377 | args.vt_dset_path = './data/ssda/' + args.dset + '/validation_target_images_' \ 378 | + names[args.t] + '_3.txt' 379 | 380 | args.test_dset_path = args.t_dset_path 381 | 382 | args.output_dir = osp.join(args.output, 'mixmatch', args.dset, 383 | names[args.s][0].upper() + names[args.t][0].upper()) 384 | 385 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 386 | if not osp.exists(args.output_dir): 387 | os.system('mkdir -p ' + args.output_dir) 388 | if not osp.exists(args.output_dir): 389 | os.mkdir(args.output_dir) 390 | 391 | args.log = 'mixmatch_' + args.pl + '_' + str(args.shot) 392 | args.out_file = open(osp.join(args.output_dir, "{:}.txt".format(args.log)), "w") 393 | 394 | utils.print_args(args) 395 | train(args) -------------------------------------------------------------------------------- /demo_uda.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import loss 10 | import random, pdb, math 11 | import sys, copy 12 | from tqdm import tqdm 13 | import utils 14 | import scipy.io as sio 15 | 16 | def data_load(args): 17 | train_transform = torchvision.transforms.Compose([ 18 | torchvision.transforms.Resize((256, 256)), 19 | torchvision.transforms.RandomCrop((224, 224)), 20 | torchvision.transforms.RandomHorizontalFlip(), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]) 24 | test_transform = torchvision.transforms.Compose([ 25 | torchvision.transforms.Resize((256, 256)), 26 | torchvision.transforms.CenterCrop((224, 224)), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | 31 | source_set = utils.ObjectImage('', args.s_dset_path, train_transform) 32 | target_set = utils.ObjectImage_mul('', args.t_dset_path, train_transform) 33 | test_set = utils.ObjectImage('', args.test_dset_path, test_transform) 34 | 35 | dset_loaders = {} 36 | dset_loaders["source"] = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size, 37 | shuffle=True, num_workers=args.worker, drop_last=True) 38 | dset_loaders["target"] = torch.utils.data.DataLoader(target_set, batch_size=args.batch_size, 39 | shuffle=True, num_workers=args.worker, drop_last=True) 40 | dset_loaders["test"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3, 41 | shuffle=False, num_workers=args.worker, drop_last=False) 42 | 43 | return dset_loaders 44 | 45 | def data_load_y(args, labels): 46 | train_transform = torchvision.transforms.Compose([ 47 | torchvision.transforms.Resize((256, 256)), 48 | torchvision.transforms.RandomCrop((224, 224)), 49 | torchvision.transforms.RandomHorizontalFlip(), 50 | torchvision.transforms.ToTensor(), 51 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 52 | ]) 53 | test_transform = torchvision.transforms.Compose([ 54 | torchvision.transforms.Resize((256, 256)), 55 | torchvision.transforms.CenterCrop((224, 224)), 56 | torchvision.transforms.ToTensor(), 57 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 58 | ]) 59 | 60 | # pdb.set_trace() 61 | source_set = utils.ObjectImage_y('', args.t_dset_path, train_transform, labels) 62 | target_set = utils.ObjectImage_mul('', args.s_dset_path, train_transform) 63 | test_set = utils.ObjectImage('', args.s_dset_path, test_transform) 64 | 65 | dset_loaders = {} 66 | dset_loaders["source"] = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size, 67 | shuffle=True, num_workers=args.worker, drop_last=True) 68 | dset_loaders["target"] = torch.utils.data.DataLoader(target_set, batch_size=args.batch_size, 69 | shuffle=True, num_workers=args.worker, drop_last=True) 70 | dset_loaders["test"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3, 71 | shuffle=False, num_workers=args.worker, drop_last=False) 72 | 73 | return dset_loaders 74 | 75 | def lr_scheduler(optimizer, init_lr, iter_num, max_iter, gamma=10, power=0.75): 76 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 77 | for param_group in optimizer.param_groups: 78 | param_group['lr'] = init_lr * decay 79 | param_group['weight_decay'] = 1e-3 80 | param_group['momentum'] = 0.9 81 | param_group['nesterov'] = True 82 | return optimizer 83 | 84 | def bsp_loss(feature): 85 | train_bs = feature.size(0) // 2 86 | feature_s = feature.narrow(0, 0, train_bs) 87 | feature_t = feature.narrow(0, train_bs, train_bs) 88 | _, s_s, _ = torch.svd(feature_s) 89 | _, s_t, _ = torch.svd(feature_t) 90 | sigma = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2) 91 | sigma *= 0.0001 92 | return sigma 93 | 94 | def train(args, validate=False, label=None): 95 | ## set pre-process 96 | if validate: 97 | dset_loaders = data_load_y(args, label) 98 | else: 99 | dset_loaders = data_load(args) 100 | class_num = args.class_num 101 | class_weight_src = torch.ones(class_num, ).cuda() 102 | ################################################################################################## 103 | 104 | ## set base network 105 | if args.net == 'resnet101': 106 | netG = utils.ResBase101().cuda() 107 | elif args.net == 'resnet50': 108 | netG = utils.ResBase50().cuda() 109 | 110 | netF = utils.ResClassifier(class_num=class_num, feature_dim=netG.in_features, 111 | bottleneck_dim=args.bottleneck_dim).cuda() 112 | 113 | max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) 114 | args.max_iter = args.max_epoch * max_len 115 | 116 | ad_flag = False 117 | if args.method in {'DANN', 'DANNE'}: 118 | ad_net = utils.AdversarialNetwork(args.bottleneck_dim, 1024, max_iter=args.max_iter).cuda() 119 | ad_flag = True 120 | if args.method in {'CDAN', 'CDANE'}: 121 | ad_net = utils.AdversarialNetwork(args.bottleneck_dim*class_num, 1024, max_iter=args.max_iter).cuda() 122 | random_layer = None 123 | ad_flag = True 124 | 125 | optimizer_g = optim.SGD(netG.parameters(), lr = args.lr * 0.1) 126 | optimizer_f = optim.SGD(netF.parameters(), lr = args.lr) 127 | if ad_flag: 128 | optimizer_d = optim.SGD(ad_net.parameters(), lr = args.lr) 129 | 130 | base_network = nn.Sequential(netG, netF) 131 | 132 | if args.pl.startswith('atdoc_na'): 133 | mem_fea = torch.rand(len(dset_loaders["target"].dataset), args.bottleneck_dim).cuda() 134 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 135 | mem_cls = torch.ones(len(dset_loaders["target"].dataset), class_num).cuda() / class_num 136 | 137 | if args.pl == 'atdoc_nc': 138 | mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda() 139 | mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 140 | 141 | source_loader_iter = iter(dset_loaders["source"]) 142 | target_loader_iter = iter(dset_loaders["target"]) 143 | 144 | #### 145 | list_acc = [] 146 | best_ent = 100 147 | 148 | for iter_num in range(1, args.max_iter + 1): 149 | base_network.train() 150 | lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=args.max_iter) 151 | lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) 152 | if ad_flag: 153 | lr_scheduler(optimizer_d, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) 154 | 155 | try: 156 | inputs_source, labels_source = source_loader_iter.next() 157 | except: 158 | source_loader_iter = iter(dset_loaders["source"]) 159 | inputs_source, labels_source = source_loader_iter.next() 160 | try: 161 | inputs_target, _, idx = target_loader_iter.next() 162 | except: 163 | target_loader_iter = iter(dset_loaders["target"]) 164 | inputs_target, _, idx = target_loader_iter.next() 165 | 166 | inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda() 167 | 168 | if args.method == 'srconly' and args.pl == 'none': 169 | features_source, outputs_source = base_network(inputs_source) 170 | else: 171 | features_source, outputs_source = base_network(inputs_source) 172 | features_target, outputs_target = base_network(inputs_target) 173 | features = torch.cat((features_source, features_target), dim=0) 174 | outputs = torch.cat((outputs_source, outputs_target), dim=0) 175 | softmax_out = nn.Softmax(dim=1)(outputs) 176 | 177 | eff = utils.calc_coeff(iter_num, max_iter=args.max_iter) 178 | if args.method[-1] == 'E': 179 | entropy = loss.Entropy(softmax_out) 180 | else: 181 | entropy = None 182 | 183 | if args.method in {'CDAN', 'CDANE'}: 184 | transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, eff, random_layer) 185 | 186 | elif args.method in {'DANN', 'DANNE'}: 187 | transfer_loss = loss.DANN(features, ad_net, entropy, eff) 188 | 189 | elif args.method == 'DAN': 190 | transfer_loss = eff * loss.DAN(features_source, features_target) 191 | elif args.method == 'DAN_Linear': 192 | transfer_loss = eff * loss.DAN_Linear(features_source, features_target) 193 | 194 | elif args.method == 'JAN': 195 | transfer_loss = eff * loss.JAN([features_source, softmax_out[0:args.batch_size,:]], [features_target, softmax_out[args.batch_size::,:]]) 196 | elif args.method == 'JAN_Linear': 197 | transfer_loss = eff * loss.JAN_Linear([features_source, softmax_out[0:args.batch_size,:]], [features_target, softmax_out[args.batch_size::,:]]) 198 | 199 | elif args.method == 'CORAL': 200 | transfer_loss = eff * loss.CORAL(features_source, features_target) 201 | elif args.method == 'DDC': 202 | transfer_loss = loss.MMD_loss()(features_source, features_target) 203 | 204 | elif args.method == 'srconly': 205 | transfer_loss = torch.tensor(0.0).cuda() 206 | else: 207 | raise ValueError('Method cannot be recognized.') 208 | 209 | src_ = loss.CrossEntropyLabelSmooth(reduction='none',num_classes=class_num, epsilon=args.smooth)(outputs_source, labels_source) 210 | weight_src = class_weight_src[labels_source].unsqueeze(0) 211 | classifier_loss = torch.sum(weight_src * src_) / (torch.sum(weight_src).item()) 212 | total_loss = transfer_loss + classifier_loss 213 | 214 | eff = iter_num / args.max_iter 215 | 216 | if args.pl == 'none': 217 | pass 218 | 219 | elif args.pl == 'square': 220 | softmax_out = nn.Softmax(dim=1)(outputs_target) 221 | square_loss = - torch.sqrt((softmax_out**2).sum(dim=1)).mean() 222 | total_loss += args.tar_par * eff * square_loss 223 | 224 | elif args.pl == 'bsp': 225 | sigma_loss = bsp_loss(features) 226 | total_loss += args.tar_par * sigma_loss 227 | 228 | elif args.pl == 'bnm': 229 | softmax_out = nn.Softmax(dim=1)(outputs_target) 230 | bnm_loss = -torch.norm(softmax_out, 'nuc') 231 | cof = torch.tensor(np.sqrt(np.min(softmax_out.size())) / softmax_out.size(0)) 232 | bnm_loss *= cof 233 | total_loss += args.tar_par * eff * bnm_loss 234 | 235 | elif args.pl == "mcc": 236 | softmax_out = nn.Softmax(dim=1)(outputs_target) 237 | ent_weight = 1 + torch.exp(-loss.Entropy(softmax_out)).detach() 238 | ent_weight /= ent_weight.sum() 239 | cov_tar = softmax_out.t().mm(torch.diag(softmax_out.size(0)*ent_weight)).mm(softmax_out) 240 | mcc_loss = (torch.diag(cov_tar)/ cov_tar.sum(dim=1)).mean() 241 | total_loss -= args.tar_par * eff * mcc_loss 242 | 243 | elif args.pl == 'ent': 244 | softmax_out = nn.Softmax(dim=1)(outputs_target) 245 | ent_loss = torch.mean(loss.Entropy(softmax_out)) 246 | ent_loss /= torch.log(torch.tensor(class_num+0.0)) 247 | total_loss += args.tar_par * eff * ent_loss 248 | 249 | elif args.pl[0:3] == 'npl': 250 | softmax_out = nn.Softmax(dim=1)(outputs_target) 251 | softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 252 | 253 | weight_, pred = torch.max(softmax_out, 1) 254 | loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target, pred) 255 | classifier_loss = torch.sum(weight_ * loss_) / (torch.sum(weight_).item()) 256 | total_loss += args.tar_par * eff * classifier_loss 257 | 258 | elif args.pl == 'atdoc_nc': 259 | mem_fea_norm = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) 260 | dis = torch.mm(features_target.detach(), mem_fea_norm.t()) 261 | _, pred = torch.max(dis, dim=1) 262 | classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred) 263 | total_loss += args.tar_par * eff * classifier_loss 264 | 265 | elif args.pl.startswith('atdoc_na'): 266 | 267 | dis = -torch.mm(features_target.detach(), mem_fea.t()) 268 | for di in range(dis.size(0)): 269 | dis[di, idx[di]] = torch.max(dis) 270 | _, p1 = torch.sort(dis, dim=1) 271 | 272 | w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda() 273 | for wi in range(w.size(0)): 274 | for wj in range(args.K): 275 | w[wi][p1[wi, wj]] = 1/ args.K 276 | 277 | weight_, pred = torch.max(w.mm(mem_cls), 1) 278 | 279 | if args.pl == 'atdoc_na_now': 280 | classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred) 281 | else: 282 | loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target, pred) 283 | classifier_loss = torch.sum(weight_ * loss_) / (torch.sum(weight_).item()) 284 | total_loss += args.tar_par * eff * classifier_loss 285 | 286 | optimizer_g.zero_grad() 287 | optimizer_f.zero_grad() 288 | if ad_flag: 289 | optimizer_d.zero_grad() 290 | total_loss.backward() 291 | optimizer_g.step() 292 | optimizer_f.step() 293 | if ad_flag: 294 | optimizer_d.step() 295 | 296 | if args.pl.startswith('atdoc_na'): 297 | base_network.eval() 298 | with torch.no_grad(): 299 | features_target, outputs_target = base_network(inputs_target) 300 | features_target = features_target / torch.norm(features_target, p=2, dim=1, keepdim=True) 301 | softmax_out = nn.Softmax(dim=1)(outputs_target) 302 | if args.pl == 'atdoc_na_nos': 303 | outputs_target = softmax_out 304 | else: 305 | outputs_target = softmax_out**2 / ((softmax_out**2).sum(dim=0)) 306 | 307 | mem_fea[idx] = (1.0 - args.momentum) * mem_fea[idx] + args.momentum * features_target.clone() 308 | mem_cls[idx] = (1.0 - args.momentum) * mem_cls[idx] + args.momentum * outputs_target.clone() 309 | 310 | if args.pl == 'atdoc_nc': 311 | base_network.eval() 312 | with torch.no_grad(): 313 | features_target, outputs_target = base_network(inputs_target) 314 | softmax_t = nn.Softmax(dim=1)(outputs_target) 315 | _, pred_t = torch.max(softmax_t, 1) 316 | onehot_t = torch.eye(args.class_num)[pred_t].cuda() 317 | center_t = torch.mm(features_target.t(), onehot_t) / (onehot_t.sum(dim=0) + 1e-8) 318 | 319 | mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * center_t.t().clone() 320 | 321 | if iter_num % int(args.eval_epoch*max_len) == 0: 322 | base_network.eval() 323 | if args.dset == 'VISDA-C': 324 | acc, py, score, y, tacc = utils.cal_acc_visda(dset_loaders["test"], base_network) 325 | args.out_file.write(tacc + '\n') 326 | args.out_file.flush() 327 | 328 | _ent = loss.Entropy(score) 329 | mean_ent = 0 330 | for ci in range(args.class_num): 331 | mean_ent += _ent[py==ci].mean() 332 | mean_ent /= args.class_num 333 | 334 | else: 335 | acc, py, score, y = utils.cal_acc(dset_loaders["test"], base_network) 336 | mean_ent = torch.mean(loss.Entropy(score)) 337 | 338 | list_acc.append(acc * 100) 339 | if best_ent > mean_ent: 340 | best_ent = mean_ent 341 | val_acc = acc * 100 342 | best_y = y 343 | best_py = py 344 | best_score = score 345 | 346 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.name, iter_num, args.max_iter, acc*100, mean_ent) 347 | args.out_file.write(log_str + '\n') 348 | args.out_file.flush() 349 | print(log_str+'\n') 350 | 351 | idx = np.argmax(np.array(list_acc)) 352 | max_acc = list_acc[idx] 353 | final_acc = list_acc[-1] 354 | 355 | log_str = '\n==========================================\n' 356 | log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(val_acc, max_acc, final_acc) 357 | args.out_file.write(log_str + '\n') 358 | args.out_file.flush() 359 | 360 | # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt")) 361 | # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), 362 | # 'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()}) 363 | 364 | return best_y.cpu().numpy().astype(np.int64) 365 | 366 | 367 | if __name__ == "__main__": 368 | parser = argparse.ArgumentParser(description='Domain Adaptation Methods') 369 | parser.add_argument('--method', type=str, default='srconly', choices=['srconly', 'CDAN', 'CDANE', 'DANN', 370 | 'DANNE', 'JAN_Linear', 'JAN', 'DAN_Linear', 'DAN', 'CORAL', 'DDC']) 371 | parser.add_argument('--pl', type=str, default='none', choices=['none', 'square', 'npl', 'bnm', 'mcc', 'ent', 'bsp', 372 | 'atdoc_na', 'atdoc_nc', 'atdoc_na_now', 'atdoc_na_nos']) 373 | # atdoc_na_now: atdoc_na without instance weights 374 | # atdoc_na_nos: atdoc_na without predictions sharpening 375 | 376 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 377 | parser.add_argument('--s', type=int, default=0, help="source") 378 | parser.add_argument('--t', type=int, default=1, help="target") 379 | parser.add_argument('--output', type=str, default='san') 380 | parser.add_argument('--seed', type=int, default=0, help="random seed") 381 | parser.add_argument('--batch_size', type=int, default=36, help="batch_size") 382 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 383 | parser.add_argument('--bottleneck_dim', type=int, default=256) 384 | 385 | parser.add_argument('--max_epoch', type=int, default=30) 386 | parser.add_argument('--momentum', type=float, default=1.0) 387 | parser.add_argument('--K', type=int, default=5) 388 | parser.add_argument('--smooth', type=float, default=0.1) 389 | parser.add_argument('--tar_par', type=float, default=1.0) 390 | parser.add_argument('--validate', type=bool, default=False) 391 | 392 | parser.add_argument('--net', type=str, default='resnet50', choices=["resnet50", "resnet101"]) 393 | parser.add_argument('--dset', type=str, default='office-home', choices=['DomainNet126', 'VISDA-C', 'office', 'office-home'], help="The dataset or source dataset used") 394 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 395 | 396 | args = parser.parse_args() 397 | args.output = args.output.strip() 398 | 399 | if args.pl.startswith('atdoc_na'): 400 | args.pl += str(args.K) 401 | args.momentum = 1.0 402 | if args.pl == 'atdoc_nc': 403 | args.momentum = 0.1 404 | 405 | args.eval_epoch = args.max_epoch / 10 406 | 407 | if args.dset == 'office-home': 408 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 409 | args.class_num = 65 410 | if args.dset == 'office': 411 | names = ['amazon', 'dslr', 'webcam'] 412 | args.class_num = 31 413 | if args.dset == 'DomainNet126': 414 | names = ['clipart', 'painting', 'real', 'sketch'] 415 | args.class_num = 126 416 | if args.dset == 'VISDA-C': 417 | names = ['train', 'validation'] 418 | args.class_num = 12 419 | 420 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 421 | SEED = args.seed 422 | torch.manual_seed(SEED) 423 | torch.cuda.manual_seed(SEED) 424 | np.random.seed(SEED) 425 | random.seed(SEED) 426 | # torch.backends.cudnn.deterministic = True 427 | 428 | args.s_dset_path = './data/' + args.dset + '/' + names[args.s] + '_list.txt' 429 | args.t_dset_path = './data/' + args.dset + '/' + names[args.t] + '_list.txt' 430 | args.test_dset_path = args.t_dset_path 431 | 432 | if args.pl == 'none': 433 | args.output_dir = osp.join(args.output, args.pl, args.dset, 434 | names[args.s][0].upper() + names[args.t][0].upper()) 435 | else: 436 | args.output_dir = osp.join(args.output, args.pl + '_' + str(args.tar_par), args.dset, 437 | names[args.s][0].upper() + names[args.t][0].upper()) 438 | 439 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 440 | if not osp.exists(args.output_dir): 441 | os.system('mkdir -p ' + args.output_dir) 442 | if not osp.exists(args.output_dir): 443 | os.mkdir(args.output_dir) 444 | 445 | args.log = args.method 446 | args.out_file = open(osp.join(args.output_dir, "{:}.txt".format(args.log)), "w") 447 | 448 | utils.print_args(args) 449 | label = train(args) 450 | if args.validate: 451 | train(args, validate=True, label=label) -------------------------------------------------------------------------------- /logs/uda/run1/mixmatch/office-home/AC/mixmatch_atdoc_naatdoc_na5.txt: -------------------------------------------------------------------------------- 1 | ========================================== 2 | ========== config ============= 3 | ========================================== 4 | gpu_id:7 5 | s:0 6 | t:1 7 | output:logs/uda/run1/ 8 | seed:0 9 | max_epoch:50 10 | batch_size:36 11 | worker:4 12 | bottleneck_dim:256 13 | net:resnet50 14 | dset:office-home 15 | lr:0.01 16 | pl:atdoc_naatdoc_na5 17 | K:5 18 | momentum:1.0 19 | alpha:0.75 20 | lambda_u:100 21 | T:0.5 22 | ema_decay:0.999 23 | eval_epoch:5.0 24 | class_num:65 25 | s_dset_path:./data/office-home/Art_list.txt 26 | t_dset_path:./data/office-home/Clipart_list.txt 27 | test_dset_path:./data/office-home/Clipart_list.txt 28 | output_dir:logs/uda/run1/mixmatch/office-home/AC 29 | name:AC 30 | log:mixmatch_atdoc_naatdoc_na5 31 | out_file:<_io.TextIOWrapper name='logs/uda/run1/mixmatch/office-home/AC/mixmatch_atdoc_naatdoc_na5.txt' mode='w' encoding='UTF-8'> 32 | 33 | ========================================== 34 | 35 | -------------------------------------------------------------------------------- /logs/uda/run1/mixmatch/office-home/AP/mixmatch_atdoc_naatdoc_na5.txt: -------------------------------------------------------------------------------- 1 | ========================================== 2 | ========== config ============= 3 | ========================================== 4 | gpu_id:7 5 | s:0 6 | t:2 7 | output:logs/uda/run1/ 8 | seed:0 9 | max_epoch:50 10 | batch_size:36 11 | worker:4 12 | bottleneck_dim:256 13 | net:resnet50 14 | dset:office-home 15 | lr:0.01 16 | pl:atdoc_naatdoc_na5 17 | K:5 18 | momentum:1.0 19 | alpha:0.75 20 | lambda_u:100 21 | T:0.5 22 | ema_decay:0.999 23 | eval_epoch:5.0 24 | class_num:65 25 | s_dset_path:./data/office-home/Art_list.txt 26 | t_dset_path:./data/office-home/Product_list.txt 27 | test_dset_path:./data/office-home/Product_list.txt 28 | output_dir:logs/uda/run1/mixmatch/office-home/AP 29 | name:AP 30 | log:mixmatch_atdoc_naatdoc_na5 31 | out_file:<_io.TextIOWrapper name='logs/uda/run1/mixmatch/office-home/AP/mixmatch_atdoc_naatdoc_na5.txt' mode='w' encoding='UTF-8'> 32 | 33 | ========================================== 34 | 35 | -------------------------------------------------------------------------------- /logs/uda/run1/mixmatch/office-home/AR/mixmatch_atdoc_naatdoc_na5.txt: -------------------------------------------------------------------------------- 1 | ========================================== 2 | ========== config ============= 3 | ========================================== 4 | gpu_id:7 5 | s:0 6 | t:3 7 | output:logs/uda/run1/ 8 | seed:0 9 | max_epoch:50 10 | batch_size:36 11 | worker:4 12 | bottleneck_dim:256 13 | net:resnet50 14 | dset:office-home 15 | lr:0.01 16 | pl:atdoc_naatdoc_na5 17 | K:5 18 | momentum:1.0 19 | alpha:0.75 20 | lambda_u:100 21 | T:0.5 22 | ema_decay:0.999 23 | eval_epoch:5.0 24 | class_num:65 25 | s_dset_path:./data/office-home/Art_list.txt 26 | t_dset_path:./data/office-home/RealWorld_list.txt 27 | test_dset_path:./data/office-home/RealWorld_list.txt 28 | output_dir:logs/uda/run1/mixmatch/office-home/AR 29 | name:AR 30 | log:mixmatch_atdoc_naatdoc_na5 31 | out_file:<_io.TextIOWrapper name='logs/uda/run1/mixmatch/office-home/AR/mixmatch_atdoc_naatdoc_na5.txt' mode='w' encoding='UTF-8'> 32 | 33 | ========================================== 34 | 35 | Task: AR, Iter:605/6050; Accuracy = 72.76%; Mean Ent = 1.4166 36 | Task: AR, Iter:1210/6050; Accuracy = 75.37%; Mean Ent = 1.3595 37 | Task: AR, Iter:1815/6050; Accuracy = 76.75%; Mean Ent = 1.1428 38 | Task: AR, Iter:2420/6050; Accuracy = 77.19%; Mean Ent = 1.2178 39 | Task: AR, Iter:3025/6050; Accuracy = 78.04%; Mean Ent = 1.0779 40 | Task: AR, Iter:3630/6050; Accuracy = 79.14%; Mean Ent = 1.1267 41 | Task: AR, Iter:4235/6050; Accuracy = 79.30%; Mean Ent = 1.0302 42 | -------------------------------------------------------------------------------- /logs/uda/run1/mixmatch/office/AD/mixmatch_atdoc_naatdoc_na5.txt: -------------------------------------------------------------------------------- 1 | ========================================== 2 | ========== config ============= 3 | ========================================== 4 | gpu_id:7 5 | s:0 6 | t:1 7 | output:logs/uda/run1/ 8 | seed:0 9 | max_epoch:100 10 | batch_size:36 11 | worker:4 12 | bottleneck_dim:256 13 | net:resnet50 14 | dset:office 15 | lr:0.01 16 | pl:atdoc_naatdoc_na5 17 | K:5 18 | momentum:1.0 19 | alpha:0.75 20 | lambda_u:100 21 | T:0.5 22 | ema_decay:0.999 23 | eval_epoch:10.0 24 | class_num:31 25 | s_dset_path:../data/office/amazon_list.txt 26 | t_dset_path:../data/office/dslr_list.txt 27 | test_dset_path:../data/office/dslr_list.txt 28 | output_dir:logs/uda/run1/mixmatch/office/AD 29 | name:AD 30 | log:mixmatch_atdoc_naatdoc_na5 31 | out_file:<_io.TextIOWrapper name='logs/uda/run1/mixmatch/office/AD/mixmatch_atdoc_naatdoc_na5.txt' mode='w' encoding='UTF-8'> 32 | 33 | ========================================== 34 | 35 | Task: AD, Iter:780/7800; Accuracy = 85.94%; Mean Ent = 0.6320 36 | Task: AD, Iter:1560/7800; Accuracy = 88.76%; Mean Ent = 0.4815 37 | Task: AD, Iter:2340/7800; Accuracy = 89.16%; Mean Ent = 0.4171 38 | Task: AD, Iter:3120/7800; Accuracy = 89.16%; Mean Ent = 0.3950 39 | Task: AD, Iter:3900/7800; Accuracy = 89.16%; Mean Ent = 0.4243 40 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | class CrossEntropyLabelSmooth(nn.Module): 10 | """Cross entropy loss with label smoothing regularizer. 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | Equation: y = (1 - epsilon) * y + epsilon / K. 14 | Args: 15 | num_classes (int): number of classes. 16 | epsilon (float): weight. 17 | """ 18 | 19 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 20 | super(CrossEntropyLabelSmooth, self).__init__() 21 | self.num_classes = num_classes 22 | self.epsilon = epsilon 23 | self.use_gpu = use_gpu 24 | self.logsoftmax = nn.LogSoftmax(dim=1) 25 | self.reduction = reduction 26 | 27 | def forward(self, inputs, targets): 28 | """ 29 | Args: 30 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 31 | targets: ground truth labels with shape (num_classes) 32 | """ 33 | log_probs = self.logsoftmax(inputs) 34 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 35 | if self.use_gpu: targets = targets.cuda() 36 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 37 | loss = (- targets * log_probs).sum(dim=1) 38 | 39 | if self.reduction: 40 | return loss.mean() 41 | else: 42 | return loss 43 | 44 | class MMD_loss(nn.Module): 45 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5): 46 | super(MMD_loss, self).__init__() 47 | self.kernel_num = kernel_num 48 | self.kernel_mul = kernel_mul 49 | self.fix_sigma = None 50 | self.kernel_type = kernel_type 51 | 52 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 53 | n_samples = int(source.size()[0]) + int(target.size()[0]) 54 | total = torch.cat([source, target], dim=0) 55 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 56 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 57 | L2_distance = ((total0-total1)**2).sum(2) 58 | if fix_sigma: 59 | bandwidth = fix_sigma 60 | else: 61 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 62 | bandwidth /= kernel_mul ** (kernel_num // 2) 63 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 64 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 65 | return sum(kernel_val) 66 | 67 | def linear_mmd2(self, f_of_X, f_of_Y): 68 | loss = 0.0 69 | delta = f_of_X - f_of_Y 70 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 71 | return loss 72 | 73 | def forward(self, source, target): 74 | if self.kernel_type == 'linear': 75 | return self.linear_mmd2(source, target) 76 | elif self.kernel_type == 'rbf': 77 | batch_size = int(source.size()[0]) 78 | kernels = self.guassian_kernel( 79 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 80 | with torch.no_grad(): 81 | XX = torch.mean(kernels[:batch_size, :batch_size]) 82 | YY = torch.mean(kernels[batch_size:, batch_size:]) 83 | XY = torch.mean(kernels[:batch_size, batch_size:]) 84 | YX = torch.mean(kernels[batch_size:, :batch_size]) 85 | loss = torch.mean(XX + YY - XY - YX) 86 | torch.cuda.empty_cache() 87 | return loss 88 | 89 | def bsp_loss(feature): 90 | train_bs = feature.size(0) // 2 91 | feature_s = feature.narrow(0, 0, train_bs) 92 | feature_t = feature.narrow(0, train_bs, train_bs) 93 | _, s_s, _ = torch.svd(feature_s) 94 | _, s_t, _ = torch.svd(feature_t) 95 | sigma = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2) 96 | sigma *= 0.0001 97 | return sigma 98 | 99 | def Entropy(input_): 100 | bs = input_.size(0) 101 | epsilon = 1e-5 102 | entropy = -input_ * torch.log(input_ + epsilon) 103 | entropy = torch.sum(entropy, dim=1) 104 | return entropy 105 | 106 | def grl_hook(coeff): 107 | def fun1(grad): 108 | return -coeff*grad.clone() 109 | return fun1 110 | 111 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 112 | softmax_output = input_list[1].detach() 113 | feature = input_list[0] 114 | if random_layer is None: 115 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 116 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 117 | else: 118 | random_out = random_layer.forward([feature, softmax_output]) 119 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 120 | batch_size = softmax_output.size(0) // 2 121 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 122 | if entropy is not None: 123 | entropy.register_hook(grl_hook(coeff)) 124 | entropy = 1.0+torch.exp(-entropy) 125 | source_mask = torch.ones_like(entropy) 126 | source_mask[feature.size(0)//2:] = 0 127 | source_weight = entropy*source_mask 128 | target_mask = torch.ones_like(entropy) 129 | target_mask[0:feature.size(0)//2] = 0 130 | target_weight = entropy*target_mask 131 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 132 | target_weight / torch.sum(target_weight).detach().item() 133 | 134 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 135 | else: 136 | return nn.BCELoss()(ad_out, dc_target) 137 | 138 | def DANN(features, ad_net, entropy=None, coeff=None): 139 | ad_out = ad_net(features) 140 | batch_size = ad_out.size(0) // 2 141 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 142 | if entropy is not None: 143 | entropy.register_hook(grl_hook(coeff)) 144 | entropy = 1.0+torch.exp(-entropy) 145 | source_mask = torch.ones_like(entropy) 146 | source_mask[feature.size(0)//2:] = 0 147 | source_weight = entropy*source_mask 148 | target_mask = torch.ones_like(entropy) 149 | target_mask[0:feature.size(0)//2] = 0 150 | target_weight = entropy*target_mask 151 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 152 | target_weight / torch.sum(target_weight).detach().item() 153 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 154 | else: 155 | return nn.BCELoss()(ad_out, dc_target) 156 | 157 | def CORAL(source, target): 158 | d = source.size(1) 159 | ns, nt = source.size(0), target.size(0) 160 | 161 | # source covariance 162 | tmp_s = torch.ones((1, ns)).cuda() @ source 163 | cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1) 164 | 165 | # target covariance 166 | tmp_t = torch.ones((1, nt)).cuda() @ target 167 | ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1) 168 | 169 | # frobenius norm 170 | loss = (cs - ct).pow(2).sum().sqrt() 171 | loss = loss / (4 * d * d) 172 | return loss 173 | 174 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 175 | n_samples = int(source.size()[0])+int(target.size()[0]) 176 | total = torch.cat([source, target], dim=0) 177 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 178 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 179 | # pdb.set_trace() 180 | L2_distance = ((total0-total1)**2).sum(2) 181 | if fix_sigma: 182 | bandwidth = fix_sigma 183 | else: 184 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 185 | bandwidth /= kernel_mul ** (kernel_num // 2) 186 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 187 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 188 | return sum(kernel_val)#/len(kernel_val) 189 | 190 | 191 | def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 192 | batch_size = int(source.size()[0]) 193 | kernels = guassian_kernel(source, target, 194 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 195 | 196 | loss1 = 0 197 | for s1 in range(batch_size): 198 | for s2 in range(s1+1, batch_size): 199 | t1, t2 = s1+batch_size, s2+batch_size 200 | loss1 += kernels[s1, s2] + kernels[t1, t2] 201 | loss1 = loss1 / float(batch_size * (batch_size - 1) // 2) 202 | 203 | loss2 = 0 204 | for s1 in range(batch_size): 205 | for s2 in range(batch_size): 206 | t1, t2 = s1+batch_size, s2+batch_size 207 | loss2 -= kernels[s1, t2] + kernels[s2, t1] 208 | loss2 = loss2 / float(batch_size * batch_size) 209 | return loss1 + loss2 210 | 211 | def DAN_Linear(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 212 | batch_size = int(source.size()[0]) 213 | kernels = guassian_kernel(source, target, 214 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 215 | 216 | # Linear version 217 | loss = 0 218 | for i in range(batch_size): 219 | s1, s2 = i, (i+1)%batch_size 220 | t1, t2 = s1+batch_size, s2+batch_size 221 | loss += kernels[s1, s2] + kernels[t1, t2] 222 | loss -= kernels[s1, t2] + kernels[s2, t1] 223 | return loss / float(batch_size) 224 | 225 | 226 | def RTN(): 227 | pass 228 | 229 | 230 | def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fix_sigma_list=[None, 1.68]): 231 | batch_size = int(source_list[0].size()[0]) 232 | layer_num = len(source_list) 233 | joint_kernels = None 234 | for i in range(layer_num): 235 | source = source_list[i] 236 | target = target_list[i] 237 | kernel_mul = kernel_muls[i] 238 | kernel_num = kernel_nums[i] 239 | fix_sigma = fix_sigma_list[i] 240 | kernels = guassian_kernel(source, target, 241 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 242 | if joint_kernels is not None: 243 | joint_kernels = joint_kernels * kernels 244 | else: 245 | joint_kernels = kernels 246 | 247 | loss1 = 0 248 | for s1 in range(batch_size): 249 | for s2 in range(s1 + 1, batch_size): 250 | t1, t2 = s1 + batch_size, s2 + batch_size 251 | loss1 += joint_kernels[s1, s2] + joint_kernels[t1, t2] 252 | loss1 = loss1 / float(batch_size * (batch_size - 1) // 2) 253 | 254 | loss2 = 0 255 | for s1 in range(batch_size): 256 | for s2 in range(batch_size): 257 | t1, t2 = s1 + batch_size, s2 + batch_size 258 | loss2 -= joint_kernels[s1, t2] + joint_kernels[s2, t1] 259 | loss2 = loss2 / float(batch_size * batch_size) 260 | return loss1 + loss2 261 | 262 | def JAN_Linear(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fix_sigma_list=[None, 1.68]): 263 | batch_size = int(source_list[0].size()[0]) 264 | layer_num = len(source_list) 265 | joint_kernels = None 266 | for i in range(layer_num): 267 | source = source_list[i] 268 | target = target_list[i] 269 | kernel_mul = kernel_muls[i] 270 | kernel_num = kernel_nums[i] 271 | fix_sigma = fix_sigma_list[i] 272 | kernels = guassian_kernel(source, target, 273 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 274 | if joint_kernels is not None: 275 | joint_kernels = joint_kernels * kernels 276 | else: 277 | joint_kernels = kernels 278 | 279 | # Linear version 280 | loss = 0 281 | for i in range(batch_size): 282 | s1, s2 = i, (i+1)%batch_size 283 | t1, t2 = s1+batch_size, s2+batch_size 284 | loss += joint_kernels[s1, s2] + joint_kernels[t1, t2] 285 | loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1] 286 | return loss / float(batch_size) 287 | 288 | loss_dict = {"DAN":DAN, "DAN_Linear":DAN_Linear, "RTN":RTN, "JAN":JAN, "JAN_Linear":JAN_Linear} 289 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import tqdm 4 | from itertools import chain 5 | from collections import OrderedDict 6 | import torch 7 | import torchvision 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import math, pdb 12 | from PIL import Image 13 | import numpy as np 14 | from sklearn.metrics import confusion_matrix 15 | 16 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 17 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 18 | 19 | def init_weights(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 22 | nn.init.kaiming_uniform_(m.weight) 23 | nn.init.zeros_(m.bias) 24 | elif classname.find('BatchNorm') != -1: 25 | nn.init.normal_(m.weight, 1.0, 0.02) 26 | nn.init.zeros_(m.bias) 27 | elif classname.find('Linear') != -1: 28 | nn.init.xavier_normal_(m.weight) 29 | nn.init.zeros_(m.bias) 30 | 31 | def grl_hook(coeff): 32 | def fun1(grad): 33 | return -coeff*grad.clone() 34 | return fun1 35 | 36 | class VGG16Base(nn.Module): 37 | def __init__(self): 38 | super(VGG16Base, self).__init__() 39 | model_vgg = torchvision.models.vgg16(pretrained=True) 40 | self.features = model_vgg.features 41 | self.classifier = nn.Sequential() 42 | for i in range(6): 43 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 44 | self.feature_layers = nn.Sequential(self.features, self.classifier) 45 | self.in_features = 4096 46 | 47 | def forward(self, x): 48 | x = self.features(x) 49 | x = x.view(x.size(0), -1) 50 | x = self.classifier(x) 51 | return x 52 | 53 | class ResBase34(nn.Module): 54 | def __init__(self): 55 | super(ResBase34, self).__init__() 56 | model_resnet = torchvision.models.resnet34(pretrained=True) 57 | self.conv1 = model_resnet.conv1 58 | self.bn1 = model_resnet.bn1 59 | self.relu = model_resnet.relu 60 | self.maxpool = model_resnet.maxpool 61 | self.layer1 = model_resnet.layer1 62 | self.layer2 = model_resnet.layer2 63 | self.layer3 = model_resnet.layer3 64 | self.layer4 = model_resnet.layer4 65 | self.avgpool = model_resnet.avgpool 66 | self.in_features = model_resnet.fc.in_features 67 | 68 | def forward(self, x): 69 | x = self.conv1(x) 70 | x = self.bn1(x) 71 | x = self.relu(x) 72 | x = self.maxpool(x) 73 | x = self.layer1(x) 74 | x = self.layer2(x) 75 | x = self.layer3(x) 76 | x = self.layer4(x) 77 | x = self.avgpool(x) 78 | x = x.view(x.size(0), -1) 79 | return x 80 | 81 | class ResBase50(nn.Module): 82 | def __init__(self): 83 | super(ResBase50, self).__init__() 84 | model_resnet50 = torchvision.models.resnet50(pretrained=True) 85 | self.conv1 = model_resnet50.conv1 86 | self.bn1 = model_resnet50.bn1 87 | self.relu = model_resnet50.relu 88 | self.maxpool = model_resnet50.maxpool 89 | self.layer1 = model_resnet50.layer1 90 | self.layer2 = model_resnet50.layer2 91 | self.layer3 = model_resnet50.layer3 92 | self.layer4 = model_resnet50.layer4 93 | self.avgpool = model_resnet50.avgpool 94 | self.in_features = model_resnet50.fc.in_features 95 | 96 | def forward(self, x): 97 | x = self.conv1(x) 98 | x = self.bn1(x) 99 | x = self.relu(x) 100 | x = self.maxpool(x) 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | x = self.layer4(x) 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | return x 108 | 109 | class ResBase101(nn.Module): 110 | def __init__(self): 111 | super(ResBase101, self).__init__() 112 | model_resnet101 = torchvision.models.resnet101(pretrained=True) 113 | self.conv1 = model_resnet101.conv1 114 | self.bn1 = model_resnet101.bn1 115 | self.relu = model_resnet101.relu 116 | self.maxpool = model_resnet101.maxpool 117 | self.layer1 = model_resnet101.layer1 118 | self.layer2 = model_resnet101.layer2 119 | self.layer3 = model_resnet101.layer3 120 | self.layer4 = model_resnet101.layer4 121 | self.avgpool = model_resnet101.avgpool 122 | self.in_features = model_resnet101.fc.in_features 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | x = self.maxpool(x) 129 | x = self.layer1(x) 130 | x = self.layer2(x) 131 | x = self.layer3(x) 132 | x = self.layer4(x) 133 | x = self.avgpool(x) 134 | x = x.view(x.size(0), -1) 135 | return x 136 | 137 | class ResClassifier(nn.Module): 138 | def __init__(self, class_num, feature_dim, bottleneck_dim=256): 139 | super(ResClassifier, self).__init__() 140 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 141 | self.fc = nn.Linear(bottleneck_dim, class_num) 142 | self.bottleneck.apply(init_weights) 143 | self.fc.apply(init_weights) 144 | 145 | def forward(self, x): 146 | x = self.bottleneck(x) 147 | y = self.fc(x) 148 | return x,y 149 | 150 | class AdversarialNetwork(nn.Module): 151 | def __init__(self, in_feature, hidden_size, max_iter=10000): 152 | super(AdversarialNetwork, self).__init__() 153 | self.ad_layer1 = nn.Linear(in_feature, hidden_size) 154 | self.ad_layer2 = nn.Linear(hidden_size, hidden_size) 155 | self.ad_layer3 = nn.Linear(hidden_size, 1) 156 | self.relu1 = nn.ReLU() 157 | self.relu2 = nn.ReLU() 158 | self.dropout1 = nn.Dropout(0.5) 159 | self.dropout2 = nn.Dropout(0.5) 160 | self.sigmoid = nn.Sigmoid() 161 | self.apply(init_weights) 162 | self.iter_num = 0 163 | self.alpha = 10 164 | self.low = 0.0 165 | self.high = 1.0 166 | self.max_iter = max_iter 167 | 168 | def forward(self, x): 169 | if self.training: 170 | self.iter_num += 1 171 | coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter) 172 | x = x * 1.0 173 | x.register_hook(grl_hook(coeff)) 174 | x = self.ad_layer1(x) 175 | x = self.relu1(x) 176 | y = self.ad_layer3(x) 177 | y = self.sigmoid(y) 178 | return y 179 | 180 | def output_num(self): 181 | return 1 182 | def get_parameters(self): 183 | return [{"params":self.parameters(), "lr_mult":10, 'decay_mult':2}] 184 | 185 | 186 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',] 187 | 188 | def is_image_file(filename): 189 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 190 | 191 | def default_loader(path): 192 | return Image.open(path).convert('RGB') 193 | 194 | def make_dataset(root, label): 195 | images = [] 196 | labeltxt = open(label) 197 | for line in labeltxt: 198 | data = line.strip().split(' ') 199 | if is_image_file(data[0]): 200 | path = os.path.join(root, data[0]) 201 | gt = int(data[1]) 202 | item = (path, gt) 203 | images.append(item) 204 | return images 205 | 206 | class ObjectImage_y(torch.utils.data.Dataset): 207 | def __init__(self, root, label, transform=None, y=None, loader=default_loader): 208 | imgs = make_dataset(root, label) 209 | self.root = root 210 | self.label = label 211 | self.imgs = imgs 212 | self.transform = transform 213 | self.loader = loader 214 | self.y = y 215 | 216 | def __getitem__(self, index): 217 | path, _ = self.imgs[index] 218 | target = self.y[index] 219 | img = self.loader(path) 220 | if self.transform is not None: 221 | img = self.transform(img) 222 | return img, target 223 | 224 | def __len__(self): 225 | return len(self.imgs) 226 | 227 | class ObjectImage(torch.utils.data.Dataset): 228 | def __init__(self, root, label, transform=None, loader=default_loader): 229 | imgs = make_dataset(root, label) 230 | self.root = root 231 | self.label = label 232 | self.imgs = imgs 233 | self.transform = transform 234 | self.loader = loader 235 | 236 | def __getitem__(self, index): 237 | path, target = self.imgs[index] 238 | img = self.loader(path) 239 | if self.transform is not None: 240 | img = self.transform(img) 241 | return img, target 242 | 243 | def __len__(self): 244 | return len(self.imgs) 245 | 246 | class ObjectImage_mul(torch.utils.data.Dataset): 247 | def __init__(self, root, label, transform=None, loader=default_loader): 248 | imgs = make_dataset(root, label) 249 | self.root = root 250 | self.label = label 251 | self.imgs = imgs 252 | self.transform = transform 253 | self.loader = loader 254 | 255 | def __getitem__(self, index): 256 | path, target = self.imgs[index] 257 | img = self.loader(path) 258 | if self.transform is not None: 259 | # print(type(self.transform).__name__) 260 | if type(self.transform).__name__=='list': 261 | img = [t(img) for t in self.transform] 262 | else: 263 | img = self.transform(img) 264 | return img, target, index 265 | 266 | def __len__(self): 267 | return len(self.imgs) 268 | 269 | def weights_init(m): 270 | classname = m.__class__.__name__ 271 | if classname.find('Conv') != -1: 272 | m.weight.data.normal_(0.0, 0.01) 273 | m.bias.data.normal_(0.0, 0.01) 274 | elif classname.find('BatchNorm') != -1: 275 | m.weight.data.normal_(1.0, 0.01) 276 | m.bias.data.fill_(0) 277 | elif classname.find('Linear') != -1: 278 | m.weight.data.normal_(0.0, 0.01) 279 | m.bias.data.normal_(0.0, 0.01) 280 | 281 | def print_args(args): 282 | log_str = ("==========================================\n") 283 | log_str += ("========== config =============\n") 284 | log_str += ("==========================================\n") 285 | for arg, content in args.__dict__.items(): 286 | log_str += ("{}:{}\n".format(arg, content)) 287 | log_str += ("\n==========================================\n") 288 | print(log_str) 289 | args.out_file.write(log_str+'\n') 290 | args.out_file.flush() 291 | 292 | def cal_fea(loader, model): 293 | start_test = True 294 | with torch.no_grad(): 295 | iter_test = iter(loader) 296 | for i in range(len(loader)): 297 | inputs, labels = iter_test.next() 298 | inputs = inputs.cuda() 299 | feas, outputs = model(inputs) 300 | if start_test: 301 | all_feas = feas.float().cpu() 302 | all_label = labels.float() 303 | start_test = False 304 | else: 305 | all_feas = torch.cat((all_feas, feas.float().cpu()), 0) 306 | all_label = torch.cat((all_label, labels.float()), 0) 307 | return all_feas, all_label 308 | 309 | def cal_acc(loader, model, flag=True, fc=None): 310 | start_test = True 311 | with torch.no_grad(): 312 | iter_test = iter(loader) 313 | for i in range(len(loader)): 314 | data = iter_test.next() 315 | inputs = data[0] 316 | labels = data[1] 317 | inputs = inputs.cuda() 318 | if flag: 319 | _, outputs = model(inputs) 320 | else: 321 | if fc is not None: 322 | feas, outputs = model(inputs) 323 | outputs = fc(feas) 324 | else: 325 | outputs = model(inputs) 326 | if start_test: 327 | all_output = outputs.float().cpu() 328 | all_label = labels.float() 329 | start_test = False 330 | else: 331 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 332 | all_label = torch.cat((all_label, labels.float()), 0) 333 | all_output = nn.Softmax(dim=1)(all_output) 334 | _, predict = torch.max(all_output, 1) 335 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 336 | return accuracy, predict, all_output, all_label 337 | 338 | def cal_acc_visda(loader, model, flag=True, fc=None): 339 | start_test = True 340 | with torch.no_grad(): 341 | iter_test = iter(loader) 342 | for i in range(len(loader)): 343 | data = iter_test.next() 344 | inputs = data[0] 345 | labels = data[1] 346 | inputs = inputs.cuda() 347 | if flag: 348 | _, outputs = model(inputs) 349 | else: 350 | if fc is not None: 351 | feas, outputs = model(inputs) 352 | outputs = fc(feas) 353 | else: 354 | outputs = model(inputs) 355 | if start_test: 356 | all_output = outputs.float().cpu() 357 | all_label = labels.float() 358 | start_test = False 359 | else: 360 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 361 | all_label = torch.cat((all_label, labels.float()), 0) 362 | all_output = nn.Softmax(dim=1)(all_output) 363 | _, predict = torch.max(all_output, 1) 364 | 365 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 366 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 367 | aacc = acc.mean() / 100 368 | aa = [str(np.round(i, 2)) for i in acc] 369 | acc = ' '.join(aa) 370 | print(acc) 371 | 372 | # accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 373 | return aacc, predict, all_output, all_label, acc 374 | 375 | def linear_rampup(current, rampup_length): 376 | if rampup_length == 0: 377 | return 1.0 378 | else: 379 | current = np.clip(current / rampup_length, 0.0, 1.0) 380 | return float(current) 381 | 382 | class SemiLoss(object): 383 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, max_epochs=30, lambda_u=75): 384 | probs_u = torch.softmax(outputs_u, dim=1) 385 | 386 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) 387 | Lu = torch.mean((probs_u - targets_u)**2) 388 | 389 | return Lx, Lu, lambda_u * linear_rampup(epoch, max_epochs) 390 | 391 | class WeightEMA(object): 392 | def __init__(self, model, ema_model, alpha=0.999): 393 | self.model = model 394 | self.ema_model = ema_model 395 | self.alpha = alpha 396 | self.params = list(model.state_dict().values()) 397 | self.ema_params = list(ema_model.state_dict().values()) 398 | self.wd = 0.02 * args.lr 399 | 400 | for param, ema_param in zip(self.params, self.ema_params): 401 | param.data.copy_(ema_param.data) 402 | 403 | def step(self): 404 | one_minus_alpha = 1.0 - self.alpha 405 | for param, ema_param in zip(self.params, self.ema_params): 406 | ema_param.mul_(self.alpha) 407 | ema_param.add_(param * one_minus_alpha) 408 | # customized weight decay 409 | param.mul_(1 - self.wd) 410 | 411 | def interleave_offsets(batch, nu): 412 | groups = [batch // (nu + 1)] * (nu + 1) 413 | for x in range(batch - sum(groups)): 414 | groups[-x - 1] += 1 415 | offsets = [0] 416 | for g in groups: 417 | offsets.append(offsets[-1] + g) 418 | assert offsets[-1] == batch 419 | return offsets 420 | 421 | def interleave(xy, batch): 422 | nu = len(xy) - 1 423 | offsets = interleave_offsets(batch, nu) 424 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] 425 | for i in range(1, nu + 1): 426 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 427 | return [torch.cat(v, dim=0) for v in xy] --------------------------------------------------------------------------------