├── .gitignore ├── DINE ├── DINE_dist.py ├── DINE_dist_kDINE.py ├── DINE_ft.py ├── data_list.py ├── loss.py ├── network.py └── run_all_kDINE.sh ├── LICENSE ├── README.md ├── SHOT ├── __init__.py ├── augmentations.py ├── data_list.py ├── image_source.py ├── image_target.py ├── image_target_kSHOT.py ├── loss.py ├── network.py └── run_all_kSHOT.sh ├── data ├── domainnet40 │ └── image_list │ │ ├── clipart_test_mini.txt │ │ ├── clipart_train_mini.txt │ │ ├── painting_test_mini.txt │ │ ├── painting_train_mini.txt │ │ ├── real_test_mini.txt │ │ ├── real_train_mini.txt │ │ ├── sketch_test_mini.txt │ │ └── sketch_train_mini.txt ├── multi │ └── image_list │ │ ├── clipart.txt │ │ ├── painting.txt │ │ ├── real.txt │ │ └── sketch.txt ├── office-home-rsut │ └── image_list │ │ ├── Clipart_RS.txt │ │ ├── Clipart_UT.txt │ │ ├── Product_RS.txt │ │ ├── Product_UT.txt │ │ ├── Real_World_RS.txt │ │ └── Real_World_UT.txt ├── office-home │ └── image_list │ │ ├── Art.txt │ │ ├── Clipart.txt │ │ ├── Product.txt │ │ └── Real_World.txt ├── office31 │ └── image_list │ │ ├── amazon.txt │ │ ├── dslr.txt │ │ └── webcam.txt ├── setup_data_path.sh └── visda-2017 │ └── image_list │ ├── train.txt │ └── validation.txt ├── fig ├── PK.png └── framework.png ├── pklib └── pksolver.py └── util ├── __init__.py ├── get_time.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /DINE/DINE_dist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network 11 | import loss 12 | from torch.utils.data import DataLoader 13 | from data_list import ImageList, ImageList_idx 14 | from loss import CrossEntropyLabelSmooth 15 | from sklearn.metrics import confusion_matrix 16 | import distutils 17 | import distutils.util 18 | import logging 19 | 20 | import sys 21 | sys.path.append("../util/") 22 | from utils import resetRNGseed, init_logger, get_hostname, get_pid 23 | 24 | import time 25 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 26 | 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | def op_copy(optimizer): 31 | for param_group in optimizer.param_groups: 32 | param_group['lr0'] = param_group['lr'] 33 | return optimizer 34 | 35 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 36 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 37 | for param_group in optimizer.param_groups: 38 | param_group['lr'] = param_group['lr0'] * decay 39 | param_group['weight_decay'] = 1e-3 40 | param_group['momentum'] = 0.9 41 | param_group['nesterov'] = True 42 | return optimizer 43 | 44 | def image_train(resize_size=256, crop_size=224, alexnet=False): 45 | if not alexnet: 46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 47 | std=[0.229, 0.224, 0.225]) 48 | else: 49 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 50 | return transforms.Compose([ 51 | transforms.Resize((resize_size, resize_size)), 52 | transforms.RandomCrop(crop_size), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | normalize 56 | ]) 57 | 58 | def image_test(resize_size=256, crop_size=224, alexnet=False): 59 | if not alexnet: 60 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 61 | std=[0.229, 0.224, 0.225]) 62 | else: 63 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 64 | return transforms.Compose([ 65 | transforms.Resize((resize_size, resize_size)), 66 | transforms.CenterCrop(crop_size), 67 | transforms.ToTensor(), 68 | normalize 69 | ]) 70 | 71 | def data_load(args): 72 | ## prepare data 73 | dsets = {} 74 | dset_loaders = {} 75 | train_bs = args.batch_size 76 | txt_src = open(args.s_dset_path).readlines() 77 | txt_tar = open(args.t_dset_path).readlines() 78 | txt_test = open(args.test_dset_path).readlines() 79 | 80 | count = np.zeros(args.class_num) 81 | tr_txt = [] 82 | te_txt = [] 83 | for i in range(len(txt_src)): 84 | line = txt_src[i] 85 | reci = line.strip().split(' ') 86 | if count[int(reci[1])] < 3: 87 | count[int(reci[1])] += 1 88 | te_txt.append(line) 89 | else: 90 | tr_txt.append(line) 91 | 92 | if not args.da == 'uda': 93 | label_map_s = {} 94 | for i in range(len(args.src_classes)): 95 | label_map_s[args.src_classes[i]] = i 96 | 97 | new_tar = [] 98 | for i in range(len(txt_tar)): 99 | rec = txt_tar[i] 100 | reci = rec.strip().split(' ') 101 | if int(reci[1]) in args.tar_classes: 102 | if int(reci[1]) in args.src_classes: 103 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 104 | new_tar.append(line) 105 | else: 106 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 107 | new_tar.append(line) 108 | txt_tar = new_tar.copy() 109 | txt_test = txt_tar.copy() 110 | 111 | dsets["source_tr"] = ImageList(tr_txt, root="../data/{}/".format(args.dset), transform=image_train()) 112 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 113 | dsets["source_te"] = ImageList(te_txt, root="../data/{}/".format(args.dset), transform=image_test()) 114 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 115 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train()) 116 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 117 | dsets["target_te"] = ImageList(txt_tar, root="../data/{}/".format(args.dset), transform=image_test()) 118 | dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) 119 | dsets["test"] = ImageList(txt_test, root="../data/{}/".format(args.dset), transform=image_test()) 120 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False) 121 | 122 | return dset_loaders 123 | 124 | def cal_acc(loader, netF, netB, netC, flag=False): 125 | start_test = True 126 | with torch.no_grad(): 127 | iter_test = iter(loader) 128 | for i in range(len(loader)): 129 | data = iter_test.next() 130 | inputs = data[0] 131 | labels = data[1] 132 | inputs = inputs.cuda() 133 | if netB is None: 134 | outputs = netC(netF(inputs)) 135 | else: 136 | outputs = netC(netB(netF(inputs))) 137 | if start_test: 138 | all_output = outputs.float().cpu() 139 | all_label = labels.float() 140 | start_test = False 141 | else: 142 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 143 | all_label = torch.cat((all_label, labels.float()), 0) 144 | 145 | all_output = nn.Softmax(dim=1)(all_output) 146 | _, predict = torch.max(all_output, 1) 147 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 148 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0]) 149 | 150 | if flag: 151 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 152 | matrix = matrix[np.unique(all_label).astype(int),:] 153 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 154 | aacc = acc.mean() 155 | aa = [str(np.round(i, 2)) for i in acc] 156 | acc = ' '.join(aa) 157 | return aacc, acc, mean_ent 158 | else: 159 | return accuracy*100, mean_ent 160 | 161 | def train_source_simp(args): 162 | dset_loaders = data_load(args) 163 | if args.net_src[0:3] == 'res': 164 | netF = network.ResBase(res_name=args.net_src).cuda() 165 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda() 166 | 167 | param_group = [] 168 | learning_rate = args.lr_src 169 | for k, v in netF.named_parameters(): 170 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 171 | for k, v in netC.named_parameters(): 172 | param_group += [{'params': v, 'lr': learning_rate}] 173 | optimizer = optim.SGD(param_group) 174 | optimizer = op_copy(optimizer) 175 | 176 | acc_init = 0 177 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 178 | interval_iter = max_iter // 10 179 | iter_num = 0 180 | 181 | netF.train() 182 | netC.train() 183 | 184 | while iter_num < max_iter: 185 | try: 186 | inputs_source, labels_source = iter_source.next() 187 | except: 188 | iter_source = iter(dset_loaders["source_tr"]) 189 | inputs_source, labels_source = iter_source.next() 190 | 191 | if inputs_source.size(0) == 1: 192 | continue 193 | 194 | iter_num += 1 195 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 196 | 197 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 198 | outputs_source = netC(netF(inputs_source)) 199 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0.1)(outputs_source, labels_source) 200 | 201 | optimizer.zero_grad() 202 | classifier_loss.backward() 203 | optimizer.step() 204 | 205 | if iter_num % interval_iter == 0 or iter_num == max_iter: 206 | netF.eval() 207 | netC.eval() 208 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, False) 209 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te) 210 | if args.dset == 'visda-2017': 211 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, True) 212 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, 213 | acc_s_te) + '\n' + acc_list 214 | logging.info(log_str) 215 | 216 | if acc_s_te >= acc_init: 217 | acc_init = acc_s_te 218 | best_netF = netF.state_dict() 219 | best_netC = netC.state_dict() 220 | 221 | netF.train() 222 | netC.train() 223 | 224 | torch.save(best_netF, osp.join(args.output_dir_src,'{}_{}_source_F.pt'.format(args.s, args.net_src))) 225 | torch.save(best_netC, osp.join(args.output_dir_src, '{}_{}_source_C.pt'.format(args.s, args.net_src))) 226 | 227 | return netF, netC 228 | 229 | def test_target_simp(args): 230 | dset_loaders = data_load(args) 231 | if args.net_src[0:3] == 'res': 232 | netF = network.ResBase(res_name=args.net_src).cuda() 233 | netC = network.feat_classifier_simpl(class_num = args.class_num, feat_dim=netF.in_features).cuda() 234 | 235 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src) 236 | netF.load_state_dict(torch.load(args.modelpath)) 237 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src) 238 | netC.load_state_dict(torch.load(args.modelpath)) 239 | netF.eval() 240 | netC.eval() 241 | 242 | acc, _ = cal_acc(dset_loaders['test'], netF, None, netC, False) 243 | log_str = '\nTask: {}->{}, Accuracy = {:.2f}%'.format(args.s, args.t, acc) 244 | if args.dset == 'visda-2017': 245 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['test'], netF, None, netC, True) 246 | log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.s, acc_s_te) + '\n' + acc_list 247 | 248 | logging.info(log_str) 249 | 250 | def copy_target_simp(args): 251 | dset_loaders = data_load(args) 252 | if args.net_src[0:3] == 'res': 253 | netF = network.ResBase(res_name=args.net_src).cuda() 254 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda() 255 | 256 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src) 257 | netF.load_state_dict(torch.load(args.modelpath)) 258 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src) 259 | netC.load_state_dict(torch.load(args.modelpath)) 260 | source_model = nn.Sequential(netF, netC).cuda() 261 | source_model.eval() 262 | 263 | if args.net[0:3] == 'res': 264 | netF = network.ResBase(res_name=args.net, pretrain=True).cuda() 265 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 266 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 267 | 268 | param_group = [] 269 | learning_rate = args.lr 270 | for k, v in netF.named_parameters(): 271 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 272 | for k, v in netB.named_parameters(): 273 | param_group += [{'params': v, 'lr': learning_rate}] 274 | for k, v in netC.named_parameters(): 275 | param_group += [{'params': v, 'lr': learning_rate}] 276 | optimizer = optim.SGD(param_group) 277 | optimizer = op_copy(optimizer) 278 | 279 | ent_best = 1.0 280 | max_iter = args.max_epoch * len(dset_loaders["target"]) 281 | interval_iter = max_iter // 10 282 | iter_num = 0 283 | 284 | model = nn.Sequential(netF, netB, netC).cuda() 285 | model.eval() 286 | 287 | start_test = True 288 | with torch.no_grad(): 289 | iter_test = iter(dset_loaders["target_te"]) 290 | for i in range(len(dset_loaders["target_te"])): 291 | data = iter_test.next() 292 | inputs, labels = data[0], data[1] 293 | inputs = inputs.cuda() 294 | outputs = source_model(inputs) 295 | outputs = nn.Softmax(dim=1)(outputs) 296 | _, src_idx = torch.sort(outputs, 1, descending=True) 297 | if args.topk > 0: 298 | topk = np.min([args.topk, args.class_num]) 299 | for i in range(outputs.size()[0]): 300 | outputs[i, src_idx[i, topk:]] = (1.0 - outputs[i, src_idx[i, :topk]].sum())/ (outputs.size()[1] - topk) 301 | 302 | if start_test: 303 | all_output = outputs.float() 304 | all_label = labels 305 | start_test = False 306 | else: 307 | all_output = torch.cat((all_output, outputs.float()), 0) 308 | all_label = torch.cat((all_label, labels), 0) 309 | mem_P = all_output.detach() 310 | 311 | model.train() 312 | while iter_num < max_iter: 313 | 314 | if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0: 315 | model.eval() 316 | start_test = True 317 | with torch.no_grad(): 318 | iter_test = iter(dset_loaders["target_te"]) 319 | for i in range(len(dset_loaders["target_te"])): 320 | data = iter_test.next() 321 | inputs = data[0] 322 | inputs = inputs.cuda() 323 | outputs = model(inputs) 324 | outputs = nn.Softmax(dim=1)(outputs) 325 | if start_test: 326 | all_output = outputs.float() 327 | start_test = False 328 | else: 329 | all_output = torch.cat((all_output, outputs.float()), 0) 330 | mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema) 331 | model.train() 332 | 333 | try: 334 | inputs_target, y, tar_idx = iter_target.next() 335 | except: 336 | iter_target = iter(dset_loaders["target"]) 337 | inputs_target, y, tar_idx = iter_target.next() 338 | 339 | if inputs_target.size(0) == 1: 340 | continue 341 | 342 | iter_num += 1 343 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5) 344 | inputs_target = inputs_target.cuda() 345 | with torch.no_grad(): 346 | outputs_target_by_source = mem_P[tar_idx, :] 347 | _, src_idx = torch.sort(outputs_target_by_source, 1, descending=True) 348 | outputs_target = model(inputs_target) 349 | outputs_target = torch.nn.Softmax(dim=1)(outputs_target) 350 | classifier_loss = nn.KLDivLoss(reduction='batchmean')(outputs_target.log(), outputs_target_by_source) 351 | optimizer.zero_grad() 352 | 353 | entropy_loss = torch.mean(loss.Entropy(outputs_target)) 354 | msoftmax = outputs_target.mean(dim=0) 355 | gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5)) 356 | entropy_loss -= gentropy_loss 357 | classifier_loss += entropy_loss 358 | 359 | classifier_loss.backward() 360 | 361 | if args.mix > 0: 362 | alpha = 0.3 363 | lam = np.random.beta(alpha, alpha) 364 | index = torch.randperm(inputs_target.size()[0]).cuda() 365 | mixed_input = lam * inputs_target + (1 - lam) * inputs_target[index, :] 366 | mixed_output = (lam * outputs_target + (1 - lam) * outputs_target[index, :]).detach() 367 | 368 | update_batch_stats(model, False) 369 | outputs_target_m = model(mixed_input) 370 | update_batch_stats(model, True) 371 | outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m) 372 | classifier_loss = args.mix*nn.KLDivLoss(reduction='batchmean')(outputs_target_m.log(), mixed_output) 373 | classifier_loss.backward() 374 | optimizer.step() 375 | 376 | if iter_num % interval_iter == 0 or iter_num == max_iter: 377 | model.eval() 378 | acc_s_te, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False) 379 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent) 380 | if args.dset == 'visda-2017': 381 | acc_s_te, acc_list, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True) 382 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, 383 | acc_s_te, mean_ent) + '\n' + acc_list 384 | 385 | logging.info(log_str) 386 | model.train() 387 | 388 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F".format(args.timestamp, args.s, args.t, args.net) + ".pt")) 389 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B".format(args.timestamp, args.s, args.t, args.net) + ".pt")) 390 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C".format(args.timestamp, args.s, args.t, args.net) + ".pt")) 391 | 392 | def update_batch_stats(model, flag): 393 | for m in model.modules(): 394 | if isinstance(m, nn.BatchNorm2d): 395 | m.update_batch_stats = flag 396 | 397 | def print_args(args): 398 | s = "==========================================\n" 399 | for arg, content in args.__dict__.items(): 400 | s += "{}:{}\n".format(arg, content) 401 | return s 402 | 403 | if __name__ == "__main__": 404 | parser = argparse.ArgumentParser(description='DINE') 405 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 406 | parser.add_argument('--s', type=str, default=None, help="source") 407 | parser.add_argument('--t', type=str, default=None, help="target") 408 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 409 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 410 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 411 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'office31', 'image-clef', 'office-home', 'office-caltech']) 412 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 413 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101") 414 | parser.add_argument('--output', type=str, default='san') 415 | parser.add_argument('--lr_src', type=float, default=1e-2, help="learning rate") 416 | parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101") 417 | parser.add_argument('--output_src', type=str, default='san') 418 | 419 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 420 | parser.add_argument('--bottleneck', type=int, default=256) 421 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 422 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 423 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 424 | parser.add_argument('--topk', type=int, default=1) 425 | 426 | parser.add_argument('--distill', action='store_true') 427 | parser.add_argument('--ema', type=float, default=0.6) 428 | parser.add_argument('--mix', type=float, default=1.0) 429 | 430 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp') 431 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)), 432 | help='whether use file logger') 433 | parser.add_argument('--names', default=[], type=list, help='names of tasks') 434 | 435 | parser.add_argument('--method', type=str, default="dine") 436 | 437 | args = parser.parse_args() 438 | if args.dset == 'office-home': 439 | args.names = ['Art', 'Clipart', 'Product', 'Real_World'] 440 | args.class_num = 65 441 | if args.dset == 'visda-2017': 442 | args.names = ['train', 'validation'] 443 | args.class_num = 12 444 | if args.dset == 'office31': 445 | args.names = ['amazon', 'dslr', 'webcam'] 446 | args.class_num = 31 447 | 448 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 449 | resetRNGseed(args.seed) 450 | 451 | if not args.distill: 452 | dir = "{}_{}_{}_{}_source".format(args.timestamp, args.s, args.da, args.method) 453 | else: 454 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method) 455 | if args.use_file_logger: 456 | init_logger(dir, True, '../logs/DINE/{}/'.format(args.method)) 457 | logging.info("{}:{}".format(get_hostname(), get_pid())) 458 | 459 | folder = '../data/' 460 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 461 | args.t_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 462 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 463 | 464 | if args.dset == 'office-home': 465 | if args.da == 'pda': 466 | args.class_num = 65 467 | args.src_classes = [i for i in range(65)] 468 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 469 | 470 | 471 | args.output_dir_src = "../checkpoints/DINE/{}/source/{}/".format(args.seed, args.da) 472 | 473 | if not osp.exists(args.output_dir_src): 474 | os.system('mkdir -p ' + args.output_dir_src) 475 | if not osp.exists(args.output_dir_src): 476 | os.mkdir(args.output_dir_src) 477 | 478 | if not args.distill: 479 | logging.info(print_args(args)) 480 | train_source_simp(args) 481 | 482 | for t in args.names: 483 | if t == args.s: 484 | continue 485 | args.t = t 486 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 487 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 488 | 489 | test_target_simp(args) 490 | 491 | if args.distill: 492 | for t in args.names: 493 | if t == args.s: 494 | continue 495 | args.t = t 496 | 497 | args.output_dir = "../checkpoints/DINE/{}/target/{}/".format(args.seed, args.da) 498 | if not osp.exists(args.output_dir): 499 | os.system('mkdir -p ' + args.output_dir) 500 | if not osp.exists(args.output_dir): 501 | os.mkdir(args.output_dir) 502 | 503 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 504 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 505 | 506 | logging.info(print_args(args)) 507 | 508 | copy_target_simp(args) -------------------------------------------------------------------------------- /DINE/DINE_dist_kDINE.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torchvision import transforms 8 | import network 9 | import loss 10 | from torch.utils.data import DataLoader 11 | from data_list import ImageList, ImageList_idx 12 | from loss import CrossEntropyLabelSmooth 13 | from scipy.spatial.distance import cdist 14 | from sklearn.metrics import confusion_matrix 15 | import distutils 16 | import distutils.util 17 | import logging 18 | 19 | import sys, os 20 | sys.path.append("../util/") 21 | from utils import resetRNGseed, init_logger, get_hostname, get_pid 22 | sys.path.append("../pklib") 23 | from pksolver import PK_solver 24 | 25 | import time 26 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 27 | 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | def op_copy(optimizer): 32 | for param_group in optimizer.param_groups: 33 | param_group['lr0'] = param_group['lr'] 34 | return optimizer 35 | 36 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 37 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 38 | for param_group in optimizer.param_groups: 39 | param_group['lr'] = param_group['lr0'] * decay 40 | param_group['weight_decay'] = 1e-3 41 | param_group['momentum'] = 0.9 42 | param_group['nesterov'] = True 43 | return optimizer 44 | 45 | def image_train(resize_size=256, crop_size=224, alexnet=False): 46 | if not alexnet: 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | else: 50 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 51 | return transforms.Compose([ 52 | transforms.Resize((resize_size, resize_size)), 53 | transforms.RandomCrop(crop_size), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | def image_test(resize_size=256, crop_size=224, alexnet=False): 60 | if not alexnet: 61 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | else: 64 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 65 | return transforms.Compose([ 66 | transforms.Resize((resize_size, resize_size)), 67 | transforms.CenterCrop(crop_size), 68 | transforms.ToTensor(), 69 | normalize 70 | ]) 71 | 72 | def data_load(args): 73 | ## prepare data 74 | dsets = {} 75 | dset_loaders = {} 76 | train_bs = args.batch_size 77 | txt_src = open(args.s_dset_path).readlines() 78 | txt_tar = open(args.t_dset_path).readlines() 79 | txt_test = open(args.test_dset_path).readlines() 80 | 81 | count = np.zeros(args.class_num) 82 | tr_txt = [] 83 | te_txt = [] 84 | for i in range(len(txt_src)): 85 | line = txt_src[i] 86 | reci = line.strip().split(' ') 87 | if count[int(reci[1])] < 3: 88 | count[int(reci[1])] += 1 89 | te_txt.append(line) 90 | else: 91 | tr_txt.append(line) 92 | 93 | if not args.da == 'uda': 94 | label_map_s = {} 95 | for i in range(len(args.src_classes)): 96 | label_map_s[args.src_classes[i]] = i 97 | 98 | new_tar = [] 99 | for i in range(len(txt_tar)): 100 | rec = txt_tar[i] 101 | reci = rec.strip().split(' ') 102 | if int(reci[1]) in args.tar_classes: 103 | if int(reci[1]) in args.src_classes: 104 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 105 | new_tar.append(line) 106 | else: 107 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 108 | new_tar.append(line) 109 | txt_tar = new_tar.copy() 110 | txt_test = txt_tar.copy() 111 | 112 | dsets["source_tr"] = ImageList(tr_txt, root="../data/{}/".format(args.dset), transform=image_train()) 113 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 114 | dsets["source_te"] = ImageList(te_txt, root="../data/{}/".format(args.dset), transform=image_test()) 115 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 116 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train()) 117 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 118 | dsets["target_te"] = ImageList(txt_tar, root="../data/{}/".format(args.dset), transform=image_test()) 119 | dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) 120 | dsets["test"] = ImageList(txt_test, root="../data/{}/".format(args.dset), transform=image_test()) 121 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False) 122 | 123 | return dset_loaders 124 | 125 | def cal_acc(loader, netF, netB, netC, flag=False): 126 | start_test = True 127 | with torch.no_grad(): 128 | iter_test = iter(loader) 129 | for i in range(len(loader)): 130 | data = iter_test.next() 131 | inputs = data[0] 132 | labels = data[1] 133 | inputs = inputs.cuda() 134 | if netB is None: 135 | outputs = netC(netF(inputs)) 136 | else: 137 | outputs = netC(netB(netF(inputs))) 138 | if start_test: 139 | all_output = outputs.float().cpu() 140 | all_label = labels.float() 141 | start_test = False 142 | else: 143 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 144 | all_label = torch.cat((all_label, labels.float()), 0) 145 | 146 | all_output = nn.Softmax(dim=1)(all_output) 147 | _, predict = torch.max(all_output, 1) 148 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 149 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0]) 150 | 151 | if flag: 152 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 153 | matrix = matrix[np.unique(all_label).astype(int),:] 154 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 155 | aacc = acc.mean() 156 | aa = [str(np.round(i, 2)) for i in acc] 157 | acc = ' '.join(aa) 158 | return aacc, acc, mean_ent 159 | else: 160 | return accuracy*100, mean_ent 161 | 162 | def train_source_simp(args): 163 | dset_loaders = data_load(args) 164 | if args.net_src[0:3] == 'res': 165 | netF = network.ResBase(res_name=args.net_src).cuda() 166 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda() 167 | 168 | param_group = [] 169 | learning_rate = args.lr_src 170 | for k, v in netF.named_parameters(): 171 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 172 | for k, v in netC.named_parameters(): 173 | param_group += [{'params': v, 'lr': learning_rate}] 174 | optimizer = optim.SGD(param_group) 175 | optimizer = op_copy(optimizer) 176 | 177 | acc_init = 0 178 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 179 | interval_iter = max_iter // 10 180 | iter_num = 0 181 | 182 | netF.train() 183 | netC.train() 184 | 185 | while iter_num < max_iter: 186 | try: 187 | inputs_source, labels_source = iter_source.next() 188 | except: 189 | iter_source = iter(dset_loaders["source_tr"]) 190 | inputs_source, labels_source = iter_source.next() 191 | 192 | if inputs_source.size(0) == 1: 193 | continue 194 | 195 | iter_num += 1 196 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 197 | 198 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 199 | outputs_source = netC(netF(inputs_source)) 200 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0.1)(outputs_source, labels_source) 201 | 202 | optimizer.zero_grad() 203 | classifier_loss.backward() 204 | optimizer.step() 205 | 206 | if iter_num % interval_iter == 0 or iter_num == max_iter: 207 | netF.eval() 208 | netC.eval() 209 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, False) 210 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te) 211 | if args.dset == 'visda-2017': 212 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, True) 213 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, 214 | acc_s_te) + '\n' + acc_list 215 | logging.info(log_str) 216 | 217 | if acc_s_te >= acc_init: 218 | acc_init = acc_s_te 219 | best_netF = netF.state_dict() 220 | best_netC = netC.state_dict() 221 | 222 | netF.train() 223 | netC.train() 224 | 225 | torch.save(best_netF, osp.join(args.output_dir_src,'{}_{}_source_F.pt'.format(args.s, args.net_src))) 226 | torch.save(best_netC, osp.join(args.output_dir_src, '{}_{}_source_C.pt'.format(args.s, args.net_src))) 227 | 228 | return netF, netC 229 | 230 | def test_target_simp(args): 231 | dset_loaders = data_load(args) 232 | if args.net_src[0:3] == 'res': 233 | netF = network.ResBase(res_name=args.net_src).cuda() 234 | netC = network.feat_classifier_simpl(class_num = args.class_num, feat_dim=netF.in_features).cuda() 235 | 236 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src) 237 | netF.load_state_dict(torch.load(args.modelpath)) 238 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src) 239 | netC.load_state_dict(torch.load(args.modelpath)) 240 | netF.eval() 241 | netC.eval() 242 | 243 | acc, _ = cal_acc(dset_loaders['test'], netF, None, netC, False) 244 | log_str = '\nTask: {}->{}, Accuracy = {:.2f}%'.format(args.s, args.t, acc) 245 | if args.dset == 'visda-2017': 246 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['test'], netF, None, netC, True) 247 | log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.s, acc_s_te) + '\n' + acc_list 248 | 249 | logging.info(log_str) 250 | 251 | def copy_target_simp(args): 252 | dset_loaders = data_load(args) 253 | if args.net_src[0:3] == 'res': 254 | netF = network.ResBase(res_name=args.net_src).cuda() 255 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda() 256 | 257 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src) 258 | netF.load_state_dict(torch.load(args.modelpath)) 259 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src) 260 | netC.load_state_dict(torch.load(args.modelpath)) 261 | source_model = nn.Sequential(netF, netC).cuda() 262 | source_model.eval() 263 | 264 | if args.net[0:3] == 'res': 265 | netF = network.ResBase(res_name=args.net, pretrain=True).cuda() 266 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 267 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 268 | 269 | param_group = [] 270 | learning_rate = args.lr 271 | for k, v in netF.named_parameters(): 272 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 273 | for k, v in netB.named_parameters(): 274 | param_group += [{'params': v, 'lr': learning_rate}] 275 | for k, v in netC.named_parameters(): 276 | param_group += [{'params': v, 'lr': learning_rate}] 277 | optimizer = optim.SGD(param_group) 278 | optimizer = op_copy(optimizer) 279 | 280 | ent_best = 1.0 281 | max_iter = args.max_epoch * len(dset_loaders["target"]) 282 | interval_iter = max_iter // 10 283 | iter_num = 0 284 | 285 | model = nn.Sequential(netF, netB, netC).cuda() 286 | model.eval() 287 | 288 | start_test = True 289 | with torch.no_grad(): 290 | iter_test = iter(dset_loaders["target_te"]) 291 | for i in range(len(dset_loaders["target_te"])): 292 | data = iter_test.next() 293 | inputs, labels = data[0], data[1] 294 | inputs = inputs.cuda() 295 | outputs = source_model(inputs) 296 | outputs = nn.Softmax(dim=1)(outputs) 297 | _, src_idx = torch.sort(outputs, 1, descending=True) 298 | if args.topk > 0: 299 | topk = np.min([args.topk, args.class_num]) 300 | for i in range(outputs.size()[0]): 301 | outputs[i, src_idx[i, topk:]] = (1.0 - outputs[i, src_idx[i, :topk]].sum())/ (outputs.size()[1] - topk) 302 | 303 | if start_test: 304 | all_output = outputs.float() 305 | all_label = labels 306 | start_test = False 307 | else: 308 | all_output = torch.cat((all_output, outputs.float()), 0) 309 | all_label = torch.cat((all_label, labels), 0) 310 | mem_P = all_output.detach() 311 | 312 | # get ground-truth label probabilities of target domain 313 | cls_probs = torch.eye(args.class_num)[all_label].sum(0) 314 | cls_probs = cls_probs / cls_probs.sum() 315 | 316 | pk_solver = PK_solver(all_label.shape[0], args.class_num, pk_prior_weight=args.pk_prior_weight) 317 | if args.pk_type == 'ub': 318 | pk_solver.create_C_ub(cls_probs.cpu().numpy(), args.pk_uconf) 319 | elif args.pk_type == 'br': 320 | pk_solver.create_C_br(cls_probs.cpu().numpy(), args.pk_uconf) 321 | else: 322 | raise NotImplementedError 323 | 324 | mem_label = obtain_label(mem_P.cpu(), all_label.cpu(), None, args, pk_solver) 325 | mem_label = torch.from_numpy(mem_label).cuda() 326 | mem_label = torch.eye(args.class_num)[mem_label].cuda() 327 | 328 | model.train() 329 | while iter_num < max_iter: 330 | 331 | if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0: 332 | model.eval() 333 | start_test = True 334 | with torch.no_grad(): 335 | iter_test = iter(dset_loaders["target_te"]) 336 | for i in range(len(dset_loaders["target_te"])): 337 | data = iter_test.next() 338 | inputs = data[0] 339 | inputs = inputs.cuda() 340 | outputs = model(inputs) 341 | feas = model[1](model[0](inputs)) 342 | outputs = nn.Softmax(dim=1)(outputs) 343 | if start_test: 344 | all_fea = feas.float().cpu() 345 | all_output = outputs.float() 346 | start_test = False 347 | else: 348 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 349 | all_output = torch.cat((all_output, outputs.float()), 0) 350 | mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema) 351 | model.train() 352 | 353 | mem_label = obtain_label(mem_P.cpu(), all_label.cpu(), all_fea, args, pk_solver) 354 | mem_label = torch.from_numpy(mem_label).cuda() 355 | mem_label = torch.eye(args.class_num)[mem_label].cuda() 356 | 357 | try: 358 | inputs_target, y, tar_idx = iter_target.next() 359 | except: 360 | iter_target = iter(dset_loaders["target"]) 361 | inputs_target, y, tar_idx = iter_target.next() 362 | 363 | if inputs_target.size(0) == 1: 364 | continue 365 | 366 | iter_num += 1 367 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5) 368 | inputs_target = inputs_target.cuda() 369 | with torch.no_grad(): 370 | outputs_target_by_source = mem_P[tar_idx, :] 371 | _, src_idx = torch.sort(outputs_target_by_source, 1, descending=True) 372 | outputs_target = model(inputs_target) 373 | outputs_target = torch.nn.Softmax(dim=1)(outputs_target) 374 | 375 | target = (outputs_target_by_source + mem_label[tar_idx, :]*0.9 + 1/mem_label.shape[-1]*0.1) / 2 376 | if iter_num < interval_iter and args.dset == "visda-2017": 377 | target = outputs_target_by_source 378 | 379 | classifier_loss = nn.KLDivLoss(reduction='batchmean')(outputs_target.log(), target) 380 | optimizer.zero_grad() 381 | 382 | entropy_loss = torch.mean(loss.Entropy(outputs_target)) 383 | msoftmax = outputs_target.mean(dim=0) 384 | gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5)) 385 | entropy_loss -= gentropy_loss 386 | classifier_loss += entropy_loss 387 | 388 | classifier_loss.backward() 389 | 390 | if args.mix > 0: 391 | alpha = 0.3 392 | lam = np.random.beta(alpha, alpha) 393 | index = torch.randperm(inputs_target.size()[0]).cuda() 394 | mixed_input = lam * inputs_target + (1 - lam) * inputs_target[index, :] 395 | mixed_output = (lam * outputs_target + (1 - lam) * outputs_target[index, :]).detach() 396 | 397 | update_batch_stats(model, False) 398 | outputs_target_m = model(mixed_input) 399 | update_batch_stats(model, True) 400 | outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m) 401 | classifier_loss = args.mix*nn.KLDivLoss(reduction='batchmean')(outputs_target_m.log(), mixed_output) 402 | classifier_loss.backward() 403 | optimizer.step() 404 | 405 | if iter_num % interval_iter == 0 or iter_num == max_iter: 406 | model.eval() 407 | acc_s_te, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False) 408 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent) 409 | if args.dset == 'visda-2017': 410 | acc_s_te, acc_list, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True) 411 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, 412 | acc_s_te, mean_ent) + '\n' + acc_list 413 | logging.info(log_str) 414 | model.train() 415 | 416 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F".format(args.timestamp, args.s, args.t, args.net) + ".pt")) 417 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B".format(args.timestamp, args.s, args.t, args.net) + ".pt")) 418 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C".format(args.timestamp, args.s, args.t, args.net) + ".pt")) 419 | 420 | 421 | def obtain_label(mem_P, all_label, all_fea, args, pk_solver): 422 | predict = mem_P.argmax(-1) 423 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 424 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 425 | avg_accuracy = (matrix.diagonal() / matrix.sum(axis=1)).mean() 426 | 427 | # update labels with prior knowledge 428 | probs = mem_P 429 | # first solve without smooth regularization 430 | pred_label_PK = pk_solver.solve_soft(probs) 431 | 432 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / float(all_label.size()[0]) 433 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK) 434 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean() 435 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc_PK * 100, avg_accuracy * 100, avg_acc_PK * 100) 436 | logging.info(log_str) 437 | 438 | if args.pk_knn > 0 and all_fea is not None: 439 | # now solve with smooth regularization 440 | predict = predict.cpu().numpy() 441 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 442 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 443 | all_fea = all_fea.float().cpu().numpy() 444 | 445 | idx_unconf = np.where(pred_label_PK != predict)[0] 446 | knn_sample_idx = idx_unconf 447 | idx_conf = np.where(pred_label_PK == predict)[0] 448 | 449 | if len(idx_unconf) > 0 and len(idx_conf) > 0: 450 | # get knn of each samples 451 | dd_knn = cdist(all_fea[idx_unconf], all_fea[idx_conf], args.distance) 452 | knn_idx = [] 453 | K = args.pk_knn 454 | for i in range(dd_knn.shape[0]): 455 | ind = np.argpartition(dd_knn[i], K)[:K] 456 | knn_idx.append(idx_conf[ind]) 457 | 458 | knn_idx = np.stack(knn_idx, axis=0) 459 | knn_regs = list(zip(knn_sample_idx, knn_idx)) 460 | pred_label_PK = pk_solver.solve_soft_knn_cst(probs, knn_regs=knn_regs) 461 | 462 | 463 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / len(all_fea) 464 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK) 465 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean() 466 | if args.da == 'pda': 467 | avg_acc_PK = 0 468 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc_PK * 100, avg_accuracy * 100, avg_acc_PK * 100) 469 | logging.info(log_str) 470 | 471 | return pred_label_PK.astype('int') 472 | 473 | def update_batch_stats(model, flag): 474 | for m in model.modules(): 475 | if isinstance(m, nn.BatchNorm2d): 476 | m.update_batch_stats = flag 477 | 478 | def print_args(args): 479 | s = "==========================================\n" 480 | for arg, content in args.__dict__.items(): 481 | s += "{}:{}\n".format(arg, content) 482 | return s 483 | 484 | if __name__ == "__main__": 485 | parser = argparse.ArgumentParser(description='DINE') 486 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 487 | parser.add_argument('--s', type=str, default=None, help="source") 488 | parser.add_argument('--t', type=str, default=None, help="target") 489 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 490 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 491 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 492 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'office31', 'image-clef', 'office-home', 'office-caltech']) 493 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 494 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101") 495 | parser.add_argument('--output', type=str, default='san') 496 | parser.add_argument('--lr_src', type=float, default=1e-2, help="learning rate") 497 | parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101") 498 | parser.add_argument('--output_src', type=str, default='san') 499 | 500 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 501 | parser.add_argument('--bottleneck', type=int, default=256) 502 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 503 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 504 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 505 | parser.add_argument('--topk', type=int, default=1) 506 | 507 | parser.add_argument('--distill', action='store_true') 508 | parser.add_argument('--ema', type=float, default=0.6) 509 | parser.add_argument('--mix', type=float, default=1.0) 510 | 511 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp') 512 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)), 513 | help='whether use file logger') 514 | parser.add_argument('--names', default=[], type=list, help='names of tasks') 515 | 516 | parser.add_argument('--cls_par', type=float, default=0.3) 517 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 518 | 519 | parser.add_argument('--pk_uconf', type=float, default=0.0) 520 | parser.add_argument('--pk_type', type=str, default="ub") 521 | parser.add_argument('--pk_allow', type=int, default=None) 522 | parser.add_argument('--pk_temp', type=float, default=1.0) 523 | parser.add_argument('--pk_prior_weight', type=float, default=10.) 524 | parser.add_argument('--pk_knn', type=int, default=1) 525 | parser.add_argument('--method', type=str, default="kdine") 526 | 527 | args = parser.parse_args() 528 | 529 | if args.dset == 'office-home': 530 | args.names = ['Art', 'Clipart', 'Product', 'Real_World'] 531 | args.class_num = 65 532 | if args.dset == 'visda-2017': 533 | args.names = ['train', 'validation'] 534 | args.class_num = 12 535 | if args.dset == 'office31': 536 | args.names = ['amazon', 'dslr', 'webcam'] 537 | args.class_num = 31 538 | 539 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 540 | resetRNGseed(args.seed) 541 | 542 | if not args.distill: 543 | dir = "{}_{}_{}_{}_source".format(args.timestamp, args.s, args.da, args.method) 544 | else: 545 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method) 546 | if args.use_file_logger: 547 | init_logger(dir, True, '../logs/DINE/{}/'.format(args.method)) 548 | logging.info("{}:{}".format(get_hostname(), get_pid())) 549 | 550 | folder = '../data/' 551 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 552 | args.t_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 553 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 554 | 555 | if args.dset == 'office-home': 556 | if args.da == 'pda': 557 | args.class_num = 65 558 | args.src_classes = [i for i in range(65)] 559 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 560 | 561 | args.output_dir_src = "../checkpoints/DINE/{}/source/{}/".format(args.seed, args.da) 562 | 563 | if not osp.exists(args.output_dir_src): 564 | os.system('mkdir -p ' + args.output_dir_src) 565 | if not osp.exists(args.output_dir_src): 566 | os.mkdir(args.output_dir_src) 567 | 568 | if not args.distill: 569 | logging.info(print_args(args)) 570 | train_source_simp(args) 571 | 572 | for t in args.names: 573 | if t == args.s: 574 | continue 575 | args.t = t 576 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 577 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 578 | 579 | test_target_simp(args) 580 | 581 | if args.distill: 582 | for t in args.names: 583 | if t == args.s: 584 | continue 585 | args.t = t 586 | args.output_dir = "../checkpoints/DINE/{}/target/{}/".format(args.seed, args.da) 587 | if not osp.exists(args.output_dir): 588 | os.system('mkdir -p ' + args.output_dir) 589 | if not osp.exists(args.output_dir): 590 | os.mkdir(args.output_dir) 591 | 592 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 593 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 594 | 595 | logging.info(print_args(args)) 596 | 597 | copy_target_simp(args) -------------------------------------------------------------------------------- /DINE/DINE_ft.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | import distutils 18 | import distutils.util 19 | import logging 20 | 21 | import sys 22 | sys.path.append("../util/") 23 | from utils import resetRNGseed, init_logger, get_hostname, get_pid 24 | 25 | import time 26 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 27 | 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | def op_copy(optimizer): 32 | for param_group in optimizer.param_groups: 33 | param_group['lr0'] = param_group['lr'] 34 | return optimizer 35 | 36 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 37 | decay = (11 + gamma * iter_num / max_iter) ** (-power) 38 | # decay = (1 + gamma) ** (-power) 39 | for param_group in optimizer.param_groups: 40 | param_group['lr'] = param_group['lr0'] * decay 41 | param_group['weight_decay'] = 1e-3 42 | param_group['momentum'] = 0.9 43 | param_group['nesterov'] = True 44 | return optimizer 45 | 46 | def image_train(resize_size=256, crop_size=224, alexnet=False): 47 | if not alexnet: 48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | else: 51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 52 | return transforms.Compose([ 53 | transforms.Resize((resize_size, resize_size)), 54 | transforms.RandomCrop(crop_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | normalize 58 | ]) 59 | 60 | def image_test(resize_size=256, crop_size=224, alexnet=False): 61 | if not alexnet: 62 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225]) 64 | else: 65 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 66 | return transforms.Compose([ 67 | transforms.Resize((resize_size, resize_size)), 68 | transforms.CenterCrop(crop_size), 69 | transforms.ToTensor(), 70 | normalize 71 | ]) 72 | 73 | def data_load(args): 74 | ## prepare data 75 | dsets = {} 76 | dset_loaders = {} 77 | train_bs = args.batch_size 78 | txt_tar = open(args.t_dset_path).readlines() 79 | txt_test = open(args.test_dset_path).readlines() 80 | 81 | if not args.da == 'uda': 82 | label_map_s = {} 83 | for i in range(len(args.src_classes)): 84 | label_map_s[args.src_classes[i]] = i 85 | 86 | new_tar = [] 87 | for i in range(len(txt_tar)): 88 | rec = txt_tar[i] 89 | reci = rec.strip().split(' ') 90 | if int(reci[1]) in args.tar_classes: 91 | if int(reci[1]) in args.src_classes: 92 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 93 | new_tar.append(line) 94 | else: 95 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 96 | new_tar.append(line) 97 | txt_tar = new_tar.copy() 98 | txt_test = txt_tar.copy() 99 | 100 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train()) 101 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 102 | dsets["test"] = ImageList_idx(txt_test, root="../data/{}/".format(args.dset), transform=image_test()) 103 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 104 | dsets["target_te"] = ImageList(txt_tar, root="../data/{}/".format(args.dset), transform=image_test()) 105 | dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 106 | 107 | return dset_loaders 108 | 109 | def cal_acc(loader, netF, netB, netC, flag=False): 110 | start_test = True 111 | with torch.no_grad(): 112 | iter_test = iter(loader) 113 | for i in range(len(loader)): 114 | data = iter_test.next() 115 | inputs = data[0] 116 | labels = data[1] 117 | inputs = inputs.cuda() 118 | outputs = netC(netB(netF(inputs))) 119 | if start_test: 120 | all_output = outputs.float().cpu() 121 | all_label = labels.float() 122 | start_test = False 123 | else: 124 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 125 | all_label = torch.cat((all_label, labels.float()), 0) 126 | all_output = nn.Softmax(dim=1)(all_output) 127 | _, predict = torch.max(all_output, 1) 128 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 129 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0]) 130 | 131 | if flag: 132 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 133 | matrix = matrix[np.unique(all_label).astype(int),:] 134 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 135 | aacc = acc.mean() 136 | aa = [str(np.round(i, 2)) for i in acc] 137 | acc = ' '.join(aa) 138 | return aacc, acc, predict, mean_ent 139 | else: 140 | return accuracy*100, mean_ent, predict, mean_ent 141 | 142 | def train_target(args): 143 | dset_loaders = data_load(args) 144 | if args.net[0:3] == 'res': 145 | netF = network.ResBase(res_name=args.net).cuda() 146 | 147 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 148 | netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() 149 | 150 | modelpath = osp.join(args.output_dir, "{}_{}_{}_{}_target_F".format(args.timestamp, args.s, args.t, args.net) + ".pt" ) 151 | netF.load_state_dict(torch.load(modelpath)) 152 | modelpath = osp.join(args.output_dir, "{}_{}_{}_{}_target_B".format(args.timestamp, args.s, args.t, args.net) + ".pt") 153 | netB.load_state_dict(torch.load(modelpath)) 154 | modelpath = osp.join(args.output_dir, "{}_{}_{}_{}_target_C".format(args.timestamp, args.s, args.t, args.net) + ".pt") 155 | netC.load_state_dict(torch.load(modelpath)) 156 | 157 | param_group = [] 158 | for k, v in netF.named_parameters(): 159 | param_group += [{'params': v, 'lr': args.lr*0.1}] 160 | for k, v in netB.named_parameters(): 161 | param_group += [{'params': v, 'lr': args.lr}] 162 | for k, v in netC.named_parameters(): 163 | param_group += [{'params': v, 'lr': args.lr}] 164 | 165 | optimizer = optim.SGD(param_group) 166 | optimizer = op_copy(optimizer) 167 | 168 | max_iter = args.max_epoch * len(dset_loaders["target"]) 169 | interval_iter = max_iter // 10 170 | iter_num = 0 171 | 172 | netF.eval() 173 | netB.eval() 174 | netC.eval() 175 | acc_s_te, _, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False) 176 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy={:.2f}%, Ent={:.3f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent) 177 | if args.dset == 'visda-2017': 178 | acc_s_te, acc_list, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True) 179 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, 180 | mean_ent) + '\n' + acc_list 181 | 182 | logging.info(log_str) 183 | netF.train() 184 | netB.train() 185 | netC.train() 186 | 187 | old_pry = 0 188 | while iter_num < max_iter: 189 | optimizer.zero_grad() 190 | try: 191 | inputs_test, _, tar_idx = iter_test.next() 192 | except: 193 | iter_test = iter(dset_loaders["target"]) 194 | inputs_test, _, tar_idx = iter_test.next() 195 | 196 | if inputs_test.size(0) == 1: 197 | continue 198 | 199 | inputs_test = inputs_test.cuda() 200 | 201 | iter_num += 1 202 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=0.75) 203 | 204 | features_test = netB(netF(inputs_test)) 205 | outputs_test = netC(features_test) 206 | 207 | softmax_out = nn.Softmax(dim=1)(outputs_test) 208 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 209 | 210 | msoftmax = softmax_out.mean(dim=0) 211 | gentropy_loss = -torch.sum(msoftmax * torch.log(msoftmax + 1e-5)) 212 | entropy_loss -= gentropy_loss 213 | entropy_loss.backward() 214 | optimizer.step() 215 | 216 | if iter_num % interval_iter == 0 or iter_num == max_iter: 217 | netF.eval() 218 | netB.eval() 219 | netC.eval() 220 | acc_s_te, _, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False) 221 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy={:.2f}%, Ent={:.3f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent) 222 | if args.dset == 'visda-2017': 223 | acc_s_te, acc_list, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True) 224 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent) + '\n' + acc_list 225 | logging.info(log_str) 226 | 227 | netF.train() 228 | netB.train() 229 | netC.train() 230 | 231 | if torch.abs(pry - old_pry).sum() == 0: 232 | break 233 | else: 234 | old_pry = pry.clone() 235 | 236 | return netF, netB, netC 237 | 238 | def print_args(args): 239 | s = "==========================================\n" 240 | for arg, content in args.__dict__.items(): 241 | s += "{}:{}\n".format(arg, content) 242 | return s 243 | 244 | if __name__ == "__main__": 245 | parser = argparse.ArgumentParser(description='DINE') 246 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 247 | parser.add_argument('--s', type=str, default=None, help="source") 248 | parser.add_argument('--t', type=str, default=None, help="target") 249 | parser.add_argument('--max_epoch', type=int, default=30, help="max iterations") 250 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 251 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 252 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'office31', 'image-clef', 'office-home', 'office-caltech']) 253 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 254 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet50, resnext50") 255 | parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101") 256 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 257 | 258 | parser.add_argument('--bottleneck', type=int, default=256) 259 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 260 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 261 | parser.add_argument('--output', type=str, default='san') 262 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 263 | 264 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp') 265 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)), 266 | help='whether use file logger') 267 | parser.add_argument('--names', default=[], type=list, help='names of tasks') 268 | parser.add_argument('--method', type=str, default=None) 269 | 270 | args = parser.parse_args() 271 | if args.dset == 'office-home': 272 | args.names = ['Art', 'Clipart', 'Product', 'Real_World'] 273 | args.class_num = 65 274 | if args.dset == 'visda-2017': 275 | args.names = ['train', 'validation'] 276 | args.class_num = 12 277 | if args.dset == 'office31': 278 | args.names = ['amazon', 'dslr', 'webcam'] 279 | args.class_num = 31 280 | 281 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 282 | resetRNGseed(args.seed) 283 | 284 | if args.dset == 'office-home': 285 | if args.da == 'pda': 286 | args.class_num = 65 287 | args.src_classes = [i for i in range(65)] 288 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 289 | 290 | if args.method is not None: 291 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method) 292 | if args.use_file_logger: 293 | init_logger(dir, True, '../logs/DINE/{}/'.format(args.method)) 294 | else: 295 | dir = "{}_{}_{}".format(args.timestamp, args.s, args.da) 296 | if args.use_file_logger: 297 | init_logger(dir, True, '../logs/DINE/') 298 | logging.info("{}:{}".format(get_hostname(), get_pid())) 299 | 300 | folder = '../data/' 301 | for t in args.names: 302 | if t == args.s: 303 | continue 304 | args.t = t 305 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 306 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 307 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 308 | 309 | args.output_dir = "../checkpoints/DINE/{}/target/{}/".format(args.seed, args.da) 310 | 311 | 312 | if not osp.exists(args.output_dir): 313 | os.system('mkdir -p ' + args.output_dir) 314 | if not osp.exists(args.output_dir): 315 | os.mkdir(args.output_dir) 316 | 317 | logging.info(print_args(args)) 318 | 319 | train_target(args) -------------------------------------------------------------------------------- /DINE/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import os 9 | import os.path 10 | 11 | import cv2 12 | import torchvision 13 | 14 | def make_dataset(image_list, labels): 15 | if labels: 16 | len_ = len(image_list) 17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 18 | else: 19 | if len(image_list[0].split()) > 2: 20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 21 | else: 22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 23 | return images 24 | 25 | 26 | def rgb_loader(path): 27 | with open(path, 'rb') as f: 28 | with Image.open(f) as img: 29 | return img.convert('RGB') 30 | 31 | def l_loader(path): 32 | with open(path, 'rb') as f: 33 | with Image.open(f) as img: 34 | return img.convert('L') 35 | 36 | class ImageList(Dataset): 37 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'): 38 | imgs = make_dataset(image_list, labels) 39 | if len(imgs) == 0: 40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 42 | 43 | self.root = root 44 | self.imgs = imgs 45 | self.transform = transform 46 | self.target_transform = target_transform 47 | if mode == 'RGB': 48 | self.loader = rgb_loader 49 | elif mode == 'L': 50 | self.loader = l_loader 51 | 52 | def __getitem__(self, index): 53 | path, target = self.imgs[index] 54 | path = os.path.join(self.root, path) 55 | img = self.loader(path) 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | if self.target_transform is not None: 59 | target = self.target_transform(target) 60 | 61 | return img, target 62 | 63 | def __len__(self): 64 | return len(self.imgs) 65 | 66 | class ImageList_idx(Dataset): 67 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'): 68 | imgs = make_dataset(image_list, labels) 69 | if len(imgs) == 0: 70 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 71 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 72 | 73 | self.root = root 74 | self.imgs = imgs 75 | self.transform = transform 76 | self.target_transform = target_transform 77 | if mode == 'RGB': 78 | self.loader = rgb_loader 79 | elif mode == 'L': 80 | self.loader = l_loader 81 | 82 | def __getitem__(self, index): 83 | path, target = self.imgs[index] 84 | path = os.path.join(self.root, path) 85 | img = self.loader(path) 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return img, target, index 92 | 93 | def __len__(self): 94 | return len(self.imgs) -------------------------------------------------------------------------------- /DINE/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | class CrossEntropyLabelSmooth(nn.Module): 17 | """Cross entropy loss with label smoothing regularizer. 18 | Reference: 19 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 20 | Equation: y = (1 - epsilon) * y + epsilon / K. 21 | Args: 22 | num_classes (int): number of classes. 23 | epsilon (float): weight. 24 | """ 25 | 26 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 27 | super(CrossEntropyLabelSmooth, self).__init__() 28 | self.num_classes = num_classes 29 | self.epsilon = epsilon 30 | self.use_gpu = use_gpu 31 | self.reduction = reduction 32 | self.logsoftmax = nn.LogSoftmax(dim=1) 33 | 34 | def forward(self, inputs, targets): 35 | """ 36 | Args: 37 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 38 | targets: ground truth labels with shape (num_classes) 39 | """ 40 | log_probs = self.logsoftmax(inputs) 41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 42 | if self.use_gpu: targets = targets.cuda() 43 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 44 | loss = (- targets * log_probs).sum(dim=1) 45 | if self.reduction: 46 | return loss.mean() 47 | else: 48 | return loss 49 | return loss -------------------------------------------------------------------------------- /DINE/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import pdb 9 | import torch.nn.utils.weight_norm as weightNorm 10 | from collections import OrderedDict 11 | 12 | def init_weights(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 15 | nn.init.kaiming_uniform_(m.weight) 16 | nn.init.zeros_(m.bias) 17 | elif classname.find('BatchNorm') != -1: 18 | nn.init.normal_(m.weight, 1.0, 0.02) 19 | nn.init.zeros_(m.bias) 20 | elif classname.find('Linear') != -1: 21 | nn.init.xavier_normal_(m.weight) 22 | nn.init.zeros_(m.bias) 23 | 24 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50, 25 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d} 26 | 27 | class ResBase(nn.Module): 28 | def __init__(self, res_name, pretrain=True): 29 | super(ResBase, self).__init__() 30 | model_resnet = res_dict[res_name](pretrained=pretrain) 31 | self.conv1 = model_resnet.conv1 32 | self.bn1 = model_resnet.bn1 33 | self.relu = model_resnet.relu 34 | self.maxpool = model_resnet.maxpool 35 | self.layer1 = model_resnet.layer1 36 | self.layer2 = model_resnet.layer2 37 | self.layer3 = model_resnet.layer3 38 | self.layer4 = model_resnet.layer4 39 | self.avgpool = model_resnet.avgpool 40 | self.in_features = model_resnet.fc.in_features 41 | 42 | def forward(self, x): 43 | x = self.conv1(x) 44 | x = self.bn1(x) 45 | x = self.relu(x) 46 | x = self.maxpool(x) 47 | x = self.layer1(x) 48 | x = self.layer2(x) 49 | x = self.layer3(x) 50 | x = self.layer4(x) 51 | x = self.avgpool(x) 52 | x = x.view(x.size(0), -1) 53 | return x 54 | 55 | class feat_bootleneck(nn.Module): 56 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 57 | super(feat_bootleneck, self).__init__() 58 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.dropout = nn.Dropout(p=0.5) 61 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 62 | self.bottleneck.apply(init_weights) 63 | self.type = type 64 | 65 | def forward(self, x): 66 | x = self.bottleneck(x) 67 | if self.type == "bn" or self.type == "bn_relu" or self.type == "bn_relu_drop": 68 | x = self.bn(x) 69 | if self.type == "bn_relu" or self.type == "bn_relu_drop": 70 | x = self.relu(x) 71 | if self.type == "bn_relu_drop": 72 | x = self.dropout(x) 73 | return x 74 | 75 | class feat_classifier(nn.Module): 76 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 77 | super(feat_classifier, self).__init__() 78 | self.type = type 79 | if type == 'wn': 80 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 81 | self.fc.apply(init_weights) 82 | elif type == 'linear': 83 | self.fc = nn.Linear(bottleneck_dim, class_num) 84 | self.fc.apply(init_weights) 85 | else: 86 | self.fc = nn.Linear(bottleneck_dim, class_num, bias=False) 87 | nn.init.xavier_normal_(self.fc.weight) 88 | 89 | def forward(self, x): 90 | if not self.type in {'wn', 'linear'}: 91 | w = self.fc.weight 92 | w = torch.nn.functional.normalize(w, dim=1, p=2) 93 | 94 | x = torch.nn.functional.normalize(x, dim=1, p=2) 95 | x = torch.nn.functional.linear(x, w) 96 | else: 97 | x = self.fc(x) 98 | return x 99 | 100 | class feat_classifier_simpl(nn.Module): 101 | def __init__(self, class_num, feat_dim): 102 | super(feat_classifier_simpl, self).__init__() 103 | self.fc = nn.Linear(feat_dim, class_num) 104 | nn.init.xavier_normal_(self.fc.weight) 105 | 106 | def forward(self, x): 107 | x = self.fc(x) 108 | return x -------------------------------------------------------------------------------- /DINE/run_all_kDINE.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | time=`python ../util/get_time.py` 5 | 6 | # office31 ------------------------------------------------------------------------------------------------------------- 7 | for seed in 2020 2021 2022; do 8 | for src in 'webcam' 'amazon' 'dslr' ; do 9 | echo $src 10 | python DINE_dist.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 50 --timestamp $time 11 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 12 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf $pk_uconf 13 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine 14 | done 15 | 16 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type br --pk_uconf 1.0 17 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine 18 | done 19 | done 20 | 21 | 22 | # office-home ---------------------------------------------------------------------------------------------------------- 23 | for seed in 2020 2021 2022; do 24 | for src in 'Product' 'Real_World' 'Art' 'Clipart' ; do 25 | echo $src 26 | python DINE_dist.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 50 --timestamp $time 27 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 28 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf $pk_uconf 29 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine 30 | done 31 | 32 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type br --pk_uconf 1.0 33 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine 34 | done 35 | done 36 | 37 | # office-home (PDA)----------------------------------------------------------------------------------------------------- 38 | for seed in 2020 2021 2022; do 39 | for src in 'Product' 'Real_World' 'Art' 'Clipart' ; do 40 | echo $src 41 | python DINE_dist.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da pda --net_src resnet50 --max_epoch 50 --timestamp $time 42 | 43 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da pda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf 0.0 44 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da pda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type br --pk_uconf 1.0 45 | done 46 | done -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 tsun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KUDA 2 | Pytorch implementation of KUDA. 3 | > [Prior Knowledge Guided Unsupervised Domain Adaptation](https://arxiv.org/abs/2207.08877) 4 | > Tao Sun, Cheng Lu, and Haibin Ling 5 | > *ECCV 2022* 6 | 7 | ## Abstract 8 | The waive of labels in the target domain makes Unsupervised Domain Adaptation (UDA) an attractive technique in many real-world applications, though it also brings great challenges as model adaptation becomes harder without labeled target data. In this paper, we address this issue by seeking compensation from target domain prior knowledge, which is often (partially) available in practice, e.g., from human expertise. This leads to a novel yet practical setting where in addition to the training data, some prior knowledge about the target class distribution are available. We term the setting as Knowledge-guided Unsupervised Domain Adaptation (KUDA). In particular, we consider two specific types of prior knowledge about the class distribution in the target domain: Unary Bound that describes the lower and upper bounds of individual class probabilities, and Binary Relationship that describes the relations between two class probabilities. We propose a general rectification module that uses such prior knowledge to refine model generated pseudo labels. The module is formulated as a Zero-One Programming problem derived from the prior knowledge and a smooth regularizer. It can be easily plugged into self-training based UDA methods, and we combine it with two state-of-the-art methods, SHOT and DINE. Empirical results on four benchmarks confirm that the rectification module clearly improves the quality of pseudo labels, which in turn benefits the self-training stage. With the guidance from prior knowledge, the performances of both methods are substantially boosted. We expect our work to inspire further investigations in integrating prior knowledge in UDA. 9 | 10 | ### Knowledge-guided Unsupervised Domain Adaptation (KUDA) 11 | 12 | 13 | ### Integrating rectification module into SHOT and DINE 14 | 15 | 16 | ## Usage 17 | ### Prerequisites 18 | 19 | We experimented with python==3.8, pytorch==1.8.0, cudatoolkit==11.1, gurobi==9.5.0. 20 | 21 | For Zero-One programming, we use [Gurobi Optimizer](https://www.gurobi.com/). A free [academic license](https://www.gurobi.com/academia/academic-program-and-licenses/) can be obtained from its official website. 22 | 23 | 24 | ### Data Preparation 25 | Download the [office31](https://faculty.cc.gatech.edu/~judy/domainadapt/), [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html), [VisDA](https://ai.bu.edu/visda-2017/), [DomainNet](http://ai.bu.edu/M3SDA/) datasets. 26 | 27 | Setup dataset path in ./data 28 | ```shell 29 | bash setup_data_path.sh /Path_to_data/office/domain_adaptation_images office31 30 | bash setup_data_path.sh /Path_to_data/office-home/images office-home 31 | bash setup_data_path.sh /Path_to_data/office-home/images office-home-rsut 32 | bash setup_data_path.sh /Path_to_data/VisDA visda 33 | bash setup_data_path.sh /Path_to_data/DomainNet domainnet40 34 | ``` 35 | 36 | ### kSHOT 37 | Unsupervised Closed-set Domain Adaptation (UDA) on the Office-Home dataset 38 | ```shell 39 | cd SHOT 40 | 41 | time=`python ../util/get_time.py` 42 | gpu_id=0 43 | 44 | # generate source models 45 | for src in "Product" "Clipart" "Art" "Real_World"; do 46 | echo $src 47 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office-home --max_epoch 50 --s $src --timestamp $time 48 | done 49 | 50 | # adapt to other target domains with Unary Bound prior knowledge 51 | for seed in 2020 2021 2022; do 52 | for src in "Product" "Clipart" "Art" "Real_World"; do 53 | echo $src 54 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --pk_uconf 0.0 --seed $seed --pk_type ub 55 | done 56 | done 57 | ``` 58 | 59 | ### kDINE 60 | Unsupervised Closed-set Domain Adaptation (UDA) on the Office-Home dataset 61 | ```shell 62 | cd DINE 63 | 64 | time=`python ./get_time.py` 65 | gpu=0 66 | 67 | for seed in 2020 2021 2022; do 68 | for src in 'Product' 'Real_World' 'Art' 'Clipart' ; do 69 | echo $src 70 | # training the source model first 71 | python DINE_dist.py --gpu_id $gpu --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 50 --timestamp $time 72 | # the first step (Distill) with Unary Bound prior knowledge 73 | python DINE_dist_kDINE.py --gpu_id $gpu --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf 0.0 74 | # the second step (Finetune) 75 | python DINE_ft.py --gpu_id $gpu --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine 76 | done 77 | done 78 | ``` 79 | Complete commands are available in ./SHOT/run_all_kSHOT.sh and ./DINE/run_all_kDINE.sh. 80 | 81 | ## Acknowledgements 82 | The implementations are adapted from [SHOT](https://github.com/tim-learn/SHOT) and 83 | [DINE](https://github.com/tim-learn/DINE). 84 | 85 | 86 | ## Citation 87 | If you find our paper and code useful for your research, please consider citing 88 | ```bibtex 89 | @inproceedings{sun2022prior, 90 | author = {Sun, Tao and Lu, Cheng and Ling, Haibin}, 91 | title = {Prior Knowledge Guided Unsupervised Domain Adaptation}, 92 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)}, 93 | year = {2022} 94 | } 95 | ``` -------------------------------------------------------------------------------- /SHOT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/SHOT/__init__.py -------------------------------------------------------------------------------- /SHOT/augmentations.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | 10 | def ShearX(img, v): # [-0.3, 0.3] 11 | assert -0.3 <= v <= 0.3 12 | if random.random() > 0.5: 13 | v = -v 14 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 15 | 16 | 17 | def ShearY(img, v): # [-0.3, 0.3] 18 | assert -0.3 <= v <= 0.3 19 | if random.random() > 0.5: 20 | v = -v 21 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 22 | 23 | 24 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 25 | assert -0.45 <= v <= 0.45 26 | if random.random() > 0.5: 27 | v = -v 28 | v = v * img.size[0] 29 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 30 | 31 | 32 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 33 | assert 0 <= v 34 | if random.random() > 0.5: 35 | v = -v 36 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 37 | 38 | 39 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 40 | assert -0.45 <= v <= 0.45 41 | if random.random() > 0.5: 42 | v = -v 43 | v = v * img.size[1] 44 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 45 | 46 | 47 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 48 | assert 0 <= v 49 | if random.random() > 0.5: 50 | v = -v 51 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 52 | 53 | 54 | def Rotate(img, v): # [-30, 30] 55 | assert -30 <= v <= 30 56 | if random.random() > 0.5: 57 | v = -v 58 | return img.rotate(v) 59 | 60 | 61 | def AutoContrast(img, _): 62 | return PIL.ImageOps.autocontrast(img) 63 | 64 | 65 | def Invert(img, _): 66 | return PIL.ImageOps.invert(img) 67 | 68 | 69 | def Equalize(img, _): 70 | return PIL.ImageOps.equalize(img) 71 | 72 | 73 | def Flip(img, _): # not from the paper 74 | return PIL.ImageOps.mirror(img) 75 | 76 | 77 | def Solarize(img, v): # [0, 256] 78 | assert 0 <= v <= 256 79 | return PIL.ImageOps.solarize(img, v) 80 | 81 | 82 | def SolarizeAdd(img, addition=0, threshold=128): 83 | img_np = np.array(img).astype(np.int) 84 | img_np = img_np + addition 85 | img_np = np.clip(img_np, 0, 255) 86 | img_np = img_np.astype(np.uint8) 87 | img = Image.fromarray(img_np) 88 | return PIL.ImageOps.solarize(img, threshold) 89 | 90 | 91 | def Posterize(img, v): # [4, 8] 92 | v = int(v) 93 | v = max(1, v) 94 | return PIL.ImageOps.posterize(img, v) 95 | 96 | 97 | def Contrast(img, v): # [0.1,1.9] 98 | assert 0.1 <= v <= 1.9 99 | return PIL.ImageEnhance.Contrast(img).enhance(v) 100 | 101 | 102 | def Color(img, v): # [0.1,1.9] 103 | assert 0.1 <= v <= 1.9 104 | return PIL.ImageEnhance.Color(img).enhance(v) 105 | 106 | 107 | def Brightness(img, v): # [0.1,1.9] 108 | assert 0.1 <= v <= 1.9 109 | return PIL.ImageEnhance.Brightness(img).enhance(v) 110 | 111 | 112 | def Sharpness(img, v): # [0.1,1.9] 113 | assert 0.1 <= v <= 1.9 114 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 115 | 116 | 117 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 118 | assert 0.0 <= v <= 0.2 119 | if v <= 0.: 120 | return img 121 | 122 | v = v * img.size[0] 123 | return CutoutAbs(img, v) 124 | 125 | 126 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 127 | # assert 0 <= v <= 20 128 | if v < 0: 129 | return img 130 | w, h = img.size 131 | x0 = np.random.uniform(w) 132 | y0 = np.random.uniform(h) 133 | 134 | x0 = int(max(0, x0 - v / 2.)) 135 | y0 = int(max(0, y0 - v / 2.)) 136 | x1 = min(w, x0 + v) 137 | y1 = min(h, y0 + v) 138 | 139 | xy = (x0, y0, x1, y1) 140 | color = (125, 123, 114) 141 | # color = (0, 0, 0) 142 | img = img.copy() 143 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 144 | return img 145 | 146 | 147 | def SamplePairing(imgs): # [0, 0.4] 148 | def f(img1, v): 149 | i = np.random.choice(len(imgs)) 150 | img2 = PIL.Image.fromarray(imgs[i]) 151 | return PIL.Image.blend(img1, img2, v) 152 | 153 | return f 154 | 155 | 156 | def Identity(img, v): 157 | return img 158 | 159 | 160 | def augment_list(): # 16 oeprations and their ranges 161 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 162 | # l = [ 163 | # (Identity, 0., 1.0), 164 | # (ShearX, 0., 0.3), # 0 165 | # (ShearY, 0., 0.3), # 1 166 | # (TranslateX, 0., 0.33), # 2 167 | # (TranslateY, 0., 0.33), # 3 168 | # (Rotate, 0, 30), # 4 169 | # (AutoContrast, 0, 1), # 5 170 | # (Invert, 0, 1), # 6 171 | # (Equalize, 0, 1), # 7 172 | # (Solarize, 0, 110), # 8 173 | # (Posterize, 4, 8), # 9 174 | # # (Contrast, 0.1, 1.9), # 10 175 | # (Color, 0.1, 1.9), # 11 176 | # (Brightness, 0.1, 1.9), # 12 177 | # (Sharpness, 0.1, 1.9), # 13 178 | # # (Cutout, 0, 0.2), # 14 179 | # # (SamplePairing(imgs), 0, 0.4), # 15 180 | # ] 181 | 182 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 183 | l = [ 184 | (AutoContrast, 0, 1), 185 | (Equalize, 0, 1), 186 | (Invert, 0, 1), 187 | (Rotate, 0, 30), 188 | (Posterize, 0, 4), 189 | (Solarize, 0, 256), 190 | (SolarizeAdd, 0, 110), 191 | (Color, 0.1, 1.9), 192 | (Contrast, 0.1, 1.9), 193 | (Brightness, 0.1, 1.9), 194 | (Sharpness, 0.1, 1.9), 195 | (ShearX, 0., 0.3), 196 | (ShearY, 0., 0.3), 197 | (CutoutAbs, 0, 40), 198 | (TranslateXabs, 0., 100), 199 | (TranslateYabs, 0., 100), 200 | ] 201 | 202 | return l 203 | 204 | 205 | class Lighting(object): 206 | """Lighting noise(AlexNet - style PCA - based noise)""" 207 | 208 | def __init__(self, alphastd, eigval, eigvec): 209 | self.alphastd = alphastd 210 | self.eigval = torch.Tensor(eigval) 211 | self.eigvec = torch.Tensor(eigvec) 212 | 213 | def __call__(self, img): 214 | if self.alphastd == 0: 215 | return img 216 | 217 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 218 | rgb = self.eigvec.type_as(img).clone() \ 219 | .mul(alpha.view(1, 3).expand(3, 3)) \ 220 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 221 | .sum(1).squeeze() 222 | 223 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 224 | 225 | 226 | class CutoutDefault(object): 227 | """ 228 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 229 | """ 230 | def __init__(self, length): 231 | self.length = length 232 | 233 | def __call__(self, img): 234 | h, w = img.size(1), img.size(2) 235 | mask = np.ones((h, w), np.float32) 236 | y = np.random.randint(h) 237 | x = np.random.randint(w) 238 | 239 | y1 = np.clip(y - self.length // 2, 0, h) 240 | y2 = np.clip(y + self.length // 2, 0, h) 241 | x1 = np.clip(x - self.length // 2, 0, w) 242 | x2 = np.clip(x + self.length // 2, 0, w) 243 | 244 | mask[y1: y2, x1: x2] = 0. 245 | mask = torch.from_numpy(mask) 246 | mask = mask.expand_as(img) 247 | img *= mask 248 | return img 249 | 250 | 251 | class RandAugment: 252 | def __init__(self, n, m): 253 | self.n = n 254 | self.m = m # [0, 30] 255 | self.augment_list = augment_list() 256 | 257 | def __call__(self, img): 258 | 259 | if self.n == 0: 260 | return img 261 | 262 | ops = random.choices(self.augment_list, k=self.n) 263 | for op, minval, maxval in ops: 264 | val = (float(self.m) / 30) * float(maxval - minval) + minval 265 | img = op(img, val) 266 | 267 | return img -------------------------------------------------------------------------------- /SHOT/data_list.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import os 5 | from augmentations import RandAugment 6 | import copy 7 | 8 | def make_dataset(image_list, labels): 9 | if labels: 10 | len_ = len(image_list) 11 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 12 | else: 13 | if len(image_list[0].split()) > 2: 14 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 15 | else: 16 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 17 | return images 18 | 19 | 20 | def rgb_loader(path): 21 | with open(path, 'rb') as f: 22 | with Image.open(f) as img: 23 | return img.convert('RGB') 24 | 25 | def l_loader(path): 26 | with open(path, 'rb') as f: 27 | with Image.open(f) as img: 28 | return img.convert('L') 29 | 30 | class ImageList(Dataset): 31 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'): 32 | imgs = make_dataset(image_list, labels) 33 | if len(imgs) == 0: 34 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 35 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 36 | 37 | self.root = root 38 | self.imgs = imgs 39 | self.transform = transform 40 | self.target_transform = target_transform 41 | if mode == 'RGB': 42 | self.loader = rgb_loader 43 | elif mode == 'L': 44 | self.loader = l_loader 45 | 46 | def __getitem__(self, index): 47 | path, target = self.imgs[index] 48 | path = os.path.join(self.root, path) 49 | img = self.loader(path) 50 | if self.transform is not None: 51 | img = self.transform(img) 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return len(self.imgs) 59 | 60 | class ImageList_idx_aug(Dataset): 61 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB', 62 | rand_aug_size=0, rand_aug_n=2, rand_aug_m=2.): 63 | imgs = make_dataset(image_list, labels) 64 | if len(imgs) == 0: 65 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 67 | 68 | self.root = root 69 | self.imgs = imgs 70 | self.transform = transform 71 | self.target_transform = target_transform 72 | if mode == 'RGB': 73 | self.loader = rgb_loader 74 | elif mode == 'L': 75 | self.loader = l_loader 76 | 77 | self.rand_aug_size = rand_aug_size 78 | 79 | if self.rand_aug_size > 0: 80 | self.rand_aug_transform = copy.deepcopy(self.transform) 81 | self.rand_aug_transform.transforms.insert(0, RandAugment(rand_aug_n, rand_aug_m)) 82 | 83 | def __getitem__(self, index): 84 | path, target = self.imgs[index] 85 | path = os.path.join(self.root, path) 86 | img = self.loader(path) 87 | img_ = self.loader(path) 88 | if self.transform is not None: 89 | img = self.transform(img) 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | rand_imgs = [self.rand_aug_transform(img_) for _ in range(self.rand_aug_size)] 94 | return img, target, index, rand_imgs 95 | 96 | def __len__(self): 97 | return len(self.imgs) 98 | 99 | 100 | class ImageList_idx(Dataset): 101 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'): 102 | imgs = make_dataset(image_list, labels) 103 | if len(imgs) == 0: 104 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 105 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 106 | 107 | self.root = root 108 | self.imgs = imgs 109 | self.transform = transform 110 | self.target_transform = target_transform 111 | if mode == 'RGB': 112 | self.loader = rgb_loader 113 | elif mode == 'L': 114 | self.loader = l_loader 115 | 116 | def __getitem__(self, index): 117 | path, target = self.imgs[index] 118 | path = os.path.join(self.root, path) 119 | img = self.loader(path) 120 | if self.transform is not None: 121 | img = self.transform(img) 122 | if self.target_transform is not None: 123 | target = self.target_transform(target) 124 | 125 | return img, target, index 126 | 127 | def __len__(self): 128 | return len(self.imgs) -------------------------------------------------------------------------------- /SHOT/image_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torchvision import transforms 9 | import network, loss 10 | from torch.utils.data import DataLoader 11 | from data_list import ImageList 12 | from loss import CrossEntropyLabelSmooth 13 | from sklearn.metrics import confusion_matrix 14 | from sklearn.cluster import KMeans 15 | import distutils 16 | import distutils.util 17 | import logging 18 | 19 | import sys 20 | sys.path.append("../util/") 21 | from utils import resetRNGseed, init_logger, get_hostname, get_pid 22 | 23 | import time 24 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 25 | 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | 30 | def op_copy(optimizer): 31 | for param_group in optimizer.param_groups: 32 | param_group['lr0'] = param_group['lr'] 33 | return optimizer 34 | 35 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 36 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 37 | for param_group in optimizer.param_groups: 38 | param_group['lr'] = param_group['lr0'] * decay 39 | param_group['weight_decay'] = 1e-3 40 | param_group['momentum'] = 0.9 41 | param_group['nesterov'] = True 42 | return optimizer 43 | 44 | def image_train(resize_size=256, crop_size=224, alexnet=False): 45 | if not alexnet: 46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 47 | std=[0.229, 0.224, 0.225]) 48 | else: 49 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 50 | return transforms.Compose([ 51 | transforms.Resize((resize_size, resize_size)), 52 | transforms.RandomCrop(crop_size), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | normalize 56 | ]) 57 | 58 | def image_test(resize_size=256, crop_size=224, alexnet=False): 59 | if not alexnet: 60 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 61 | std=[0.229, 0.224, 0.225]) 62 | else: 63 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 64 | return transforms.Compose([ 65 | transforms.Resize((resize_size, resize_size)), 66 | transforms.CenterCrop(crop_size), 67 | transforms.ToTensor(), 68 | normalize 69 | ]) 70 | 71 | def data_load(args): 72 | ## prepare data 73 | dsets = {} 74 | dset_loaders = {} 75 | train_bs = args.batch_size 76 | txt_src = open(args.s_dset_path).readlines() 77 | txt_test = open(args.test_dset_path).readlines() 78 | 79 | if not args.da == 'uda': 80 | label_map_s = {} 81 | for i in range(len(args.src_classes)): 82 | label_map_s[args.src_classes[i]] = i 83 | 84 | new_src = [] 85 | for i in range(len(txt_src)): 86 | rec = txt_src[i] 87 | reci = rec.strip().split(' ') 88 | if int(reci[1]) in args.src_classes: 89 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 90 | new_src.append(line) 91 | txt_src = new_src.copy() 92 | 93 | new_tar = [] 94 | for i in range(len(txt_test)): 95 | rec = txt_test[i] 96 | reci = rec.strip().split(' ') 97 | if int(reci[1]) in args.tar_classes: 98 | if int(reci[1]) in args.src_classes: 99 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 100 | new_tar.append(line) 101 | else: 102 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 103 | new_tar.append(line) 104 | txt_test = new_tar.copy() 105 | 106 | if args.trte == "val": 107 | dsize = len(txt_src) 108 | tr_size = int(0.9*dsize) 109 | # print(dsize, tr_size, dsize - tr_size) 110 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 111 | else: 112 | dsize = len(txt_src) 113 | tr_size = int(0.9*dsize) 114 | _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 115 | tr_txt = txt_src 116 | 117 | dsets["source_tr"] = ImageList(tr_txt, root="../data/{}/".format(args.dset), transform=image_train()) 118 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 119 | dsets["source_te"] = ImageList(te_txt, root="../data/{}/".format(args.dset), transform=image_test()) 120 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 121 | dsets["test"] = ImageList(txt_test, root="../data/{}/".format(args.dset), transform=image_test()) 122 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False) 123 | 124 | return dset_loaders 125 | 126 | def cal_acc(loader, netF, netB, netC, flag=False): 127 | start_test = True 128 | with torch.no_grad(): 129 | iter_test = iter(loader) 130 | for i in range(len(loader)): 131 | data = iter_test.next() 132 | inputs = data[0] 133 | labels = data[1] 134 | inputs = inputs.cuda() 135 | outputs = netC(netB(netF(inputs))) 136 | if start_test: 137 | all_output = outputs.float().cpu() 138 | all_label = labels.float() 139 | start_test = False 140 | else: 141 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 142 | all_label = torch.cat((all_label, labels.float()), 0) 143 | 144 | all_output = nn.Softmax(dim=1)(all_output) 145 | _, predict = torch.max(all_output, 1) 146 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 147 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 148 | 149 | if flag: 150 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 151 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 152 | aacc = acc.mean() 153 | aa = [str(np.round(i, 2)) for i in acc] 154 | acc = ' '.join(aa) 155 | return aacc, acc 156 | else: 157 | return accuracy*100, mean_ent 158 | 159 | def cal_acc_oda(loader, netF, netB, netC): 160 | start_test = True 161 | with torch.no_grad(): 162 | iter_test = iter(loader) 163 | for i in range(len(loader)): 164 | data = iter_test.next() 165 | inputs = data[0] 166 | labels = data[1] 167 | inputs = inputs.cuda() 168 | outputs = netC(netB(netF(inputs))) 169 | if start_test: 170 | all_output = outputs.float().cpu() 171 | all_label = labels.float() 172 | start_test = False 173 | else: 174 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 175 | all_label = torch.cat((all_label, labels.float()), 0) 176 | 177 | all_output = nn.Softmax(dim=1)(all_output) 178 | _, predict = torch.max(all_output, 1) 179 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 180 | ent = ent.float().cpu() 181 | initc = np.array([[0], [1]]) 182 | kmeans = KMeans(n_clusters=2, random_state=0, init=initc, n_init=1).fit(ent.reshape(-1,1)) 183 | threshold = (kmeans.cluster_centers_).mean() 184 | 185 | predict[ent>threshold] = args.class_num 186 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 187 | matrix = matrix[np.unique(all_label).astype(int),:] 188 | 189 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 190 | unknown_acc = acc[-1:].item() 191 | 192 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc 193 | # return np.mean(acc), np.mean(acc[:-1]) 194 | 195 | def train_source(args): 196 | dset_loaders = data_load(args) 197 | ## set base network 198 | if args.net[0:3] == 'res': 199 | netF = network.ResBase(res_name=args.net).cuda() 200 | elif args.net[0:3] == 'vgg': 201 | netF = network.VGGBase(vgg_name=args.net).cuda() 202 | 203 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 204 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 205 | 206 | param_group = [] 207 | learning_rate = args.lr 208 | for k, v in netF.named_parameters(): 209 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 210 | for k, v in netB.named_parameters(): 211 | param_group += [{'params': v, 'lr': learning_rate}] 212 | for k, v in netC.named_parameters(): 213 | param_group += [{'params': v, 'lr': learning_rate}] 214 | optimizer = optim.SGD(param_group) 215 | optimizer = op_copy(optimizer) 216 | 217 | acc_init = 0 218 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 219 | interval_iter = max_iter // 10 220 | iter_num = 0 221 | 222 | netF.train() 223 | netB.train() 224 | netC.train() 225 | 226 | while iter_num < max_iter: 227 | try: 228 | inputs_source, labels_source = iter_source.next() 229 | except: 230 | iter_source = iter(dset_loaders["source_tr"]) 231 | inputs_source, labels_source = iter_source.next() 232 | 233 | if inputs_source.size(0) == 1: 234 | continue 235 | 236 | iter_num += 1 237 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 238 | 239 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 240 | outputs_source = netC(netB(netF(inputs_source))) 241 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 242 | 243 | optimizer.zero_grad() 244 | classifier_loss.backward() 245 | optimizer.step() 246 | 247 | if iter_num % interval_iter == 0 or iter_num == max_iter: 248 | netF.eval() 249 | netB.eval() 250 | netC.eval() 251 | if args.dset=='visda-2017': 252 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True) 253 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te) + '\n' + acc_list 254 | else: 255 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False) 256 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te) 257 | # args.out_file.write(log_str + '\n') 258 | # args.out_file.flush() 259 | # print(log_str+'\n') 260 | logging.info(log_str) 261 | 262 | if acc_s_te >= acc_init: 263 | acc_init = acc_s_te 264 | best_netF = netF.state_dict() 265 | best_netB = netB.state_dict() 266 | best_netC = netC.state_dict() 267 | 268 | netF.train() 269 | netB.train() 270 | netC.train() 271 | 272 | torch.save(best_netF, osp.join(args.output_dir_src, "{}_{}_source_F.pt".format(args.s, args.net))) 273 | torch.save(best_netB, osp.join(args.output_dir_src, "{}_{}_source_B.pt".format(args.s, args.net))) 274 | torch.save(best_netC, osp.join(args.output_dir_src, "{}_{}_source_C.pt".format(args.s, args.net))) 275 | 276 | return netF, netB, netC 277 | 278 | def test_target(args): 279 | dset_loaders = data_load(args) 280 | ## set base network 281 | if args.net[0:3] == 'res': 282 | netF = network.ResBase(res_name=args.net).cuda() 283 | elif args.net[0:3] == 'vgg': 284 | netF = network.VGGBase(vgg_name=args.net).cuda() 285 | 286 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 287 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 288 | 289 | args.modelpath = osp.join(args.output_dir_src, '{}_{}_source_F.pt'.format(args.s, args.net)) 290 | netF.load_state_dict(torch.load(args.modelpath)) 291 | args.modelpath = osp.join(args.output_dir_src, '{}_{}_source_B.pt'.format(args.s, args.net)) 292 | netB.load_state_dict(torch.load(args.modelpath)) 293 | args.modelpath = osp.join(args.output_dir_src, '{}_{}_source_C.pt'.format(args.s, args.net)) 294 | netC.load_state_dict(torch.load(args.modelpath)) 295 | netF.eval() 296 | netB.eval() 297 | netC.eval() 298 | 299 | if args.da == 'oda': 300 | acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC) 301 | log_str = '\nTraining: {}, Task: {}->{}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.s, args.t, acc_os2, acc_os1, acc_unknown) 302 | else: 303 | if args.dset=='visda-2017': 304 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 305 | log_str = '\nTraining: {}, Task: {}->{}, Accuracy = {:.2f}%'.format(args.trte, args.s, args.t, acc) + '\n' + acc_list 306 | else: 307 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 308 | log_str = '\nTraining: {}, Task: {}->{}, Accuracy = {:.2f}%'.format(args.trte, args.s, args.t, acc) 309 | 310 | # args.out_file.write(log_str) 311 | # args.out_file.flush() 312 | # print(log_str) 313 | logging.info(log_str) 314 | 315 | def print_args(args): 316 | s = "==========================================\n" 317 | for arg, content in args.__dict__.items(): 318 | s += "{}:{}\n".format(arg, content) 319 | return s 320 | 321 | if __name__ == "__main__": 322 | parser = argparse.ArgumentParser(description='SHOT') 323 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 324 | parser.add_argument('--s', type=str, default=None, help="source") 325 | parser.add_argument('--t', type=str, default=None, help="target") 326 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 327 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 328 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 329 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'domainnet40', 'office31', 330 | 'office-home', 'office-home-rsut', 'office-caltech', 'multi']) 331 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 332 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 333 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 334 | parser.add_argument('--bottleneck', type=int, default=256) 335 | parser.add_argument('--epsilon', type=float, default=1e-5) 336 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 337 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 338 | parser.add_argument('--smooth', type=float, default=0.1) 339 | parser.add_argument('--output', type=str, default='san') 340 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda', 'oda']) 341 | parser.add_argument('--trte', type=str, default='val', choices=['full', 'val']) 342 | 343 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp') 344 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)), 345 | help='whether use file logger') 346 | parser.add_argument('--names', default=[], type=list, help='names of tasks') 347 | 348 | 349 | args = parser.parse_args() 350 | 351 | if args.dset == 'office-home': 352 | args.names = ['Art', 'Clipart', 'Product', 'Real_World'] 353 | args.class_num = 65 354 | if args.dset == 'office-home-rsut': 355 | args.names = ['Clipart', 'Product', 'Real_World'] 356 | args.class_num = 65 357 | if args.dset == 'domainnet40': 358 | args.names = ['sketch', 'clipart', 'painting', 'real'] 359 | args.class_num = 40 360 | if args.dset == 'multi': 361 | args.names = ['real', 'clipart', 'sketch', 'painting'] 362 | args.class_num = 126 363 | if args.dset == 'office31': 364 | args.names = ['amazon', 'dslr', 'webcam'] 365 | args.class_num = 31 366 | if args.dset == 'visda-2017': 367 | args.names = ['train', 'validation'] 368 | args.class_num = 12 369 | if args.dset == 'office-caltech': 370 | args.names = ['amazon', 'caltech', 'dslr', 'webcam'] 371 | args.class_num = 10 372 | 373 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 374 | resetRNGseed(args.seed) 375 | 376 | if args.dset == 'office-home-rsut': 377 | args.s += '_RS' 378 | 379 | dir = "{}_{}_{}_source".format(args.timestamp, args.s, args.da) 380 | if args.use_file_logger: 381 | init_logger(dir, True, '../logs/SHOT/source/') 382 | logging.info("{}:{}".format(get_hostname(), get_pid())) 383 | 384 | folder = '../data/' 385 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 386 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 387 | 388 | if args.dset == 'domainnet40': 389 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt' 390 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '_test_mini.txt' 391 | 392 | if args.dset == 'office-home': 393 | if args.da == 'pda': 394 | args.class_num = 65 395 | args.src_classes = [i for i in range(65)] 396 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 397 | if args.da == 'oda': 398 | args.class_num = 25 399 | args.src_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 400 | args.tar_classes = [i for i in range(65)] 401 | 402 | # args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()) 403 | # args.name_src = names[args.s][0].upper() 404 | args.output_dir_src = "../checkpoints/SHOT/source/{}/".format(args.da) 405 | 406 | if not osp.exists(args.output_dir_src): 407 | os.system('mkdir -p ' + args.output_dir_src) 408 | if not osp.exists(args.output_dir_src): 409 | os.mkdir(args.output_dir_src) 410 | 411 | # args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 412 | # args.out_file.write(print_args(args)+'\n') 413 | # args.out_file.flush() 414 | 415 | logging.info(print_args(args)) 416 | train_source(args) 417 | 418 | # args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 419 | for t in args.names: 420 | if t == args.s or t == args.s.split('_RS')[0]: 421 | continue 422 | args.t = t 423 | # args.name = args.names[args.s][0].upper() + args.t[0].upper() 424 | 425 | if args.dset == 'office-home-rsut': 426 | args.t += '_UT' 427 | 428 | folder = '../data/' 429 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 430 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 431 | 432 | if args.dset == 'domainnet40': 433 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt' 434 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '_test_mini.txt' 435 | 436 | if args.dset == 'office-home': 437 | if args.da == 'pda': 438 | args.class_num = 65 439 | args.src_classes = [i for i in range(65)] 440 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 441 | if args.da == 'oda': 442 | args.class_num = 25 443 | args.src_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 444 | args.tar_classes = [i for i in range(65)] 445 | 446 | test_target(args) -------------------------------------------------------------------------------- /SHOT/image_target.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | import distutils 18 | import distutils.util 19 | import logging 20 | 21 | import sys 22 | sys.path.append("../util/") 23 | from utils import resetRNGseed, init_logger, get_hostname, get_pid 24 | 25 | import time 26 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 27 | 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | 32 | def op_copy(optimizer): 33 | for param_group in optimizer.param_groups: 34 | param_group['lr0'] = param_group['lr'] 35 | return optimizer 36 | 37 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 38 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 39 | for param_group in optimizer.param_groups: 40 | param_group['lr'] = param_group['lr0'] * decay 41 | param_group['weight_decay'] = 1e-3 42 | param_group['momentum'] = 0.9 43 | param_group['nesterov'] = True 44 | return optimizer 45 | 46 | def image_train(resize_size=256, crop_size=224, alexnet=False): 47 | if not alexnet: 48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | else: 51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 52 | return transforms.Compose([ 53 | transforms.Resize((resize_size, resize_size)), 54 | transforms.RandomCrop(crop_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | normalize 58 | ]) 59 | 60 | def image_test(resize_size=256, crop_size=224, alexnet=False): 61 | if not alexnet: 62 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225]) 64 | else: 65 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 66 | return transforms.Compose([ 67 | transforms.Resize((resize_size, resize_size)), 68 | transforms.CenterCrop(crop_size), 69 | transforms.ToTensor(), 70 | normalize 71 | ]) 72 | 73 | def data_load(args): 74 | ## prepare data 75 | dsets = {} 76 | dset_loaders = {} 77 | train_bs = args.batch_size 78 | txt_tar = open(args.t_dset_path).readlines() 79 | txt_test = open(args.test_dset_path).readlines() 80 | 81 | if not args.da == 'uda': 82 | label_map_s = {} 83 | for i in range(len(args.src_classes)): 84 | label_map_s[args.src_classes[i]] = i 85 | 86 | new_tar = [] 87 | for i in range(len(txt_tar)): 88 | rec = txt_tar[i] 89 | reci = rec.strip().split(' ') 90 | if int(reci[1]) in args.tar_classes: 91 | if int(reci[1]) in args.src_classes: 92 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 93 | new_tar.append(line) 94 | else: 95 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 96 | new_tar.append(line) 97 | txt_tar = new_tar.copy() 98 | txt_test = txt_tar.copy() 99 | 100 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train()) 101 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 102 | 103 | dsets["test"] = ImageList_idx(txt_test, root="../data/{}/".format(args.dset), transform=image_test()) 104 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 105 | dsets["valid"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train() if args.use_train_transform else image_test()) 106 | dset_loaders["valid"] = DataLoader(dsets["valid"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, drop_last=False) 107 | 108 | return dset_loaders 109 | 110 | def cal_acc(loader, netF, netB, netC, flag=False): 111 | start_test = True 112 | with torch.no_grad(): 113 | iter_test = iter(loader) 114 | for i in range(len(loader)): 115 | data = iter_test.next() 116 | inputs = data[0] 117 | labels = data[1] 118 | inputs = inputs.cuda() 119 | outputs = netC(netB(netF(inputs))) 120 | if start_test: 121 | all_output = outputs.float().cpu() 122 | all_label = labels.float() 123 | start_test = False 124 | else: 125 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 126 | all_label = torch.cat((all_label, labels.float()), 0) 127 | _, predict = torch.max(all_output, 1) 128 | acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100 129 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 130 | 131 | # if flag: 132 | # matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 133 | # acc = matrix.diagonal()/matrix.sum(axis=1) * 100 134 | # aacc = acc.mean() 135 | # aa = [str(np.round(i, 2)) for i in acc] 136 | # acc = ' '.join(aa) 137 | # return aacc, acc 138 | # else: 139 | # return accuracy*100, mean_ent 140 | 141 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 142 | acc_list = matrix.diagonal() / (matrix.sum(axis=1)+1e-12) * 100 143 | per_class_acc = acc_list.mean() 144 | if args.da == 'pda': 145 | acc_list = '' 146 | per_class_acc = 0 147 | return acc, mean_ent, per_class_acc, acc_list 148 | 149 | 150 | def train_target(args): 151 | dset_loaders = data_load(args) 152 | ## set base network 153 | if args.net[0:3] == 'res': 154 | netF = network.ResBase(res_name=args.net).cuda() 155 | elif args.net[0:3] == 'vgg': 156 | netF = network.VGGBase(vgg_name=args.net).cuda() 157 | 158 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 159 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 160 | 161 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_F.pt".format(args.s, args.net)) 162 | netF.load_state_dict(torch.load(modelpath)) 163 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_B.pt".format(args.s, args.net)) 164 | netB.load_state_dict(torch.load(modelpath)) 165 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_C.pt".format(args.s, args.net)) 166 | netC.load_state_dict(torch.load(modelpath)) 167 | netC.eval() 168 | for k, v in netC.named_parameters(): 169 | v.requires_grad = False 170 | 171 | param_group = [] 172 | for k, v in netF.named_parameters(): 173 | if args.lr_decay1 > 0: 174 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 175 | else: 176 | v.requires_grad = False 177 | for k, v in netB.named_parameters(): 178 | if args.lr_decay2 > 0: 179 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 180 | else: 181 | v.requires_grad = False 182 | 183 | optimizer = optim.SGD(param_group) 184 | optimizer = op_copy(optimizer) 185 | 186 | max_iter = args.max_epoch * len(dset_loaders["target"]) 187 | interval_iter = max_iter // args.interval 188 | iter_num = 0 189 | 190 | while iter_num < max_iter: 191 | try: 192 | inputs_test, _, tar_idx = iter_test.next() 193 | except: 194 | iter_test = iter(dset_loaders["target"]) 195 | inputs_test, _, tar_idx = iter_test.next() 196 | 197 | if inputs_test.size(0) == 1: 198 | continue 199 | 200 | if iter_num % interval_iter == 0 and args.cls_par > 0: 201 | netF.eval() 202 | netB.eval() 203 | mem_label = obtain_label(dset_loaders['valid'], netF, netB, netC, args) 204 | mem_label = torch.from_numpy(mem_label).cuda() 205 | netF.train() 206 | netB.train() 207 | 208 | if args.use_balanced_sampler: 209 | dset_loaders["target"].sampler.update(mem_label) 210 | 211 | inputs_test = inputs_test.cuda() 212 | 213 | iter_num += 1 214 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 215 | 216 | features_test = netB(netF(inputs_test)) 217 | outputs_test = netC(features_test) 218 | 219 | if args.cls_par > 0: 220 | pred = mem_label[tar_idx] 221 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 222 | classifier_loss *= args.cls_par 223 | if iter_num < interval_iter and args.dset == "visda-2017": 224 | classifier_loss *= 0 225 | else: 226 | classifier_loss = torch.tensor(0.0).cuda() 227 | 228 | if args.ent: 229 | softmax_out = nn.Softmax(dim=1)(outputs_test) 230 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 231 | if args.gent: 232 | msoftmax = softmax_out.mean(dim=0) 233 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 234 | entropy_loss -= gentropy_loss 235 | im_loss = entropy_loss * args.ent_par 236 | classifier_loss += im_loss 237 | 238 | optimizer.zero_grad() 239 | classifier_loss.backward() 240 | optimizer.step() 241 | 242 | if iter_num % interval_iter == 0 or iter_num == max_iter: 243 | netF.eval() 244 | netB.eval() 245 | 246 | if args.dset=='visda-2017': 247 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC) 248 | aa = [str(np.round(i, 2)) for i in acc_list] 249 | aa = ' '.join(aa) 250 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc) + '\n' + aa 251 | else: 252 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC) 253 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc) 254 | 255 | # args.out_file.write(log_str + '\n') 256 | # args.out_file.flush() 257 | # print(log_str+'\n') 258 | logging.info(log_str) 259 | netF.train() 260 | netB.train() 261 | 262 | if args.issave: 263 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt")) 264 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt")) 265 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt")) 266 | 267 | return netF, netB, netC 268 | 269 | def print_args(args): 270 | s = "==========================================\n" 271 | for arg, content in args.__dict__.items(): 272 | s += "{}:{}\n".format(arg, content) 273 | return s 274 | 275 | def obtain_label(loader, netF, netB, netC, args): 276 | start_test = True 277 | with torch.no_grad(): 278 | iter_test = iter(loader) 279 | for _ in range(len(loader)): 280 | data = iter_test.next() 281 | inputs = data[0] 282 | labels = data[1] 283 | inputs = inputs.cuda() 284 | feas = netB(netF(inputs)) 285 | outputs = netC(feas) 286 | if start_test: 287 | all_fea = feas.float().cpu() 288 | all_output = outputs.float().cpu() 289 | all_label = labels.float() 290 | start_test = False 291 | else: 292 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 293 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 294 | all_label = torch.cat((all_label, labels.float()), 0) 295 | 296 | all_output = nn.Softmax(dim=1)(all_output) 297 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 298 | unknown_weight = 1 - ent / np.log(args.class_num) 299 | _, predict = torch.max(all_output, 1) 300 | 301 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 302 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 303 | 304 | acc_list = matrix.diagonal() / (matrix.sum(axis=1)+1e-12) 305 | avg_accuracy = (acc_list).mean() 306 | if args.da == 'pda': 307 | acc_list = '' 308 | avg_accuracy = 0 309 | 310 | if args.distance == 'cosine': 311 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 312 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 313 | 314 | all_fea = all_fea.float().cpu().numpy() 315 | K = all_output.size(1) 316 | aff = all_output.float().cpu().numpy() 317 | initc = aff.transpose().dot(all_fea) 318 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 319 | cls_count = np.eye(K)[predict].sum(axis=0) 320 | labelset = np.where(cls_count>args.threshold) 321 | labelset = labelset[0] 322 | # print(labelset) 323 | 324 | dd = cdist(all_fea, initc[labelset], args.distance) 325 | pred_label = dd.argmin(axis=1) 326 | pred_label = labelset[pred_label] 327 | 328 | for round in range(1): 329 | aff = np.eye(K)[pred_label] 330 | initc = aff.transpose().dot(all_fea) 331 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 332 | dd = cdist(all_fea, initc[labelset], args.distance) 333 | pred_label = dd.argmin(axis=1) 334 | pred_label = labelset[pred_label] 335 | 336 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 337 | matrix = confusion_matrix(all_label.float().numpy(), pred_label) 338 | acc_list = matrix.diagonal() / (matrix.sum(axis=1)+1e-12) 339 | avg_acc = acc_list.mean() 340 | if args.da == 'pda': 341 | acc_list = '' 342 | avg_acc = 0 343 | log_str = 'Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100, avg_accuracy * 100, avg_acc * 100) 344 | 345 | # args.out_file.write(log_str + '\n') 346 | # args.out_file.flush() 347 | # print(log_str+'\n') 348 | logging.info(log_str) 349 | 350 | return pred_label.astype('int') 351 | 352 | 353 | if __name__ == "__main__": 354 | parser = argparse.ArgumentParser(description='SHOT') 355 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 356 | parser.add_argument('--s', type=str, default=None, help="source") 357 | parser.add_argument('--t', type=str, default=None, help="target") 358 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 359 | parser.add_argument('--interval', type=int, default=15) 360 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 361 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 362 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'domainnet40', 'office31', 363 | 'office-home', 'office-home-rsut', 'office-caltech', 'multi']) 364 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 365 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet50, res101") 366 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 367 | 368 | parser.add_argument('--gent', type=bool, default=True) 369 | parser.add_argument('--ent', type=bool, default=True) 370 | parser.add_argument('--threshold', type=int, default=0) 371 | parser.add_argument('--cls_par', type=float, default=0.3) 372 | parser.add_argument('--ent_par', type=float, default=1.0) 373 | parser.add_argument('--lr_decay1', type=float, default=0.1) 374 | parser.add_argument('--lr_decay2', type=float, default=1.0) 375 | 376 | parser.add_argument('--bottleneck', type=int, default=256) 377 | parser.add_argument('--epsilon', type=float, default=1e-5) 378 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 379 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 380 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 381 | parser.add_argument('--output', type=str, default='san') 382 | parser.add_argument('--output_src', type=str, default='san') 383 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 384 | parser.add_argument('--issave', type=bool, default=True) 385 | 386 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp') 387 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)), 388 | help='whether use file logger') 389 | parser.add_argument('--names', default=[], type=list, help='names of tasks') 390 | parser.add_argument('--use_train_transform', default='False', type=lambda x: bool(distutils.util.strtobool(x)), 391 | help='whether use train transform for label refinement') 392 | parser.add_argument('--use_balanced_sampler', default='False', type=lambda x: bool(distutils.util.strtobool(x)), 393 | help='whether use class balanced sampler') 394 | args = parser.parse_args() 395 | 396 | if args.dset == 'office-home': 397 | args.names = ['Art', 'Clipart', 'Product', 'Real_World'] 398 | args.class_num = 65 399 | if args.dset == 'office-home-rsut': 400 | args.names = ['Clipart', 'Product', 'Real_World'] 401 | args.class_num = 65 402 | if args.dset == 'domainnet40': 403 | args.names = ['sketch', 'clipart', 'painting', 'real'] 404 | args.class_num = 40 405 | if args.dset == 'multi': 406 | args.names = ['real', 'clipart', 'sketch', 'painting'] 407 | args.class_num = 126 408 | if args.dset == 'office31': 409 | args.names = ['amazon', 'dslr', 'webcam'] 410 | args.class_num = 31 411 | if args.dset == 'visda-2017': 412 | args.names = ['train', 'validation'] 413 | args.class_num = 12 414 | if args.dset == 'office-caltech': 415 | args.names = ['amazon', 'caltech', 'dslr', 'webcam'] 416 | args.class_num = 10 417 | 418 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 419 | resetRNGseed(args.seed) 420 | 421 | if args.dset == 'office-home-rsut': 422 | args.s += '_RS' 423 | 424 | dir = "{}_{}_{}".format(args.timestamp, args.s, args.da) 425 | if args.use_file_logger: 426 | init_logger(dir, True, '../logs/SHOT/shot/') 427 | logging.info("{}:{}".format(get_hostname(), get_pid())) 428 | 429 | for t in args.names: 430 | if t == args.s or t == args.s.split('_RS')[0]: 431 | continue 432 | args.t = t 433 | 434 | if args.dset == 'office-home-rsut': 435 | args.t += '_UT' 436 | 437 | folder = '../data/' 438 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 439 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 440 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 441 | 442 | if args.dset == 'domainnet40': 443 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt' 444 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '_train_mini.txt' 445 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '_test_mini.txt' 446 | 447 | if args.dset == 'office-home': 448 | if args.da == 'pda': 449 | args.class_num = 65 450 | args.src_classes = [i for i in range(65)] 451 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 452 | 453 | # args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 454 | # args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 455 | # args.name = names[args.s][0].upper()+names[args.t][0].upper() 456 | args.output_dir_src = "../checkpoints/SHOT/source/{}/".format(args.da) 457 | args.output_dir = "../checkpoints/SHOT/target/{}/".format(args.da) 458 | 459 | if not osp.exists(args.output_dir): 460 | os.system('mkdir -p ' + args.output_dir) 461 | if not osp.exists(args.output_dir): 462 | os.mkdir(args.output_dir) 463 | 464 | args.savename = 'par_' + str(args.cls_par) 465 | if args.da == 'pda': 466 | args.gent = '' 467 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 468 | # args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 469 | # args.out_file.write(print_args(args)+'\n') 470 | # args.out_file.flush() 471 | logging.info(print_args(args)) 472 | train_target(args) -------------------------------------------------------------------------------- /SHOT/image_target_kSHOT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torchvision import transforms 8 | import network, loss 9 | from torch.utils.data import DataLoader 10 | from data_list import ImageList_idx 11 | from scipy.spatial.distance import cdist 12 | from sklearn.metrics import confusion_matrix 13 | import distutils 14 | import distutils.util 15 | import logging 16 | 17 | import sys, os 18 | sys.path.append("../util/") 19 | from utils import resetRNGseed, init_logger, get_hostname, get_pid 20 | 21 | sys.path.append("../pklib") 22 | from pksolver import PK_solver 23 | 24 | import time 25 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 26 | 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | 31 | def op_copy(optimizer): 32 | for param_group in optimizer.param_groups: 33 | param_group['lr0'] = param_group['lr'] 34 | return optimizer 35 | 36 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 37 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 38 | for param_group in optimizer.param_groups: 39 | param_group['lr'] = param_group['lr0'] * decay 40 | param_group['weight_decay'] = 1e-3 41 | param_group['momentum'] = 0.9 42 | param_group['nesterov'] = True 43 | return optimizer 44 | 45 | def image_train(resize_size=256, crop_size=224, alexnet=False): 46 | if not alexnet: 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | else: 50 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 51 | return transforms.Compose([ 52 | transforms.Resize((resize_size, resize_size)), 53 | transforms.RandomCrop(crop_size), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | def image_test(resize_size=256, crop_size=224, alexnet=False): 60 | if not alexnet: 61 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | else: 64 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 65 | return transforms.Compose([ 66 | transforms.Resize((resize_size, resize_size)), 67 | transforms.CenterCrop(crop_size), 68 | transforms.ToTensor(), 69 | normalize 70 | ]) 71 | 72 | def data_load(args): 73 | ## prepare data 74 | dsets = {} 75 | dset_loaders = {} 76 | train_bs = args.batch_size 77 | txt_tar = open(args.t_dset_path).readlines() 78 | txt_test = open(args.test_dset_path).readlines() 79 | 80 | if not args.da == 'uda': 81 | label_map_s = {} 82 | for i in range(len(args.src_classes)): 83 | label_map_s[args.src_classes[i]] = i 84 | 85 | new_tar = [] 86 | for i in range(len(txt_tar)): 87 | rec = txt_tar[i] 88 | reci = rec.strip().split(' ') 89 | if int(reci[1]) in args.tar_classes: 90 | if int(reci[1]) in args.src_classes: 91 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 92 | new_tar.append(line) 93 | else: 94 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 95 | new_tar.append(line) 96 | txt_tar = new_tar.copy() 97 | txt_test = txt_tar.copy() 98 | 99 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train()) 100 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 101 | 102 | dsets["test"] = ImageList_idx(txt_test, root="../data/{}/".format(args.dset), transform=image_test()) 103 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 104 | dsets["valid"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train() if args.use_train_transform else image_test()) 105 | dset_loaders["valid"] = DataLoader(dsets["valid"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, drop_last=False) 106 | 107 | return dset_loaders 108 | 109 | def cal_acc(loader, netF, netB, netC): 110 | start_test = True 111 | with torch.no_grad(): 112 | iter_test = iter(loader) 113 | for i in range(len(loader)): 114 | data = iter_test.next() 115 | inputs = data[0] 116 | labels = data[1] 117 | inputs = inputs.cuda() 118 | outputs = netC(netB(netF(inputs))) 119 | if start_test: 120 | all_output = outputs.float().cpu() 121 | all_label = labels.float() 122 | start_test = False 123 | else: 124 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 125 | all_label = torch.cat((all_label, labels.float()), 0) 126 | _, predict = torch.max(all_output, 1) 127 | acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100 128 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 129 | 130 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 131 | acc_list = matrix.diagonal() / matrix.sum(axis=1) * 100 132 | per_class_acc = acc_list.mean() 133 | if args.da == 'pda': 134 | acc_list = '' 135 | per_class_acc = 0 136 | acc_list = ' '.join([str(np.round(i, 2)) for i in acc_list]) 137 | return acc, mean_ent, per_class_acc, acc_list 138 | 139 | 140 | def train_target(args): 141 | dset_loaders = data_load(args) 142 | ## set base network 143 | if args.net[0:3] == 'res': 144 | netF = network.ResBase(res_name=args.net).cuda() 145 | elif args.net[0:3] == 'vgg': 146 | netF = network.VGGBase(vgg_name=args.net).cuda() 147 | 148 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 149 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 150 | 151 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_F.pt".format(args.s, args.net)) 152 | netF.load_state_dict(torch.load(modelpath)) 153 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_B.pt".format(args.s, args.net)) 154 | netB.load_state_dict(torch.load(modelpath)) 155 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_C.pt".format(args.s, args.net)) 156 | netC.load_state_dict(torch.load(modelpath)) 157 | netC.eval() 158 | for k, v in netC.named_parameters(): 159 | v.requires_grad = False 160 | 161 | param_group = [] 162 | for k, v in netF.named_parameters(): 163 | if args.lr_decay1 > 0: 164 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 165 | else: 166 | v.requires_grad = False 167 | for k, v in netB.named_parameters(): 168 | if args.lr_decay2 > 0: 169 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 170 | else: 171 | v.requires_grad = False 172 | 173 | optimizer = optim.SGD(param_group) 174 | optimizer = op_copy(optimizer) 175 | 176 | max_iter = args.max_epoch * len(dset_loaders["target"]) 177 | interval_iter = max_iter // args.interval 178 | iter_num = 0 179 | 180 | # get ground-truth label probabilities of target domain 181 | start = True 182 | iter_valid = iter(dset_loaders['valid']) 183 | for _ in range(len(dset_loaders['valid'])): 184 | data = iter_valid.next() 185 | labels = data[1] 186 | if start: 187 | all_label = labels.long() 188 | start = False 189 | else: 190 | all_label = torch.cat((all_label, labels.long()), 0) 191 | 192 | cls_probs = torch.eye(args.class_num)[all_label].sum(0) 193 | cls_probs = cls_probs / cls_probs.sum() 194 | 195 | if args.pk_dratio < 1.0: 196 | ND = int(len(all_label)*args.pk_dratio) 197 | cls_probs_sample = torch.eye(args.class_num)[all_label[torch.randint(len(all_label), (ND,))]].sum(0) 198 | cls_probs_sample = cls_probs_sample / cls_probs_sample.sum() 199 | err = (cls_probs_sample-cls_probs)/cls_probs 200 | logging.info('True probs: {}'.format(cls_probs)) 201 | logging.info('Sample probs: {}'.format(cls_probs_sample)) 202 | logging.info('Probs err: {}, max err: {}, mean err{}'.format(err, err.abs().max(), err.abs().mean())) 203 | cls_probs = cls_probs_sample 204 | 205 | pk_solver = PK_solver(all_label.shape[0], args.class_num, pk_prior_weight=args.pk_prior_weight) 206 | if args.pk_type == 'ub': 207 | pk_solver.create_C_ub(cls_probs.cpu().numpy(), args.pk_uconf) 208 | elif args.pk_type == 'br': 209 | pk_solver.create_C_br(cls_probs.cpu().numpy(), args.pk_uconf) 210 | elif args.pk_type == 'ub+rel': 211 | pk_solver.create_C_ub(cls_probs.cpu().numpy(), args.pk_uconf) 212 | pk_solver.create_C_br(cls_probs.cpu().numpy(), 1.0) 213 | elif args.pk_type == 'ub_partial': 214 | pk_solver.create_C_ub_partial(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC) 215 | elif args.pk_type == 'ub_partial_reverse': 216 | pk_solver.create_C_ub_partial_reverse(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC) 217 | elif args.pk_type == 'ub_partial_rand': 218 | pk_solver.create_C_ub_partial_rand(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC) 219 | elif args.pk_type == 'br_partial': 220 | pk_solver.create_C_br_partial(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC) 221 | elif args.pk_type == 'br_partial_reverse': 222 | pk_solver.create_C_br_partial_reverse(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC) 223 | elif args.pk_type == 'br_partial_rand': 224 | pk_solver.create_C_br_partial_rand(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC) 225 | elif args.pk_type == 'ub_noisy': 226 | pk_solver.create_C_ub_noisy(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_noise) 227 | elif args.pk_type == 'br_noisy': 228 | pk_solver.create_C_br_noisy(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_noise) 229 | 230 | epoch = 0 231 | while iter_num < max_iter: 232 | try: 233 | inputs_test, _, tar_idx = iter_test.next() 234 | except: 235 | iter_test = iter(dset_loaders["target"]) 236 | inputs_test, _, tar_idx = iter_test.next() 237 | 238 | if inputs_test.size(0) == 1: 239 | continue 240 | 241 | if iter_num % interval_iter == 0 and args.cls_par > 0: 242 | netF.eval() 243 | netB.eval() 244 | mem_label = obtain_label(dset_loaders['valid'], netF, netB, netC, args, pk_solver, epoch) 245 | mem_label = torch.from_numpy(mem_label).cuda() 246 | netF.train() 247 | netB.train() 248 | epoch += 1 249 | 250 | 251 | inputs_test = inputs_test.cuda() 252 | 253 | iter_num += 1 254 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 255 | 256 | features_test = netB(netF(inputs_test)) 257 | outputs_test = netC(features_test) 258 | 259 | if args.cls_par > 0: 260 | pred = mem_label[tar_idx] 261 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 262 | classifier_loss *= args.cls_par 263 | if iter_num < interval_iter and args.dset == "visda-2017": 264 | classifier_loss *= 0 265 | else: 266 | classifier_loss = torch.tensor(0.0).cuda() 267 | 268 | if args.ent: 269 | softmax_out = nn.Softmax(dim=1)(outputs_test) 270 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 271 | if args.gent: 272 | msoftmax = softmax_out.mean(dim=0) 273 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 274 | entropy_loss -= gentropy_loss 275 | im_loss = entropy_loss * args.ent_par 276 | classifier_loss += im_loss 277 | 278 | optimizer.zero_grad() 279 | classifier_loss.backward() 280 | optimizer.step() 281 | 282 | if iter_num % interval_iter == 0 or iter_num == max_iter: 283 | netF.eval() 284 | netB.eval() 285 | 286 | if args.dset=='visda-2017': 287 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC) 288 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc) + '\n' + acc_list 289 | else: 290 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC) 291 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc) 292 | 293 | logging.info(log_str) 294 | netF.train() 295 | netB.train() 296 | 297 | if args.issave: 298 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt")) 299 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt")) 300 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt")) 301 | 302 | return netF, netB, netC 303 | 304 | def print_args(args): 305 | s = "==========================================\n" 306 | for arg, content in args.__dict__.items(): 307 | s += "{}:{}\n".format(arg, content) 308 | return s 309 | 310 | def obtain_label(loader, netF, netB, netC, args, pk_solver, epoch): 311 | start_test = True 312 | with torch.no_grad(): 313 | iter_test = iter(loader) 314 | for _ in range(len(loader)): 315 | data = iter_test.next() 316 | inputs = data[0] 317 | labels = data[1] 318 | inputs = inputs.cuda() 319 | feas = netB(netF(inputs)) 320 | outputs = netC(feas) 321 | if start_test: 322 | all_fea = feas.float().cpu() 323 | all_output = outputs.float().cpu() 324 | all_label = labels.float() 325 | start_test = False 326 | else: 327 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 328 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 329 | all_label = torch.cat((all_label, labels.float()), 0) 330 | 331 | all_output = nn.Softmax(dim=1)(all_output) 332 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 333 | unknown_weight = 1 - ent / np.log(args.class_num) 334 | _, predict = torch.max(all_output, 1) 335 | 336 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 337 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 338 | avg_accuracy = (matrix.diagonal() / matrix.sum(axis=1)).mean() 339 | if args.da == 'pda': 340 | avg_accuracy = 0 341 | 342 | if args.distance == 'cosine': 343 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 344 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 345 | 346 | all_fea = all_fea.float().cpu().numpy() 347 | K = all_output.size(1) 348 | aff = all_output.float().cpu().numpy() 349 | initc = aff.transpose().dot(all_fea) 350 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 351 | 352 | dd = cdist(all_fea, initc, args.distance) 353 | dd[np.isnan(dd)] = np.inf 354 | pred_label = dd.argmin(axis=1) 355 | 356 | for round in range(1): 357 | aff = np.eye(K)[pred_label] 358 | initc = aff.transpose().dot(all_fea) 359 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 360 | dd = cdist(all_fea, initc, args.distance) 361 | dd[np.isnan(dd)] = np.inf 362 | pred_label = dd.argmin(axis=1) 363 | 364 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 365 | matrix = confusion_matrix(all_label.float().numpy(), pred_label) 366 | avg_acc = (matrix.diagonal() / matrix.sum(axis=1)).mean() 367 | if args.da == 'pda': 368 | acc_list = '' 369 | avg_acc = 0 370 | log_str = 'Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100, avg_accuracy * 100, avg_acc * 100) 371 | logging.info(log_str) 372 | 373 | # update labels with prior knowledge 374 | T = args.pk_temp 375 | probs = np.exp(-dd / T) 376 | probs = probs / probs.sum(1, keepdims=True) 377 | # first solve without smooth regularization 378 | pred_label_PK = pk_solver.solve_soft(probs) 379 | 380 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / len(all_fea) 381 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK) 382 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean() 383 | if args.da == 'pda': 384 | avg_acc_PK = 0 385 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(acc * 100, acc_PK * 100, avg_acc * 100, avg_acc_PK * 100) 386 | logging.info(log_str) 387 | 388 | # now solve with smooth regularization 389 | if args.pk_knn > 0: 390 | idx_unconf = np.where(pred_label_PK != pred_label)[0] 391 | knn_sample_idx = idx_unconf 392 | idx_conf = np.where(pred_label_PK == pred_label)[0] 393 | 394 | if len(idx_unconf) > 0 and len(idx_conf) > 0: 395 | # get knn of each samples 396 | dd_knn = cdist(all_fea[idx_unconf], all_fea[idx_conf], args.distance) 397 | knn_idx = [] 398 | K = args.pk_knn 399 | for i in range(dd_knn.shape[0]): 400 | ind = np.argpartition(dd_knn[i], K)[:K] 401 | knn_idx.append(idx_conf[ind]) 402 | 403 | knn_idx = np.stack(knn_idx, axis=0) 404 | knn_regs = list(zip(knn_sample_idx, knn_idx)) 405 | pred_label_PK = pk_solver.solve_soft_knn_cst(probs, knn_regs=knn_regs) 406 | 407 | 408 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / len(all_fea) 409 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK) 410 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean() 411 | if args.da == 'pda': 412 | avg_acc_PK = 0 413 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(acc * 100, acc_PK * 100, avg_acc * 100, avg_acc_PK * 100) 414 | logging.info(log_str) 415 | 416 | return pred_label_PK.astype('int') 417 | 418 | 419 | if __name__ == "__main__": 420 | parser = argparse.ArgumentParser(description='SHOT') 421 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 422 | parser.add_argument('--s', type=str, default=None, help="source") 423 | parser.add_argument('--t', type=str, default=None, help="target") 424 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 425 | parser.add_argument('--interval', type=int, default=15) 426 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 427 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 428 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'domainnet40', 'office31', 'office-home', 'office-home-rsut', 'office-caltech']) 429 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 430 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet50, res101") 431 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 432 | 433 | parser.add_argument('--gent', type=bool, default=True) 434 | parser.add_argument('--ent', type=bool, default=True) 435 | parser.add_argument('--threshold', type=int, default=0) 436 | parser.add_argument('--cls_par', type=float, default=0.3) 437 | parser.add_argument('--ent_par', type=float, default=1.0) 438 | parser.add_argument('--lr_decay1', type=float, default=0.1) 439 | parser.add_argument('--lr_decay2', type=float, default=1.0) 440 | 441 | parser.add_argument('--bottleneck', type=int, default=256) 442 | parser.add_argument('--epsilon', type=float, default=1e-5) 443 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 444 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 445 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 446 | parser.add_argument('--output', type=str, default='san') 447 | parser.add_argument('--output_src', type=str, default='san') 448 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 449 | parser.add_argument('--issave', type=bool, default=True) 450 | 451 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp') 452 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)), 453 | help='whether use file logger') 454 | parser.add_argument('--names', default=[], type=list, help='names of tasks') 455 | parser.add_argument('--use_train_transform', default='False', type=lambda x: bool(distutils.util.strtobool(x)), 456 | help='whether use train transform for label refinement') 457 | 458 | parser.add_argument('--pk_uconf', type=float, default=0.0) 459 | parser.add_argument('--pk_type', type=str, default="ub") 460 | parser.add_argument('--pk_allow', type=int, default=None) 461 | parser.add_argument('--pk_temp', type=float, default=1.0) 462 | parser.add_argument('--pk_prior_weight', type=float, default=10.) 463 | parser.add_argument('--pk_knn', type=int, default=1) 464 | parser.add_argument('--pk_NC', type=int, default=None) 465 | parser.add_argument('--pk_noise', type=float, default=0.0) 466 | parser.add_argument('--pk_dratio', type=float, default=1.0) 467 | parser.add_argument('--method', type=str, default="kshot") 468 | 469 | args = parser.parse_args() 470 | 471 | if args.dset == 'office-home': 472 | args.names = ['Art', 'Clipart', 'Product', 'Real_World'] 473 | args.class_num = 65 474 | if args.dset == 'office-home-rsut': 475 | args.names = ['Clipart', 'Product', 'Real_World'] 476 | args.class_num = 65 477 | if args.dset == 'domainnet40': 478 | args.names = ['sketch', 'clipart', 'painting', 'real'] 479 | args.class_num = 40 480 | if args.dset == 'office31': 481 | args.names = ['webcam', 'amazon', 'dslr'] 482 | args.class_num = 31 483 | if args.dset == 'visda-2017': 484 | args.names = ['train', 'validation'] 485 | args.class_num = 12 486 | if args.dset == 'office-caltech': 487 | args.names = ['amazon', 'caltech', 'dslr', 'webcam'] 488 | args.class_num = 10 489 | 490 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 491 | resetRNGseed(args.seed) 492 | 493 | if args.dset == 'office-home-rsut': 494 | args.s += '_RS' 495 | 496 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method) 497 | if args.use_file_logger: 498 | init_logger(dir, True, '../logs/SHOT/{}/'.format(args.method)) 499 | logging.info("{}:{}".format(get_hostname(), get_pid())) 500 | 501 | for t in args.names: 502 | if t == args.s or t == args.s.split('_RS')[0]: 503 | continue 504 | args.t = t 505 | 506 | if args.dset == 'office-home-rsut': 507 | args.t += '_UT' 508 | 509 | folder = '../data/' 510 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt' 511 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 512 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt' 513 | 514 | if args.dset == 'domainnet40': 515 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt' 516 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '_train_mini.txt' 517 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '_test_mini.txt' 518 | 519 | if args.dset == 'office-home': 520 | if args.da == 'pda': 521 | args.class_num = 65 522 | args.src_classes = [i for i in range(65)] 523 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58] 524 | 525 | args.output_dir_src = "../checkpoints/SHOT/source/{}/".format(args.da) 526 | args.output_dir = "../checkpoints/SHOT/target_{}/".format(args.method) 527 | 528 | if not osp.exists(args.output_dir): 529 | os.system('mkdir -p ' + args.output_dir) 530 | if not osp.exists(args.output_dir): 531 | os.mkdir(args.output_dir) 532 | 533 | args.savename = 'par_' + str(args.cls_par) 534 | if args.da == 'pda': 535 | args.gent = '' 536 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 537 | 538 | logging.info(print_args(args)) 539 | train_target(args) -------------------------------------------------------------------------------- /SHOT/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def PK_loss(PK_solver, prob): 10 | pk_loss = 0.0 11 | N = PK_solver.N 12 | if PK_solver.C_abs is not None and len(PK_solver.C_abs)>0: 13 | for (c, lb, ub) in PK_solver.C_abs: 14 | if lb is not None: 15 | pk_loss += torch.maximum(lb/N-prob[c], torch.tensor(0.)) 16 | if ub is not None: 17 | pk_loss += torch.maximum(-ub/N+prob[c], torch.tensor(0.)) 18 | 19 | if PK_solver.C_rel is not None and len(PK_solver.C_rel)>0: 20 | for (c1, c2, diff) in PK_solver.C_rel: 21 | pk_loss += torch.maximum(diff-prob[c1]+prob[c2], torch.tensor(0.)) 22 | 23 | return pk_loss 24 | 25 | def Entropy(input_): 26 | bs = input_.size(0) 27 | epsilon = 1e-5 28 | entropy = -input_ * torch.log(input_ + epsilon) 29 | entropy = torch.sum(entropy, dim=1) 30 | return entropy 31 | 32 | def grl_hook(coeff): 33 | def fun1(grad): 34 | return -coeff*grad.clone() 35 | return fun1 36 | 37 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 38 | softmax_output = input_list[1].detach() 39 | feature = input_list[0] 40 | if random_layer is None: 41 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 42 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 43 | else: 44 | random_out = random_layer.forward([feature, softmax_output]) 45 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 46 | batch_size = softmax_output.size(0) // 2 47 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 48 | if entropy is not None: 49 | entropy.register_hook(grl_hook(coeff)) 50 | entropy = 1.0+torch.exp(-entropy) 51 | source_mask = torch.ones_like(entropy) 52 | source_mask[feature.size(0)//2:] = 0 53 | source_weight = entropy*source_mask 54 | target_mask = torch.ones_like(entropy) 55 | target_mask[0:feature.size(0)//2] = 0 56 | target_weight = entropy*target_mask 57 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 58 | target_weight / torch.sum(target_weight).detach().item() 59 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 60 | else: 61 | return nn.BCELoss()(ad_out, dc_target) 62 | 63 | def DANN(features, ad_net): 64 | ad_out = ad_net(features) 65 | batch_size = ad_out.size(0) // 2 66 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 67 | return nn.BCELoss()(ad_out, dc_target) 68 | 69 | 70 | class CrossEntropyLabelSmooth(nn.Module): 71 | """Cross entropy loss with label smoothing regularizer. 72 | Reference: 73 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 74 | Equation: y = (1 - epsilon) * y + epsilon / K. 75 | Args: 76 | num_classes (int): number of classes. 77 | epsilon (float): weight. 78 | """ 79 | 80 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 81 | super(CrossEntropyLabelSmooth, self).__init__() 82 | self.num_classes = num_classes 83 | self.epsilon = epsilon 84 | self.use_gpu = use_gpu 85 | self.reduction = reduction 86 | self.logsoftmax = nn.LogSoftmax(dim=1) 87 | 88 | def forward(self, inputs, targets): 89 | """ 90 | Args: 91 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 92 | targets: ground truth labels with shape (num_classes) 93 | """ 94 | log_probs = self.logsoftmax(inputs) 95 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 96 | if self.use_gpu: targets = targets.cuda() 97 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 98 | loss = (- targets * log_probs).sum(dim=1) 99 | if self.reduction: 100 | return loss.mean() 101 | else: 102 | return loss 103 | return loss -------------------------------------------------------------------------------- /SHOT/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.utils.weight_norm as weightNorm 9 | from collections import OrderedDict 10 | 11 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 12 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 13 | 14 | def init_weights(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 17 | nn.init.kaiming_uniform_(m.weight) 18 | nn.init.zeros_(m.bias) 19 | elif classname.find('BatchNorm') != -1: 20 | nn.init.normal_(m.weight, 1.0, 0.02) 21 | nn.init.zeros_(m.bias) 22 | elif classname.find('Linear') != -1: 23 | nn.init.xavier_normal_(m.weight) 24 | nn.init.zeros_(m.bias) 25 | 26 | vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19, 27 | "vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn} 28 | class VGGBase(nn.Module): 29 | def __init__(self, vgg_name): 30 | super(VGGBase, self).__init__() 31 | model_vgg = vgg_dict[vgg_name](pretrained=True) 32 | self.features = model_vgg.features 33 | self.classifier = nn.Sequential() 34 | for i in range(6): 35 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 36 | self.in_features = model_vgg.classifier[6].in_features 37 | 38 | def forward(self, x): 39 | x = self.features(x) 40 | x = x.view(x.size(0), -1) 41 | x = self.classifier(x) 42 | return x 43 | 44 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50, 45 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d} 46 | 47 | class ResBase(nn.Module): 48 | def __init__(self, res_name): 49 | super(ResBase, self).__init__() 50 | model_resnet = res_dict[res_name](pretrained=True) 51 | self.conv1 = model_resnet.conv1 52 | self.bn1 = model_resnet.bn1 53 | self.relu = model_resnet.relu 54 | self.maxpool = model_resnet.maxpool 55 | self.layer1 = model_resnet.layer1 56 | self.layer2 = model_resnet.layer2 57 | self.layer3 = model_resnet.layer3 58 | self.layer4 = model_resnet.layer4 59 | self.avgpool = model_resnet.avgpool 60 | self.in_features = model_resnet.fc.in_features 61 | 62 | def forward(self, x): 63 | x = self.conv1(x) 64 | x = self.bn1(x) 65 | x = self.relu(x) 66 | x = self.maxpool(x) 67 | x = self.layer1(x) 68 | x = self.layer2(x) 69 | x = self.layer3(x) 70 | x = self.layer4(x) 71 | x = self.avgpool(x) 72 | x = x.view(x.size(0), -1) 73 | return x 74 | 75 | class feat_bootleneck(nn.Module): 76 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 77 | super(feat_bootleneck, self).__init__() 78 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.dropout = nn.Dropout(p=0.5) 81 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 82 | self.bottleneck.apply(init_weights) 83 | self.type = type 84 | 85 | def forward(self, x): 86 | x = self.bottleneck(x) 87 | if self.type == "bn": 88 | x = self.bn(x) 89 | return x 90 | 91 | class feat_classifier(nn.Module): 92 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 93 | super(feat_classifier, self).__init__() 94 | self.type = type 95 | if type == 'wn': 96 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 97 | self.fc.apply(init_weights) 98 | else: 99 | self.fc = nn.Linear(bottleneck_dim, class_num) 100 | self.fc.apply(init_weights) 101 | 102 | def forward(self, x): 103 | x = self.fc(x) 104 | return x 105 | 106 | class feat_classifier_two(nn.Module): 107 | def __init__(self, class_num, input_dim, bottleneck_dim=256): 108 | super(feat_classifier_two, self).__init__() 109 | self.type = type 110 | self.fc0 = nn.Linear(input_dim, bottleneck_dim) 111 | self.fc0.apply(init_weights) 112 | self.fc1 = nn.Linear(bottleneck_dim, class_num) 113 | self.fc1.apply(init_weights) 114 | 115 | def forward(self, x): 116 | x = self.fc0(x) 117 | x = self.fc1(x) 118 | return x 119 | 120 | class Res50(nn.Module): 121 | def __init__(self): 122 | super(Res50, self).__init__() 123 | model_resnet = models.resnet50(pretrained=True) 124 | self.conv1 = model_resnet.conv1 125 | self.bn1 = model_resnet.bn1 126 | self.relu = model_resnet.relu 127 | self.maxpool = model_resnet.maxpool 128 | self.layer1 = model_resnet.layer1 129 | self.layer2 = model_resnet.layer2 130 | self.layer3 = model_resnet.layer3 131 | self.layer4 = model_resnet.layer4 132 | self.avgpool = model_resnet.avgpool 133 | self.in_features = model_resnet.fc.in_features 134 | self.fc = model_resnet.fc 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | y = self.fc(x) 148 | return x, y -------------------------------------------------------------------------------- /SHOT/run_all_kSHOT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | time=`python ../util/get_time.py` 5 | 6 | # office31 ------------------------------------------------------------------------------------------------------------- 7 | for src in "amazon" "webcam" "dslr"; do 8 | echo $src 9 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office31 --s $src --max_epoch 100 --timestamp $time 10 | done 11 | 12 | for seed in 2020 2021 2022; do 13 | for src in "amazon" "webcam" "dslr"; do 14 | echo $src 15 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 16 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office31 --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub 17 | done 18 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office31 --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br 19 | done 20 | done 21 | 22 | 23 | # office-home-rsut ---------------------------------------------------------------------------------------------------- 24 | for src in "Product" "Clipart" "Real_World"; do 25 | echo $src 26 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office-home-rsut --s $src --max_epoch 50 --timestamp $time 27 | done 28 | 29 | for seed in 2020 2021 2022; do 30 | for src in "Product" "Clipart" "Real_World"; do 31 | echo $src 32 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 33 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home-rsut --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub 34 | done 35 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home-rsut --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br 36 | done 37 | done 38 | 39 | 40 | # office-home ---------------------------------------------------------------------------------------------------------- 41 | for src in "Product" "Clipart" "Art" "Real_World"; do 42 | echo $src 43 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office-home --s $src --max_epoch 50 --timestamp $time 44 | done 45 | 46 | for seed in 2020 2021 2022; do 47 | for src in "Product" "Clipart" "Art" "Real_World"; do 48 | echo $src 49 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 50 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub 51 | done 52 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br 53 | done 54 | done 55 | 56 | 57 | # visda-2017 ----------------------------------------------------------------------------------------------------------- 58 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset visda-2017 --s train --max_epoch 10 --timestamp $time --net resnet101 --lr 1e-3 59 | 60 | for seed in 2020 2021 2022; do 61 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 62 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset visda-2017 --s train --timestamp $time --seed $seed --pk_uconf $pk_uconf --net resnet101 --lr 1e-3 --pk_type ub 63 | done 64 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset visda-2017 --s train --timestamp $time --seed $seed --pk_uconf 1.0 --net resnet101 --lr 1e-3 --pk_type br 65 | done 66 | 67 | 68 | # domainnet40 ---------------------------------------------------------------------------------------------------------- 69 | for src in "sketch" "clipart" "painting" "real"; do 70 | echo $src 71 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset domainnet40 --s $src --max_epoch 50 --timestamp $time 72 | done 73 | 74 | for seed in 2020 2021 2022; do 75 | for src in "sketch" "clipart" "painting" "real"; do 76 | echo $src 77 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do 78 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset domainnet40 --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub 79 | done 80 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset domainnet40 --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br 81 | done 82 | done 83 | 84 | 85 | # office-home (PDA)----------------------------------------------------------------------------------------------------- 86 | for src in "Product" "Clipart" "Art" "Real_World"; do 87 | echo $src 88 | python image_source.py --trte val --da pda --gpu_id $gpu_id --dset office-home --s $src --max_epoch 50 --timestamp $time 89 | done 90 | 91 | for seed in 2020 2021 2022; do 92 | for src in "Product" "Clipart" "Art" "Real_World"; do 93 | echo $src 94 | python image_target_kSHOT.py --cls_par 0.3 --da pda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf 0.0 --pk_type ub 95 | python image_target_kSHOT.py --cls_par 0.3 --da pda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br 96 | done 97 | done -------------------------------------------------------------------------------- /data/office31/image_list/dslr.txt: -------------------------------------------------------------------------------- 1 | dslr/images/calculator/frame_0001.jpg 5 2 | dslr/images/calculator/frame_0002.jpg 5 3 | dslr/images/calculator/frame_0003.jpg 5 4 | dslr/images/calculator/frame_0004.jpg 5 5 | dslr/images/calculator/frame_0005.jpg 5 6 | dslr/images/calculator/frame_0006.jpg 5 7 | dslr/images/calculator/frame_0007.jpg 5 8 | dslr/images/calculator/frame_0008.jpg 5 9 | dslr/images/calculator/frame_0009.jpg 5 10 | dslr/images/calculator/frame_0010.jpg 5 11 | dslr/images/calculator/frame_0011.jpg 5 12 | dslr/images/calculator/frame_0012.jpg 5 13 | dslr/images/ring_binder/frame_0001.jpg 24 14 | dslr/images/ring_binder/frame_0002.jpg 24 15 | dslr/images/ring_binder/frame_0003.jpg 24 16 | dslr/images/ring_binder/frame_0004.jpg 24 17 | dslr/images/ring_binder/frame_0005.jpg 24 18 | dslr/images/ring_binder/frame_0006.jpg 24 19 | dslr/images/ring_binder/frame_0007.jpg 24 20 | dslr/images/ring_binder/frame_0008.jpg 24 21 | dslr/images/ring_binder/frame_0009.jpg 24 22 | dslr/images/ring_binder/frame_0010.jpg 24 23 | dslr/images/printer/frame_0001.jpg 21 24 | dslr/images/printer/frame_0002.jpg 21 25 | dslr/images/printer/frame_0003.jpg 21 26 | dslr/images/printer/frame_0004.jpg 21 27 | dslr/images/printer/frame_0005.jpg 21 28 | dslr/images/printer/frame_0006.jpg 21 29 | dslr/images/printer/frame_0007.jpg 21 30 | dslr/images/printer/frame_0008.jpg 21 31 | dslr/images/printer/frame_0009.jpg 21 32 | dslr/images/printer/frame_0010.jpg 21 33 | dslr/images/printer/frame_0011.jpg 21 34 | dslr/images/printer/frame_0012.jpg 21 35 | dslr/images/printer/frame_0013.jpg 21 36 | dslr/images/printer/frame_0014.jpg 21 37 | dslr/images/printer/frame_0015.jpg 21 38 | dslr/images/keyboard/frame_0001.jpg 11 39 | dslr/images/keyboard/frame_0002.jpg 11 40 | dslr/images/keyboard/frame_0003.jpg 11 41 | dslr/images/keyboard/frame_0004.jpg 11 42 | dslr/images/keyboard/frame_0005.jpg 11 43 | dslr/images/keyboard/frame_0006.jpg 11 44 | dslr/images/keyboard/frame_0007.jpg 11 45 | dslr/images/keyboard/frame_0008.jpg 11 46 | dslr/images/keyboard/frame_0009.jpg 11 47 | dslr/images/keyboard/frame_0010.jpg 11 48 | dslr/images/scissors/frame_0001.jpg 26 49 | dslr/images/scissors/frame_0002.jpg 26 50 | dslr/images/scissors/frame_0003.jpg 26 51 | dslr/images/scissors/frame_0004.jpg 26 52 | dslr/images/scissors/frame_0005.jpg 26 53 | dslr/images/scissors/frame_0006.jpg 26 54 | dslr/images/scissors/frame_0007.jpg 26 55 | dslr/images/scissors/frame_0008.jpg 26 56 | dslr/images/scissors/frame_0009.jpg 26 57 | dslr/images/scissors/frame_0010.jpg 26 58 | dslr/images/scissors/frame_0011.jpg 26 59 | dslr/images/scissors/frame_0012.jpg 26 60 | dslr/images/scissors/frame_0013.jpg 26 61 | dslr/images/scissors/frame_0014.jpg 26 62 | dslr/images/scissors/frame_0015.jpg 26 63 | dslr/images/scissors/frame_0016.jpg 26 64 | dslr/images/scissors/frame_0017.jpg 26 65 | dslr/images/scissors/frame_0018.jpg 26 66 | dslr/images/laptop_computer/frame_0001.jpg 12 67 | dslr/images/laptop_computer/frame_0002.jpg 12 68 | dslr/images/laptop_computer/frame_0003.jpg 12 69 | dslr/images/laptop_computer/frame_0004.jpg 12 70 | dslr/images/laptop_computer/frame_0005.jpg 12 71 | dslr/images/laptop_computer/frame_0006.jpg 12 72 | dslr/images/laptop_computer/frame_0007.jpg 12 73 | dslr/images/laptop_computer/frame_0008.jpg 12 74 | dslr/images/laptop_computer/frame_0009.jpg 12 75 | dslr/images/laptop_computer/frame_0010.jpg 12 76 | dslr/images/laptop_computer/frame_0011.jpg 12 77 | dslr/images/laptop_computer/frame_0012.jpg 12 78 | dslr/images/laptop_computer/frame_0013.jpg 12 79 | dslr/images/laptop_computer/frame_0014.jpg 12 80 | dslr/images/laptop_computer/frame_0015.jpg 12 81 | dslr/images/laptop_computer/frame_0016.jpg 12 82 | dslr/images/laptop_computer/frame_0017.jpg 12 83 | dslr/images/laptop_computer/frame_0018.jpg 12 84 | dslr/images/laptop_computer/frame_0019.jpg 12 85 | dslr/images/laptop_computer/frame_0020.jpg 12 86 | dslr/images/laptop_computer/frame_0021.jpg 12 87 | dslr/images/laptop_computer/frame_0022.jpg 12 88 | dslr/images/laptop_computer/frame_0023.jpg 12 89 | dslr/images/laptop_computer/frame_0024.jpg 12 90 | dslr/images/mouse/frame_0001.jpg 16 91 | dslr/images/mouse/frame_0002.jpg 16 92 | dslr/images/mouse/frame_0003.jpg 16 93 | dslr/images/mouse/frame_0004.jpg 16 94 | dslr/images/mouse/frame_0005.jpg 16 95 | dslr/images/mouse/frame_0006.jpg 16 96 | dslr/images/mouse/frame_0007.jpg 16 97 | dslr/images/mouse/frame_0008.jpg 16 98 | dslr/images/mouse/frame_0009.jpg 16 99 | dslr/images/mouse/frame_0010.jpg 16 100 | dslr/images/mouse/frame_0011.jpg 16 101 | dslr/images/mouse/frame_0012.jpg 16 102 | dslr/images/monitor/frame_0001.jpg 15 103 | dslr/images/monitor/frame_0002.jpg 15 104 | dslr/images/monitor/frame_0003.jpg 15 105 | dslr/images/monitor/frame_0004.jpg 15 106 | dslr/images/monitor/frame_0005.jpg 15 107 | dslr/images/monitor/frame_0006.jpg 15 108 | dslr/images/monitor/frame_0007.jpg 15 109 | dslr/images/monitor/frame_0008.jpg 15 110 | dslr/images/monitor/frame_0009.jpg 15 111 | dslr/images/monitor/frame_0010.jpg 15 112 | dslr/images/monitor/frame_0011.jpg 15 113 | dslr/images/monitor/frame_0012.jpg 15 114 | dslr/images/monitor/frame_0013.jpg 15 115 | dslr/images/monitor/frame_0014.jpg 15 116 | dslr/images/monitor/frame_0015.jpg 15 117 | dslr/images/monitor/frame_0016.jpg 15 118 | dslr/images/monitor/frame_0017.jpg 15 119 | dslr/images/monitor/frame_0018.jpg 15 120 | dslr/images/monitor/frame_0019.jpg 15 121 | dslr/images/monitor/frame_0020.jpg 15 122 | dslr/images/monitor/frame_0021.jpg 15 123 | dslr/images/monitor/frame_0022.jpg 15 124 | dslr/images/mug/frame_0001.jpg 17 125 | dslr/images/mug/frame_0002.jpg 17 126 | dslr/images/mug/frame_0003.jpg 17 127 | dslr/images/mug/frame_0004.jpg 17 128 | dslr/images/mug/frame_0005.jpg 17 129 | dslr/images/mug/frame_0006.jpg 17 130 | dslr/images/mug/frame_0007.jpg 17 131 | dslr/images/mug/frame_0008.jpg 17 132 | dslr/images/tape_dispenser/frame_0001.jpg 29 133 | dslr/images/tape_dispenser/frame_0002.jpg 29 134 | dslr/images/tape_dispenser/frame_0003.jpg 29 135 | dslr/images/tape_dispenser/frame_0004.jpg 29 136 | dslr/images/tape_dispenser/frame_0005.jpg 29 137 | dslr/images/tape_dispenser/frame_0006.jpg 29 138 | dslr/images/tape_dispenser/frame_0007.jpg 29 139 | dslr/images/tape_dispenser/frame_0008.jpg 29 140 | dslr/images/tape_dispenser/frame_0009.jpg 29 141 | dslr/images/tape_dispenser/frame_0010.jpg 29 142 | dslr/images/tape_dispenser/frame_0011.jpg 29 143 | dslr/images/tape_dispenser/frame_0012.jpg 29 144 | dslr/images/tape_dispenser/frame_0013.jpg 29 145 | dslr/images/tape_dispenser/frame_0014.jpg 29 146 | dslr/images/tape_dispenser/frame_0015.jpg 29 147 | dslr/images/tape_dispenser/frame_0016.jpg 29 148 | dslr/images/tape_dispenser/frame_0017.jpg 29 149 | dslr/images/tape_dispenser/frame_0018.jpg 29 150 | dslr/images/tape_dispenser/frame_0019.jpg 29 151 | dslr/images/tape_dispenser/frame_0020.jpg 29 152 | dslr/images/tape_dispenser/frame_0021.jpg 29 153 | dslr/images/tape_dispenser/frame_0022.jpg 29 154 | dslr/images/pen/frame_0001.jpg 19 155 | dslr/images/pen/frame_0002.jpg 19 156 | dslr/images/pen/frame_0003.jpg 19 157 | dslr/images/pen/frame_0004.jpg 19 158 | dslr/images/pen/frame_0005.jpg 19 159 | dslr/images/pen/frame_0006.jpg 19 160 | dslr/images/pen/frame_0007.jpg 19 161 | dslr/images/pen/frame_0008.jpg 19 162 | dslr/images/pen/frame_0009.jpg 19 163 | dslr/images/pen/frame_0010.jpg 19 164 | dslr/images/bike/frame_0001.jpg 1 165 | dslr/images/bike/frame_0002.jpg 1 166 | dslr/images/bike/frame_0003.jpg 1 167 | dslr/images/bike/frame_0004.jpg 1 168 | dslr/images/bike/frame_0005.jpg 1 169 | dslr/images/bike/frame_0006.jpg 1 170 | dslr/images/bike/frame_0007.jpg 1 171 | dslr/images/bike/frame_0008.jpg 1 172 | dslr/images/bike/frame_0009.jpg 1 173 | dslr/images/bike/frame_0010.jpg 1 174 | dslr/images/bike/frame_0011.jpg 1 175 | dslr/images/bike/frame_0012.jpg 1 176 | dslr/images/bike/frame_0013.jpg 1 177 | dslr/images/bike/frame_0014.jpg 1 178 | dslr/images/bike/frame_0015.jpg 1 179 | dslr/images/bike/frame_0016.jpg 1 180 | dslr/images/bike/frame_0017.jpg 1 181 | dslr/images/bike/frame_0018.jpg 1 182 | dslr/images/bike/frame_0019.jpg 1 183 | dslr/images/bike/frame_0020.jpg 1 184 | dslr/images/bike/frame_0021.jpg 1 185 | dslr/images/punchers/frame_0001.jpg 23 186 | dslr/images/punchers/frame_0002.jpg 23 187 | dslr/images/punchers/frame_0003.jpg 23 188 | dslr/images/punchers/frame_0004.jpg 23 189 | dslr/images/punchers/frame_0005.jpg 23 190 | dslr/images/punchers/frame_0006.jpg 23 191 | dslr/images/punchers/frame_0007.jpg 23 192 | dslr/images/punchers/frame_0008.jpg 23 193 | dslr/images/punchers/frame_0009.jpg 23 194 | dslr/images/punchers/frame_0010.jpg 23 195 | dslr/images/punchers/frame_0011.jpg 23 196 | dslr/images/punchers/frame_0012.jpg 23 197 | dslr/images/punchers/frame_0013.jpg 23 198 | dslr/images/punchers/frame_0014.jpg 23 199 | dslr/images/punchers/frame_0015.jpg 23 200 | dslr/images/punchers/frame_0016.jpg 23 201 | dslr/images/punchers/frame_0017.jpg 23 202 | dslr/images/punchers/frame_0018.jpg 23 203 | dslr/images/back_pack/frame_0001.jpg 0 204 | dslr/images/back_pack/frame_0002.jpg 0 205 | dslr/images/back_pack/frame_0003.jpg 0 206 | dslr/images/back_pack/frame_0004.jpg 0 207 | dslr/images/back_pack/frame_0005.jpg 0 208 | dslr/images/back_pack/frame_0006.jpg 0 209 | dslr/images/back_pack/frame_0007.jpg 0 210 | dslr/images/back_pack/frame_0008.jpg 0 211 | dslr/images/back_pack/frame_0009.jpg 0 212 | dslr/images/back_pack/frame_0010.jpg 0 213 | dslr/images/back_pack/frame_0011.jpg 0 214 | dslr/images/back_pack/frame_0012.jpg 0 215 | dslr/images/desktop_computer/frame_0001.jpg 8 216 | dslr/images/desktop_computer/frame_0002.jpg 8 217 | dslr/images/desktop_computer/frame_0003.jpg 8 218 | dslr/images/desktop_computer/frame_0004.jpg 8 219 | dslr/images/desktop_computer/frame_0005.jpg 8 220 | dslr/images/desktop_computer/frame_0006.jpg 8 221 | dslr/images/desktop_computer/frame_0007.jpg 8 222 | dslr/images/desktop_computer/frame_0008.jpg 8 223 | dslr/images/desktop_computer/frame_0009.jpg 8 224 | dslr/images/desktop_computer/frame_0010.jpg 8 225 | dslr/images/desktop_computer/frame_0011.jpg 8 226 | dslr/images/desktop_computer/frame_0012.jpg 8 227 | dslr/images/desktop_computer/frame_0013.jpg 8 228 | dslr/images/desktop_computer/frame_0014.jpg 8 229 | dslr/images/desktop_computer/frame_0015.jpg 8 230 | dslr/images/speaker/frame_0001.jpg 27 231 | dslr/images/speaker/frame_0002.jpg 27 232 | dslr/images/speaker/frame_0003.jpg 27 233 | dslr/images/speaker/frame_0004.jpg 27 234 | dslr/images/speaker/frame_0005.jpg 27 235 | dslr/images/speaker/frame_0006.jpg 27 236 | dslr/images/speaker/frame_0007.jpg 27 237 | dslr/images/speaker/frame_0008.jpg 27 238 | dslr/images/speaker/frame_0009.jpg 27 239 | dslr/images/speaker/frame_0010.jpg 27 240 | dslr/images/speaker/frame_0011.jpg 27 241 | dslr/images/speaker/frame_0012.jpg 27 242 | dslr/images/speaker/frame_0013.jpg 27 243 | dslr/images/speaker/frame_0014.jpg 27 244 | dslr/images/speaker/frame_0015.jpg 27 245 | dslr/images/speaker/frame_0016.jpg 27 246 | dslr/images/speaker/frame_0017.jpg 27 247 | dslr/images/speaker/frame_0018.jpg 27 248 | dslr/images/speaker/frame_0019.jpg 27 249 | dslr/images/speaker/frame_0020.jpg 27 250 | dslr/images/speaker/frame_0021.jpg 27 251 | dslr/images/speaker/frame_0022.jpg 27 252 | dslr/images/speaker/frame_0023.jpg 27 253 | dslr/images/speaker/frame_0024.jpg 27 254 | dslr/images/speaker/frame_0025.jpg 27 255 | dslr/images/speaker/frame_0026.jpg 27 256 | dslr/images/mobile_phone/frame_0001.jpg 14 257 | dslr/images/mobile_phone/frame_0002.jpg 14 258 | dslr/images/mobile_phone/frame_0003.jpg 14 259 | dslr/images/mobile_phone/frame_0004.jpg 14 260 | dslr/images/mobile_phone/frame_0005.jpg 14 261 | dslr/images/mobile_phone/frame_0006.jpg 14 262 | dslr/images/mobile_phone/frame_0007.jpg 14 263 | dslr/images/mobile_phone/frame_0008.jpg 14 264 | dslr/images/mobile_phone/frame_0009.jpg 14 265 | dslr/images/mobile_phone/frame_0010.jpg 14 266 | dslr/images/mobile_phone/frame_0011.jpg 14 267 | dslr/images/mobile_phone/frame_0012.jpg 14 268 | dslr/images/mobile_phone/frame_0013.jpg 14 269 | dslr/images/mobile_phone/frame_0014.jpg 14 270 | dslr/images/mobile_phone/frame_0015.jpg 14 271 | dslr/images/mobile_phone/frame_0016.jpg 14 272 | dslr/images/mobile_phone/frame_0017.jpg 14 273 | dslr/images/mobile_phone/frame_0018.jpg 14 274 | dslr/images/mobile_phone/frame_0019.jpg 14 275 | dslr/images/mobile_phone/frame_0020.jpg 14 276 | dslr/images/mobile_phone/frame_0021.jpg 14 277 | dslr/images/mobile_phone/frame_0022.jpg 14 278 | dslr/images/mobile_phone/frame_0023.jpg 14 279 | dslr/images/mobile_phone/frame_0024.jpg 14 280 | dslr/images/mobile_phone/frame_0025.jpg 14 281 | dslr/images/mobile_phone/frame_0026.jpg 14 282 | dslr/images/mobile_phone/frame_0027.jpg 14 283 | dslr/images/mobile_phone/frame_0028.jpg 14 284 | dslr/images/mobile_phone/frame_0029.jpg 14 285 | dslr/images/mobile_phone/frame_0030.jpg 14 286 | dslr/images/mobile_phone/frame_0031.jpg 14 287 | dslr/images/paper_notebook/frame_0001.jpg 18 288 | dslr/images/paper_notebook/frame_0002.jpg 18 289 | dslr/images/paper_notebook/frame_0003.jpg 18 290 | dslr/images/paper_notebook/frame_0004.jpg 18 291 | dslr/images/paper_notebook/frame_0005.jpg 18 292 | dslr/images/paper_notebook/frame_0006.jpg 18 293 | dslr/images/paper_notebook/frame_0007.jpg 18 294 | dslr/images/paper_notebook/frame_0008.jpg 18 295 | dslr/images/paper_notebook/frame_0009.jpg 18 296 | dslr/images/paper_notebook/frame_0010.jpg 18 297 | dslr/images/ruler/frame_0001.jpg 25 298 | dslr/images/ruler/frame_0002.jpg 25 299 | dslr/images/ruler/frame_0003.jpg 25 300 | dslr/images/ruler/frame_0004.jpg 25 301 | dslr/images/ruler/frame_0005.jpg 25 302 | dslr/images/ruler/frame_0006.jpg 25 303 | dslr/images/ruler/frame_0007.jpg 25 304 | dslr/images/letter_tray/frame_0001.jpg 13 305 | dslr/images/letter_tray/frame_0002.jpg 13 306 | dslr/images/letter_tray/frame_0003.jpg 13 307 | dslr/images/letter_tray/frame_0004.jpg 13 308 | dslr/images/letter_tray/frame_0005.jpg 13 309 | dslr/images/letter_tray/frame_0006.jpg 13 310 | dslr/images/letter_tray/frame_0007.jpg 13 311 | dslr/images/letter_tray/frame_0008.jpg 13 312 | dslr/images/letter_tray/frame_0009.jpg 13 313 | dslr/images/letter_tray/frame_0010.jpg 13 314 | dslr/images/letter_tray/frame_0011.jpg 13 315 | dslr/images/letter_tray/frame_0012.jpg 13 316 | dslr/images/letter_tray/frame_0013.jpg 13 317 | dslr/images/letter_tray/frame_0014.jpg 13 318 | dslr/images/letter_tray/frame_0015.jpg 13 319 | dslr/images/letter_tray/frame_0016.jpg 13 320 | dslr/images/file_cabinet/frame_0001.jpg 9 321 | dslr/images/file_cabinet/frame_0002.jpg 9 322 | dslr/images/file_cabinet/frame_0003.jpg 9 323 | dslr/images/file_cabinet/frame_0004.jpg 9 324 | dslr/images/file_cabinet/frame_0005.jpg 9 325 | dslr/images/file_cabinet/frame_0006.jpg 9 326 | dslr/images/file_cabinet/frame_0007.jpg 9 327 | dslr/images/file_cabinet/frame_0008.jpg 9 328 | dslr/images/file_cabinet/frame_0009.jpg 9 329 | dslr/images/file_cabinet/frame_0010.jpg 9 330 | dslr/images/file_cabinet/frame_0011.jpg 9 331 | dslr/images/file_cabinet/frame_0012.jpg 9 332 | dslr/images/file_cabinet/frame_0013.jpg 9 333 | dslr/images/file_cabinet/frame_0014.jpg 9 334 | dslr/images/file_cabinet/frame_0015.jpg 9 335 | dslr/images/phone/frame_0001.jpg 20 336 | dslr/images/phone/frame_0002.jpg 20 337 | dslr/images/phone/frame_0003.jpg 20 338 | dslr/images/phone/frame_0004.jpg 20 339 | dslr/images/phone/frame_0005.jpg 20 340 | dslr/images/phone/frame_0006.jpg 20 341 | dslr/images/phone/frame_0007.jpg 20 342 | dslr/images/phone/frame_0008.jpg 20 343 | dslr/images/phone/frame_0009.jpg 20 344 | dslr/images/phone/frame_0010.jpg 20 345 | dslr/images/phone/frame_0011.jpg 20 346 | dslr/images/phone/frame_0012.jpg 20 347 | dslr/images/phone/frame_0013.jpg 20 348 | dslr/images/bookcase/frame_0001.jpg 3 349 | dslr/images/bookcase/frame_0002.jpg 3 350 | dslr/images/bookcase/frame_0003.jpg 3 351 | dslr/images/bookcase/frame_0004.jpg 3 352 | dslr/images/bookcase/frame_0005.jpg 3 353 | dslr/images/bookcase/frame_0006.jpg 3 354 | dslr/images/bookcase/frame_0007.jpg 3 355 | dslr/images/bookcase/frame_0008.jpg 3 356 | dslr/images/bookcase/frame_0009.jpg 3 357 | dslr/images/bookcase/frame_0010.jpg 3 358 | dslr/images/bookcase/frame_0011.jpg 3 359 | dslr/images/bookcase/frame_0012.jpg 3 360 | dslr/images/projector/frame_0001.jpg 22 361 | dslr/images/projector/frame_0002.jpg 22 362 | dslr/images/projector/frame_0003.jpg 22 363 | dslr/images/projector/frame_0004.jpg 22 364 | dslr/images/projector/frame_0005.jpg 22 365 | dslr/images/projector/frame_0006.jpg 22 366 | dslr/images/projector/frame_0007.jpg 22 367 | dslr/images/projector/frame_0008.jpg 22 368 | dslr/images/projector/frame_0009.jpg 22 369 | dslr/images/projector/frame_0010.jpg 22 370 | dslr/images/projector/frame_0011.jpg 22 371 | dslr/images/projector/frame_0012.jpg 22 372 | dslr/images/projector/frame_0013.jpg 22 373 | dslr/images/projector/frame_0014.jpg 22 374 | dslr/images/projector/frame_0015.jpg 22 375 | dslr/images/projector/frame_0016.jpg 22 376 | dslr/images/projector/frame_0017.jpg 22 377 | dslr/images/projector/frame_0018.jpg 22 378 | dslr/images/projector/frame_0019.jpg 22 379 | dslr/images/projector/frame_0020.jpg 22 380 | dslr/images/projector/frame_0021.jpg 22 381 | dslr/images/projector/frame_0022.jpg 22 382 | dslr/images/projector/frame_0023.jpg 22 383 | dslr/images/stapler/frame_0001.jpg 28 384 | dslr/images/stapler/frame_0002.jpg 28 385 | dslr/images/stapler/frame_0003.jpg 28 386 | dslr/images/stapler/frame_0004.jpg 28 387 | dslr/images/stapler/frame_0005.jpg 28 388 | dslr/images/stapler/frame_0006.jpg 28 389 | dslr/images/stapler/frame_0007.jpg 28 390 | dslr/images/stapler/frame_0008.jpg 28 391 | dslr/images/stapler/frame_0009.jpg 28 392 | dslr/images/stapler/frame_0010.jpg 28 393 | dslr/images/stapler/frame_0011.jpg 28 394 | dslr/images/stapler/frame_0012.jpg 28 395 | dslr/images/stapler/frame_0013.jpg 28 396 | dslr/images/stapler/frame_0014.jpg 28 397 | dslr/images/stapler/frame_0015.jpg 28 398 | dslr/images/stapler/frame_0016.jpg 28 399 | dslr/images/stapler/frame_0017.jpg 28 400 | dslr/images/stapler/frame_0018.jpg 28 401 | dslr/images/stapler/frame_0019.jpg 28 402 | dslr/images/stapler/frame_0020.jpg 28 403 | dslr/images/stapler/frame_0021.jpg 28 404 | dslr/images/trash_can/frame_0001.jpg 30 405 | dslr/images/trash_can/frame_0002.jpg 30 406 | dslr/images/trash_can/frame_0003.jpg 30 407 | dslr/images/trash_can/frame_0004.jpg 30 408 | dslr/images/trash_can/frame_0005.jpg 30 409 | dslr/images/trash_can/frame_0006.jpg 30 410 | dslr/images/trash_can/frame_0007.jpg 30 411 | dslr/images/trash_can/frame_0008.jpg 30 412 | dslr/images/trash_can/frame_0009.jpg 30 413 | dslr/images/trash_can/frame_0010.jpg 30 414 | dslr/images/trash_can/frame_0011.jpg 30 415 | dslr/images/trash_can/frame_0012.jpg 30 416 | dslr/images/trash_can/frame_0013.jpg 30 417 | dslr/images/trash_can/frame_0014.jpg 30 418 | dslr/images/trash_can/frame_0015.jpg 30 419 | dslr/images/bike_helmet/frame_0001.jpg 2 420 | dslr/images/bike_helmet/frame_0002.jpg 2 421 | dslr/images/bike_helmet/frame_0003.jpg 2 422 | dslr/images/bike_helmet/frame_0004.jpg 2 423 | dslr/images/bike_helmet/frame_0005.jpg 2 424 | dslr/images/bike_helmet/frame_0006.jpg 2 425 | dslr/images/bike_helmet/frame_0007.jpg 2 426 | dslr/images/bike_helmet/frame_0008.jpg 2 427 | dslr/images/bike_helmet/frame_0009.jpg 2 428 | dslr/images/bike_helmet/frame_0010.jpg 2 429 | dslr/images/bike_helmet/frame_0011.jpg 2 430 | dslr/images/bike_helmet/frame_0012.jpg 2 431 | dslr/images/bike_helmet/frame_0013.jpg 2 432 | dslr/images/bike_helmet/frame_0014.jpg 2 433 | dslr/images/bike_helmet/frame_0015.jpg 2 434 | dslr/images/bike_helmet/frame_0016.jpg 2 435 | dslr/images/bike_helmet/frame_0017.jpg 2 436 | dslr/images/bike_helmet/frame_0018.jpg 2 437 | dslr/images/bike_helmet/frame_0019.jpg 2 438 | dslr/images/bike_helmet/frame_0020.jpg 2 439 | dslr/images/bike_helmet/frame_0021.jpg 2 440 | dslr/images/bike_helmet/frame_0022.jpg 2 441 | dslr/images/bike_helmet/frame_0023.jpg 2 442 | dslr/images/bike_helmet/frame_0024.jpg 2 443 | dslr/images/headphones/frame_0001.jpg 10 444 | dslr/images/headphones/frame_0002.jpg 10 445 | dslr/images/headphones/frame_0003.jpg 10 446 | dslr/images/headphones/frame_0004.jpg 10 447 | dslr/images/headphones/frame_0005.jpg 10 448 | dslr/images/headphones/frame_0006.jpg 10 449 | dslr/images/headphones/frame_0007.jpg 10 450 | dslr/images/headphones/frame_0008.jpg 10 451 | dslr/images/headphones/frame_0009.jpg 10 452 | dslr/images/headphones/frame_0010.jpg 10 453 | dslr/images/headphones/frame_0011.jpg 10 454 | dslr/images/headphones/frame_0012.jpg 10 455 | dslr/images/headphones/frame_0013.jpg 10 456 | dslr/images/desk_lamp/frame_0001.jpg 7 457 | dslr/images/desk_lamp/frame_0002.jpg 7 458 | dslr/images/desk_lamp/frame_0003.jpg 7 459 | dslr/images/desk_lamp/frame_0004.jpg 7 460 | dslr/images/desk_lamp/frame_0005.jpg 7 461 | dslr/images/desk_lamp/frame_0006.jpg 7 462 | dslr/images/desk_lamp/frame_0007.jpg 7 463 | dslr/images/desk_lamp/frame_0008.jpg 7 464 | dslr/images/desk_lamp/frame_0009.jpg 7 465 | dslr/images/desk_lamp/frame_0010.jpg 7 466 | dslr/images/desk_lamp/frame_0011.jpg 7 467 | dslr/images/desk_lamp/frame_0012.jpg 7 468 | dslr/images/desk_lamp/frame_0013.jpg 7 469 | dslr/images/desk_lamp/frame_0014.jpg 7 470 | dslr/images/desk_chair/frame_0001.jpg 6 471 | dslr/images/desk_chair/frame_0002.jpg 6 472 | dslr/images/desk_chair/frame_0003.jpg 6 473 | dslr/images/desk_chair/frame_0004.jpg 6 474 | dslr/images/desk_chair/frame_0005.jpg 6 475 | dslr/images/desk_chair/frame_0006.jpg 6 476 | dslr/images/desk_chair/frame_0007.jpg 6 477 | dslr/images/desk_chair/frame_0008.jpg 6 478 | dslr/images/desk_chair/frame_0009.jpg 6 479 | dslr/images/desk_chair/frame_0010.jpg 6 480 | dslr/images/desk_chair/frame_0011.jpg 6 481 | dslr/images/desk_chair/frame_0012.jpg 6 482 | dslr/images/desk_chair/frame_0013.jpg 6 483 | dslr/images/bottle/frame_0001.jpg 4 484 | dslr/images/bottle/frame_0002.jpg 4 485 | dslr/images/bottle/frame_0003.jpg 4 486 | dslr/images/bottle/frame_0004.jpg 4 487 | dslr/images/bottle/frame_0005.jpg 4 488 | dslr/images/bottle/frame_0006.jpg 4 489 | dslr/images/bottle/frame_0007.jpg 4 490 | dslr/images/bottle/frame_0008.jpg 4 491 | dslr/images/bottle/frame_0009.jpg 4 492 | dslr/images/bottle/frame_0010.jpg 4 493 | dslr/images/bottle/frame_0011.jpg 4 494 | dslr/images/bottle/frame_0012.jpg 4 495 | dslr/images/bottle/frame_0013.jpg 4 496 | dslr/images/bottle/frame_0014.jpg 4 497 | dslr/images/bottle/frame_0015.jpg 4 498 | dslr/images/bottle/frame_0016.jpg 4 499 | -------------------------------------------------------------------------------- /data/setup_data_path.sh: -------------------------------------------------------------------------------- 1 | # sh setup_data_path.sh data_path dataset 2 | data_path=$1 3 | dataset=$2 4 | 5 | if [[ ${dataset} == "domainnet40" ]] ; 6 | then 7 | cd domainnet40 8 | rm clipart 9 | ln -s "${data_path}/clipart" clipart 10 | rm infograph 11 | ln -s "${data_path}/infograph" infograph 12 | rm painting 13 | ln -s "${data_path}/painting" painting 14 | rm quickdraw 15 | ln -s "${data_path}/quickdraw" quickdraw 16 | rm real 17 | ln -s "${data_path}/real" real 18 | rm sketch 19 | ln -s "${data_path}/sketch" sketch 20 | cd .. 21 | elif [[ ${dataset} == "office31" ]] ; 22 | then 23 | cd office31 24 | rm amazon 25 | ln -s "${data_path}/amazon" amazon 26 | rm webcam 27 | ln -s "${data_path}/webcam" webcam 28 | rm dslr 29 | ln -s "${data_path}/dslr" dslr 30 | elif [[ ${dataset} == "office-home" ]] ; 31 | then 32 | cd office-home 33 | rm Art 34 | ln -s "${data_path}/Art" Art 35 | rm Clipart 36 | ln -s "${data_path}/Clipart" Clipart 37 | rm Product 38 | ln -s "${data_path}/Product" Product 39 | rm Real_World 40 | ln -s "${data_path}/Real_World" Real_World 41 | elif [[ ${dataset} == "office-home-rsut" ]] ; 42 | then 43 | cd office-home-rsut 44 | rm Art 45 | ln -s "${data_path}/Art" Art 46 | rm Clipart 47 | ln -s "${data_path}/Clipart" Clipart 48 | rm Product 49 | ln -s "${data_path}/Product" Product 50 | rm Real_World 51 | ln -s "${data_path}/Real_World" Real_World 52 | elif [[ ${dataset} == "visda" ]] ; 53 | then 54 | cd visda-2017 55 | rm train 56 | ln -s "${data_path}/train" train 57 | rm validation 58 | ln -s "${data_path}/validation" validation 59 | fi 60 | cd .. -------------------------------------------------------------------------------- /fig/PK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/fig/PK.png -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/fig/framework.png -------------------------------------------------------------------------------- /pklib/pksolver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gurobipy as grb 3 | import random 4 | 5 | class PK_solver(): 6 | def __init__(self, N, C, C_ub=[], C_br=[], pk_prior_weight=10.): 7 | self.N = N # number of samples 8 | self.C = C # number of classes 9 | self.C_ub = C_ub # constraints of unary bound 10 | self.C_br = C_br # constraints of binary relationship 11 | self.pk_prior_weight = pk_prior_weight 12 | 13 | # create unary bound constraints 14 | def create_C_ub(self, cls_probs, uconf=0.): 15 | ubs = cls_probs * (1 + uconf) 16 | lbs = cls_probs * (1 - uconf) 17 | ubs[ubs > 1.0] = 1.0 18 | lbs[lbs < 0.0] = 0.0 19 | ubs = (ubs*self.N).tolist() 20 | lbs = (lbs*self.N).tolist() 21 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs)) 22 | 23 | # create unary bound constraints with noises 24 | def create_C_ub_noisy(self, cls_probs, uconf=0., noise=0.): 25 | ubs = cls_probs * (1 + uconf) 26 | lbs = cls_probs * (1 - uconf) 27 | bias = (2*np.random.rand(len(cls_probs))-1)*cls_probs*noise 28 | bias -= bias.mean() 29 | ubs += bias 30 | lbs += bias 31 | ubs[ubs > 1.0] = 1.0 32 | lbs[lbs < 0.0] = 0.0 33 | ubs = (ubs*self.N).tolist() 34 | lbs = (lbs*self.N).tolist() 35 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs)) 36 | 37 | # create binary relationship constraints 38 | def create_C_br(self, cls_probs, uconf=0.): 39 | idx = np.argsort(-cls_probs) 40 | self.C_br = [(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)] 41 | 42 | # create binary relationship constraints with noises 43 | def create_C_br_noisy(self, cls_probs, uconf=0., noise=0.): 44 | idx = np.argsort(-cls_probs) 45 | C=len(idx) 46 | score = np.arange(C) + (2*np.random.rand(C)-1)*noise + np.random.rand(C)*0.0001 47 | idd = np.argsort(score) 48 | idx = idx[idd] 49 | self.C_br = [(idx[c], idx[c+1], 0) for c in range(self.C-1)] 50 | 51 | # create unary bound constraints from (head) partial classes 52 | def create_C_ub_partial(self, cls_probs, uconf=0., N=10): 53 | ubs = cls_probs * (1 + uconf) 54 | lbs = cls_probs * (1 - uconf) 55 | ubs[ubs > 1.0] = 1.0 56 | lbs[lbs < 0.0] = 0.0 57 | ubs = (ubs*self.N).tolist() 58 | lbs = (lbs*self.N).tolist() 59 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs))[:N] 60 | 61 | # create unary bound constraints from (tail) partial classes 62 | def create_C_ub_partial_reverse(self, cls_probs, uconf=0., N=10): 63 | ubs = cls_probs * (1 + uconf) 64 | lbs = cls_probs * (1 - uconf) 65 | ubs[ubs > 1.0] = 1.0 66 | lbs[lbs < 0.0] = 0.0 67 | ubs = (ubs*self.N).tolist() 68 | lbs = (lbs*self.N).tolist() 69 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs))[-N:] 70 | 71 | # create unary bound constraints from (random) partial classes 72 | def create_C_ub_partial_rand(self, cls_probs, uconf=0., N=10): 73 | ubs = cls_probs * (1 + uconf) 74 | lbs = cls_probs * (1 - uconf) 75 | ubs[ubs > 1.0] = 1.0 76 | lbs[lbs < 0.0] = 0.0 77 | ubs = (ubs*self.N).tolist() 78 | lbs = (lbs*self.N).tolist() 79 | self.C_ub = random.sample(list(zip(list(range(self.C)), lbs, ubs)), k=N) 80 | 81 | # create binary relationship constraints from (head) partial classes 82 | def create_C_br_partial(self, cls_probs, uconf=0., N=10): 83 | idx = np.argsort(-cls_probs) 84 | self.C_br = [(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)][:N] 85 | 86 | # create binary relationship constraints from (tail) partial classes 87 | def create_C_br_partial_reverse(self, cls_probs, uconf=0., N=10): 88 | idx = np.argsort(-cls_probs) 89 | self.C_br = [(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)][-N:] 90 | 91 | # create binary relationship constraints from (random) partial classes 92 | def create_C_br_partial_rand(self, cls_probs, uconf=0., N=10): 93 | idx = np.argsort(-cls_probs) 94 | self.C_br = random.sample([(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)], k=N) 95 | 96 | 97 | # solver with smooth regularization 98 | def solve_soft_knn_cst(self, probs, fix_set=[], fix_labels=[], knn_regs=[]): 99 | # fix_set and fix_labels are samples with given (pseudo) labels that do not require optimization 100 | fix_cls_probs = np.eye(self.C)[fix_labels].sum(0) 101 | 102 | # var_set are samples to refine (pseudo) labels 103 | var_set = list(set(range(self.N)) - set(fix_set)) 104 | Nvar = len(var_set) 105 | 106 | # create an optimization model 107 | LP = grb.Model(name="Prior Constraint Problem") 108 | x = {(n, c): LP.addVar(vtype=grb.GRB.BINARY, 109 | name="x_{0}_{1}".format(n, c)) 110 | for n in range(Nvar) for c in range(self.C)} 111 | 112 | LP.addConstrs( (grb.quicksum(x[n, c] for c in range(self.C))==1) for n in range(len(var_set))) 113 | 114 | objective = grb.quicksum(x[n, c] * probs[var_set[n], c] 115 | for n in range(Nvar) 116 | for c in range(self.C)) 117 | 118 | # add soft constraints of unary bound 119 | xi_ub = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar, 120 | name="xi_ub_{0}_{1}".format(c,k)) 121 | for c in range(len(self.C_ub)) for k in range(2)} 122 | 123 | xi_lb = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar, 124 | name="xi_lb_{0}_{1}".format(c,k)) 125 | for c in range(len(self.C_ub)) for k in range(2)} 126 | 127 | margin_ub = [] 128 | margin_lb = [] 129 | for i, (c, lb, ub) in enumerate(self.C_ub): 130 | if ub is not None: 131 | ub = ub - fix_cls_probs[c] 132 | margin_ub.append(grb.quicksum(x[n, c] for n in range(Nvar))-ub) 133 | else: 134 | margin_ub.append(0.) 135 | 136 | if lb is not None: 137 | lb = lb - fix_cls_probs[c] 138 | margin_lb.append( - grb.quicksum(x[n, c] for n in range(Nvar)) + lb) 139 | else: 140 | margin_lb.append(0.) 141 | 142 | 143 | LP.addConstrs( 144 | (xi_ub[i, 1] == margin_ub[i] for i in range(len(self.C_ub))), name="slack_ub_0" 145 | ) 146 | LP.addConstrs( 147 | (xi_ub[i, 0] == grb.max_(xi_ub[i, 1], 0) for i in range(len(self.C_ub))), name="slack_ub_1" 148 | ) 149 | 150 | LP.addConstrs( 151 | (xi_lb[i, 1] == margin_lb[i] for i in range(len(self.C_ub))), name="slack_lb_0" 152 | ) 153 | LP.addConstrs( 154 | (xi_lb[i, 0] == grb.max_(xi_lb[i, 1], 0) for i in range(len(self.C_ub))), name="slack_lb_1" 155 | ) 156 | 157 | constraint_ub = grb.quicksum(xi_ub[c, 0] for c in range(len(self.C_ub))) + \ 158 | grb.quicksum(xi_lb[c, 0] for c in range(len(self.C_ub))) 159 | 160 | constraint_ub /= (len(self.C_ub) * 2 + 1e-10) 161 | 162 | # add soft constraints of binary relationship 163 | margin_br = [] 164 | for (c1, c2, diff) in self.C_br: 165 | diff = diff - fix_cls_probs[c1] + fix_cls_probs[c2] 166 | margin_br.append( 167 | -grb.quicksum(x[n, c1] for n in range(Nvar)) + grb.quicksum(x[n, c2] for n in range(Nvar)) + diff) 168 | 169 | xi_br = {(c, k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-2 * Nvar, ub=2 * Nvar, 170 | name="xi_br_{0}_{1}".format(c, k)) 171 | for c in range(len(self.C_br)) for k in range(2)} 172 | 173 | LP.addConstrs( 174 | (xi_br[i, 1] == margin_br[i] for i in range(len(self.C_br))), name="slack_br_0" 175 | ) 176 | LP.addConstrs( 177 | (xi_br[i, 0] == grb.max_(xi_br[i, 1], 0) for i in range(len(self.C_br))), name="slack_br_1" 178 | ) 179 | 180 | constraint_br = grb.quicksum(xi_br[c, 0] for c in range(len(self.C_br))) 181 | constraint_br /= (len(self.C_br) + 1e-10) 182 | 183 | constraint = constraint_ub + constraint_br 184 | 185 | # add smooth regularization 186 | # currently it does NOT support fixset 187 | if len(knn_regs) > 0: 188 | LP.addConstrs( 189 | (x[knn_regs[i][0], c] == x[knn_regs[i][1][k], c] 190 | for i in range(len(knn_regs)) 191 | for k in range(len(knn_regs[i][1])) 192 | for c in range(self.C) ), name="smooth_regularization" 193 | ) 194 | 195 | LP.ModelSense = grb.GRB.MAXIMIZE 196 | LP.setObjective(objective - self.pk_prior_weight*constraint*Nvar) 197 | 198 | LP.optimize() 199 | 200 | # get refined (pseudo) labels from optimal solution 201 | var_labels = [] 202 | for n in range(Nvar): 203 | for c in range(self.C): 204 | var_labels.append(x[n, c].X) 205 | 206 | var_labels = np.array(var_labels) 207 | var_labels = var_labels.reshape([Nvar, self.C]) 208 | var_labels = np.argmax(var_labels, axis=-1) 209 | 210 | labels = np.zeros(self.N).astype(np.int32) 211 | labels[fix_set] = fix_labels 212 | labels[var_set] = var_labels 213 | 214 | return labels 215 | 216 | # solver without smooth regularization 217 | def solve_soft(self, probs, fix_set=[], fix_labels=[]): 218 | fix_cls_probs = np.eye(self.C)[fix_labels].sum(0) 219 | 220 | var_set = list(set(range(self.N)) - set(fix_set)) 221 | Nvar = len(var_set) 222 | 223 | LP = grb.Model(name="Prior Constraint Problem") 224 | x = {(n, c): LP.addVar(vtype=grb.GRB.BINARY, 225 | name="x_{0}_{1}".format(n, c)) 226 | for n in range(Nvar) for c in range(self.C)} 227 | 228 | LP.addConstrs( (grb.quicksum(x[n, c] for c in range(self.C))==1) for n in range(len(var_set))) 229 | 230 | objective = grb.quicksum(x[n, c] * probs[var_set[n], c] 231 | for n in range(Nvar) 232 | for c in range(self.C)) 233 | 234 | # add soft constraints of unary bound 235 | xi_ub = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar, 236 | name="xi_ub_{0}_{1}".format(c,k)) 237 | for c in range(len(self.C_ub)) for k in range(2)} 238 | 239 | xi_lb = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar, 240 | name="xi_lb_{0}_{1}".format(c,k)) 241 | for c in range(len(self.C_ub)) for k in range(2)} 242 | 243 | margin_ub = [] 244 | margin_lb = [] 245 | for i, (c, lb, ub) in enumerate(self.C_ub): 246 | if ub is not None: 247 | ub = ub - fix_cls_probs[c] 248 | margin_ub.append(grb.quicksum(x[n, c] for n in range(Nvar))-ub) 249 | else: 250 | margin_ub.append(0.) 251 | 252 | if lb is not None: 253 | lb = lb - fix_cls_probs[c] 254 | margin_lb.append( - grb.quicksum(x[n, c] for n in range(Nvar)) + lb) 255 | else: 256 | margin_lb.append(0.) 257 | 258 | 259 | LP.addConstrs( 260 | (xi_ub[i, 1] == margin_ub[i] for i in range(len(self.C_ub))), name="slack_ub_0" 261 | ) 262 | LP.addConstrs( 263 | (xi_ub[i, 0] == grb.max_(xi_ub[i, 1], 0) for i in range(len(self.C_ub))), name="slack_ub_1" 264 | ) 265 | 266 | LP.addConstrs( 267 | (xi_lb[i, 1] == margin_lb[i] for i in range(len(self.C_ub))), name="slack_lb_0" 268 | ) 269 | LP.addConstrs( 270 | (xi_lb[i, 0] == grb.max_(xi_lb[i, 1], 0) for i in range(len(self.C_ub))), name="slack_lb_1" 271 | ) 272 | 273 | constraint_ub = grb.quicksum(xi_ub[c,0] for c in range(len(self.C_ub))) + \ 274 | grb.quicksum(xi_lb[c,0] for c in range(len(self.C_ub))) 275 | 276 | constraint_ub /= (len(self.C_ub)*2 + 1e-10) 277 | 278 | # add soft constraints of binary relationship 279 | margin_br = [] 280 | for (c1, c2, diff) in self.C_br: 281 | diff = diff - fix_cls_probs[c1] + fix_cls_probs[c2] 282 | margin_br.append(-grb.quicksum(x[n, c1] for n in range(Nvar)) + grb.quicksum(x[n, c2] for n in range(Nvar)) + diff) 283 | 284 | xi_br = {(c, k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-2*Nvar, ub=2*Nvar, 285 | name="xi_br_{0}_{1}".format(c, k)) 286 | for c in range(len(self.C_br)) for k in range(2)} 287 | 288 | LP.addConstrs( 289 | (xi_br[i, 1] == margin_br[i] for i in range(len(self.C_br))), name="slack_br_0" 290 | ) 291 | LP.addConstrs( 292 | (xi_br[i, 0] == grb.max_(xi_br[i, 1], 0) for i in range(len(self.C_br))), name="slack_br_1" 293 | ) 294 | 295 | constraint_br = grb.quicksum(xi_br[c, 0] for c in range(len(self.C_br))) 296 | constraint_br /= (len(self.C_br) + 1e-10) 297 | 298 | constraint = constraint_ub + constraint_br 299 | 300 | 301 | LP.ModelSense = grb.GRB.MAXIMIZE 302 | LP.setObjective(objective - self.pk_prior_weight*constraint*Nvar) 303 | 304 | LP.optimize() 305 | 306 | # get refined (pseudo) labels from optimal solution 307 | var_labels = [] 308 | for n in range(Nvar): 309 | for c in range(self.C): 310 | var_labels.append(x[n, c].X) 311 | 312 | var_labels = np.array(var_labels) 313 | var_labels = var_labels.reshape([Nvar, self.C]) 314 | var_labels = np.argmax(var_labels, axis=-1) 315 | 316 | labels = np.zeros(self.N).astype(np.int32) 317 | labels[fix_set] = fix_labels 318 | labels[var_set] = var_labels 319 | 320 | return labels 321 | 322 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/util/__init__.py -------------------------------------------------------------------------------- /util/get_time.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def get_time(): 4 | return time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 5 | 6 | if __name__ == '__main__': 7 | print(get_time()) 8 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import socket 3 | import os 4 | import numpy as np 5 | import torch 6 | import random 7 | import logging 8 | pil_logger = logging.getLogger('PIL') 9 | pil_logger.setLevel(logging.INFO) 10 | 11 | def get_time(): 12 | return time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime()) 13 | 14 | def get_hostname(): 15 | return socket.gethostname() 16 | 17 | def get_pid(): 18 | return os.getpid() 19 | 20 | # set random number generators' seeds 21 | def resetRNGseed(seed): 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | import logging 29 | logger_init = False 30 | 31 | def init_logger(_log_file, use_file_logger=True, dir='log/'): 32 | if not os.path.exists(dir): 33 | os.makedirs(dir) 34 | log_file = os.path.join(dir, _log_file + '.log') 35 | #logging.basicConfig(filename=log_file, format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S',level=logging.DEBUG) 36 | logger = logging.getLogger() 37 | for handler in logger.handlers[:]: 38 | logger.removeHandler(handler) 39 | 40 | logger.setLevel('DEBUG') 41 | BASIC_FORMAT = "%(asctime)s:%(levelname)s:%(message)s" 42 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 43 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 44 | chlr = logging.StreamHandler() 45 | chlr.setFormatter(formatter) 46 | logger.addHandler(chlr) 47 | if use_file_logger: 48 | fhlr = logging.FileHandler(log_file) 49 | fhlr.setFormatter(formatter) 50 | logger.addHandler(fhlr) 51 | 52 | global logger_init 53 | logger_init = True --------------------------------------------------------------------------------