├── .gitignore
├── DINE
├── DINE_dist.py
├── DINE_dist_kDINE.py
├── DINE_ft.py
├── data_list.py
├── loss.py
├── network.py
└── run_all_kDINE.sh
├── LICENSE
├── README.md
├── SHOT
├── __init__.py
├── augmentations.py
├── data_list.py
├── image_source.py
├── image_target.py
├── image_target_kSHOT.py
├── loss.py
├── network.py
└── run_all_kSHOT.sh
├── data
├── domainnet40
│ └── image_list
│ │ ├── clipart_test_mini.txt
│ │ ├── clipart_train_mini.txt
│ │ ├── painting_test_mini.txt
│ │ ├── painting_train_mini.txt
│ │ ├── real_test_mini.txt
│ │ ├── real_train_mini.txt
│ │ ├── sketch_test_mini.txt
│ │ └── sketch_train_mini.txt
├── multi
│ └── image_list
│ │ ├── clipart.txt
│ │ ├── painting.txt
│ │ ├── real.txt
│ │ └── sketch.txt
├── office-home-rsut
│ └── image_list
│ │ ├── Clipart_RS.txt
│ │ ├── Clipart_UT.txt
│ │ ├── Product_RS.txt
│ │ ├── Product_UT.txt
│ │ ├── Real_World_RS.txt
│ │ └── Real_World_UT.txt
├── office-home
│ └── image_list
│ │ ├── Art.txt
│ │ ├── Clipart.txt
│ │ ├── Product.txt
│ │ └── Real_World.txt
├── office31
│ └── image_list
│ │ ├── amazon.txt
│ │ ├── dslr.txt
│ │ └── webcam.txt
├── setup_data_path.sh
└── visda-2017
│ └── image_list
│ ├── train.txt
│ └── validation.txt
├── fig
├── PK.png
└── framework.png
├── pklib
└── pksolver.py
└── util
├── __init__.py
├── get_time.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/DINE/DINE_dist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network
11 | import loss
12 | from torch.utils.data import DataLoader
13 | from data_list import ImageList, ImageList_idx
14 | from loss import CrossEntropyLabelSmooth
15 | from sklearn.metrics import confusion_matrix
16 | import distutils
17 | import distutils.util
18 | import logging
19 |
20 | import sys
21 | sys.path.append("../util/")
22 | from utils import resetRNGseed, init_logger, get_hostname, get_pid
23 |
24 | import time
25 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
26 |
27 | torch.backends.cudnn.deterministic = True
28 | torch.backends.cudnn.benchmark = False
29 |
30 | def op_copy(optimizer):
31 | for param_group in optimizer.param_groups:
32 | param_group['lr0'] = param_group['lr']
33 | return optimizer
34 |
35 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
36 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
37 | for param_group in optimizer.param_groups:
38 | param_group['lr'] = param_group['lr0'] * decay
39 | param_group['weight_decay'] = 1e-3
40 | param_group['momentum'] = 0.9
41 | param_group['nesterov'] = True
42 | return optimizer
43 |
44 | def image_train(resize_size=256, crop_size=224, alexnet=False):
45 | if not alexnet:
46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
47 | std=[0.229, 0.224, 0.225])
48 | else:
49 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
50 | return transforms.Compose([
51 | transforms.Resize((resize_size, resize_size)),
52 | transforms.RandomCrop(crop_size),
53 | transforms.RandomHorizontalFlip(),
54 | transforms.ToTensor(),
55 | normalize
56 | ])
57 |
58 | def image_test(resize_size=256, crop_size=224, alexnet=False):
59 | if not alexnet:
60 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
61 | std=[0.229, 0.224, 0.225])
62 | else:
63 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
64 | return transforms.Compose([
65 | transforms.Resize((resize_size, resize_size)),
66 | transforms.CenterCrop(crop_size),
67 | transforms.ToTensor(),
68 | normalize
69 | ])
70 |
71 | def data_load(args):
72 | ## prepare data
73 | dsets = {}
74 | dset_loaders = {}
75 | train_bs = args.batch_size
76 | txt_src = open(args.s_dset_path).readlines()
77 | txt_tar = open(args.t_dset_path).readlines()
78 | txt_test = open(args.test_dset_path).readlines()
79 |
80 | count = np.zeros(args.class_num)
81 | tr_txt = []
82 | te_txt = []
83 | for i in range(len(txt_src)):
84 | line = txt_src[i]
85 | reci = line.strip().split(' ')
86 | if count[int(reci[1])] < 3:
87 | count[int(reci[1])] += 1
88 | te_txt.append(line)
89 | else:
90 | tr_txt.append(line)
91 |
92 | if not args.da == 'uda':
93 | label_map_s = {}
94 | for i in range(len(args.src_classes)):
95 | label_map_s[args.src_classes[i]] = i
96 |
97 | new_tar = []
98 | for i in range(len(txt_tar)):
99 | rec = txt_tar[i]
100 | reci = rec.strip().split(' ')
101 | if int(reci[1]) in args.tar_classes:
102 | if int(reci[1]) in args.src_classes:
103 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
104 | new_tar.append(line)
105 | else:
106 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
107 | new_tar.append(line)
108 | txt_tar = new_tar.copy()
109 | txt_test = txt_tar.copy()
110 |
111 | dsets["source_tr"] = ImageList(tr_txt, root="../data/{}/".format(args.dset), transform=image_train())
112 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
113 | dsets["source_te"] = ImageList(te_txt, root="../data/{}/".format(args.dset), transform=image_test())
114 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
115 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train())
116 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
117 | dsets["target_te"] = ImageList(txt_tar, root="../data/{}/".format(args.dset), transform=image_test())
118 | dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False)
119 | dsets["test"] = ImageList(txt_test, root="../data/{}/".format(args.dset), transform=image_test())
120 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False)
121 |
122 | return dset_loaders
123 |
124 | def cal_acc(loader, netF, netB, netC, flag=False):
125 | start_test = True
126 | with torch.no_grad():
127 | iter_test = iter(loader)
128 | for i in range(len(loader)):
129 | data = iter_test.next()
130 | inputs = data[0]
131 | labels = data[1]
132 | inputs = inputs.cuda()
133 | if netB is None:
134 | outputs = netC(netF(inputs))
135 | else:
136 | outputs = netC(netB(netF(inputs)))
137 | if start_test:
138 | all_output = outputs.float().cpu()
139 | all_label = labels.float()
140 | start_test = False
141 | else:
142 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
143 | all_label = torch.cat((all_label, labels.float()), 0)
144 |
145 | all_output = nn.Softmax(dim=1)(all_output)
146 | _, predict = torch.max(all_output, 1)
147 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
148 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0])
149 |
150 | if flag:
151 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
152 | matrix = matrix[np.unique(all_label).astype(int),:]
153 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
154 | aacc = acc.mean()
155 | aa = [str(np.round(i, 2)) for i in acc]
156 | acc = ' '.join(aa)
157 | return aacc, acc, mean_ent
158 | else:
159 | return accuracy*100, mean_ent
160 |
161 | def train_source_simp(args):
162 | dset_loaders = data_load(args)
163 | if args.net_src[0:3] == 'res':
164 | netF = network.ResBase(res_name=args.net_src).cuda()
165 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda()
166 |
167 | param_group = []
168 | learning_rate = args.lr_src
169 | for k, v in netF.named_parameters():
170 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
171 | for k, v in netC.named_parameters():
172 | param_group += [{'params': v, 'lr': learning_rate}]
173 | optimizer = optim.SGD(param_group)
174 | optimizer = op_copy(optimizer)
175 |
176 | acc_init = 0
177 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
178 | interval_iter = max_iter // 10
179 | iter_num = 0
180 |
181 | netF.train()
182 | netC.train()
183 |
184 | while iter_num < max_iter:
185 | try:
186 | inputs_source, labels_source = iter_source.next()
187 | except:
188 | iter_source = iter(dset_loaders["source_tr"])
189 | inputs_source, labels_source = iter_source.next()
190 |
191 | if inputs_source.size(0) == 1:
192 | continue
193 |
194 | iter_num += 1
195 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
196 |
197 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
198 | outputs_source = netC(netF(inputs_source))
199 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0.1)(outputs_source, labels_source)
200 |
201 | optimizer.zero_grad()
202 | classifier_loss.backward()
203 | optimizer.step()
204 |
205 | if iter_num % interval_iter == 0 or iter_num == max_iter:
206 | netF.eval()
207 | netC.eval()
208 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, False)
209 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te)
210 | if args.dset == 'visda-2017':
211 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, True)
212 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter,
213 | acc_s_te) + '\n' + acc_list
214 | logging.info(log_str)
215 |
216 | if acc_s_te >= acc_init:
217 | acc_init = acc_s_te
218 | best_netF = netF.state_dict()
219 | best_netC = netC.state_dict()
220 |
221 | netF.train()
222 | netC.train()
223 |
224 | torch.save(best_netF, osp.join(args.output_dir_src,'{}_{}_source_F.pt'.format(args.s, args.net_src)))
225 | torch.save(best_netC, osp.join(args.output_dir_src, '{}_{}_source_C.pt'.format(args.s, args.net_src)))
226 |
227 | return netF, netC
228 |
229 | def test_target_simp(args):
230 | dset_loaders = data_load(args)
231 | if args.net_src[0:3] == 'res':
232 | netF = network.ResBase(res_name=args.net_src).cuda()
233 | netC = network.feat_classifier_simpl(class_num = args.class_num, feat_dim=netF.in_features).cuda()
234 |
235 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src)
236 | netF.load_state_dict(torch.load(args.modelpath))
237 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src)
238 | netC.load_state_dict(torch.load(args.modelpath))
239 | netF.eval()
240 | netC.eval()
241 |
242 | acc, _ = cal_acc(dset_loaders['test'], netF, None, netC, False)
243 | log_str = '\nTask: {}->{}, Accuracy = {:.2f}%'.format(args.s, args.t, acc)
244 | if args.dset == 'visda-2017':
245 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['test'], netF, None, netC, True)
246 | log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.s, acc_s_te) + '\n' + acc_list
247 |
248 | logging.info(log_str)
249 |
250 | def copy_target_simp(args):
251 | dset_loaders = data_load(args)
252 | if args.net_src[0:3] == 'res':
253 | netF = network.ResBase(res_name=args.net_src).cuda()
254 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda()
255 |
256 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src)
257 | netF.load_state_dict(torch.load(args.modelpath))
258 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src)
259 | netC.load_state_dict(torch.load(args.modelpath))
260 | source_model = nn.Sequential(netF, netC).cuda()
261 | source_model.eval()
262 |
263 | if args.net[0:3] == 'res':
264 | netF = network.ResBase(res_name=args.net, pretrain=True).cuda()
265 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
266 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
267 |
268 | param_group = []
269 | learning_rate = args.lr
270 | for k, v in netF.named_parameters():
271 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
272 | for k, v in netB.named_parameters():
273 | param_group += [{'params': v, 'lr': learning_rate}]
274 | for k, v in netC.named_parameters():
275 | param_group += [{'params': v, 'lr': learning_rate}]
276 | optimizer = optim.SGD(param_group)
277 | optimizer = op_copy(optimizer)
278 |
279 | ent_best = 1.0
280 | max_iter = args.max_epoch * len(dset_loaders["target"])
281 | interval_iter = max_iter // 10
282 | iter_num = 0
283 |
284 | model = nn.Sequential(netF, netB, netC).cuda()
285 | model.eval()
286 |
287 | start_test = True
288 | with torch.no_grad():
289 | iter_test = iter(dset_loaders["target_te"])
290 | for i in range(len(dset_loaders["target_te"])):
291 | data = iter_test.next()
292 | inputs, labels = data[0], data[1]
293 | inputs = inputs.cuda()
294 | outputs = source_model(inputs)
295 | outputs = nn.Softmax(dim=1)(outputs)
296 | _, src_idx = torch.sort(outputs, 1, descending=True)
297 | if args.topk > 0:
298 | topk = np.min([args.topk, args.class_num])
299 | for i in range(outputs.size()[0]):
300 | outputs[i, src_idx[i, topk:]] = (1.0 - outputs[i, src_idx[i, :topk]].sum())/ (outputs.size()[1] - topk)
301 |
302 | if start_test:
303 | all_output = outputs.float()
304 | all_label = labels
305 | start_test = False
306 | else:
307 | all_output = torch.cat((all_output, outputs.float()), 0)
308 | all_label = torch.cat((all_label, labels), 0)
309 | mem_P = all_output.detach()
310 |
311 | model.train()
312 | while iter_num < max_iter:
313 |
314 | if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0:
315 | model.eval()
316 | start_test = True
317 | with torch.no_grad():
318 | iter_test = iter(dset_loaders["target_te"])
319 | for i in range(len(dset_loaders["target_te"])):
320 | data = iter_test.next()
321 | inputs = data[0]
322 | inputs = inputs.cuda()
323 | outputs = model(inputs)
324 | outputs = nn.Softmax(dim=1)(outputs)
325 | if start_test:
326 | all_output = outputs.float()
327 | start_test = False
328 | else:
329 | all_output = torch.cat((all_output, outputs.float()), 0)
330 | mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema)
331 | model.train()
332 |
333 | try:
334 | inputs_target, y, tar_idx = iter_target.next()
335 | except:
336 | iter_target = iter(dset_loaders["target"])
337 | inputs_target, y, tar_idx = iter_target.next()
338 |
339 | if inputs_target.size(0) == 1:
340 | continue
341 |
342 | iter_num += 1
343 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5)
344 | inputs_target = inputs_target.cuda()
345 | with torch.no_grad():
346 | outputs_target_by_source = mem_P[tar_idx, :]
347 | _, src_idx = torch.sort(outputs_target_by_source, 1, descending=True)
348 | outputs_target = model(inputs_target)
349 | outputs_target = torch.nn.Softmax(dim=1)(outputs_target)
350 | classifier_loss = nn.KLDivLoss(reduction='batchmean')(outputs_target.log(), outputs_target_by_source)
351 | optimizer.zero_grad()
352 |
353 | entropy_loss = torch.mean(loss.Entropy(outputs_target))
354 | msoftmax = outputs_target.mean(dim=0)
355 | gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5))
356 | entropy_loss -= gentropy_loss
357 | classifier_loss += entropy_loss
358 |
359 | classifier_loss.backward()
360 |
361 | if args.mix > 0:
362 | alpha = 0.3
363 | lam = np.random.beta(alpha, alpha)
364 | index = torch.randperm(inputs_target.size()[0]).cuda()
365 | mixed_input = lam * inputs_target + (1 - lam) * inputs_target[index, :]
366 | mixed_output = (lam * outputs_target + (1 - lam) * outputs_target[index, :]).detach()
367 |
368 | update_batch_stats(model, False)
369 | outputs_target_m = model(mixed_input)
370 | update_batch_stats(model, True)
371 | outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m)
372 | classifier_loss = args.mix*nn.KLDivLoss(reduction='batchmean')(outputs_target_m.log(), mixed_output)
373 | classifier_loss.backward()
374 | optimizer.step()
375 |
376 | if iter_num % interval_iter == 0 or iter_num == max_iter:
377 | model.eval()
378 | acc_s_te, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False)
379 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent)
380 | if args.dset == 'visda-2017':
381 | acc_s_te, acc_list, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True)
382 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter,
383 | acc_s_te, mean_ent) + '\n' + acc_list
384 |
385 | logging.info(log_str)
386 | model.train()
387 |
388 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F".format(args.timestamp, args.s, args.t, args.net) + ".pt"))
389 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B".format(args.timestamp, args.s, args.t, args.net) + ".pt"))
390 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C".format(args.timestamp, args.s, args.t, args.net) + ".pt"))
391 |
392 | def update_batch_stats(model, flag):
393 | for m in model.modules():
394 | if isinstance(m, nn.BatchNorm2d):
395 | m.update_batch_stats = flag
396 |
397 | def print_args(args):
398 | s = "==========================================\n"
399 | for arg, content in args.__dict__.items():
400 | s += "{}:{}\n".format(arg, content)
401 | return s
402 |
403 | if __name__ == "__main__":
404 | parser = argparse.ArgumentParser(description='DINE')
405 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
406 | parser.add_argument('--s', type=str, default=None, help="source")
407 | parser.add_argument('--t', type=str, default=None, help="target")
408 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
409 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
410 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
411 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'office31', 'image-clef', 'office-home', 'office-caltech'])
412 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
413 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
414 | parser.add_argument('--output', type=str, default='san')
415 | parser.add_argument('--lr_src', type=float, default=1e-2, help="learning rate")
416 | parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
417 | parser.add_argument('--output_src', type=str, default='san')
418 |
419 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
420 | parser.add_argument('--bottleneck', type=int, default=256)
421 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
422 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
423 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
424 | parser.add_argument('--topk', type=int, default=1)
425 |
426 | parser.add_argument('--distill', action='store_true')
427 | parser.add_argument('--ema', type=float, default=0.6)
428 | parser.add_argument('--mix', type=float, default=1.0)
429 |
430 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp')
431 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)),
432 | help='whether use file logger')
433 | parser.add_argument('--names', default=[], type=list, help='names of tasks')
434 |
435 | parser.add_argument('--method', type=str, default="dine")
436 |
437 | args = parser.parse_args()
438 | if args.dset == 'office-home':
439 | args.names = ['Art', 'Clipart', 'Product', 'Real_World']
440 | args.class_num = 65
441 | if args.dset == 'visda-2017':
442 | args.names = ['train', 'validation']
443 | args.class_num = 12
444 | if args.dset == 'office31':
445 | args.names = ['amazon', 'dslr', 'webcam']
446 | args.class_num = 31
447 |
448 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
449 | resetRNGseed(args.seed)
450 |
451 | if not args.distill:
452 | dir = "{}_{}_{}_{}_source".format(args.timestamp, args.s, args.da, args.method)
453 | else:
454 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method)
455 | if args.use_file_logger:
456 | init_logger(dir, True, '../logs/DINE/{}/'.format(args.method))
457 | logging.info("{}:{}".format(get_hostname(), get_pid()))
458 |
459 | folder = '../data/'
460 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
461 | args.t_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
462 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
463 |
464 | if args.dset == 'office-home':
465 | if args.da == 'pda':
466 | args.class_num = 65
467 | args.src_classes = [i for i in range(65)]
468 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
469 |
470 |
471 | args.output_dir_src = "../checkpoints/DINE/{}/source/{}/".format(args.seed, args.da)
472 |
473 | if not osp.exists(args.output_dir_src):
474 | os.system('mkdir -p ' + args.output_dir_src)
475 | if not osp.exists(args.output_dir_src):
476 | os.mkdir(args.output_dir_src)
477 |
478 | if not args.distill:
479 | logging.info(print_args(args))
480 | train_source_simp(args)
481 |
482 | for t in args.names:
483 | if t == args.s:
484 | continue
485 | args.t = t
486 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
487 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
488 |
489 | test_target_simp(args)
490 |
491 | if args.distill:
492 | for t in args.names:
493 | if t == args.s:
494 | continue
495 | args.t = t
496 |
497 | args.output_dir = "../checkpoints/DINE/{}/target/{}/".format(args.seed, args.da)
498 | if not osp.exists(args.output_dir):
499 | os.system('mkdir -p ' + args.output_dir)
500 | if not osp.exists(args.output_dir):
501 | os.mkdir(args.output_dir)
502 |
503 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
504 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
505 |
506 | logging.info(print_args(args))
507 |
508 | copy_target_simp(args)
--------------------------------------------------------------------------------
/DINE/DINE_dist_kDINE.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torchvision import transforms
8 | import network
9 | import loss
10 | from torch.utils.data import DataLoader
11 | from data_list import ImageList, ImageList_idx
12 | from loss import CrossEntropyLabelSmooth
13 | from scipy.spatial.distance import cdist
14 | from sklearn.metrics import confusion_matrix
15 | import distutils
16 | import distutils.util
17 | import logging
18 |
19 | import sys, os
20 | sys.path.append("../util/")
21 | from utils import resetRNGseed, init_logger, get_hostname, get_pid
22 | sys.path.append("../pklib")
23 | from pksolver import PK_solver
24 |
25 | import time
26 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
27 |
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 |
31 | def op_copy(optimizer):
32 | for param_group in optimizer.param_groups:
33 | param_group['lr0'] = param_group['lr']
34 | return optimizer
35 |
36 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
37 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
38 | for param_group in optimizer.param_groups:
39 | param_group['lr'] = param_group['lr0'] * decay
40 | param_group['weight_decay'] = 1e-3
41 | param_group['momentum'] = 0.9
42 | param_group['nesterov'] = True
43 | return optimizer
44 |
45 | def image_train(resize_size=256, crop_size=224, alexnet=False):
46 | if not alexnet:
47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
48 | std=[0.229, 0.224, 0.225])
49 | else:
50 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
51 | return transforms.Compose([
52 | transforms.Resize((resize_size, resize_size)),
53 | transforms.RandomCrop(crop_size),
54 | transforms.RandomHorizontalFlip(),
55 | transforms.ToTensor(),
56 | normalize
57 | ])
58 |
59 | def image_test(resize_size=256, crop_size=224, alexnet=False):
60 | if not alexnet:
61 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
62 | std=[0.229, 0.224, 0.225])
63 | else:
64 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
65 | return transforms.Compose([
66 | transforms.Resize((resize_size, resize_size)),
67 | transforms.CenterCrop(crop_size),
68 | transforms.ToTensor(),
69 | normalize
70 | ])
71 |
72 | def data_load(args):
73 | ## prepare data
74 | dsets = {}
75 | dset_loaders = {}
76 | train_bs = args.batch_size
77 | txt_src = open(args.s_dset_path).readlines()
78 | txt_tar = open(args.t_dset_path).readlines()
79 | txt_test = open(args.test_dset_path).readlines()
80 |
81 | count = np.zeros(args.class_num)
82 | tr_txt = []
83 | te_txt = []
84 | for i in range(len(txt_src)):
85 | line = txt_src[i]
86 | reci = line.strip().split(' ')
87 | if count[int(reci[1])] < 3:
88 | count[int(reci[1])] += 1
89 | te_txt.append(line)
90 | else:
91 | tr_txt.append(line)
92 |
93 | if not args.da == 'uda':
94 | label_map_s = {}
95 | for i in range(len(args.src_classes)):
96 | label_map_s[args.src_classes[i]] = i
97 |
98 | new_tar = []
99 | for i in range(len(txt_tar)):
100 | rec = txt_tar[i]
101 | reci = rec.strip().split(' ')
102 | if int(reci[1]) in args.tar_classes:
103 | if int(reci[1]) in args.src_classes:
104 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
105 | new_tar.append(line)
106 | else:
107 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
108 | new_tar.append(line)
109 | txt_tar = new_tar.copy()
110 | txt_test = txt_tar.copy()
111 |
112 | dsets["source_tr"] = ImageList(tr_txt, root="../data/{}/".format(args.dset), transform=image_train())
113 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
114 | dsets["source_te"] = ImageList(te_txt, root="../data/{}/".format(args.dset), transform=image_test())
115 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
116 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train())
117 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
118 | dsets["target_te"] = ImageList(txt_tar, root="../data/{}/".format(args.dset), transform=image_test())
119 | dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False)
120 | dsets["test"] = ImageList(txt_test, root="../data/{}/".format(args.dset), transform=image_test())
121 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False)
122 |
123 | return dset_loaders
124 |
125 | def cal_acc(loader, netF, netB, netC, flag=False):
126 | start_test = True
127 | with torch.no_grad():
128 | iter_test = iter(loader)
129 | for i in range(len(loader)):
130 | data = iter_test.next()
131 | inputs = data[0]
132 | labels = data[1]
133 | inputs = inputs.cuda()
134 | if netB is None:
135 | outputs = netC(netF(inputs))
136 | else:
137 | outputs = netC(netB(netF(inputs)))
138 | if start_test:
139 | all_output = outputs.float().cpu()
140 | all_label = labels.float()
141 | start_test = False
142 | else:
143 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
144 | all_label = torch.cat((all_label, labels.float()), 0)
145 |
146 | all_output = nn.Softmax(dim=1)(all_output)
147 | _, predict = torch.max(all_output, 1)
148 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
149 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0])
150 |
151 | if flag:
152 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
153 | matrix = matrix[np.unique(all_label).astype(int),:]
154 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
155 | aacc = acc.mean()
156 | aa = [str(np.round(i, 2)) for i in acc]
157 | acc = ' '.join(aa)
158 | return aacc, acc, mean_ent
159 | else:
160 | return accuracy*100, mean_ent
161 |
162 | def train_source_simp(args):
163 | dset_loaders = data_load(args)
164 | if args.net_src[0:3] == 'res':
165 | netF = network.ResBase(res_name=args.net_src).cuda()
166 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda()
167 |
168 | param_group = []
169 | learning_rate = args.lr_src
170 | for k, v in netF.named_parameters():
171 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
172 | for k, v in netC.named_parameters():
173 | param_group += [{'params': v, 'lr': learning_rate}]
174 | optimizer = optim.SGD(param_group)
175 | optimizer = op_copy(optimizer)
176 |
177 | acc_init = 0
178 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
179 | interval_iter = max_iter // 10
180 | iter_num = 0
181 |
182 | netF.train()
183 | netC.train()
184 |
185 | while iter_num < max_iter:
186 | try:
187 | inputs_source, labels_source = iter_source.next()
188 | except:
189 | iter_source = iter(dset_loaders["source_tr"])
190 | inputs_source, labels_source = iter_source.next()
191 |
192 | if inputs_source.size(0) == 1:
193 | continue
194 |
195 | iter_num += 1
196 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
197 |
198 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
199 | outputs_source = netC(netF(inputs_source))
200 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0.1)(outputs_source, labels_source)
201 |
202 | optimizer.zero_grad()
203 | classifier_loss.backward()
204 | optimizer.step()
205 |
206 | if iter_num % interval_iter == 0 or iter_num == max_iter:
207 | netF.eval()
208 | netC.eval()
209 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, False)
210 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te)
211 | if args.dset == 'visda-2017':
212 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['source_te'], netF, None, netC, True)
213 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter,
214 | acc_s_te) + '\n' + acc_list
215 | logging.info(log_str)
216 |
217 | if acc_s_te >= acc_init:
218 | acc_init = acc_s_te
219 | best_netF = netF.state_dict()
220 | best_netC = netC.state_dict()
221 |
222 | netF.train()
223 | netC.train()
224 |
225 | torch.save(best_netF, osp.join(args.output_dir_src,'{}_{}_source_F.pt'.format(args.s, args.net_src)))
226 | torch.save(best_netC, osp.join(args.output_dir_src, '{}_{}_source_C.pt'.format(args.s, args.net_src)))
227 |
228 | return netF, netC
229 |
230 | def test_target_simp(args):
231 | dset_loaders = data_load(args)
232 | if args.net_src[0:3] == 'res':
233 | netF = network.ResBase(res_name=args.net_src).cuda()
234 | netC = network.feat_classifier_simpl(class_num = args.class_num, feat_dim=netF.in_features).cuda()
235 |
236 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src)
237 | netF.load_state_dict(torch.load(args.modelpath))
238 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src)
239 | netC.load_state_dict(torch.load(args.modelpath))
240 | netF.eval()
241 | netC.eval()
242 |
243 | acc, _ = cal_acc(dset_loaders['test'], netF, None, netC, False)
244 | log_str = '\nTask: {}->{}, Accuracy = {:.2f}%'.format(args.s, args.t, acc)
245 | if args.dset == 'visda-2017':
246 | acc_s_te, acc_list, _ = cal_acc(dset_loaders['test'], netF, None, netC, True)
247 | log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.s, acc_s_te) + '\n' + acc_list
248 |
249 | logging.info(log_str)
250 |
251 | def copy_target_simp(args):
252 | dset_loaders = data_load(args)
253 | if args.net_src[0:3] == 'res':
254 | netF = network.ResBase(res_name=args.net_src).cuda()
255 | netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda()
256 |
257 | args.modelpath = args.output_dir_src + '/{}_{}_source_F.pt'.format(args.s, args.net_src)
258 | netF.load_state_dict(torch.load(args.modelpath))
259 | args.modelpath = args.output_dir_src + '/{}_{}_source_C.pt'.format(args.s, args.net_src)
260 | netC.load_state_dict(torch.load(args.modelpath))
261 | source_model = nn.Sequential(netF, netC).cuda()
262 | source_model.eval()
263 |
264 | if args.net[0:3] == 'res':
265 | netF = network.ResBase(res_name=args.net, pretrain=True).cuda()
266 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
267 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
268 |
269 | param_group = []
270 | learning_rate = args.lr
271 | for k, v in netF.named_parameters():
272 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
273 | for k, v in netB.named_parameters():
274 | param_group += [{'params': v, 'lr': learning_rate}]
275 | for k, v in netC.named_parameters():
276 | param_group += [{'params': v, 'lr': learning_rate}]
277 | optimizer = optim.SGD(param_group)
278 | optimizer = op_copy(optimizer)
279 |
280 | ent_best = 1.0
281 | max_iter = args.max_epoch * len(dset_loaders["target"])
282 | interval_iter = max_iter // 10
283 | iter_num = 0
284 |
285 | model = nn.Sequential(netF, netB, netC).cuda()
286 | model.eval()
287 |
288 | start_test = True
289 | with torch.no_grad():
290 | iter_test = iter(dset_loaders["target_te"])
291 | for i in range(len(dset_loaders["target_te"])):
292 | data = iter_test.next()
293 | inputs, labels = data[0], data[1]
294 | inputs = inputs.cuda()
295 | outputs = source_model(inputs)
296 | outputs = nn.Softmax(dim=1)(outputs)
297 | _, src_idx = torch.sort(outputs, 1, descending=True)
298 | if args.topk > 0:
299 | topk = np.min([args.topk, args.class_num])
300 | for i in range(outputs.size()[0]):
301 | outputs[i, src_idx[i, topk:]] = (1.0 - outputs[i, src_idx[i, :topk]].sum())/ (outputs.size()[1] - topk)
302 |
303 | if start_test:
304 | all_output = outputs.float()
305 | all_label = labels
306 | start_test = False
307 | else:
308 | all_output = torch.cat((all_output, outputs.float()), 0)
309 | all_label = torch.cat((all_label, labels), 0)
310 | mem_P = all_output.detach()
311 |
312 | # get ground-truth label probabilities of target domain
313 | cls_probs = torch.eye(args.class_num)[all_label].sum(0)
314 | cls_probs = cls_probs / cls_probs.sum()
315 |
316 | pk_solver = PK_solver(all_label.shape[0], args.class_num, pk_prior_weight=args.pk_prior_weight)
317 | if args.pk_type == 'ub':
318 | pk_solver.create_C_ub(cls_probs.cpu().numpy(), args.pk_uconf)
319 | elif args.pk_type == 'br':
320 | pk_solver.create_C_br(cls_probs.cpu().numpy(), args.pk_uconf)
321 | else:
322 | raise NotImplementedError
323 |
324 | mem_label = obtain_label(mem_P.cpu(), all_label.cpu(), None, args, pk_solver)
325 | mem_label = torch.from_numpy(mem_label).cuda()
326 | mem_label = torch.eye(args.class_num)[mem_label].cuda()
327 |
328 | model.train()
329 | while iter_num < max_iter:
330 |
331 | if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0:
332 | model.eval()
333 | start_test = True
334 | with torch.no_grad():
335 | iter_test = iter(dset_loaders["target_te"])
336 | for i in range(len(dset_loaders["target_te"])):
337 | data = iter_test.next()
338 | inputs = data[0]
339 | inputs = inputs.cuda()
340 | outputs = model(inputs)
341 | feas = model[1](model[0](inputs))
342 | outputs = nn.Softmax(dim=1)(outputs)
343 | if start_test:
344 | all_fea = feas.float().cpu()
345 | all_output = outputs.float()
346 | start_test = False
347 | else:
348 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
349 | all_output = torch.cat((all_output, outputs.float()), 0)
350 | mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema)
351 | model.train()
352 |
353 | mem_label = obtain_label(mem_P.cpu(), all_label.cpu(), all_fea, args, pk_solver)
354 | mem_label = torch.from_numpy(mem_label).cuda()
355 | mem_label = torch.eye(args.class_num)[mem_label].cuda()
356 |
357 | try:
358 | inputs_target, y, tar_idx = iter_target.next()
359 | except:
360 | iter_target = iter(dset_loaders["target"])
361 | inputs_target, y, tar_idx = iter_target.next()
362 |
363 | if inputs_target.size(0) == 1:
364 | continue
365 |
366 | iter_num += 1
367 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5)
368 | inputs_target = inputs_target.cuda()
369 | with torch.no_grad():
370 | outputs_target_by_source = mem_P[tar_idx, :]
371 | _, src_idx = torch.sort(outputs_target_by_source, 1, descending=True)
372 | outputs_target = model(inputs_target)
373 | outputs_target = torch.nn.Softmax(dim=1)(outputs_target)
374 |
375 | target = (outputs_target_by_source + mem_label[tar_idx, :]*0.9 + 1/mem_label.shape[-1]*0.1) / 2
376 | if iter_num < interval_iter and args.dset == "visda-2017":
377 | target = outputs_target_by_source
378 |
379 | classifier_loss = nn.KLDivLoss(reduction='batchmean')(outputs_target.log(), target)
380 | optimizer.zero_grad()
381 |
382 | entropy_loss = torch.mean(loss.Entropy(outputs_target))
383 | msoftmax = outputs_target.mean(dim=0)
384 | gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5))
385 | entropy_loss -= gentropy_loss
386 | classifier_loss += entropy_loss
387 |
388 | classifier_loss.backward()
389 |
390 | if args.mix > 0:
391 | alpha = 0.3
392 | lam = np.random.beta(alpha, alpha)
393 | index = torch.randperm(inputs_target.size()[0]).cuda()
394 | mixed_input = lam * inputs_target + (1 - lam) * inputs_target[index, :]
395 | mixed_output = (lam * outputs_target + (1 - lam) * outputs_target[index, :]).detach()
396 |
397 | update_batch_stats(model, False)
398 | outputs_target_m = model(mixed_input)
399 | update_batch_stats(model, True)
400 | outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m)
401 | classifier_loss = args.mix*nn.KLDivLoss(reduction='batchmean')(outputs_target_m.log(), mixed_output)
402 | classifier_loss.backward()
403 | optimizer.step()
404 |
405 | if iter_num % interval_iter == 0 or iter_num == max_iter:
406 | model.eval()
407 | acc_s_te, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False)
408 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent)
409 | if args.dset == 'visda-2017':
410 | acc_s_te, acc_list, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True)
411 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter,
412 | acc_s_te, mean_ent) + '\n' + acc_list
413 | logging.info(log_str)
414 | model.train()
415 |
416 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F".format(args.timestamp, args.s, args.t, args.net) + ".pt"))
417 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B".format(args.timestamp, args.s, args.t, args.net) + ".pt"))
418 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C".format(args.timestamp, args.s, args.t, args.net) + ".pt"))
419 |
420 |
421 | def obtain_label(mem_P, all_label, all_fea, args, pk_solver):
422 | predict = mem_P.argmax(-1)
423 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
424 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
425 | avg_accuracy = (matrix.diagonal() / matrix.sum(axis=1)).mean()
426 |
427 | # update labels with prior knowledge
428 | probs = mem_P
429 | # first solve without smooth regularization
430 | pred_label_PK = pk_solver.solve_soft(probs)
431 |
432 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / float(all_label.size()[0])
433 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK)
434 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean()
435 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc_PK * 100, avg_accuracy * 100, avg_acc_PK * 100)
436 | logging.info(log_str)
437 |
438 | if args.pk_knn > 0 and all_fea is not None:
439 | # now solve with smooth regularization
440 | predict = predict.cpu().numpy()
441 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
442 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
443 | all_fea = all_fea.float().cpu().numpy()
444 |
445 | idx_unconf = np.where(pred_label_PK != predict)[0]
446 | knn_sample_idx = idx_unconf
447 | idx_conf = np.where(pred_label_PK == predict)[0]
448 |
449 | if len(idx_unconf) > 0 and len(idx_conf) > 0:
450 | # get knn of each samples
451 | dd_knn = cdist(all_fea[idx_unconf], all_fea[idx_conf], args.distance)
452 | knn_idx = []
453 | K = args.pk_knn
454 | for i in range(dd_knn.shape[0]):
455 | ind = np.argpartition(dd_knn[i], K)[:K]
456 | knn_idx.append(idx_conf[ind])
457 |
458 | knn_idx = np.stack(knn_idx, axis=0)
459 | knn_regs = list(zip(knn_sample_idx, knn_idx))
460 | pred_label_PK = pk_solver.solve_soft_knn_cst(probs, knn_regs=knn_regs)
461 |
462 |
463 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / len(all_fea)
464 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK)
465 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean()
466 | if args.da == 'pda':
467 | avg_acc_PK = 0
468 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc_PK * 100, avg_accuracy * 100, avg_acc_PK * 100)
469 | logging.info(log_str)
470 |
471 | return pred_label_PK.astype('int')
472 |
473 | def update_batch_stats(model, flag):
474 | for m in model.modules():
475 | if isinstance(m, nn.BatchNorm2d):
476 | m.update_batch_stats = flag
477 |
478 | def print_args(args):
479 | s = "==========================================\n"
480 | for arg, content in args.__dict__.items():
481 | s += "{}:{}\n".format(arg, content)
482 | return s
483 |
484 | if __name__ == "__main__":
485 | parser = argparse.ArgumentParser(description='DINE')
486 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
487 | parser.add_argument('--s', type=str, default=None, help="source")
488 | parser.add_argument('--t', type=str, default=None, help="target")
489 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
490 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
491 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
492 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'office31', 'image-clef', 'office-home', 'office-caltech'])
493 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
494 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
495 | parser.add_argument('--output', type=str, default='san')
496 | parser.add_argument('--lr_src', type=float, default=1e-2, help="learning rate")
497 | parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
498 | parser.add_argument('--output_src', type=str, default='san')
499 |
500 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
501 | parser.add_argument('--bottleneck', type=int, default=256)
502 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
503 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
504 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
505 | parser.add_argument('--topk', type=int, default=1)
506 |
507 | parser.add_argument('--distill', action='store_true')
508 | parser.add_argument('--ema', type=float, default=0.6)
509 | parser.add_argument('--mix', type=float, default=1.0)
510 |
511 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp')
512 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)),
513 | help='whether use file logger')
514 | parser.add_argument('--names', default=[], type=list, help='names of tasks')
515 |
516 | parser.add_argument('--cls_par', type=float, default=0.3)
517 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
518 |
519 | parser.add_argument('--pk_uconf', type=float, default=0.0)
520 | parser.add_argument('--pk_type', type=str, default="ub")
521 | parser.add_argument('--pk_allow', type=int, default=None)
522 | parser.add_argument('--pk_temp', type=float, default=1.0)
523 | parser.add_argument('--pk_prior_weight', type=float, default=10.)
524 | parser.add_argument('--pk_knn', type=int, default=1)
525 | parser.add_argument('--method', type=str, default="kdine")
526 |
527 | args = parser.parse_args()
528 |
529 | if args.dset == 'office-home':
530 | args.names = ['Art', 'Clipart', 'Product', 'Real_World']
531 | args.class_num = 65
532 | if args.dset == 'visda-2017':
533 | args.names = ['train', 'validation']
534 | args.class_num = 12
535 | if args.dset == 'office31':
536 | args.names = ['amazon', 'dslr', 'webcam']
537 | args.class_num = 31
538 |
539 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
540 | resetRNGseed(args.seed)
541 |
542 | if not args.distill:
543 | dir = "{}_{}_{}_{}_source".format(args.timestamp, args.s, args.da, args.method)
544 | else:
545 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method)
546 | if args.use_file_logger:
547 | init_logger(dir, True, '../logs/DINE/{}/'.format(args.method))
548 | logging.info("{}:{}".format(get_hostname(), get_pid()))
549 |
550 | folder = '../data/'
551 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
552 | args.t_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
553 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
554 |
555 | if args.dset == 'office-home':
556 | if args.da == 'pda':
557 | args.class_num = 65
558 | args.src_classes = [i for i in range(65)]
559 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
560 |
561 | args.output_dir_src = "../checkpoints/DINE/{}/source/{}/".format(args.seed, args.da)
562 |
563 | if not osp.exists(args.output_dir_src):
564 | os.system('mkdir -p ' + args.output_dir_src)
565 | if not osp.exists(args.output_dir_src):
566 | os.mkdir(args.output_dir_src)
567 |
568 | if not args.distill:
569 | logging.info(print_args(args))
570 | train_source_simp(args)
571 |
572 | for t in args.names:
573 | if t == args.s:
574 | continue
575 | args.t = t
576 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
577 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
578 |
579 | test_target_simp(args)
580 |
581 | if args.distill:
582 | for t in args.names:
583 | if t == args.s:
584 | continue
585 | args.t = t
586 | args.output_dir = "../checkpoints/DINE/{}/target/{}/".format(args.seed, args.da)
587 | if not osp.exists(args.output_dir):
588 | os.system('mkdir -p ' + args.output_dir)
589 | if not osp.exists(args.output_dir):
590 | os.mkdir(args.output_dir)
591 |
592 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
593 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
594 |
595 | logging.info(print_args(args))
596 |
597 | copy_target_simp(args)
--------------------------------------------------------------------------------
/DINE/DINE_ft.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from scipy.spatial.distance import cdist
16 | from sklearn.metrics import confusion_matrix
17 | import distutils
18 | import distutils.util
19 | import logging
20 |
21 | import sys
22 | sys.path.append("../util/")
23 | from utils import resetRNGseed, init_logger, get_hostname, get_pid
24 |
25 | import time
26 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
27 |
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 |
31 | def op_copy(optimizer):
32 | for param_group in optimizer.param_groups:
33 | param_group['lr0'] = param_group['lr']
34 | return optimizer
35 |
36 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
37 | decay = (11 + gamma * iter_num / max_iter) ** (-power)
38 | # decay = (1 + gamma) ** (-power)
39 | for param_group in optimizer.param_groups:
40 | param_group['lr'] = param_group['lr0'] * decay
41 | param_group['weight_decay'] = 1e-3
42 | param_group['momentum'] = 0.9
43 | param_group['nesterov'] = True
44 | return optimizer
45 |
46 | def image_train(resize_size=256, crop_size=224, alexnet=False):
47 | if not alexnet:
48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
49 | std=[0.229, 0.224, 0.225])
50 | else:
51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
52 | return transforms.Compose([
53 | transforms.Resize((resize_size, resize_size)),
54 | transforms.RandomCrop(crop_size),
55 | transforms.RandomHorizontalFlip(),
56 | transforms.ToTensor(),
57 | normalize
58 | ])
59 |
60 | def image_test(resize_size=256, crop_size=224, alexnet=False):
61 | if not alexnet:
62 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
63 | std=[0.229, 0.224, 0.225])
64 | else:
65 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
66 | return transforms.Compose([
67 | transforms.Resize((resize_size, resize_size)),
68 | transforms.CenterCrop(crop_size),
69 | transforms.ToTensor(),
70 | normalize
71 | ])
72 |
73 | def data_load(args):
74 | ## prepare data
75 | dsets = {}
76 | dset_loaders = {}
77 | train_bs = args.batch_size
78 | txt_tar = open(args.t_dset_path).readlines()
79 | txt_test = open(args.test_dset_path).readlines()
80 |
81 | if not args.da == 'uda':
82 | label_map_s = {}
83 | for i in range(len(args.src_classes)):
84 | label_map_s[args.src_classes[i]] = i
85 |
86 | new_tar = []
87 | for i in range(len(txt_tar)):
88 | rec = txt_tar[i]
89 | reci = rec.strip().split(' ')
90 | if int(reci[1]) in args.tar_classes:
91 | if int(reci[1]) in args.src_classes:
92 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
93 | new_tar.append(line)
94 | else:
95 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
96 | new_tar.append(line)
97 | txt_tar = new_tar.copy()
98 | txt_test = txt_tar.copy()
99 |
100 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train())
101 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
102 | dsets["test"] = ImageList_idx(txt_test, root="../data/{}/".format(args.dset), transform=image_test())
103 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
104 | dsets["target_te"] = ImageList(txt_tar, root="../data/{}/".format(args.dset), transform=image_test())
105 | dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
106 |
107 | return dset_loaders
108 |
109 | def cal_acc(loader, netF, netB, netC, flag=False):
110 | start_test = True
111 | with torch.no_grad():
112 | iter_test = iter(loader)
113 | for i in range(len(loader)):
114 | data = iter_test.next()
115 | inputs = data[0]
116 | labels = data[1]
117 | inputs = inputs.cuda()
118 | outputs = netC(netB(netF(inputs)))
119 | if start_test:
120 | all_output = outputs.float().cpu()
121 | all_label = labels.float()
122 | start_test = False
123 | else:
124 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
125 | all_label = torch.cat((all_label, labels.float()), 0)
126 | all_output = nn.Softmax(dim=1)(all_output)
127 | _, predict = torch.max(all_output, 1)
128 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
129 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0])
130 |
131 | if flag:
132 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
133 | matrix = matrix[np.unique(all_label).astype(int),:]
134 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
135 | aacc = acc.mean()
136 | aa = [str(np.round(i, 2)) for i in acc]
137 | acc = ' '.join(aa)
138 | return aacc, acc, predict, mean_ent
139 | else:
140 | return accuracy*100, mean_ent, predict, mean_ent
141 |
142 | def train_target(args):
143 | dset_loaders = data_load(args)
144 | if args.net[0:3] == 'res':
145 | netF = network.ResBase(res_name=args.net).cuda()
146 |
147 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
148 | netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda()
149 |
150 | modelpath = osp.join(args.output_dir, "{}_{}_{}_{}_target_F".format(args.timestamp, args.s, args.t, args.net) + ".pt" )
151 | netF.load_state_dict(torch.load(modelpath))
152 | modelpath = osp.join(args.output_dir, "{}_{}_{}_{}_target_B".format(args.timestamp, args.s, args.t, args.net) + ".pt")
153 | netB.load_state_dict(torch.load(modelpath))
154 | modelpath = osp.join(args.output_dir, "{}_{}_{}_{}_target_C".format(args.timestamp, args.s, args.t, args.net) + ".pt")
155 | netC.load_state_dict(torch.load(modelpath))
156 |
157 | param_group = []
158 | for k, v in netF.named_parameters():
159 | param_group += [{'params': v, 'lr': args.lr*0.1}]
160 | for k, v in netB.named_parameters():
161 | param_group += [{'params': v, 'lr': args.lr}]
162 | for k, v in netC.named_parameters():
163 | param_group += [{'params': v, 'lr': args.lr}]
164 |
165 | optimizer = optim.SGD(param_group)
166 | optimizer = op_copy(optimizer)
167 |
168 | max_iter = args.max_epoch * len(dset_loaders["target"])
169 | interval_iter = max_iter // 10
170 | iter_num = 0
171 |
172 | netF.eval()
173 | netB.eval()
174 | netC.eval()
175 | acc_s_te, _, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False)
176 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy={:.2f}%, Ent={:.3f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent)
177 | if args.dset == 'visda-2017':
178 | acc_s_te, acc_list, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True)
179 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te,
180 | mean_ent) + '\n' + acc_list
181 |
182 | logging.info(log_str)
183 | netF.train()
184 | netB.train()
185 | netC.train()
186 |
187 | old_pry = 0
188 | while iter_num < max_iter:
189 | optimizer.zero_grad()
190 | try:
191 | inputs_test, _, tar_idx = iter_test.next()
192 | except:
193 | iter_test = iter(dset_loaders["target"])
194 | inputs_test, _, tar_idx = iter_test.next()
195 |
196 | if inputs_test.size(0) == 1:
197 | continue
198 |
199 | inputs_test = inputs_test.cuda()
200 |
201 | iter_num += 1
202 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=0.75)
203 |
204 | features_test = netB(netF(inputs_test))
205 | outputs_test = netC(features_test)
206 |
207 | softmax_out = nn.Softmax(dim=1)(outputs_test)
208 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
209 |
210 | msoftmax = softmax_out.mean(dim=0)
211 | gentropy_loss = -torch.sum(msoftmax * torch.log(msoftmax + 1e-5))
212 | entropy_loss -= gentropy_loss
213 | entropy_loss.backward()
214 | optimizer.step()
215 |
216 | if iter_num % interval_iter == 0 or iter_num == max_iter:
217 | netF.eval()
218 | netB.eval()
219 | netC.eval()
220 | acc_s_te, _, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False)
221 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy={:.2f}%, Ent={:.3f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent)
222 | if args.dset == 'visda-2017':
223 | acc_s_te, acc_list, pry, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, True)
224 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(args.s, args.t, iter_num, max_iter, acc_s_te, mean_ent) + '\n' + acc_list
225 | logging.info(log_str)
226 |
227 | netF.train()
228 | netB.train()
229 | netC.train()
230 |
231 | if torch.abs(pry - old_pry).sum() == 0:
232 | break
233 | else:
234 | old_pry = pry.clone()
235 |
236 | return netF, netB, netC
237 |
238 | def print_args(args):
239 | s = "==========================================\n"
240 | for arg, content in args.__dict__.items():
241 | s += "{}:{}\n".format(arg, content)
242 | return s
243 |
244 | if __name__ == "__main__":
245 | parser = argparse.ArgumentParser(description='DINE')
246 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
247 | parser.add_argument('--s', type=str, default=None, help="source")
248 | parser.add_argument('--t', type=str, default=None, help="target")
249 | parser.add_argument('--max_epoch', type=int, default=30, help="max iterations")
250 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
251 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
252 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'office31', 'image-clef', 'office-home', 'office-caltech'])
253 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
254 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet50, resnext50")
255 | parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
256 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
257 |
258 | parser.add_argument('--bottleneck', type=int, default=256)
259 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
260 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
261 | parser.add_argument('--output', type=str, default='san')
262 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
263 |
264 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp')
265 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)),
266 | help='whether use file logger')
267 | parser.add_argument('--names', default=[], type=list, help='names of tasks')
268 | parser.add_argument('--method', type=str, default=None)
269 |
270 | args = parser.parse_args()
271 | if args.dset == 'office-home':
272 | args.names = ['Art', 'Clipart', 'Product', 'Real_World']
273 | args.class_num = 65
274 | if args.dset == 'visda-2017':
275 | args.names = ['train', 'validation']
276 | args.class_num = 12
277 | if args.dset == 'office31':
278 | args.names = ['amazon', 'dslr', 'webcam']
279 | args.class_num = 31
280 |
281 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
282 | resetRNGseed(args.seed)
283 |
284 | if args.dset == 'office-home':
285 | if args.da == 'pda':
286 | args.class_num = 65
287 | args.src_classes = [i for i in range(65)]
288 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
289 |
290 | if args.method is not None:
291 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method)
292 | if args.use_file_logger:
293 | init_logger(dir, True, '../logs/DINE/{}/'.format(args.method))
294 | else:
295 | dir = "{}_{}_{}".format(args.timestamp, args.s, args.da)
296 | if args.use_file_logger:
297 | init_logger(dir, True, '../logs/DINE/')
298 | logging.info("{}:{}".format(get_hostname(), get_pid()))
299 |
300 | folder = '../data/'
301 | for t in args.names:
302 | if t == args.s:
303 | continue
304 | args.t = t
305 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
306 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
307 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
308 |
309 | args.output_dir = "../checkpoints/DINE/{}/target/{}/".format(args.seed, args.da)
310 |
311 |
312 | if not osp.exists(args.output_dir):
313 | os.system('mkdir -p ' + args.output_dir)
314 | if not osp.exists(args.output_dir):
315 | os.mkdir(args.output_dir)
316 |
317 | logging.info(print_args(args))
318 |
319 | train_target(args)
--------------------------------------------------------------------------------
/DINE/data_list.py:
--------------------------------------------------------------------------------
1 | #from __future__ import print_function, division
2 |
3 | import torch
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | import os
9 | import os.path
10 |
11 | import cv2
12 | import torchvision
13 |
14 | def make_dataset(image_list, labels):
15 | if labels:
16 | len_ = len(image_list)
17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
18 | else:
19 | if len(image_list[0].split()) > 2:
20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
21 | else:
22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
23 | return images
24 |
25 |
26 | def rgb_loader(path):
27 | with open(path, 'rb') as f:
28 | with Image.open(f) as img:
29 | return img.convert('RGB')
30 |
31 | def l_loader(path):
32 | with open(path, 'rb') as f:
33 | with Image.open(f) as img:
34 | return img.convert('L')
35 |
36 | class ImageList(Dataset):
37 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'):
38 | imgs = make_dataset(image_list, labels)
39 | if len(imgs) == 0:
40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
42 |
43 | self.root = root
44 | self.imgs = imgs
45 | self.transform = transform
46 | self.target_transform = target_transform
47 | if mode == 'RGB':
48 | self.loader = rgb_loader
49 | elif mode == 'L':
50 | self.loader = l_loader
51 |
52 | def __getitem__(self, index):
53 | path, target = self.imgs[index]
54 | path = os.path.join(self.root, path)
55 | img = self.loader(path)
56 | if self.transform is not None:
57 | img = self.transform(img)
58 | if self.target_transform is not None:
59 | target = self.target_transform(target)
60 |
61 | return img, target
62 |
63 | def __len__(self):
64 | return len(self.imgs)
65 |
66 | class ImageList_idx(Dataset):
67 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'):
68 | imgs = make_dataset(image_list, labels)
69 | if len(imgs) == 0:
70 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
71 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
72 |
73 | self.root = root
74 | self.imgs = imgs
75 | self.transform = transform
76 | self.target_transform = target_transform
77 | if mode == 'RGB':
78 | self.loader = rgb_loader
79 | elif mode == 'L':
80 | self.loader = l_loader
81 |
82 | def __getitem__(self, index):
83 | path, target = self.imgs[index]
84 | path = os.path.join(self.root, path)
85 | img = self.loader(path)
86 | if self.transform is not None:
87 | img = self.transform(img)
88 | if self.target_transform is not None:
89 | target = self.target_transform(target)
90 |
91 | return img, target, index
92 |
93 | def __len__(self):
94 | return len(self.imgs)
--------------------------------------------------------------------------------
/DINE/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | class CrossEntropyLabelSmooth(nn.Module):
17 | """Cross entropy loss with label smoothing regularizer.
18 | Reference:
19 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
20 | Equation: y = (1 - epsilon) * y + epsilon / K.
21 | Args:
22 | num_classes (int): number of classes.
23 | epsilon (float): weight.
24 | """
25 |
26 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
27 | super(CrossEntropyLabelSmooth, self).__init__()
28 | self.num_classes = num_classes
29 | self.epsilon = epsilon
30 | self.use_gpu = use_gpu
31 | self.reduction = reduction
32 | self.logsoftmax = nn.LogSoftmax(dim=1)
33 |
34 | def forward(self, inputs, targets):
35 | """
36 | Args:
37 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
38 | targets: ground truth labels with shape (num_classes)
39 | """
40 | log_probs = self.logsoftmax(inputs)
41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
42 | if self.use_gpu: targets = targets.cuda()
43 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
44 | loss = (- targets * log_probs).sum(dim=1)
45 | if self.reduction:
46 | return loss.mean()
47 | else:
48 | return loss
49 | return loss
--------------------------------------------------------------------------------
/DINE/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import math
8 | import pdb
9 | import torch.nn.utils.weight_norm as weightNorm
10 | from collections import OrderedDict
11 |
12 | def init_weights(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
15 | nn.init.kaiming_uniform_(m.weight)
16 | nn.init.zeros_(m.bias)
17 | elif classname.find('BatchNorm') != -1:
18 | nn.init.normal_(m.weight, 1.0, 0.02)
19 | nn.init.zeros_(m.bias)
20 | elif classname.find('Linear') != -1:
21 | nn.init.xavier_normal_(m.weight)
22 | nn.init.zeros_(m.bias)
23 |
24 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50,
25 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d}
26 |
27 | class ResBase(nn.Module):
28 | def __init__(self, res_name, pretrain=True):
29 | super(ResBase, self).__init__()
30 | model_resnet = res_dict[res_name](pretrained=pretrain)
31 | self.conv1 = model_resnet.conv1
32 | self.bn1 = model_resnet.bn1
33 | self.relu = model_resnet.relu
34 | self.maxpool = model_resnet.maxpool
35 | self.layer1 = model_resnet.layer1
36 | self.layer2 = model_resnet.layer2
37 | self.layer3 = model_resnet.layer3
38 | self.layer4 = model_resnet.layer4
39 | self.avgpool = model_resnet.avgpool
40 | self.in_features = model_resnet.fc.in_features
41 |
42 | def forward(self, x):
43 | x = self.conv1(x)
44 | x = self.bn1(x)
45 | x = self.relu(x)
46 | x = self.maxpool(x)
47 | x = self.layer1(x)
48 | x = self.layer2(x)
49 | x = self.layer3(x)
50 | x = self.layer4(x)
51 | x = self.avgpool(x)
52 | x = x.view(x.size(0), -1)
53 | return x
54 |
55 | class feat_bootleneck(nn.Module):
56 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
57 | super(feat_bootleneck, self).__init__()
58 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
59 | self.relu = nn.ReLU(inplace=True)
60 | self.dropout = nn.Dropout(p=0.5)
61 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
62 | self.bottleneck.apply(init_weights)
63 | self.type = type
64 |
65 | def forward(self, x):
66 | x = self.bottleneck(x)
67 | if self.type == "bn" or self.type == "bn_relu" or self.type == "bn_relu_drop":
68 | x = self.bn(x)
69 | if self.type == "bn_relu" or self.type == "bn_relu_drop":
70 | x = self.relu(x)
71 | if self.type == "bn_relu_drop":
72 | x = self.dropout(x)
73 | return x
74 |
75 | class feat_classifier(nn.Module):
76 | def __init__(self, class_num, bottleneck_dim=256, type="linear"):
77 | super(feat_classifier, self).__init__()
78 | self.type = type
79 | if type == 'wn':
80 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
81 | self.fc.apply(init_weights)
82 | elif type == 'linear':
83 | self.fc = nn.Linear(bottleneck_dim, class_num)
84 | self.fc.apply(init_weights)
85 | else:
86 | self.fc = nn.Linear(bottleneck_dim, class_num, bias=False)
87 | nn.init.xavier_normal_(self.fc.weight)
88 |
89 | def forward(self, x):
90 | if not self.type in {'wn', 'linear'}:
91 | w = self.fc.weight
92 | w = torch.nn.functional.normalize(w, dim=1, p=2)
93 |
94 | x = torch.nn.functional.normalize(x, dim=1, p=2)
95 | x = torch.nn.functional.linear(x, w)
96 | else:
97 | x = self.fc(x)
98 | return x
99 |
100 | class feat_classifier_simpl(nn.Module):
101 | def __init__(self, class_num, feat_dim):
102 | super(feat_classifier_simpl, self).__init__()
103 | self.fc = nn.Linear(feat_dim, class_num)
104 | nn.init.xavier_normal_(self.fc.weight)
105 |
106 | def forward(self, x):
107 | x = self.fc(x)
108 | return x
--------------------------------------------------------------------------------
/DINE/run_all_kDINE.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | time=`python ../util/get_time.py`
5 |
6 | # office31 -------------------------------------------------------------------------------------------------------------
7 | for seed in 2020 2021 2022; do
8 | for src in 'webcam' 'amazon' 'dslr' ; do
9 | echo $src
10 | python DINE_dist.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 50 --timestamp $time
11 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
12 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf $pk_uconf
13 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine
14 | done
15 |
16 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type br --pk_uconf 1.0
17 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office31 --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine
18 | done
19 | done
20 |
21 |
22 | # office-home ----------------------------------------------------------------------------------------------------------
23 | for seed in 2020 2021 2022; do
24 | for src in 'Product' 'Real_World' 'Art' 'Clipart' ; do
25 | echo $src
26 | python DINE_dist.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 50 --timestamp $time
27 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
28 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf $pk_uconf
29 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine
30 | done
31 |
32 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type br --pk_uconf 1.0
33 | python DINE_ft.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine
34 | done
35 | done
36 |
37 | # office-home (PDA)-----------------------------------------------------------------------------------------------------
38 | for seed in 2020 2021 2022; do
39 | for src in 'Product' 'Real_World' 'Art' 'Clipart' ; do
40 | echo $src
41 | python DINE_dist.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da pda --net_src resnet50 --max_epoch 50 --timestamp $time
42 |
43 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da pda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf 0.0
44 | python DINE_dist_kDINE.py --gpu_id $gpu_id --seed $seed --dset office-home --s $src --da pda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type br --pk_uconf 1.0
45 | done
46 | done
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 tsun
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KUDA
2 | Pytorch implementation of KUDA.
3 | > [Prior Knowledge Guided Unsupervised Domain Adaptation](https://arxiv.org/abs/2207.08877)
4 | > Tao Sun, Cheng Lu, and Haibin Ling
5 | > *ECCV 2022*
6 |
7 | ## Abstract
8 | The waive of labels in the target domain makes Unsupervised Domain Adaptation (UDA) an attractive technique in many real-world applications, though it also brings great challenges as model adaptation becomes harder without labeled target data. In this paper, we address this issue by seeking compensation from target domain prior knowledge, which is often (partially) available in practice, e.g., from human expertise. This leads to a novel yet practical setting where in addition to the training data, some prior knowledge about the target class distribution are available. We term the setting as Knowledge-guided Unsupervised Domain Adaptation (KUDA). In particular, we consider two specific types of prior knowledge about the class distribution in the target domain: Unary Bound that describes the lower and upper bounds of individual class probabilities, and Binary Relationship that describes the relations between two class probabilities. We propose a general rectification module that uses such prior knowledge to refine model generated pseudo labels. The module is formulated as a Zero-One Programming problem derived from the prior knowledge and a smooth regularizer. It can be easily plugged into self-training based UDA methods, and we combine it with two state-of-the-art methods, SHOT and DINE. Empirical results on four benchmarks confirm that the rectification module clearly improves the quality of pseudo labels, which in turn benefits the self-training stage. With the guidance from prior knowledge, the performances of both methods are substantially boosted. We expect our work to inspire further investigations in integrating prior knowledge in UDA.
9 |
10 | ### Knowledge-guided Unsupervised Domain Adaptation (KUDA)
11 |
12 |
13 | ### Integrating rectification module into SHOT and DINE
14 |
15 |
16 | ## Usage
17 | ### Prerequisites
18 |
19 | We experimented with python==3.8, pytorch==1.8.0, cudatoolkit==11.1, gurobi==9.5.0.
20 |
21 | For Zero-One programming, we use [Gurobi Optimizer](https://www.gurobi.com/). A free [academic license](https://www.gurobi.com/academia/academic-program-and-licenses/) can be obtained from its official website.
22 |
23 |
24 | ### Data Preparation
25 | Download the [office31](https://faculty.cc.gatech.edu/~judy/domainadapt/), [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html), [VisDA](https://ai.bu.edu/visda-2017/), [DomainNet](http://ai.bu.edu/M3SDA/) datasets.
26 |
27 | Setup dataset path in ./data
28 | ```shell
29 | bash setup_data_path.sh /Path_to_data/office/domain_adaptation_images office31
30 | bash setup_data_path.sh /Path_to_data/office-home/images office-home
31 | bash setup_data_path.sh /Path_to_data/office-home/images office-home-rsut
32 | bash setup_data_path.sh /Path_to_data/VisDA visda
33 | bash setup_data_path.sh /Path_to_data/DomainNet domainnet40
34 | ```
35 |
36 | ### kSHOT
37 | Unsupervised Closed-set Domain Adaptation (UDA) on the Office-Home dataset
38 | ```shell
39 | cd SHOT
40 |
41 | time=`python ../util/get_time.py`
42 | gpu_id=0
43 |
44 | # generate source models
45 | for src in "Product" "Clipart" "Art" "Real_World"; do
46 | echo $src
47 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office-home --max_epoch 50 --s $src --timestamp $time
48 | done
49 |
50 | # adapt to other target domains with Unary Bound prior knowledge
51 | for seed in 2020 2021 2022; do
52 | for src in "Product" "Clipart" "Art" "Real_World"; do
53 | echo $src
54 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --pk_uconf 0.0 --seed $seed --pk_type ub
55 | done
56 | done
57 | ```
58 |
59 | ### kDINE
60 | Unsupervised Closed-set Domain Adaptation (UDA) on the Office-Home dataset
61 | ```shell
62 | cd DINE
63 |
64 | time=`python ./get_time.py`
65 | gpu=0
66 |
67 | for seed in 2020 2021 2022; do
68 | for src in 'Product' 'Real_World' 'Art' 'Clipart' ; do
69 | echo $src
70 | # training the source model first
71 | python DINE_dist.py --gpu_id $gpu --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 50 --timestamp $time
72 | # the first step (Distill) with Unary Bound prior knowledge
73 | python DINE_dist_kDINE.py --gpu_id $gpu --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --distill --topk 1 --timestamp $time --pk_type ub --pk_uconf 0.0
74 | # the second step (Finetune)
75 | python DINE_ft.py --gpu_id $gpu --seed $seed --dset office-home --s $src --da uda --net_src resnet50 --max_epoch 30 --net resnet50 --lr 1e-2 --timestamp $time --method kdine
76 | done
77 | done
78 | ```
79 | Complete commands are available in ./SHOT/run_all_kSHOT.sh and ./DINE/run_all_kDINE.sh.
80 |
81 | ## Acknowledgements
82 | The implementations are adapted from [SHOT](https://github.com/tim-learn/SHOT) and
83 | [DINE](https://github.com/tim-learn/DINE).
84 |
85 |
86 | ## Citation
87 | If you find our paper and code useful for your research, please consider citing
88 | ```bibtex
89 | @inproceedings{sun2022prior,
90 | author = {Sun, Tao and Lu, Cheng and Ling, Haibin},
91 | title = {Prior Knowledge Guided Unsupervised Domain Adaptation},
92 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
93 | year = {2022}
94 | }
95 | ```
--------------------------------------------------------------------------------
/SHOT/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/SHOT/__init__.py
--------------------------------------------------------------------------------
/SHOT/augmentations.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 | import numpy as np
7 | from PIL import Image
8 | import torch
9 |
10 | def ShearX(img, v): # [-0.3, 0.3]
11 | assert -0.3 <= v <= 0.3
12 | if random.random() > 0.5:
13 | v = -v
14 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
15 |
16 |
17 | def ShearY(img, v): # [-0.3, 0.3]
18 | assert -0.3 <= v <= 0.3
19 | if random.random() > 0.5:
20 | v = -v
21 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
22 |
23 |
24 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
25 | assert -0.45 <= v <= 0.45
26 | if random.random() > 0.5:
27 | v = -v
28 | v = v * img.size[0]
29 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
30 |
31 |
32 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
33 | assert 0 <= v
34 | if random.random() > 0.5:
35 | v = -v
36 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
37 |
38 |
39 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
40 | assert -0.45 <= v <= 0.45
41 | if random.random() > 0.5:
42 | v = -v
43 | v = v * img.size[1]
44 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
45 |
46 |
47 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
48 | assert 0 <= v
49 | if random.random() > 0.5:
50 | v = -v
51 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
52 |
53 |
54 | def Rotate(img, v): # [-30, 30]
55 | assert -30 <= v <= 30
56 | if random.random() > 0.5:
57 | v = -v
58 | return img.rotate(v)
59 |
60 |
61 | def AutoContrast(img, _):
62 | return PIL.ImageOps.autocontrast(img)
63 |
64 |
65 | def Invert(img, _):
66 | return PIL.ImageOps.invert(img)
67 |
68 |
69 | def Equalize(img, _):
70 | return PIL.ImageOps.equalize(img)
71 |
72 |
73 | def Flip(img, _): # not from the paper
74 | return PIL.ImageOps.mirror(img)
75 |
76 |
77 | def Solarize(img, v): # [0, 256]
78 | assert 0 <= v <= 256
79 | return PIL.ImageOps.solarize(img, v)
80 |
81 |
82 | def SolarizeAdd(img, addition=0, threshold=128):
83 | img_np = np.array(img).astype(np.int)
84 | img_np = img_np + addition
85 | img_np = np.clip(img_np, 0, 255)
86 | img_np = img_np.astype(np.uint8)
87 | img = Image.fromarray(img_np)
88 | return PIL.ImageOps.solarize(img, threshold)
89 |
90 |
91 | def Posterize(img, v): # [4, 8]
92 | v = int(v)
93 | v = max(1, v)
94 | return PIL.ImageOps.posterize(img, v)
95 |
96 |
97 | def Contrast(img, v): # [0.1,1.9]
98 | assert 0.1 <= v <= 1.9
99 | return PIL.ImageEnhance.Contrast(img).enhance(v)
100 |
101 |
102 | def Color(img, v): # [0.1,1.9]
103 | assert 0.1 <= v <= 1.9
104 | return PIL.ImageEnhance.Color(img).enhance(v)
105 |
106 |
107 | def Brightness(img, v): # [0.1,1.9]
108 | assert 0.1 <= v <= 1.9
109 | return PIL.ImageEnhance.Brightness(img).enhance(v)
110 |
111 |
112 | def Sharpness(img, v): # [0.1,1.9]
113 | assert 0.1 <= v <= 1.9
114 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
115 |
116 |
117 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
118 | assert 0.0 <= v <= 0.2
119 | if v <= 0.:
120 | return img
121 |
122 | v = v * img.size[0]
123 | return CutoutAbs(img, v)
124 |
125 |
126 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
127 | # assert 0 <= v <= 20
128 | if v < 0:
129 | return img
130 | w, h = img.size
131 | x0 = np.random.uniform(w)
132 | y0 = np.random.uniform(h)
133 |
134 | x0 = int(max(0, x0 - v / 2.))
135 | y0 = int(max(0, y0 - v / 2.))
136 | x1 = min(w, x0 + v)
137 | y1 = min(h, y0 + v)
138 |
139 | xy = (x0, y0, x1, y1)
140 | color = (125, 123, 114)
141 | # color = (0, 0, 0)
142 | img = img.copy()
143 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
144 | return img
145 |
146 |
147 | def SamplePairing(imgs): # [0, 0.4]
148 | def f(img1, v):
149 | i = np.random.choice(len(imgs))
150 | img2 = PIL.Image.fromarray(imgs[i])
151 | return PIL.Image.blend(img1, img2, v)
152 |
153 | return f
154 |
155 |
156 | def Identity(img, v):
157 | return img
158 |
159 |
160 | def augment_list(): # 16 oeprations and their ranges
161 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
162 | # l = [
163 | # (Identity, 0., 1.0),
164 | # (ShearX, 0., 0.3), # 0
165 | # (ShearY, 0., 0.3), # 1
166 | # (TranslateX, 0., 0.33), # 2
167 | # (TranslateY, 0., 0.33), # 3
168 | # (Rotate, 0, 30), # 4
169 | # (AutoContrast, 0, 1), # 5
170 | # (Invert, 0, 1), # 6
171 | # (Equalize, 0, 1), # 7
172 | # (Solarize, 0, 110), # 8
173 | # (Posterize, 4, 8), # 9
174 | # # (Contrast, 0.1, 1.9), # 10
175 | # (Color, 0.1, 1.9), # 11
176 | # (Brightness, 0.1, 1.9), # 12
177 | # (Sharpness, 0.1, 1.9), # 13
178 | # # (Cutout, 0, 0.2), # 14
179 | # # (SamplePairing(imgs), 0, 0.4), # 15
180 | # ]
181 |
182 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
183 | l = [
184 | (AutoContrast, 0, 1),
185 | (Equalize, 0, 1),
186 | (Invert, 0, 1),
187 | (Rotate, 0, 30),
188 | (Posterize, 0, 4),
189 | (Solarize, 0, 256),
190 | (SolarizeAdd, 0, 110),
191 | (Color, 0.1, 1.9),
192 | (Contrast, 0.1, 1.9),
193 | (Brightness, 0.1, 1.9),
194 | (Sharpness, 0.1, 1.9),
195 | (ShearX, 0., 0.3),
196 | (ShearY, 0., 0.3),
197 | (CutoutAbs, 0, 40),
198 | (TranslateXabs, 0., 100),
199 | (TranslateYabs, 0., 100),
200 | ]
201 |
202 | return l
203 |
204 |
205 | class Lighting(object):
206 | """Lighting noise(AlexNet - style PCA - based noise)"""
207 |
208 | def __init__(self, alphastd, eigval, eigvec):
209 | self.alphastd = alphastd
210 | self.eigval = torch.Tensor(eigval)
211 | self.eigvec = torch.Tensor(eigvec)
212 |
213 | def __call__(self, img):
214 | if self.alphastd == 0:
215 | return img
216 |
217 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
218 | rgb = self.eigvec.type_as(img).clone() \
219 | .mul(alpha.view(1, 3).expand(3, 3)) \
220 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
221 | .sum(1).squeeze()
222 |
223 | return img.add(rgb.view(3, 1, 1).expand_as(img))
224 |
225 |
226 | class CutoutDefault(object):
227 | """
228 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
229 | """
230 | def __init__(self, length):
231 | self.length = length
232 |
233 | def __call__(self, img):
234 | h, w = img.size(1), img.size(2)
235 | mask = np.ones((h, w), np.float32)
236 | y = np.random.randint(h)
237 | x = np.random.randint(w)
238 |
239 | y1 = np.clip(y - self.length // 2, 0, h)
240 | y2 = np.clip(y + self.length // 2, 0, h)
241 | x1 = np.clip(x - self.length // 2, 0, w)
242 | x2 = np.clip(x + self.length // 2, 0, w)
243 |
244 | mask[y1: y2, x1: x2] = 0.
245 | mask = torch.from_numpy(mask)
246 | mask = mask.expand_as(img)
247 | img *= mask
248 | return img
249 |
250 |
251 | class RandAugment:
252 | def __init__(self, n, m):
253 | self.n = n
254 | self.m = m # [0, 30]
255 | self.augment_list = augment_list()
256 |
257 | def __call__(self, img):
258 |
259 | if self.n == 0:
260 | return img
261 |
262 | ops = random.choices(self.augment_list, k=self.n)
263 | for op, minval, maxval in ops:
264 | val = (float(self.m) / 30) * float(maxval - minval) + minval
265 | img = op(img, val)
266 |
267 | return img
--------------------------------------------------------------------------------
/SHOT/data_list.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from torch.utils.data import Dataset
4 | import os
5 | from augmentations import RandAugment
6 | import copy
7 |
8 | def make_dataset(image_list, labels):
9 | if labels:
10 | len_ = len(image_list)
11 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
12 | else:
13 | if len(image_list[0].split()) > 2:
14 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
15 | else:
16 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
17 | return images
18 |
19 |
20 | def rgb_loader(path):
21 | with open(path, 'rb') as f:
22 | with Image.open(f) as img:
23 | return img.convert('RGB')
24 |
25 | def l_loader(path):
26 | with open(path, 'rb') as f:
27 | with Image.open(f) as img:
28 | return img.convert('L')
29 |
30 | class ImageList(Dataset):
31 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'):
32 | imgs = make_dataset(image_list, labels)
33 | if len(imgs) == 0:
34 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
35 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
36 |
37 | self.root = root
38 | self.imgs = imgs
39 | self.transform = transform
40 | self.target_transform = target_transform
41 | if mode == 'RGB':
42 | self.loader = rgb_loader
43 | elif mode == 'L':
44 | self.loader = l_loader
45 |
46 | def __getitem__(self, index):
47 | path, target = self.imgs[index]
48 | path = os.path.join(self.root, path)
49 | img = self.loader(path)
50 | if self.transform is not None:
51 | img = self.transform(img)
52 | if self.target_transform is not None:
53 | target = self.target_transform(target)
54 |
55 | return img, target
56 |
57 | def __len__(self):
58 | return len(self.imgs)
59 |
60 | class ImageList_idx_aug(Dataset):
61 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB',
62 | rand_aug_size=0, rand_aug_n=2, rand_aug_m=2.):
63 | imgs = make_dataset(image_list, labels)
64 | if len(imgs) == 0:
65 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
67 |
68 | self.root = root
69 | self.imgs = imgs
70 | self.transform = transform
71 | self.target_transform = target_transform
72 | if mode == 'RGB':
73 | self.loader = rgb_loader
74 | elif mode == 'L':
75 | self.loader = l_loader
76 |
77 | self.rand_aug_size = rand_aug_size
78 |
79 | if self.rand_aug_size > 0:
80 | self.rand_aug_transform = copy.deepcopy(self.transform)
81 | self.rand_aug_transform.transforms.insert(0, RandAugment(rand_aug_n, rand_aug_m))
82 |
83 | def __getitem__(self, index):
84 | path, target = self.imgs[index]
85 | path = os.path.join(self.root, path)
86 | img = self.loader(path)
87 | img_ = self.loader(path)
88 | if self.transform is not None:
89 | img = self.transform(img)
90 | if self.target_transform is not None:
91 | target = self.target_transform(target)
92 |
93 | rand_imgs = [self.rand_aug_transform(img_) for _ in range(self.rand_aug_size)]
94 | return img, target, index, rand_imgs
95 |
96 | def __len__(self):
97 | return len(self.imgs)
98 |
99 |
100 | class ImageList_idx(Dataset):
101 | def __init__(self, image_list, root, labels=None, transform=None, target_transform=None, mode='RGB'):
102 | imgs = make_dataset(image_list, labels)
103 | if len(imgs) == 0:
104 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
105 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
106 |
107 | self.root = root
108 | self.imgs = imgs
109 | self.transform = transform
110 | self.target_transform = target_transform
111 | if mode == 'RGB':
112 | self.loader = rgb_loader
113 | elif mode == 'L':
114 | self.loader = l_loader
115 |
116 | def __getitem__(self, index):
117 | path, target = self.imgs[index]
118 | path = os.path.join(self.root, path)
119 | img = self.loader(path)
120 | if self.transform is not None:
121 | img = self.transform(img)
122 | if self.target_transform is not None:
123 | target = self.target_transform(target)
124 |
125 | return img, target, index
126 |
127 | def __len__(self):
128 | return len(self.imgs)
--------------------------------------------------------------------------------
/SHOT/image_source.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torchvision import transforms
9 | import network, loss
10 | from torch.utils.data import DataLoader
11 | from data_list import ImageList
12 | from loss import CrossEntropyLabelSmooth
13 | from sklearn.metrics import confusion_matrix
14 | from sklearn.cluster import KMeans
15 | import distutils
16 | import distutils.util
17 | import logging
18 |
19 | import sys
20 | sys.path.append("../util/")
21 | from utils import resetRNGseed, init_logger, get_hostname, get_pid
22 |
23 | import time
24 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
25 |
26 | torch.backends.cudnn.deterministic = True
27 | torch.backends.cudnn.benchmark = False
28 |
29 |
30 | def op_copy(optimizer):
31 | for param_group in optimizer.param_groups:
32 | param_group['lr0'] = param_group['lr']
33 | return optimizer
34 |
35 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
36 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
37 | for param_group in optimizer.param_groups:
38 | param_group['lr'] = param_group['lr0'] * decay
39 | param_group['weight_decay'] = 1e-3
40 | param_group['momentum'] = 0.9
41 | param_group['nesterov'] = True
42 | return optimizer
43 |
44 | def image_train(resize_size=256, crop_size=224, alexnet=False):
45 | if not alexnet:
46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
47 | std=[0.229, 0.224, 0.225])
48 | else:
49 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
50 | return transforms.Compose([
51 | transforms.Resize((resize_size, resize_size)),
52 | transforms.RandomCrop(crop_size),
53 | transforms.RandomHorizontalFlip(),
54 | transforms.ToTensor(),
55 | normalize
56 | ])
57 |
58 | def image_test(resize_size=256, crop_size=224, alexnet=False):
59 | if not alexnet:
60 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
61 | std=[0.229, 0.224, 0.225])
62 | else:
63 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
64 | return transforms.Compose([
65 | transforms.Resize((resize_size, resize_size)),
66 | transforms.CenterCrop(crop_size),
67 | transforms.ToTensor(),
68 | normalize
69 | ])
70 |
71 | def data_load(args):
72 | ## prepare data
73 | dsets = {}
74 | dset_loaders = {}
75 | train_bs = args.batch_size
76 | txt_src = open(args.s_dset_path).readlines()
77 | txt_test = open(args.test_dset_path).readlines()
78 |
79 | if not args.da == 'uda':
80 | label_map_s = {}
81 | for i in range(len(args.src_classes)):
82 | label_map_s[args.src_classes[i]] = i
83 |
84 | new_src = []
85 | for i in range(len(txt_src)):
86 | rec = txt_src[i]
87 | reci = rec.strip().split(' ')
88 | if int(reci[1]) in args.src_classes:
89 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
90 | new_src.append(line)
91 | txt_src = new_src.copy()
92 |
93 | new_tar = []
94 | for i in range(len(txt_test)):
95 | rec = txt_test[i]
96 | reci = rec.strip().split(' ')
97 | if int(reci[1]) in args.tar_classes:
98 | if int(reci[1]) in args.src_classes:
99 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
100 | new_tar.append(line)
101 | else:
102 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
103 | new_tar.append(line)
104 | txt_test = new_tar.copy()
105 |
106 | if args.trte == "val":
107 | dsize = len(txt_src)
108 | tr_size = int(0.9*dsize)
109 | # print(dsize, tr_size, dsize - tr_size)
110 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
111 | else:
112 | dsize = len(txt_src)
113 | tr_size = int(0.9*dsize)
114 | _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
115 | tr_txt = txt_src
116 |
117 | dsets["source_tr"] = ImageList(tr_txt, root="../data/{}/".format(args.dset), transform=image_train())
118 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
119 | dsets["source_te"] = ImageList(te_txt, root="../data/{}/".format(args.dset), transform=image_test())
120 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
121 | dsets["test"] = ImageList(txt_test, root="../data/{}/".format(args.dset), transform=image_test())
122 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False)
123 |
124 | return dset_loaders
125 |
126 | def cal_acc(loader, netF, netB, netC, flag=False):
127 | start_test = True
128 | with torch.no_grad():
129 | iter_test = iter(loader)
130 | for i in range(len(loader)):
131 | data = iter_test.next()
132 | inputs = data[0]
133 | labels = data[1]
134 | inputs = inputs.cuda()
135 | outputs = netC(netB(netF(inputs)))
136 | if start_test:
137 | all_output = outputs.float().cpu()
138 | all_label = labels.float()
139 | start_test = False
140 | else:
141 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
142 | all_label = torch.cat((all_label, labels.float()), 0)
143 |
144 | all_output = nn.Softmax(dim=1)(all_output)
145 | _, predict = torch.max(all_output, 1)
146 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
147 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item()
148 |
149 | if flag:
150 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
151 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
152 | aacc = acc.mean()
153 | aa = [str(np.round(i, 2)) for i in acc]
154 | acc = ' '.join(aa)
155 | return aacc, acc
156 | else:
157 | return accuracy*100, mean_ent
158 |
159 | def cal_acc_oda(loader, netF, netB, netC):
160 | start_test = True
161 | with torch.no_grad():
162 | iter_test = iter(loader)
163 | for i in range(len(loader)):
164 | data = iter_test.next()
165 | inputs = data[0]
166 | labels = data[1]
167 | inputs = inputs.cuda()
168 | outputs = netC(netB(netF(inputs)))
169 | if start_test:
170 | all_output = outputs.float().cpu()
171 | all_label = labels.float()
172 | start_test = False
173 | else:
174 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
175 | all_label = torch.cat((all_label, labels.float()), 0)
176 |
177 | all_output = nn.Softmax(dim=1)(all_output)
178 | _, predict = torch.max(all_output, 1)
179 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num)
180 | ent = ent.float().cpu()
181 | initc = np.array([[0], [1]])
182 | kmeans = KMeans(n_clusters=2, random_state=0, init=initc, n_init=1).fit(ent.reshape(-1,1))
183 | threshold = (kmeans.cluster_centers_).mean()
184 |
185 | predict[ent>threshold] = args.class_num
186 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
187 | matrix = matrix[np.unique(all_label).astype(int),:]
188 |
189 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
190 | unknown_acc = acc[-1:].item()
191 |
192 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc
193 | # return np.mean(acc), np.mean(acc[:-1])
194 |
195 | def train_source(args):
196 | dset_loaders = data_load(args)
197 | ## set base network
198 | if args.net[0:3] == 'res':
199 | netF = network.ResBase(res_name=args.net).cuda()
200 | elif args.net[0:3] == 'vgg':
201 | netF = network.VGGBase(vgg_name=args.net).cuda()
202 |
203 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
204 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
205 |
206 | param_group = []
207 | learning_rate = args.lr
208 | for k, v in netF.named_parameters():
209 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
210 | for k, v in netB.named_parameters():
211 | param_group += [{'params': v, 'lr': learning_rate}]
212 | for k, v in netC.named_parameters():
213 | param_group += [{'params': v, 'lr': learning_rate}]
214 | optimizer = optim.SGD(param_group)
215 | optimizer = op_copy(optimizer)
216 |
217 | acc_init = 0
218 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
219 | interval_iter = max_iter // 10
220 | iter_num = 0
221 |
222 | netF.train()
223 | netB.train()
224 | netC.train()
225 |
226 | while iter_num < max_iter:
227 | try:
228 | inputs_source, labels_source = iter_source.next()
229 | except:
230 | iter_source = iter(dset_loaders["source_tr"])
231 | inputs_source, labels_source = iter_source.next()
232 |
233 | if inputs_source.size(0) == 1:
234 | continue
235 |
236 | iter_num += 1
237 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
238 |
239 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
240 | outputs_source = netC(netB(netF(inputs_source)))
241 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
242 |
243 | optimizer.zero_grad()
244 | classifier_loss.backward()
245 | optimizer.step()
246 |
247 | if iter_num % interval_iter == 0 or iter_num == max_iter:
248 | netF.eval()
249 | netB.eval()
250 | netC.eval()
251 | if args.dset=='visda-2017':
252 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True)
253 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te) + '\n' + acc_list
254 | else:
255 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False)
256 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.s, iter_num, max_iter, acc_s_te)
257 | # args.out_file.write(log_str + '\n')
258 | # args.out_file.flush()
259 | # print(log_str+'\n')
260 | logging.info(log_str)
261 |
262 | if acc_s_te >= acc_init:
263 | acc_init = acc_s_te
264 | best_netF = netF.state_dict()
265 | best_netB = netB.state_dict()
266 | best_netC = netC.state_dict()
267 |
268 | netF.train()
269 | netB.train()
270 | netC.train()
271 |
272 | torch.save(best_netF, osp.join(args.output_dir_src, "{}_{}_source_F.pt".format(args.s, args.net)))
273 | torch.save(best_netB, osp.join(args.output_dir_src, "{}_{}_source_B.pt".format(args.s, args.net)))
274 | torch.save(best_netC, osp.join(args.output_dir_src, "{}_{}_source_C.pt".format(args.s, args.net)))
275 |
276 | return netF, netB, netC
277 |
278 | def test_target(args):
279 | dset_loaders = data_load(args)
280 | ## set base network
281 | if args.net[0:3] == 'res':
282 | netF = network.ResBase(res_name=args.net).cuda()
283 | elif args.net[0:3] == 'vgg':
284 | netF = network.VGGBase(vgg_name=args.net).cuda()
285 |
286 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
287 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
288 |
289 | args.modelpath = osp.join(args.output_dir_src, '{}_{}_source_F.pt'.format(args.s, args.net))
290 | netF.load_state_dict(torch.load(args.modelpath))
291 | args.modelpath = osp.join(args.output_dir_src, '{}_{}_source_B.pt'.format(args.s, args.net))
292 | netB.load_state_dict(torch.load(args.modelpath))
293 | args.modelpath = osp.join(args.output_dir_src, '{}_{}_source_C.pt'.format(args.s, args.net))
294 | netC.load_state_dict(torch.load(args.modelpath))
295 | netF.eval()
296 | netB.eval()
297 | netC.eval()
298 |
299 | if args.da == 'oda':
300 | acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC)
301 | log_str = '\nTraining: {}, Task: {}->{}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.s, args.t, acc_os2, acc_os1, acc_unknown)
302 | else:
303 | if args.dset=='visda-2017':
304 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
305 | log_str = '\nTraining: {}, Task: {}->{}, Accuracy = {:.2f}%'.format(args.trte, args.s, args.t, acc) + '\n' + acc_list
306 | else:
307 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
308 | log_str = '\nTraining: {}, Task: {}->{}, Accuracy = {:.2f}%'.format(args.trte, args.s, args.t, acc)
309 |
310 | # args.out_file.write(log_str)
311 | # args.out_file.flush()
312 | # print(log_str)
313 | logging.info(log_str)
314 |
315 | def print_args(args):
316 | s = "==========================================\n"
317 | for arg, content in args.__dict__.items():
318 | s += "{}:{}\n".format(arg, content)
319 | return s
320 |
321 | if __name__ == "__main__":
322 | parser = argparse.ArgumentParser(description='SHOT')
323 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
324 | parser.add_argument('--s', type=str, default=None, help="source")
325 | parser.add_argument('--t', type=str, default=None, help="target")
326 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
327 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
328 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
329 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'domainnet40', 'office31',
330 | 'office-home', 'office-home-rsut', 'office-caltech', 'multi'])
331 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
332 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
333 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
334 | parser.add_argument('--bottleneck', type=int, default=256)
335 | parser.add_argument('--epsilon', type=float, default=1e-5)
336 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
337 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
338 | parser.add_argument('--smooth', type=float, default=0.1)
339 | parser.add_argument('--output', type=str, default='san')
340 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda', 'oda'])
341 | parser.add_argument('--trte', type=str, default='val', choices=['full', 'val'])
342 |
343 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp')
344 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)),
345 | help='whether use file logger')
346 | parser.add_argument('--names', default=[], type=list, help='names of tasks')
347 |
348 |
349 | args = parser.parse_args()
350 |
351 | if args.dset == 'office-home':
352 | args.names = ['Art', 'Clipart', 'Product', 'Real_World']
353 | args.class_num = 65
354 | if args.dset == 'office-home-rsut':
355 | args.names = ['Clipart', 'Product', 'Real_World']
356 | args.class_num = 65
357 | if args.dset == 'domainnet40':
358 | args.names = ['sketch', 'clipart', 'painting', 'real']
359 | args.class_num = 40
360 | if args.dset == 'multi':
361 | args.names = ['real', 'clipart', 'sketch', 'painting']
362 | args.class_num = 126
363 | if args.dset == 'office31':
364 | args.names = ['amazon', 'dslr', 'webcam']
365 | args.class_num = 31
366 | if args.dset == 'visda-2017':
367 | args.names = ['train', 'validation']
368 | args.class_num = 12
369 | if args.dset == 'office-caltech':
370 | args.names = ['amazon', 'caltech', 'dslr', 'webcam']
371 | args.class_num = 10
372 |
373 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
374 | resetRNGseed(args.seed)
375 |
376 | if args.dset == 'office-home-rsut':
377 | args.s += '_RS'
378 |
379 | dir = "{}_{}_{}_source".format(args.timestamp, args.s, args.da)
380 | if args.use_file_logger:
381 | init_logger(dir, True, '../logs/SHOT/source/')
382 | logging.info("{}:{}".format(get_hostname(), get_pid()))
383 |
384 | folder = '../data/'
385 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
386 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
387 |
388 | if args.dset == 'domainnet40':
389 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt'
390 | args.test_dset_path = folder + args.dset + '/image_list/' + args.s + '_test_mini.txt'
391 |
392 | if args.dset == 'office-home':
393 | if args.da == 'pda':
394 | args.class_num = 65
395 | args.src_classes = [i for i in range(65)]
396 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
397 | if args.da == 'oda':
398 | args.class_num = 25
399 | args.src_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
400 | args.tar_classes = [i for i in range(65)]
401 |
402 | # args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper())
403 | # args.name_src = names[args.s][0].upper()
404 | args.output_dir_src = "../checkpoints/SHOT/source/{}/".format(args.da)
405 |
406 | if not osp.exists(args.output_dir_src):
407 | os.system('mkdir -p ' + args.output_dir_src)
408 | if not osp.exists(args.output_dir_src):
409 | os.mkdir(args.output_dir_src)
410 |
411 | # args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
412 | # args.out_file.write(print_args(args)+'\n')
413 | # args.out_file.flush()
414 |
415 | logging.info(print_args(args))
416 | train_source(args)
417 |
418 | # args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
419 | for t in args.names:
420 | if t == args.s or t == args.s.split('_RS')[0]:
421 | continue
422 | args.t = t
423 | # args.name = args.names[args.s][0].upper() + args.t[0].upper()
424 |
425 | if args.dset == 'office-home-rsut':
426 | args.t += '_UT'
427 |
428 | folder = '../data/'
429 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
430 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
431 |
432 | if args.dset == 'domainnet40':
433 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt'
434 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '_test_mini.txt'
435 |
436 | if args.dset == 'office-home':
437 | if args.da == 'pda':
438 | args.class_num = 65
439 | args.src_classes = [i for i in range(65)]
440 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
441 | if args.da == 'oda':
442 | args.class_num = 25
443 | args.src_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
444 | args.tar_classes = [i for i in range(65)]
445 |
446 | test_target(args)
--------------------------------------------------------------------------------
/SHOT/image_target.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from scipy.spatial.distance import cdist
16 | from sklearn.metrics import confusion_matrix
17 | import distutils
18 | import distutils.util
19 | import logging
20 |
21 | import sys
22 | sys.path.append("../util/")
23 | from utils import resetRNGseed, init_logger, get_hostname, get_pid
24 |
25 | import time
26 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
27 |
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 |
31 |
32 | def op_copy(optimizer):
33 | for param_group in optimizer.param_groups:
34 | param_group['lr0'] = param_group['lr']
35 | return optimizer
36 |
37 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
38 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
39 | for param_group in optimizer.param_groups:
40 | param_group['lr'] = param_group['lr0'] * decay
41 | param_group['weight_decay'] = 1e-3
42 | param_group['momentum'] = 0.9
43 | param_group['nesterov'] = True
44 | return optimizer
45 |
46 | def image_train(resize_size=256, crop_size=224, alexnet=False):
47 | if not alexnet:
48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
49 | std=[0.229, 0.224, 0.225])
50 | else:
51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
52 | return transforms.Compose([
53 | transforms.Resize((resize_size, resize_size)),
54 | transforms.RandomCrop(crop_size),
55 | transforms.RandomHorizontalFlip(),
56 | transforms.ToTensor(),
57 | normalize
58 | ])
59 |
60 | def image_test(resize_size=256, crop_size=224, alexnet=False):
61 | if not alexnet:
62 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
63 | std=[0.229, 0.224, 0.225])
64 | else:
65 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
66 | return transforms.Compose([
67 | transforms.Resize((resize_size, resize_size)),
68 | transforms.CenterCrop(crop_size),
69 | transforms.ToTensor(),
70 | normalize
71 | ])
72 |
73 | def data_load(args):
74 | ## prepare data
75 | dsets = {}
76 | dset_loaders = {}
77 | train_bs = args.batch_size
78 | txt_tar = open(args.t_dset_path).readlines()
79 | txt_test = open(args.test_dset_path).readlines()
80 |
81 | if not args.da == 'uda':
82 | label_map_s = {}
83 | for i in range(len(args.src_classes)):
84 | label_map_s[args.src_classes[i]] = i
85 |
86 | new_tar = []
87 | for i in range(len(txt_tar)):
88 | rec = txt_tar[i]
89 | reci = rec.strip().split(' ')
90 | if int(reci[1]) in args.tar_classes:
91 | if int(reci[1]) in args.src_classes:
92 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
93 | new_tar.append(line)
94 | else:
95 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
96 | new_tar.append(line)
97 | txt_tar = new_tar.copy()
98 | txt_test = txt_tar.copy()
99 |
100 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train())
101 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
102 |
103 | dsets["test"] = ImageList_idx(txt_test, root="../data/{}/".format(args.dset), transform=image_test())
104 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
105 | dsets["valid"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train() if args.use_train_transform else image_test())
106 | dset_loaders["valid"] = DataLoader(dsets["valid"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, drop_last=False)
107 |
108 | return dset_loaders
109 |
110 | def cal_acc(loader, netF, netB, netC, flag=False):
111 | start_test = True
112 | with torch.no_grad():
113 | iter_test = iter(loader)
114 | for i in range(len(loader)):
115 | data = iter_test.next()
116 | inputs = data[0]
117 | labels = data[1]
118 | inputs = inputs.cuda()
119 | outputs = netC(netB(netF(inputs)))
120 | if start_test:
121 | all_output = outputs.float().cpu()
122 | all_label = labels.float()
123 | start_test = False
124 | else:
125 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
126 | all_label = torch.cat((all_label, labels.float()), 0)
127 | _, predict = torch.max(all_output, 1)
128 | acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100
129 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
130 |
131 | # if flag:
132 | # matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
133 | # acc = matrix.diagonal()/matrix.sum(axis=1) * 100
134 | # aacc = acc.mean()
135 | # aa = [str(np.round(i, 2)) for i in acc]
136 | # acc = ' '.join(aa)
137 | # return aacc, acc
138 | # else:
139 | # return accuracy*100, mean_ent
140 |
141 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
142 | acc_list = matrix.diagonal() / (matrix.sum(axis=1)+1e-12) * 100
143 | per_class_acc = acc_list.mean()
144 | if args.da == 'pda':
145 | acc_list = ''
146 | per_class_acc = 0
147 | return acc, mean_ent, per_class_acc, acc_list
148 |
149 |
150 | def train_target(args):
151 | dset_loaders = data_load(args)
152 | ## set base network
153 | if args.net[0:3] == 'res':
154 | netF = network.ResBase(res_name=args.net).cuda()
155 | elif args.net[0:3] == 'vgg':
156 | netF = network.VGGBase(vgg_name=args.net).cuda()
157 |
158 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
159 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
160 |
161 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_F.pt".format(args.s, args.net))
162 | netF.load_state_dict(torch.load(modelpath))
163 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_B.pt".format(args.s, args.net))
164 | netB.load_state_dict(torch.load(modelpath))
165 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_C.pt".format(args.s, args.net))
166 | netC.load_state_dict(torch.load(modelpath))
167 | netC.eval()
168 | for k, v in netC.named_parameters():
169 | v.requires_grad = False
170 |
171 | param_group = []
172 | for k, v in netF.named_parameters():
173 | if args.lr_decay1 > 0:
174 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
175 | else:
176 | v.requires_grad = False
177 | for k, v in netB.named_parameters():
178 | if args.lr_decay2 > 0:
179 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
180 | else:
181 | v.requires_grad = False
182 |
183 | optimizer = optim.SGD(param_group)
184 | optimizer = op_copy(optimizer)
185 |
186 | max_iter = args.max_epoch * len(dset_loaders["target"])
187 | interval_iter = max_iter // args.interval
188 | iter_num = 0
189 |
190 | while iter_num < max_iter:
191 | try:
192 | inputs_test, _, tar_idx = iter_test.next()
193 | except:
194 | iter_test = iter(dset_loaders["target"])
195 | inputs_test, _, tar_idx = iter_test.next()
196 |
197 | if inputs_test.size(0) == 1:
198 | continue
199 |
200 | if iter_num % interval_iter == 0 and args.cls_par > 0:
201 | netF.eval()
202 | netB.eval()
203 | mem_label = obtain_label(dset_loaders['valid'], netF, netB, netC, args)
204 | mem_label = torch.from_numpy(mem_label).cuda()
205 | netF.train()
206 | netB.train()
207 |
208 | if args.use_balanced_sampler:
209 | dset_loaders["target"].sampler.update(mem_label)
210 |
211 | inputs_test = inputs_test.cuda()
212 |
213 | iter_num += 1
214 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
215 |
216 | features_test = netB(netF(inputs_test))
217 | outputs_test = netC(features_test)
218 |
219 | if args.cls_par > 0:
220 | pred = mem_label[tar_idx]
221 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
222 | classifier_loss *= args.cls_par
223 | if iter_num < interval_iter and args.dset == "visda-2017":
224 | classifier_loss *= 0
225 | else:
226 | classifier_loss = torch.tensor(0.0).cuda()
227 |
228 | if args.ent:
229 | softmax_out = nn.Softmax(dim=1)(outputs_test)
230 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
231 | if args.gent:
232 | msoftmax = softmax_out.mean(dim=0)
233 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
234 | entropy_loss -= gentropy_loss
235 | im_loss = entropy_loss * args.ent_par
236 | classifier_loss += im_loss
237 |
238 | optimizer.zero_grad()
239 | classifier_loss.backward()
240 | optimizer.step()
241 |
242 | if iter_num % interval_iter == 0 or iter_num == max_iter:
243 | netF.eval()
244 | netB.eval()
245 |
246 | if args.dset=='visda-2017':
247 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC)
248 | aa = [str(np.round(i, 2)) for i in acc_list]
249 | aa = ' '.join(aa)
250 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc) + '\n' + aa
251 | else:
252 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC)
253 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc)
254 |
255 | # args.out_file.write(log_str + '\n')
256 | # args.out_file.flush()
257 | # print(log_str+'\n')
258 | logging.info(log_str)
259 | netF.train()
260 | netB.train()
261 |
262 | if args.issave:
263 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt"))
264 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt"))
265 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt"))
266 |
267 | return netF, netB, netC
268 |
269 | def print_args(args):
270 | s = "==========================================\n"
271 | for arg, content in args.__dict__.items():
272 | s += "{}:{}\n".format(arg, content)
273 | return s
274 |
275 | def obtain_label(loader, netF, netB, netC, args):
276 | start_test = True
277 | with torch.no_grad():
278 | iter_test = iter(loader)
279 | for _ in range(len(loader)):
280 | data = iter_test.next()
281 | inputs = data[0]
282 | labels = data[1]
283 | inputs = inputs.cuda()
284 | feas = netB(netF(inputs))
285 | outputs = netC(feas)
286 | if start_test:
287 | all_fea = feas.float().cpu()
288 | all_output = outputs.float().cpu()
289 | all_label = labels.float()
290 | start_test = False
291 | else:
292 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
293 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
294 | all_label = torch.cat((all_label, labels.float()), 0)
295 |
296 | all_output = nn.Softmax(dim=1)(all_output)
297 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
298 | unknown_weight = 1 - ent / np.log(args.class_num)
299 | _, predict = torch.max(all_output, 1)
300 |
301 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
302 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
303 |
304 | acc_list = matrix.diagonal() / (matrix.sum(axis=1)+1e-12)
305 | avg_accuracy = (acc_list).mean()
306 | if args.da == 'pda':
307 | acc_list = ''
308 | avg_accuracy = 0
309 |
310 | if args.distance == 'cosine':
311 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
312 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
313 |
314 | all_fea = all_fea.float().cpu().numpy()
315 | K = all_output.size(1)
316 | aff = all_output.float().cpu().numpy()
317 | initc = aff.transpose().dot(all_fea)
318 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
319 | cls_count = np.eye(K)[predict].sum(axis=0)
320 | labelset = np.where(cls_count>args.threshold)
321 | labelset = labelset[0]
322 | # print(labelset)
323 |
324 | dd = cdist(all_fea, initc[labelset], args.distance)
325 | pred_label = dd.argmin(axis=1)
326 | pred_label = labelset[pred_label]
327 |
328 | for round in range(1):
329 | aff = np.eye(K)[pred_label]
330 | initc = aff.transpose().dot(all_fea)
331 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
332 | dd = cdist(all_fea, initc[labelset], args.distance)
333 | pred_label = dd.argmin(axis=1)
334 | pred_label = labelset[pred_label]
335 |
336 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
337 | matrix = confusion_matrix(all_label.float().numpy(), pred_label)
338 | acc_list = matrix.diagonal() / (matrix.sum(axis=1)+1e-12)
339 | avg_acc = acc_list.mean()
340 | if args.da == 'pda':
341 | acc_list = ''
342 | avg_acc = 0
343 | log_str = 'Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100, avg_accuracy * 100, avg_acc * 100)
344 |
345 | # args.out_file.write(log_str + '\n')
346 | # args.out_file.flush()
347 | # print(log_str+'\n')
348 | logging.info(log_str)
349 |
350 | return pred_label.astype('int')
351 |
352 |
353 | if __name__ == "__main__":
354 | parser = argparse.ArgumentParser(description='SHOT')
355 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
356 | parser.add_argument('--s', type=str, default=None, help="source")
357 | parser.add_argument('--t', type=str, default=None, help="target")
358 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations")
359 | parser.add_argument('--interval', type=int, default=15)
360 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
361 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
362 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'domainnet40', 'office31',
363 | 'office-home', 'office-home-rsut', 'office-caltech', 'multi'])
364 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
365 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet50, res101")
366 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
367 |
368 | parser.add_argument('--gent', type=bool, default=True)
369 | parser.add_argument('--ent', type=bool, default=True)
370 | parser.add_argument('--threshold', type=int, default=0)
371 | parser.add_argument('--cls_par', type=float, default=0.3)
372 | parser.add_argument('--ent_par', type=float, default=1.0)
373 | parser.add_argument('--lr_decay1', type=float, default=0.1)
374 | parser.add_argument('--lr_decay2', type=float, default=1.0)
375 |
376 | parser.add_argument('--bottleneck', type=int, default=256)
377 | parser.add_argument('--epsilon', type=float, default=1e-5)
378 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
379 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
380 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
381 | parser.add_argument('--output', type=str, default='san')
382 | parser.add_argument('--output_src', type=str, default='san')
383 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
384 | parser.add_argument('--issave', type=bool, default=True)
385 |
386 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp')
387 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)),
388 | help='whether use file logger')
389 | parser.add_argument('--names', default=[], type=list, help='names of tasks')
390 | parser.add_argument('--use_train_transform', default='False', type=lambda x: bool(distutils.util.strtobool(x)),
391 | help='whether use train transform for label refinement')
392 | parser.add_argument('--use_balanced_sampler', default='False', type=lambda x: bool(distutils.util.strtobool(x)),
393 | help='whether use class balanced sampler')
394 | args = parser.parse_args()
395 |
396 | if args.dset == 'office-home':
397 | args.names = ['Art', 'Clipart', 'Product', 'Real_World']
398 | args.class_num = 65
399 | if args.dset == 'office-home-rsut':
400 | args.names = ['Clipart', 'Product', 'Real_World']
401 | args.class_num = 65
402 | if args.dset == 'domainnet40':
403 | args.names = ['sketch', 'clipart', 'painting', 'real']
404 | args.class_num = 40
405 | if args.dset == 'multi':
406 | args.names = ['real', 'clipart', 'sketch', 'painting']
407 | args.class_num = 126
408 | if args.dset == 'office31':
409 | args.names = ['amazon', 'dslr', 'webcam']
410 | args.class_num = 31
411 | if args.dset == 'visda-2017':
412 | args.names = ['train', 'validation']
413 | args.class_num = 12
414 | if args.dset == 'office-caltech':
415 | args.names = ['amazon', 'caltech', 'dslr', 'webcam']
416 | args.class_num = 10
417 |
418 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
419 | resetRNGseed(args.seed)
420 |
421 | if args.dset == 'office-home-rsut':
422 | args.s += '_RS'
423 |
424 | dir = "{}_{}_{}".format(args.timestamp, args.s, args.da)
425 | if args.use_file_logger:
426 | init_logger(dir, True, '../logs/SHOT/shot/')
427 | logging.info("{}:{}".format(get_hostname(), get_pid()))
428 |
429 | for t in args.names:
430 | if t == args.s or t == args.s.split('_RS')[0]:
431 | continue
432 | args.t = t
433 |
434 | if args.dset == 'office-home-rsut':
435 | args.t += '_UT'
436 |
437 | folder = '../data/'
438 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
439 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
440 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
441 |
442 | if args.dset == 'domainnet40':
443 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt'
444 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '_train_mini.txt'
445 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '_test_mini.txt'
446 |
447 | if args.dset == 'office-home':
448 | if args.da == 'pda':
449 | args.class_num = 65
450 | args.src_classes = [i for i in range(65)]
451 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
452 |
453 | # args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper())
454 | # args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
455 | # args.name = names[args.s][0].upper()+names[args.t][0].upper()
456 | args.output_dir_src = "../checkpoints/SHOT/source/{}/".format(args.da)
457 | args.output_dir = "../checkpoints/SHOT/target/{}/".format(args.da)
458 |
459 | if not osp.exists(args.output_dir):
460 | os.system('mkdir -p ' + args.output_dir)
461 | if not osp.exists(args.output_dir):
462 | os.mkdir(args.output_dir)
463 |
464 | args.savename = 'par_' + str(args.cls_par)
465 | if args.da == 'pda':
466 | args.gent = ''
467 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold)
468 | # args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
469 | # args.out_file.write(print_args(args)+'\n')
470 | # args.out_file.flush()
471 | logging.info(print_args(args))
472 | train_target(args)
--------------------------------------------------------------------------------
/SHOT/image_target_kSHOT.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torchvision import transforms
8 | import network, loss
9 | from torch.utils.data import DataLoader
10 | from data_list import ImageList_idx
11 | from scipy.spatial.distance import cdist
12 | from sklearn.metrics import confusion_matrix
13 | import distutils
14 | import distutils.util
15 | import logging
16 |
17 | import sys, os
18 | sys.path.append("../util/")
19 | from utils import resetRNGseed, init_logger, get_hostname, get_pid
20 |
21 | sys.path.append("../pklib")
22 | from pksolver import PK_solver
23 |
24 | import time
25 | timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
26 |
27 | torch.backends.cudnn.deterministic = True
28 | torch.backends.cudnn.benchmark = False
29 |
30 |
31 | def op_copy(optimizer):
32 | for param_group in optimizer.param_groups:
33 | param_group['lr0'] = param_group['lr']
34 | return optimizer
35 |
36 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
37 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
38 | for param_group in optimizer.param_groups:
39 | param_group['lr'] = param_group['lr0'] * decay
40 | param_group['weight_decay'] = 1e-3
41 | param_group['momentum'] = 0.9
42 | param_group['nesterov'] = True
43 | return optimizer
44 |
45 | def image_train(resize_size=256, crop_size=224, alexnet=False):
46 | if not alexnet:
47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
48 | std=[0.229, 0.224, 0.225])
49 | else:
50 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
51 | return transforms.Compose([
52 | transforms.Resize((resize_size, resize_size)),
53 | transforms.RandomCrop(crop_size),
54 | transforms.RandomHorizontalFlip(),
55 | transforms.ToTensor(),
56 | normalize
57 | ])
58 |
59 | def image_test(resize_size=256, crop_size=224, alexnet=False):
60 | if not alexnet:
61 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
62 | std=[0.229, 0.224, 0.225])
63 | else:
64 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
65 | return transforms.Compose([
66 | transforms.Resize((resize_size, resize_size)),
67 | transforms.CenterCrop(crop_size),
68 | transforms.ToTensor(),
69 | normalize
70 | ])
71 |
72 | def data_load(args):
73 | ## prepare data
74 | dsets = {}
75 | dset_loaders = {}
76 | train_bs = args.batch_size
77 | txt_tar = open(args.t_dset_path).readlines()
78 | txt_test = open(args.test_dset_path).readlines()
79 |
80 | if not args.da == 'uda':
81 | label_map_s = {}
82 | for i in range(len(args.src_classes)):
83 | label_map_s[args.src_classes[i]] = i
84 |
85 | new_tar = []
86 | for i in range(len(txt_tar)):
87 | rec = txt_tar[i]
88 | reci = rec.strip().split(' ')
89 | if int(reci[1]) in args.tar_classes:
90 | if int(reci[1]) in args.src_classes:
91 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
92 | new_tar.append(line)
93 | else:
94 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
95 | new_tar.append(line)
96 | txt_tar = new_tar.copy()
97 | txt_test = txt_tar.copy()
98 |
99 | dsets["target"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train())
100 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
101 |
102 | dsets["test"] = ImageList_idx(txt_test, root="../data/{}/".format(args.dset), transform=image_test())
103 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
104 | dsets["valid"] = ImageList_idx(txt_tar, root="../data/{}/".format(args.dset), transform=image_train() if args.use_train_transform else image_test())
105 | dset_loaders["valid"] = DataLoader(dsets["valid"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, drop_last=False)
106 |
107 | return dset_loaders
108 |
109 | def cal_acc(loader, netF, netB, netC):
110 | start_test = True
111 | with torch.no_grad():
112 | iter_test = iter(loader)
113 | for i in range(len(loader)):
114 | data = iter_test.next()
115 | inputs = data[0]
116 | labels = data[1]
117 | inputs = inputs.cuda()
118 | outputs = netC(netB(netF(inputs)))
119 | if start_test:
120 | all_output = outputs.float().cpu()
121 | all_label = labels.float()
122 | start_test = False
123 | else:
124 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
125 | all_label = torch.cat((all_label, labels.float()), 0)
126 | _, predict = torch.max(all_output, 1)
127 | acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100
128 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
129 |
130 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
131 | acc_list = matrix.diagonal() / matrix.sum(axis=1) * 100
132 | per_class_acc = acc_list.mean()
133 | if args.da == 'pda':
134 | acc_list = ''
135 | per_class_acc = 0
136 | acc_list = ' '.join([str(np.round(i, 2)) for i in acc_list])
137 | return acc, mean_ent, per_class_acc, acc_list
138 |
139 |
140 | def train_target(args):
141 | dset_loaders = data_load(args)
142 | ## set base network
143 | if args.net[0:3] == 'res':
144 | netF = network.ResBase(res_name=args.net).cuda()
145 | elif args.net[0:3] == 'vgg':
146 | netF = network.VGGBase(vgg_name=args.net).cuda()
147 |
148 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
149 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
150 |
151 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_F.pt".format(args.s, args.net))
152 | netF.load_state_dict(torch.load(modelpath))
153 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_B.pt".format(args.s, args.net))
154 | netB.load_state_dict(torch.load(modelpath))
155 | modelpath = osp.join(args.output_dir_src, "{}_{}_source_C.pt".format(args.s, args.net))
156 | netC.load_state_dict(torch.load(modelpath))
157 | netC.eval()
158 | for k, v in netC.named_parameters():
159 | v.requires_grad = False
160 |
161 | param_group = []
162 | for k, v in netF.named_parameters():
163 | if args.lr_decay1 > 0:
164 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
165 | else:
166 | v.requires_grad = False
167 | for k, v in netB.named_parameters():
168 | if args.lr_decay2 > 0:
169 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
170 | else:
171 | v.requires_grad = False
172 |
173 | optimizer = optim.SGD(param_group)
174 | optimizer = op_copy(optimizer)
175 |
176 | max_iter = args.max_epoch * len(dset_loaders["target"])
177 | interval_iter = max_iter // args.interval
178 | iter_num = 0
179 |
180 | # get ground-truth label probabilities of target domain
181 | start = True
182 | iter_valid = iter(dset_loaders['valid'])
183 | for _ in range(len(dset_loaders['valid'])):
184 | data = iter_valid.next()
185 | labels = data[1]
186 | if start:
187 | all_label = labels.long()
188 | start = False
189 | else:
190 | all_label = torch.cat((all_label, labels.long()), 0)
191 |
192 | cls_probs = torch.eye(args.class_num)[all_label].sum(0)
193 | cls_probs = cls_probs / cls_probs.sum()
194 |
195 | if args.pk_dratio < 1.0:
196 | ND = int(len(all_label)*args.pk_dratio)
197 | cls_probs_sample = torch.eye(args.class_num)[all_label[torch.randint(len(all_label), (ND,))]].sum(0)
198 | cls_probs_sample = cls_probs_sample / cls_probs_sample.sum()
199 | err = (cls_probs_sample-cls_probs)/cls_probs
200 | logging.info('True probs: {}'.format(cls_probs))
201 | logging.info('Sample probs: {}'.format(cls_probs_sample))
202 | logging.info('Probs err: {}, max err: {}, mean err{}'.format(err, err.abs().max(), err.abs().mean()))
203 | cls_probs = cls_probs_sample
204 |
205 | pk_solver = PK_solver(all_label.shape[0], args.class_num, pk_prior_weight=args.pk_prior_weight)
206 | if args.pk_type == 'ub':
207 | pk_solver.create_C_ub(cls_probs.cpu().numpy(), args.pk_uconf)
208 | elif args.pk_type == 'br':
209 | pk_solver.create_C_br(cls_probs.cpu().numpy(), args.pk_uconf)
210 | elif args.pk_type == 'ub+rel':
211 | pk_solver.create_C_ub(cls_probs.cpu().numpy(), args.pk_uconf)
212 | pk_solver.create_C_br(cls_probs.cpu().numpy(), 1.0)
213 | elif args.pk_type == 'ub_partial':
214 | pk_solver.create_C_ub_partial(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC)
215 | elif args.pk_type == 'ub_partial_reverse':
216 | pk_solver.create_C_ub_partial_reverse(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC)
217 | elif args.pk_type == 'ub_partial_rand':
218 | pk_solver.create_C_ub_partial_rand(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC)
219 | elif args.pk_type == 'br_partial':
220 | pk_solver.create_C_br_partial(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC)
221 | elif args.pk_type == 'br_partial_reverse':
222 | pk_solver.create_C_br_partial_reverse(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC)
223 | elif args.pk_type == 'br_partial_rand':
224 | pk_solver.create_C_br_partial_rand(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_NC)
225 | elif args.pk_type == 'ub_noisy':
226 | pk_solver.create_C_ub_noisy(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_noise)
227 | elif args.pk_type == 'br_noisy':
228 | pk_solver.create_C_br_noisy(cls_probs.cpu().numpy(), args.pk_uconf, args.pk_noise)
229 |
230 | epoch = 0
231 | while iter_num < max_iter:
232 | try:
233 | inputs_test, _, tar_idx = iter_test.next()
234 | except:
235 | iter_test = iter(dset_loaders["target"])
236 | inputs_test, _, tar_idx = iter_test.next()
237 |
238 | if inputs_test.size(0) == 1:
239 | continue
240 |
241 | if iter_num % interval_iter == 0 and args.cls_par > 0:
242 | netF.eval()
243 | netB.eval()
244 | mem_label = obtain_label(dset_loaders['valid'], netF, netB, netC, args, pk_solver, epoch)
245 | mem_label = torch.from_numpy(mem_label).cuda()
246 | netF.train()
247 | netB.train()
248 | epoch += 1
249 |
250 |
251 | inputs_test = inputs_test.cuda()
252 |
253 | iter_num += 1
254 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
255 |
256 | features_test = netB(netF(inputs_test))
257 | outputs_test = netC(features_test)
258 |
259 | if args.cls_par > 0:
260 | pred = mem_label[tar_idx]
261 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
262 | classifier_loss *= args.cls_par
263 | if iter_num < interval_iter and args.dset == "visda-2017":
264 | classifier_loss *= 0
265 | else:
266 | classifier_loss = torch.tensor(0.0).cuda()
267 |
268 | if args.ent:
269 | softmax_out = nn.Softmax(dim=1)(outputs_test)
270 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
271 | if args.gent:
272 | msoftmax = softmax_out.mean(dim=0)
273 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
274 | entropy_loss -= gentropy_loss
275 | im_loss = entropy_loss * args.ent_par
276 | classifier_loss += im_loss
277 |
278 | optimizer.zero_grad()
279 | classifier_loss.backward()
280 | optimizer.step()
281 |
282 | if iter_num % interval_iter == 0 or iter_num == max_iter:
283 | netF.eval()
284 | netB.eval()
285 |
286 | if args.dset=='visda-2017':
287 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC)
288 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc) + '\n' + acc_list
289 | else:
290 | acc, _, per_class_acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC)
291 | log_str = 'Task: {}->{}, Iter:{}/{}; Accuracy = {:.2f}% Per_class_accuracy={:.2f}'.format(args.s, args.t, iter_num, max_iter, acc, per_class_acc)
292 |
293 | logging.info(log_str)
294 | netF.train()
295 | netB.train()
296 |
297 | if args.issave:
298 | torch.save(netF.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_F_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt"))
299 | torch.save(netB.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_B_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt"))
300 | torch.save(netC.state_dict(), osp.join(args.output_dir, "{}_{}_{}_{}_target_C_".format(args.timestamp, args.s, args.t, args.net) + args.savename + ".pt"))
301 |
302 | return netF, netB, netC
303 |
304 | def print_args(args):
305 | s = "==========================================\n"
306 | for arg, content in args.__dict__.items():
307 | s += "{}:{}\n".format(arg, content)
308 | return s
309 |
310 | def obtain_label(loader, netF, netB, netC, args, pk_solver, epoch):
311 | start_test = True
312 | with torch.no_grad():
313 | iter_test = iter(loader)
314 | for _ in range(len(loader)):
315 | data = iter_test.next()
316 | inputs = data[0]
317 | labels = data[1]
318 | inputs = inputs.cuda()
319 | feas = netB(netF(inputs))
320 | outputs = netC(feas)
321 | if start_test:
322 | all_fea = feas.float().cpu()
323 | all_output = outputs.float().cpu()
324 | all_label = labels.float()
325 | start_test = False
326 | else:
327 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
328 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
329 | all_label = torch.cat((all_label, labels.float()), 0)
330 |
331 | all_output = nn.Softmax(dim=1)(all_output)
332 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
333 | unknown_weight = 1 - ent / np.log(args.class_num)
334 | _, predict = torch.max(all_output, 1)
335 |
336 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
337 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
338 | avg_accuracy = (matrix.diagonal() / matrix.sum(axis=1)).mean()
339 | if args.da == 'pda':
340 | avg_accuracy = 0
341 |
342 | if args.distance == 'cosine':
343 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
344 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
345 |
346 | all_fea = all_fea.float().cpu().numpy()
347 | K = all_output.size(1)
348 | aff = all_output.float().cpu().numpy()
349 | initc = aff.transpose().dot(all_fea)
350 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
351 |
352 | dd = cdist(all_fea, initc, args.distance)
353 | dd[np.isnan(dd)] = np.inf
354 | pred_label = dd.argmin(axis=1)
355 |
356 | for round in range(1):
357 | aff = np.eye(K)[pred_label]
358 | initc = aff.transpose().dot(all_fea)
359 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
360 | dd = cdist(all_fea, initc, args.distance)
361 | dd[np.isnan(dd)] = np.inf
362 | pred_label = dd.argmin(axis=1)
363 |
364 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
365 | matrix = confusion_matrix(all_label.float().numpy(), pred_label)
366 | avg_acc = (matrix.diagonal() / matrix.sum(axis=1)).mean()
367 | if args.da == 'pda':
368 | acc_list = ''
369 | avg_acc = 0
370 | log_str = 'Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100, avg_accuracy * 100, avg_acc * 100)
371 | logging.info(log_str)
372 |
373 | # update labels with prior knowledge
374 | T = args.pk_temp
375 | probs = np.exp(-dd / T)
376 | probs = probs / probs.sum(1, keepdims=True)
377 | # first solve without smooth regularization
378 | pred_label_PK = pk_solver.solve_soft(probs)
379 |
380 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / len(all_fea)
381 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK)
382 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean()
383 | if args.da == 'pda':
384 | avg_acc_PK = 0
385 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(acc * 100, acc_PK * 100, avg_acc * 100, avg_acc_PK * 100)
386 | logging.info(log_str)
387 |
388 | # now solve with smooth regularization
389 | if args.pk_knn > 0:
390 | idx_unconf = np.where(pred_label_PK != pred_label)[0]
391 | knn_sample_idx = idx_unconf
392 | idx_conf = np.where(pred_label_PK == pred_label)[0]
393 |
394 | if len(idx_unconf) > 0 and len(idx_conf) > 0:
395 | # get knn of each samples
396 | dd_knn = cdist(all_fea[idx_unconf], all_fea[idx_conf], args.distance)
397 | knn_idx = []
398 | K = args.pk_knn
399 | for i in range(dd_knn.shape[0]):
400 | ind = np.argpartition(dd_knn[i], K)[:K]
401 | knn_idx.append(idx_conf[ind])
402 |
403 | knn_idx = np.stack(knn_idx, axis=0)
404 | knn_regs = list(zip(knn_sample_idx, knn_idx))
405 | pred_label_PK = pk_solver.solve_soft_knn_cst(probs, knn_regs=knn_regs)
406 |
407 |
408 | acc_PK = np.sum(pred_label_PK == all_label.float().numpy()) / len(all_fea)
409 | matrix_PK = confusion_matrix(all_label.float().numpy(), pred_label_PK)
410 | avg_acc_PK = (matrix_PK.diagonal() / matrix_PK.sum(axis=1)).mean()
411 | if args.da == 'pda':
412 | avg_acc_PK = 0
413 | log_str = 'PK Accuracy = {:.2f}% -> {:.2f}% Per_class_accuracy = {:.2f}% -> {:.2f}%'.format(acc * 100, acc_PK * 100, avg_acc * 100, avg_acc_PK * 100)
414 | logging.info(log_str)
415 |
416 | return pred_label_PK.astype('int')
417 |
418 |
419 | if __name__ == "__main__":
420 | parser = argparse.ArgumentParser(description='SHOT')
421 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
422 | parser.add_argument('--s', type=str, default=None, help="source")
423 | parser.add_argument('--t', type=str, default=None, help="target")
424 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations")
425 | parser.add_argument('--interval', type=int, default=15)
426 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
427 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
428 | parser.add_argument('--dset', type=str, default='office-home', choices=['visda-2017', 'domainnet40', 'office31', 'office-home', 'office-home-rsut', 'office-caltech'])
429 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
430 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet50, res101")
431 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
432 |
433 | parser.add_argument('--gent', type=bool, default=True)
434 | parser.add_argument('--ent', type=bool, default=True)
435 | parser.add_argument('--threshold', type=int, default=0)
436 | parser.add_argument('--cls_par', type=float, default=0.3)
437 | parser.add_argument('--ent_par', type=float, default=1.0)
438 | parser.add_argument('--lr_decay1', type=float, default=0.1)
439 | parser.add_argument('--lr_decay2', type=float, default=1.0)
440 |
441 | parser.add_argument('--bottleneck', type=int, default=256)
442 | parser.add_argument('--epsilon', type=float, default=1e-5)
443 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
444 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
445 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
446 | parser.add_argument('--output', type=str, default='san')
447 | parser.add_argument('--output_src', type=str, default='san')
448 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
449 | parser.add_argument('--issave', type=bool, default=True)
450 |
451 | parser.add_argument('--timestamp', default=timestamp, type=str, help='timestamp')
452 | parser.add_argument('--use_file_logger', default='True', type=lambda x: bool(distutils.util.strtobool(x)),
453 | help='whether use file logger')
454 | parser.add_argument('--names', default=[], type=list, help='names of tasks')
455 | parser.add_argument('--use_train_transform', default='False', type=lambda x: bool(distutils.util.strtobool(x)),
456 | help='whether use train transform for label refinement')
457 |
458 | parser.add_argument('--pk_uconf', type=float, default=0.0)
459 | parser.add_argument('--pk_type', type=str, default="ub")
460 | parser.add_argument('--pk_allow', type=int, default=None)
461 | parser.add_argument('--pk_temp', type=float, default=1.0)
462 | parser.add_argument('--pk_prior_weight', type=float, default=10.)
463 | parser.add_argument('--pk_knn', type=int, default=1)
464 | parser.add_argument('--pk_NC', type=int, default=None)
465 | parser.add_argument('--pk_noise', type=float, default=0.0)
466 | parser.add_argument('--pk_dratio', type=float, default=1.0)
467 | parser.add_argument('--method', type=str, default="kshot")
468 |
469 | args = parser.parse_args()
470 |
471 | if args.dset == 'office-home':
472 | args.names = ['Art', 'Clipart', 'Product', 'Real_World']
473 | args.class_num = 65
474 | if args.dset == 'office-home-rsut':
475 | args.names = ['Clipart', 'Product', 'Real_World']
476 | args.class_num = 65
477 | if args.dset == 'domainnet40':
478 | args.names = ['sketch', 'clipart', 'painting', 'real']
479 | args.class_num = 40
480 | if args.dset == 'office31':
481 | args.names = ['webcam', 'amazon', 'dslr']
482 | args.class_num = 31
483 | if args.dset == 'visda-2017':
484 | args.names = ['train', 'validation']
485 | args.class_num = 12
486 | if args.dset == 'office-caltech':
487 | args.names = ['amazon', 'caltech', 'dslr', 'webcam']
488 | args.class_num = 10
489 |
490 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
491 | resetRNGseed(args.seed)
492 |
493 | if args.dset == 'office-home-rsut':
494 | args.s += '_RS'
495 |
496 | dir = "{}_{}_{}_{}".format(args.timestamp, args.s, args.da, args.method)
497 | if args.use_file_logger:
498 | init_logger(dir, True, '../logs/SHOT/{}/'.format(args.method))
499 | logging.info("{}:{}".format(get_hostname(), get_pid()))
500 |
501 | for t in args.names:
502 | if t == args.s or t == args.s.split('_RS')[0]:
503 | continue
504 | args.t = t
505 |
506 | if args.dset == 'office-home-rsut':
507 | args.t += '_UT'
508 |
509 | folder = '../data/'
510 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '.txt'
511 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
512 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '.txt'
513 |
514 | if args.dset == 'domainnet40':
515 | args.s_dset_path = folder + args.dset + '/image_list/' + args.s + '_train_mini.txt'
516 | args.t_dset_path = folder + args.dset + '/image_list/' + args.t + '_train_mini.txt'
517 | args.test_dset_path = folder + args.dset + '/image_list/' + args.t + '_test_mini.txt'
518 |
519 | if args.dset == 'office-home':
520 | if args.da == 'pda':
521 | args.class_num = 65
522 | args.src_classes = [i for i in range(65)]
523 | args.tar_classes = [33, 32, 36, 15, 19, 2, 46, 49, 48, 53, 47, 54, 4, 18, 57, 23, 0, 45, 1, 38, 5, 13, 50, 11, 58]
524 |
525 | args.output_dir_src = "../checkpoints/SHOT/source/{}/".format(args.da)
526 | args.output_dir = "../checkpoints/SHOT/target_{}/".format(args.method)
527 |
528 | if not osp.exists(args.output_dir):
529 | os.system('mkdir -p ' + args.output_dir)
530 | if not osp.exists(args.output_dir):
531 | os.mkdir(args.output_dir)
532 |
533 | args.savename = 'par_' + str(args.cls_par)
534 | if args.da == 'pda':
535 | args.gent = ''
536 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold)
537 |
538 | logging.info(print_args(args))
539 | train_target(args)
--------------------------------------------------------------------------------
/SHOT/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def PK_loss(PK_solver, prob):
10 | pk_loss = 0.0
11 | N = PK_solver.N
12 | if PK_solver.C_abs is not None and len(PK_solver.C_abs)>0:
13 | for (c, lb, ub) in PK_solver.C_abs:
14 | if lb is not None:
15 | pk_loss += torch.maximum(lb/N-prob[c], torch.tensor(0.))
16 | if ub is not None:
17 | pk_loss += torch.maximum(-ub/N+prob[c], torch.tensor(0.))
18 |
19 | if PK_solver.C_rel is not None and len(PK_solver.C_rel)>0:
20 | for (c1, c2, diff) in PK_solver.C_rel:
21 | pk_loss += torch.maximum(diff-prob[c1]+prob[c2], torch.tensor(0.))
22 |
23 | return pk_loss
24 |
25 | def Entropy(input_):
26 | bs = input_.size(0)
27 | epsilon = 1e-5
28 | entropy = -input_ * torch.log(input_ + epsilon)
29 | entropy = torch.sum(entropy, dim=1)
30 | return entropy
31 |
32 | def grl_hook(coeff):
33 | def fun1(grad):
34 | return -coeff*grad.clone()
35 | return fun1
36 |
37 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
38 | softmax_output = input_list[1].detach()
39 | feature = input_list[0]
40 | if random_layer is None:
41 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
42 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
43 | else:
44 | random_out = random_layer.forward([feature, softmax_output])
45 | ad_out = ad_net(random_out.view(-1, random_out.size(1)))
46 | batch_size = softmax_output.size(0) // 2
47 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
48 | if entropy is not None:
49 | entropy.register_hook(grl_hook(coeff))
50 | entropy = 1.0+torch.exp(-entropy)
51 | source_mask = torch.ones_like(entropy)
52 | source_mask[feature.size(0)//2:] = 0
53 | source_weight = entropy*source_mask
54 | target_mask = torch.ones_like(entropy)
55 | target_mask[0:feature.size(0)//2] = 0
56 | target_weight = entropy*target_mask
57 | weight = source_weight / torch.sum(source_weight).detach().item() + \
58 | target_weight / torch.sum(target_weight).detach().item()
59 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
60 | else:
61 | return nn.BCELoss()(ad_out, dc_target)
62 |
63 | def DANN(features, ad_net):
64 | ad_out = ad_net(features)
65 | batch_size = ad_out.size(0) // 2
66 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
67 | return nn.BCELoss()(ad_out, dc_target)
68 |
69 |
70 | class CrossEntropyLabelSmooth(nn.Module):
71 | """Cross entropy loss with label smoothing regularizer.
72 | Reference:
73 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
74 | Equation: y = (1 - epsilon) * y + epsilon / K.
75 | Args:
76 | num_classes (int): number of classes.
77 | epsilon (float): weight.
78 | """
79 |
80 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
81 | super(CrossEntropyLabelSmooth, self).__init__()
82 | self.num_classes = num_classes
83 | self.epsilon = epsilon
84 | self.use_gpu = use_gpu
85 | self.reduction = reduction
86 | self.logsoftmax = nn.LogSoftmax(dim=1)
87 |
88 | def forward(self, inputs, targets):
89 | """
90 | Args:
91 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
92 | targets: ground truth labels with shape (num_classes)
93 | """
94 | log_probs = self.logsoftmax(inputs)
95 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
96 | if self.use_gpu: targets = targets.cuda()
97 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
98 | loss = (- targets * log_probs).sum(dim=1)
99 | if self.reduction:
100 | return loss.mean()
101 | else:
102 | return loss
103 | return loss
--------------------------------------------------------------------------------
/SHOT/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import math
8 | import torch.nn.utils.weight_norm as weightNorm
9 | from collections import OrderedDict
10 |
11 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
12 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low)
13 |
14 | def init_weights(m):
15 | classname = m.__class__.__name__
16 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
17 | nn.init.kaiming_uniform_(m.weight)
18 | nn.init.zeros_(m.bias)
19 | elif classname.find('BatchNorm') != -1:
20 | nn.init.normal_(m.weight, 1.0, 0.02)
21 | nn.init.zeros_(m.bias)
22 | elif classname.find('Linear') != -1:
23 | nn.init.xavier_normal_(m.weight)
24 | nn.init.zeros_(m.bias)
25 |
26 | vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19,
27 | "vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn}
28 | class VGGBase(nn.Module):
29 | def __init__(self, vgg_name):
30 | super(VGGBase, self).__init__()
31 | model_vgg = vgg_dict[vgg_name](pretrained=True)
32 | self.features = model_vgg.features
33 | self.classifier = nn.Sequential()
34 | for i in range(6):
35 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i])
36 | self.in_features = model_vgg.classifier[6].in_features
37 |
38 | def forward(self, x):
39 | x = self.features(x)
40 | x = x.view(x.size(0), -1)
41 | x = self.classifier(x)
42 | return x
43 |
44 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50,
45 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d}
46 |
47 | class ResBase(nn.Module):
48 | def __init__(self, res_name):
49 | super(ResBase, self).__init__()
50 | model_resnet = res_dict[res_name](pretrained=True)
51 | self.conv1 = model_resnet.conv1
52 | self.bn1 = model_resnet.bn1
53 | self.relu = model_resnet.relu
54 | self.maxpool = model_resnet.maxpool
55 | self.layer1 = model_resnet.layer1
56 | self.layer2 = model_resnet.layer2
57 | self.layer3 = model_resnet.layer3
58 | self.layer4 = model_resnet.layer4
59 | self.avgpool = model_resnet.avgpool
60 | self.in_features = model_resnet.fc.in_features
61 |
62 | def forward(self, x):
63 | x = self.conv1(x)
64 | x = self.bn1(x)
65 | x = self.relu(x)
66 | x = self.maxpool(x)
67 | x = self.layer1(x)
68 | x = self.layer2(x)
69 | x = self.layer3(x)
70 | x = self.layer4(x)
71 | x = self.avgpool(x)
72 | x = x.view(x.size(0), -1)
73 | return x
74 |
75 | class feat_bootleneck(nn.Module):
76 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
77 | super(feat_bootleneck, self).__init__()
78 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.dropout = nn.Dropout(p=0.5)
81 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
82 | self.bottleneck.apply(init_weights)
83 | self.type = type
84 |
85 | def forward(self, x):
86 | x = self.bottleneck(x)
87 | if self.type == "bn":
88 | x = self.bn(x)
89 | return x
90 |
91 | class feat_classifier(nn.Module):
92 | def __init__(self, class_num, bottleneck_dim=256, type="linear"):
93 | super(feat_classifier, self).__init__()
94 | self.type = type
95 | if type == 'wn':
96 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
97 | self.fc.apply(init_weights)
98 | else:
99 | self.fc = nn.Linear(bottleneck_dim, class_num)
100 | self.fc.apply(init_weights)
101 |
102 | def forward(self, x):
103 | x = self.fc(x)
104 | return x
105 |
106 | class feat_classifier_two(nn.Module):
107 | def __init__(self, class_num, input_dim, bottleneck_dim=256):
108 | super(feat_classifier_two, self).__init__()
109 | self.type = type
110 | self.fc0 = nn.Linear(input_dim, bottleneck_dim)
111 | self.fc0.apply(init_weights)
112 | self.fc1 = nn.Linear(bottleneck_dim, class_num)
113 | self.fc1.apply(init_weights)
114 |
115 | def forward(self, x):
116 | x = self.fc0(x)
117 | x = self.fc1(x)
118 | return x
119 |
120 | class Res50(nn.Module):
121 | def __init__(self):
122 | super(Res50, self).__init__()
123 | model_resnet = models.resnet50(pretrained=True)
124 | self.conv1 = model_resnet.conv1
125 | self.bn1 = model_resnet.bn1
126 | self.relu = model_resnet.relu
127 | self.maxpool = model_resnet.maxpool
128 | self.layer1 = model_resnet.layer1
129 | self.layer2 = model_resnet.layer2
130 | self.layer3 = model_resnet.layer3
131 | self.layer4 = model_resnet.layer4
132 | self.avgpool = model_resnet.avgpool
133 | self.in_features = model_resnet.fc.in_features
134 | self.fc = model_resnet.fc
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | y = self.fc(x)
148 | return x, y
--------------------------------------------------------------------------------
/SHOT/run_all_kSHOT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | time=`python ../util/get_time.py`
5 |
6 | # office31 -------------------------------------------------------------------------------------------------------------
7 | for src in "amazon" "webcam" "dslr"; do
8 | echo $src
9 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office31 --s $src --max_epoch 100 --timestamp $time
10 | done
11 |
12 | for seed in 2020 2021 2022; do
13 | for src in "amazon" "webcam" "dslr"; do
14 | echo $src
15 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
16 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office31 --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub
17 | done
18 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office31 --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br
19 | done
20 | done
21 |
22 |
23 | # office-home-rsut ----------------------------------------------------------------------------------------------------
24 | for src in "Product" "Clipart" "Real_World"; do
25 | echo $src
26 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office-home-rsut --s $src --max_epoch 50 --timestamp $time
27 | done
28 |
29 | for seed in 2020 2021 2022; do
30 | for src in "Product" "Clipart" "Real_World"; do
31 | echo $src
32 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
33 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home-rsut --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub
34 | done
35 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home-rsut --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br
36 | done
37 | done
38 |
39 |
40 | # office-home ----------------------------------------------------------------------------------------------------------
41 | for src in "Product" "Clipart" "Art" "Real_World"; do
42 | echo $src
43 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset office-home --s $src --max_epoch 50 --timestamp $time
44 | done
45 |
46 | for seed in 2020 2021 2022; do
47 | for src in "Product" "Clipart" "Art" "Real_World"; do
48 | echo $src
49 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
50 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub
51 | done
52 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br
53 | done
54 | done
55 |
56 |
57 | # visda-2017 -----------------------------------------------------------------------------------------------------------
58 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset visda-2017 --s train --max_epoch 10 --timestamp $time --net resnet101 --lr 1e-3
59 |
60 | for seed in 2020 2021 2022; do
61 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
62 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset visda-2017 --s train --timestamp $time --seed $seed --pk_uconf $pk_uconf --net resnet101 --lr 1e-3 --pk_type ub
63 | done
64 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset visda-2017 --s train --timestamp $time --seed $seed --pk_uconf 1.0 --net resnet101 --lr 1e-3 --pk_type br
65 | done
66 |
67 |
68 | # domainnet40 ----------------------------------------------------------------------------------------------------------
69 | for src in "sketch" "clipart" "painting" "real"; do
70 | echo $src
71 | python image_source.py --trte val --da uda --gpu_id $gpu_id --dset domainnet40 --s $src --max_epoch 50 --timestamp $time
72 | done
73 |
74 | for seed in 2020 2021 2022; do
75 | for src in "sketch" "clipart" "painting" "real"; do
76 | echo $src
77 | for pk_uconf in 0.0 0.1 0.5 1.0 2.0; do
78 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset domainnet40 --s $src --timestamp $time --seed $seed --pk_uconf $pk_uconf --pk_type ub
79 | done
80 | python image_target_kSHOT.py --cls_par 0.3 --da uda --gpu_id $gpu_id --dset domainnet40 --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br
81 | done
82 | done
83 |
84 |
85 | # office-home (PDA)-----------------------------------------------------------------------------------------------------
86 | for src in "Product" "Clipart" "Art" "Real_World"; do
87 | echo $src
88 | python image_source.py --trte val --da pda --gpu_id $gpu_id --dset office-home --s $src --max_epoch 50 --timestamp $time
89 | done
90 |
91 | for seed in 2020 2021 2022; do
92 | for src in "Product" "Clipart" "Art" "Real_World"; do
93 | echo $src
94 | python image_target_kSHOT.py --cls_par 0.3 --da pda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf 0.0 --pk_type ub
95 | python image_target_kSHOT.py --cls_par 0.3 --da pda --gpu_id $gpu_id --dset office-home --s $src --timestamp $time --seed $seed --pk_uconf 1.0 --pk_type br
96 | done
97 | done
--------------------------------------------------------------------------------
/data/office31/image_list/dslr.txt:
--------------------------------------------------------------------------------
1 | dslr/images/calculator/frame_0001.jpg 5
2 | dslr/images/calculator/frame_0002.jpg 5
3 | dslr/images/calculator/frame_0003.jpg 5
4 | dslr/images/calculator/frame_0004.jpg 5
5 | dslr/images/calculator/frame_0005.jpg 5
6 | dslr/images/calculator/frame_0006.jpg 5
7 | dslr/images/calculator/frame_0007.jpg 5
8 | dslr/images/calculator/frame_0008.jpg 5
9 | dslr/images/calculator/frame_0009.jpg 5
10 | dslr/images/calculator/frame_0010.jpg 5
11 | dslr/images/calculator/frame_0011.jpg 5
12 | dslr/images/calculator/frame_0012.jpg 5
13 | dslr/images/ring_binder/frame_0001.jpg 24
14 | dslr/images/ring_binder/frame_0002.jpg 24
15 | dslr/images/ring_binder/frame_0003.jpg 24
16 | dslr/images/ring_binder/frame_0004.jpg 24
17 | dslr/images/ring_binder/frame_0005.jpg 24
18 | dslr/images/ring_binder/frame_0006.jpg 24
19 | dslr/images/ring_binder/frame_0007.jpg 24
20 | dslr/images/ring_binder/frame_0008.jpg 24
21 | dslr/images/ring_binder/frame_0009.jpg 24
22 | dslr/images/ring_binder/frame_0010.jpg 24
23 | dslr/images/printer/frame_0001.jpg 21
24 | dslr/images/printer/frame_0002.jpg 21
25 | dslr/images/printer/frame_0003.jpg 21
26 | dslr/images/printer/frame_0004.jpg 21
27 | dslr/images/printer/frame_0005.jpg 21
28 | dslr/images/printer/frame_0006.jpg 21
29 | dslr/images/printer/frame_0007.jpg 21
30 | dslr/images/printer/frame_0008.jpg 21
31 | dslr/images/printer/frame_0009.jpg 21
32 | dslr/images/printer/frame_0010.jpg 21
33 | dslr/images/printer/frame_0011.jpg 21
34 | dslr/images/printer/frame_0012.jpg 21
35 | dslr/images/printer/frame_0013.jpg 21
36 | dslr/images/printer/frame_0014.jpg 21
37 | dslr/images/printer/frame_0015.jpg 21
38 | dslr/images/keyboard/frame_0001.jpg 11
39 | dslr/images/keyboard/frame_0002.jpg 11
40 | dslr/images/keyboard/frame_0003.jpg 11
41 | dslr/images/keyboard/frame_0004.jpg 11
42 | dslr/images/keyboard/frame_0005.jpg 11
43 | dslr/images/keyboard/frame_0006.jpg 11
44 | dslr/images/keyboard/frame_0007.jpg 11
45 | dslr/images/keyboard/frame_0008.jpg 11
46 | dslr/images/keyboard/frame_0009.jpg 11
47 | dslr/images/keyboard/frame_0010.jpg 11
48 | dslr/images/scissors/frame_0001.jpg 26
49 | dslr/images/scissors/frame_0002.jpg 26
50 | dslr/images/scissors/frame_0003.jpg 26
51 | dslr/images/scissors/frame_0004.jpg 26
52 | dslr/images/scissors/frame_0005.jpg 26
53 | dslr/images/scissors/frame_0006.jpg 26
54 | dslr/images/scissors/frame_0007.jpg 26
55 | dslr/images/scissors/frame_0008.jpg 26
56 | dslr/images/scissors/frame_0009.jpg 26
57 | dslr/images/scissors/frame_0010.jpg 26
58 | dslr/images/scissors/frame_0011.jpg 26
59 | dslr/images/scissors/frame_0012.jpg 26
60 | dslr/images/scissors/frame_0013.jpg 26
61 | dslr/images/scissors/frame_0014.jpg 26
62 | dslr/images/scissors/frame_0015.jpg 26
63 | dslr/images/scissors/frame_0016.jpg 26
64 | dslr/images/scissors/frame_0017.jpg 26
65 | dslr/images/scissors/frame_0018.jpg 26
66 | dslr/images/laptop_computer/frame_0001.jpg 12
67 | dslr/images/laptop_computer/frame_0002.jpg 12
68 | dslr/images/laptop_computer/frame_0003.jpg 12
69 | dslr/images/laptop_computer/frame_0004.jpg 12
70 | dslr/images/laptop_computer/frame_0005.jpg 12
71 | dslr/images/laptop_computer/frame_0006.jpg 12
72 | dslr/images/laptop_computer/frame_0007.jpg 12
73 | dslr/images/laptop_computer/frame_0008.jpg 12
74 | dslr/images/laptop_computer/frame_0009.jpg 12
75 | dslr/images/laptop_computer/frame_0010.jpg 12
76 | dslr/images/laptop_computer/frame_0011.jpg 12
77 | dslr/images/laptop_computer/frame_0012.jpg 12
78 | dslr/images/laptop_computer/frame_0013.jpg 12
79 | dslr/images/laptop_computer/frame_0014.jpg 12
80 | dslr/images/laptop_computer/frame_0015.jpg 12
81 | dslr/images/laptop_computer/frame_0016.jpg 12
82 | dslr/images/laptop_computer/frame_0017.jpg 12
83 | dslr/images/laptop_computer/frame_0018.jpg 12
84 | dslr/images/laptop_computer/frame_0019.jpg 12
85 | dslr/images/laptop_computer/frame_0020.jpg 12
86 | dslr/images/laptop_computer/frame_0021.jpg 12
87 | dslr/images/laptop_computer/frame_0022.jpg 12
88 | dslr/images/laptop_computer/frame_0023.jpg 12
89 | dslr/images/laptop_computer/frame_0024.jpg 12
90 | dslr/images/mouse/frame_0001.jpg 16
91 | dslr/images/mouse/frame_0002.jpg 16
92 | dslr/images/mouse/frame_0003.jpg 16
93 | dslr/images/mouse/frame_0004.jpg 16
94 | dslr/images/mouse/frame_0005.jpg 16
95 | dslr/images/mouse/frame_0006.jpg 16
96 | dslr/images/mouse/frame_0007.jpg 16
97 | dslr/images/mouse/frame_0008.jpg 16
98 | dslr/images/mouse/frame_0009.jpg 16
99 | dslr/images/mouse/frame_0010.jpg 16
100 | dslr/images/mouse/frame_0011.jpg 16
101 | dslr/images/mouse/frame_0012.jpg 16
102 | dslr/images/monitor/frame_0001.jpg 15
103 | dslr/images/monitor/frame_0002.jpg 15
104 | dslr/images/monitor/frame_0003.jpg 15
105 | dslr/images/monitor/frame_0004.jpg 15
106 | dslr/images/monitor/frame_0005.jpg 15
107 | dslr/images/monitor/frame_0006.jpg 15
108 | dslr/images/monitor/frame_0007.jpg 15
109 | dslr/images/monitor/frame_0008.jpg 15
110 | dslr/images/monitor/frame_0009.jpg 15
111 | dslr/images/monitor/frame_0010.jpg 15
112 | dslr/images/monitor/frame_0011.jpg 15
113 | dslr/images/monitor/frame_0012.jpg 15
114 | dslr/images/monitor/frame_0013.jpg 15
115 | dslr/images/monitor/frame_0014.jpg 15
116 | dslr/images/monitor/frame_0015.jpg 15
117 | dslr/images/monitor/frame_0016.jpg 15
118 | dslr/images/monitor/frame_0017.jpg 15
119 | dslr/images/monitor/frame_0018.jpg 15
120 | dslr/images/monitor/frame_0019.jpg 15
121 | dslr/images/monitor/frame_0020.jpg 15
122 | dslr/images/monitor/frame_0021.jpg 15
123 | dslr/images/monitor/frame_0022.jpg 15
124 | dslr/images/mug/frame_0001.jpg 17
125 | dslr/images/mug/frame_0002.jpg 17
126 | dslr/images/mug/frame_0003.jpg 17
127 | dslr/images/mug/frame_0004.jpg 17
128 | dslr/images/mug/frame_0005.jpg 17
129 | dslr/images/mug/frame_0006.jpg 17
130 | dslr/images/mug/frame_0007.jpg 17
131 | dslr/images/mug/frame_0008.jpg 17
132 | dslr/images/tape_dispenser/frame_0001.jpg 29
133 | dslr/images/tape_dispenser/frame_0002.jpg 29
134 | dslr/images/tape_dispenser/frame_0003.jpg 29
135 | dslr/images/tape_dispenser/frame_0004.jpg 29
136 | dslr/images/tape_dispenser/frame_0005.jpg 29
137 | dslr/images/tape_dispenser/frame_0006.jpg 29
138 | dslr/images/tape_dispenser/frame_0007.jpg 29
139 | dslr/images/tape_dispenser/frame_0008.jpg 29
140 | dslr/images/tape_dispenser/frame_0009.jpg 29
141 | dslr/images/tape_dispenser/frame_0010.jpg 29
142 | dslr/images/tape_dispenser/frame_0011.jpg 29
143 | dslr/images/tape_dispenser/frame_0012.jpg 29
144 | dslr/images/tape_dispenser/frame_0013.jpg 29
145 | dslr/images/tape_dispenser/frame_0014.jpg 29
146 | dslr/images/tape_dispenser/frame_0015.jpg 29
147 | dslr/images/tape_dispenser/frame_0016.jpg 29
148 | dslr/images/tape_dispenser/frame_0017.jpg 29
149 | dslr/images/tape_dispenser/frame_0018.jpg 29
150 | dslr/images/tape_dispenser/frame_0019.jpg 29
151 | dslr/images/tape_dispenser/frame_0020.jpg 29
152 | dslr/images/tape_dispenser/frame_0021.jpg 29
153 | dslr/images/tape_dispenser/frame_0022.jpg 29
154 | dslr/images/pen/frame_0001.jpg 19
155 | dslr/images/pen/frame_0002.jpg 19
156 | dslr/images/pen/frame_0003.jpg 19
157 | dslr/images/pen/frame_0004.jpg 19
158 | dslr/images/pen/frame_0005.jpg 19
159 | dslr/images/pen/frame_0006.jpg 19
160 | dslr/images/pen/frame_0007.jpg 19
161 | dslr/images/pen/frame_0008.jpg 19
162 | dslr/images/pen/frame_0009.jpg 19
163 | dslr/images/pen/frame_0010.jpg 19
164 | dslr/images/bike/frame_0001.jpg 1
165 | dslr/images/bike/frame_0002.jpg 1
166 | dslr/images/bike/frame_0003.jpg 1
167 | dslr/images/bike/frame_0004.jpg 1
168 | dslr/images/bike/frame_0005.jpg 1
169 | dslr/images/bike/frame_0006.jpg 1
170 | dslr/images/bike/frame_0007.jpg 1
171 | dslr/images/bike/frame_0008.jpg 1
172 | dslr/images/bike/frame_0009.jpg 1
173 | dslr/images/bike/frame_0010.jpg 1
174 | dslr/images/bike/frame_0011.jpg 1
175 | dslr/images/bike/frame_0012.jpg 1
176 | dslr/images/bike/frame_0013.jpg 1
177 | dslr/images/bike/frame_0014.jpg 1
178 | dslr/images/bike/frame_0015.jpg 1
179 | dslr/images/bike/frame_0016.jpg 1
180 | dslr/images/bike/frame_0017.jpg 1
181 | dslr/images/bike/frame_0018.jpg 1
182 | dslr/images/bike/frame_0019.jpg 1
183 | dslr/images/bike/frame_0020.jpg 1
184 | dslr/images/bike/frame_0021.jpg 1
185 | dslr/images/punchers/frame_0001.jpg 23
186 | dslr/images/punchers/frame_0002.jpg 23
187 | dslr/images/punchers/frame_0003.jpg 23
188 | dslr/images/punchers/frame_0004.jpg 23
189 | dslr/images/punchers/frame_0005.jpg 23
190 | dslr/images/punchers/frame_0006.jpg 23
191 | dslr/images/punchers/frame_0007.jpg 23
192 | dslr/images/punchers/frame_0008.jpg 23
193 | dslr/images/punchers/frame_0009.jpg 23
194 | dslr/images/punchers/frame_0010.jpg 23
195 | dslr/images/punchers/frame_0011.jpg 23
196 | dslr/images/punchers/frame_0012.jpg 23
197 | dslr/images/punchers/frame_0013.jpg 23
198 | dslr/images/punchers/frame_0014.jpg 23
199 | dslr/images/punchers/frame_0015.jpg 23
200 | dslr/images/punchers/frame_0016.jpg 23
201 | dslr/images/punchers/frame_0017.jpg 23
202 | dslr/images/punchers/frame_0018.jpg 23
203 | dslr/images/back_pack/frame_0001.jpg 0
204 | dslr/images/back_pack/frame_0002.jpg 0
205 | dslr/images/back_pack/frame_0003.jpg 0
206 | dslr/images/back_pack/frame_0004.jpg 0
207 | dslr/images/back_pack/frame_0005.jpg 0
208 | dslr/images/back_pack/frame_0006.jpg 0
209 | dslr/images/back_pack/frame_0007.jpg 0
210 | dslr/images/back_pack/frame_0008.jpg 0
211 | dslr/images/back_pack/frame_0009.jpg 0
212 | dslr/images/back_pack/frame_0010.jpg 0
213 | dslr/images/back_pack/frame_0011.jpg 0
214 | dslr/images/back_pack/frame_0012.jpg 0
215 | dslr/images/desktop_computer/frame_0001.jpg 8
216 | dslr/images/desktop_computer/frame_0002.jpg 8
217 | dslr/images/desktop_computer/frame_0003.jpg 8
218 | dslr/images/desktop_computer/frame_0004.jpg 8
219 | dslr/images/desktop_computer/frame_0005.jpg 8
220 | dslr/images/desktop_computer/frame_0006.jpg 8
221 | dslr/images/desktop_computer/frame_0007.jpg 8
222 | dslr/images/desktop_computer/frame_0008.jpg 8
223 | dslr/images/desktop_computer/frame_0009.jpg 8
224 | dslr/images/desktop_computer/frame_0010.jpg 8
225 | dslr/images/desktop_computer/frame_0011.jpg 8
226 | dslr/images/desktop_computer/frame_0012.jpg 8
227 | dslr/images/desktop_computer/frame_0013.jpg 8
228 | dslr/images/desktop_computer/frame_0014.jpg 8
229 | dslr/images/desktop_computer/frame_0015.jpg 8
230 | dslr/images/speaker/frame_0001.jpg 27
231 | dslr/images/speaker/frame_0002.jpg 27
232 | dslr/images/speaker/frame_0003.jpg 27
233 | dslr/images/speaker/frame_0004.jpg 27
234 | dslr/images/speaker/frame_0005.jpg 27
235 | dslr/images/speaker/frame_0006.jpg 27
236 | dslr/images/speaker/frame_0007.jpg 27
237 | dslr/images/speaker/frame_0008.jpg 27
238 | dslr/images/speaker/frame_0009.jpg 27
239 | dslr/images/speaker/frame_0010.jpg 27
240 | dslr/images/speaker/frame_0011.jpg 27
241 | dslr/images/speaker/frame_0012.jpg 27
242 | dslr/images/speaker/frame_0013.jpg 27
243 | dslr/images/speaker/frame_0014.jpg 27
244 | dslr/images/speaker/frame_0015.jpg 27
245 | dslr/images/speaker/frame_0016.jpg 27
246 | dslr/images/speaker/frame_0017.jpg 27
247 | dslr/images/speaker/frame_0018.jpg 27
248 | dslr/images/speaker/frame_0019.jpg 27
249 | dslr/images/speaker/frame_0020.jpg 27
250 | dslr/images/speaker/frame_0021.jpg 27
251 | dslr/images/speaker/frame_0022.jpg 27
252 | dslr/images/speaker/frame_0023.jpg 27
253 | dslr/images/speaker/frame_0024.jpg 27
254 | dslr/images/speaker/frame_0025.jpg 27
255 | dslr/images/speaker/frame_0026.jpg 27
256 | dslr/images/mobile_phone/frame_0001.jpg 14
257 | dslr/images/mobile_phone/frame_0002.jpg 14
258 | dslr/images/mobile_phone/frame_0003.jpg 14
259 | dslr/images/mobile_phone/frame_0004.jpg 14
260 | dslr/images/mobile_phone/frame_0005.jpg 14
261 | dslr/images/mobile_phone/frame_0006.jpg 14
262 | dslr/images/mobile_phone/frame_0007.jpg 14
263 | dslr/images/mobile_phone/frame_0008.jpg 14
264 | dslr/images/mobile_phone/frame_0009.jpg 14
265 | dslr/images/mobile_phone/frame_0010.jpg 14
266 | dslr/images/mobile_phone/frame_0011.jpg 14
267 | dslr/images/mobile_phone/frame_0012.jpg 14
268 | dslr/images/mobile_phone/frame_0013.jpg 14
269 | dslr/images/mobile_phone/frame_0014.jpg 14
270 | dslr/images/mobile_phone/frame_0015.jpg 14
271 | dslr/images/mobile_phone/frame_0016.jpg 14
272 | dslr/images/mobile_phone/frame_0017.jpg 14
273 | dslr/images/mobile_phone/frame_0018.jpg 14
274 | dslr/images/mobile_phone/frame_0019.jpg 14
275 | dslr/images/mobile_phone/frame_0020.jpg 14
276 | dslr/images/mobile_phone/frame_0021.jpg 14
277 | dslr/images/mobile_phone/frame_0022.jpg 14
278 | dslr/images/mobile_phone/frame_0023.jpg 14
279 | dslr/images/mobile_phone/frame_0024.jpg 14
280 | dslr/images/mobile_phone/frame_0025.jpg 14
281 | dslr/images/mobile_phone/frame_0026.jpg 14
282 | dslr/images/mobile_phone/frame_0027.jpg 14
283 | dslr/images/mobile_phone/frame_0028.jpg 14
284 | dslr/images/mobile_phone/frame_0029.jpg 14
285 | dslr/images/mobile_phone/frame_0030.jpg 14
286 | dslr/images/mobile_phone/frame_0031.jpg 14
287 | dslr/images/paper_notebook/frame_0001.jpg 18
288 | dslr/images/paper_notebook/frame_0002.jpg 18
289 | dslr/images/paper_notebook/frame_0003.jpg 18
290 | dslr/images/paper_notebook/frame_0004.jpg 18
291 | dslr/images/paper_notebook/frame_0005.jpg 18
292 | dslr/images/paper_notebook/frame_0006.jpg 18
293 | dslr/images/paper_notebook/frame_0007.jpg 18
294 | dslr/images/paper_notebook/frame_0008.jpg 18
295 | dslr/images/paper_notebook/frame_0009.jpg 18
296 | dslr/images/paper_notebook/frame_0010.jpg 18
297 | dslr/images/ruler/frame_0001.jpg 25
298 | dslr/images/ruler/frame_0002.jpg 25
299 | dslr/images/ruler/frame_0003.jpg 25
300 | dslr/images/ruler/frame_0004.jpg 25
301 | dslr/images/ruler/frame_0005.jpg 25
302 | dslr/images/ruler/frame_0006.jpg 25
303 | dslr/images/ruler/frame_0007.jpg 25
304 | dslr/images/letter_tray/frame_0001.jpg 13
305 | dslr/images/letter_tray/frame_0002.jpg 13
306 | dslr/images/letter_tray/frame_0003.jpg 13
307 | dslr/images/letter_tray/frame_0004.jpg 13
308 | dslr/images/letter_tray/frame_0005.jpg 13
309 | dslr/images/letter_tray/frame_0006.jpg 13
310 | dslr/images/letter_tray/frame_0007.jpg 13
311 | dslr/images/letter_tray/frame_0008.jpg 13
312 | dslr/images/letter_tray/frame_0009.jpg 13
313 | dslr/images/letter_tray/frame_0010.jpg 13
314 | dslr/images/letter_tray/frame_0011.jpg 13
315 | dslr/images/letter_tray/frame_0012.jpg 13
316 | dslr/images/letter_tray/frame_0013.jpg 13
317 | dslr/images/letter_tray/frame_0014.jpg 13
318 | dslr/images/letter_tray/frame_0015.jpg 13
319 | dslr/images/letter_tray/frame_0016.jpg 13
320 | dslr/images/file_cabinet/frame_0001.jpg 9
321 | dslr/images/file_cabinet/frame_0002.jpg 9
322 | dslr/images/file_cabinet/frame_0003.jpg 9
323 | dslr/images/file_cabinet/frame_0004.jpg 9
324 | dslr/images/file_cabinet/frame_0005.jpg 9
325 | dslr/images/file_cabinet/frame_0006.jpg 9
326 | dslr/images/file_cabinet/frame_0007.jpg 9
327 | dslr/images/file_cabinet/frame_0008.jpg 9
328 | dslr/images/file_cabinet/frame_0009.jpg 9
329 | dslr/images/file_cabinet/frame_0010.jpg 9
330 | dslr/images/file_cabinet/frame_0011.jpg 9
331 | dslr/images/file_cabinet/frame_0012.jpg 9
332 | dslr/images/file_cabinet/frame_0013.jpg 9
333 | dslr/images/file_cabinet/frame_0014.jpg 9
334 | dslr/images/file_cabinet/frame_0015.jpg 9
335 | dslr/images/phone/frame_0001.jpg 20
336 | dslr/images/phone/frame_0002.jpg 20
337 | dslr/images/phone/frame_0003.jpg 20
338 | dslr/images/phone/frame_0004.jpg 20
339 | dslr/images/phone/frame_0005.jpg 20
340 | dslr/images/phone/frame_0006.jpg 20
341 | dslr/images/phone/frame_0007.jpg 20
342 | dslr/images/phone/frame_0008.jpg 20
343 | dslr/images/phone/frame_0009.jpg 20
344 | dslr/images/phone/frame_0010.jpg 20
345 | dslr/images/phone/frame_0011.jpg 20
346 | dslr/images/phone/frame_0012.jpg 20
347 | dslr/images/phone/frame_0013.jpg 20
348 | dslr/images/bookcase/frame_0001.jpg 3
349 | dslr/images/bookcase/frame_0002.jpg 3
350 | dslr/images/bookcase/frame_0003.jpg 3
351 | dslr/images/bookcase/frame_0004.jpg 3
352 | dslr/images/bookcase/frame_0005.jpg 3
353 | dslr/images/bookcase/frame_0006.jpg 3
354 | dslr/images/bookcase/frame_0007.jpg 3
355 | dslr/images/bookcase/frame_0008.jpg 3
356 | dslr/images/bookcase/frame_0009.jpg 3
357 | dslr/images/bookcase/frame_0010.jpg 3
358 | dslr/images/bookcase/frame_0011.jpg 3
359 | dslr/images/bookcase/frame_0012.jpg 3
360 | dslr/images/projector/frame_0001.jpg 22
361 | dslr/images/projector/frame_0002.jpg 22
362 | dslr/images/projector/frame_0003.jpg 22
363 | dslr/images/projector/frame_0004.jpg 22
364 | dslr/images/projector/frame_0005.jpg 22
365 | dslr/images/projector/frame_0006.jpg 22
366 | dslr/images/projector/frame_0007.jpg 22
367 | dslr/images/projector/frame_0008.jpg 22
368 | dslr/images/projector/frame_0009.jpg 22
369 | dslr/images/projector/frame_0010.jpg 22
370 | dslr/images/projector/frame_0011.jpg 22
371 | dslr/images/projector/frame_0012.jpg 22
372 | dslr/images/projector/frame_0013.jpg 22
373 | dslr/images/projector/frame_0014.jpg 22
374 | dslr/images/projector/frame_0015.jpg 22
375 | dslr/images/projector/frame_0016.jpg 22
376 | dslr/images/projector/frame_0017.jpg 22
377 | dslr/images/projector/frame_0018.jpg 22
378 | dslr/images/projector/frame_0019.jpg 22
379 | dslr/images/projector/frame_0020.jpg 22
380 | dslr/images/projector/frame_0021.jpg 22
381 | dslr/images/projector/frame_0022.jpg 22
382 | dslr/images/projector/frame_0023.jpg 22
383 | dslr/images/stapler/frame_0001.jpg 28
384 | dslr/images/stapler/frame_0002.jpg 28
385 | dslr/images/stapler/frame_0003.jpg 28
386 | dslr/images/stapler/frame_0004.jpg 28
387 | dslr/images/stapler/frame_0005.jpg 28
388 | dslr/images/stapler/frame_0006.jpg 28
389 | dslr/images/stapler/frame_0007.jpg 28
390 | dslr/images/stapler/frame_0008.jpg 28
391 | dslr/images/stapler/frame_0009.jpg 28
392 | dslr/images/stapler/frame_0010.jpg 28
393 | dslr/images/stapler/frame_0011.jpg 28
394 | dslr/images/stapler/frame_0012.jpg 28
395 | dslr/images/stapler/frame_0013.jpg 28
396 | dslr/images/stapler/frame_0014.jpg 28
397 | dslr/images/stapler/frame_0015.jpg 28
398 | dslr/images/stapler/frame_0016.jpg 28
399 | dslr/images/stapler/frame_0017.jpg 28
400 | dslr/images/stapler/frame_0018.jpg 28
401 | dslr/images/stapler/frame_0019.jpg 28
402 | dslr/images/stapler/frame_0020.jpg 28
403 | dslr/images/stapler/frame_0021.jpg 28
404 | dslr/images/trash_can/frame_0001.jpg 30
405 | dslr/images/trash_can/frame_0002.jpg 30
406 | dslr/images/trash_can/frame_0003.jpg 30
407 | dslr/images/trash_can/frame_0004.jpg 30
408 | dslr/images/trash_can/frame_0005.jpg 30
409 | dslr/images/trash_can/frame_0006.jpg 30
410 | dslr/images/trash_can/frame_0007.jpg 30
411 | dslr/images/trash_can/frame_0008.jpg 30
412 | dslr/images/trash_can/frame_0009.jpg 30
413 | dslr/images/trash_can/frame_0010.jpg 30
414 | dslr/images/trash_can/frame_0011.jpg 30
415 | dslr/images/trash_can/frame_0012.jpg 30
416 | dslr/images/trash_can/frame_0013.jpg 30
417 | dslr/images/trash_can/frame_0014.jpg 30
418 | dslr/images/trash_can/frame_0015.jpg 30
419 | dslr/images/bike_helmet/frame_0001.jpg 2
420 | dslr/images/bike_helmet/frame_0002.jpg 2
421 | dslr/images/bike_helmet/frame_0003.jpg 2
422 | dslr/images/bike_helmet/frame_0004.jpg 2
423 | dslr/images/bike_helmet/frame_0005.jpg 2
424 | dslr/images/bike_helmet/frame_0006.jpg 2
425 | dslr/images/bike_helmet/frame_0007.jpg 2
426 | dslr/images/bike_helmet/frame_0008.jpg 2
427 | dslr/images/bike_helmet/frame_0009.jpg 2
428 | dslr/images/bike_helmet/frame_0010.jpg 2
429 | dslr/images/bike_helmet/frame_0011.jpg 2
430 | dslr/images/bike_helmet/frame_0012.jpg 2
431 | dslr/images/bike_helmet/frame_0013.jpg 2
432 | dslr/images/bike_helmet/frame_0014.jpg 2
433 | dslr/images/bike_helmet/frame_0015.jpg 2
434 | dslr/images/bike_helmet/frame_0016.jpg 2
435 | dslr/images/bike_helmet/frame_0017.jpg 2
436 | dslr/images/bike_helmet/frame_0018.jpg 2
437 | dslr/images/bike_helmet/frame_0019.jpg 2
438 | dslr/images/bike_helmet/frame_0020.jpg 2
439 | dslr/images/bike_helmet/frame_0021.jpg 2
440 | dslr/images/bike_helmet/frame_0022.jpg 2
441 | dslr/images/bike_helmet/frame_0023.jpg 2
442 | dslr/images/bike_helmet/frame_0024.jpg 2
443 | dslr/images/headphones/frame_0001.jpg 10
444 | dslr/images/headphones/frame_0002.jpg 10
445 | dslr/images/headphones/frame_0003.jpg 10
446 | dslr/images/headphones/frame_0004.jpg 10
447 | dslr/images/headphones/frame_0005.jpg 10
448 | dslr/images/headphones/frame_0006.jpg 10
449 | dslr/images/headphones/frame_0007.jpg 10
450 | dslr/images/headphones/frame_0008.jpg 10
451 | dslr/images/headphones/frame_0009.jpg 10
452 | dslr/images/headphones/frame_0010.jpg 10
453 | dslr/images/headphones/frame_0011.jpg 10
454 | dslr/images/headphones/frame_0012.jpg 10
455 | dslr/images/headphones/frame_0013.jpg 10
456 | dslr/images/desk_lamp/frame_0001.jpg 7
457 | dslr/images/desk_lamp/frame_0002.jpg 7
458 | dslr/images/desk_lamp/frame_0003.jpg 7
459 | dslr/images/desk_lamp/frame_0004.jpg 7
460 | dslr/images/desk_lamp/frame_0005.jpg 7
461 | dslr/images/desk_lamp/frame_0006.jpg 7
462 | dslr/images/desk_lamp/frame_0007.jpg 7
463 | dslr/images/desk_lamp/frame_0008.jpg 7
464 | dslr/images/desk_lamp/frame_0009.jpg 7
465 | dslr/images/desk_lamp/frame_0010.jpg 7
466 | dslr/images/desk_lamp/frame_0011.jpg 7
467 | dslr/images/desk_lamp/frame_0012.jpg 7
468 | dslr/images/desk_lamp/frame_0013.jpg 7
469 | dslr/images/desk_lamp/frame_0014.jpg 7
470 | dslr/images/desk_chair/frame_0001.jpg 6
471 | dslr/images/desk_chair/frame_0002.jpg 6
472 | dslr/images/desk_chair/frame_0003.jpg 6
473 | dslr/images/desk_chair/frame_0004.jpg 6
474 | dslr/images/desk_chair/frame_0005.jpg 6
475 | dslr/images/desk_chair/frame_0006.jpg 6
476 | dslr/images/desk_chair/frame_0007.jpg 6
477 | dslr/images/desk_chair/frame_0008.jpg 6
478 | dslr/images/desk_chair/frame_0009.jpg 6
479 | dslr/images/desk_chair/frame_0010.jpg 6
480 | dslr/images/desk_chair/frame_0011.jpg 6
481 | dslr/images/desk_chair/frame_0012.jpg 6
482 | dslr/images/desk_chair/frame_0013.jpg 6
483 | dslr/images/bottle/frame_0001.jpg 4
484 | dslr/images/bottle/frame_0002.jpg 4
485 | dslr/images/bottle/frame_0003.jpg 4
486 | dslr/images/bottle/frame_0004.jpg 4
487 | dslr/images/bottle/frame_0005.jpg 4
488 | dslr/images/bottle/frame_0006.jpg 4
489 | dslr/images/bottle/frame_0007.jpg 4
490 | dslr/images/bottle/frame_0008.jpg 4
491 | dslr/images/bottle/frame_0009.jpg 4
492 | dslr/images/bottle/frame_0010.jpg 4
493 | dslr/images/bottle/frame_0011.jpg 4
494 | dslr/images/bottle/frame_0012.jpg 4
495 | dslr/images/bottle/frame_0013.jpg 4
496 | dslr/images/bottle/frame_0014.jpg 4
497 | dslr/images/bottle/frame_0015.jpg 4
498 | dslr/images/bottle/frame_0016.jpg 4
499 |
--------------------------------------------------------------------------------
/data/setup_data_path.sh:
--------------------------------------------------------------------------------
1 | # sh setup_data_path.sh data_path dataset
2 | data_path=$1
3 | dataset=$2
4 |
5 | if [[ ${dataset} == "domainnet40" ]] ;
6 | then
7 | cd domainnet40
8 | rm clipart
9 | ln -s "${data_path}/clipart" clipart
10 | rm infograph
11 | ln -s "${data_path}/infograph" infograph
12 | rm painting
13 | ln -s "${data_path}/painting" painting
14 | rm quickdraw
15 | ln -s "${data_path}/quickdraw" quickdraw
16 | rm real
17 | ln -s "${data_path}/real" real
18 | rm sketch
19 | ln -s "${data_path}/sketch" sketch
20 | cd ..
21 | elif [[ ${dataset} == "office31" ]] ;
22 | then
23 | cd office31
24 | rm amazon
25 | ln -s "${data_path}/amazon" amazon
26 | rm webcam
27 | ln -s "${data_path}/webcam" webcam
28 | rm dslr
29 | ln -s "${data_path}/dslr" dslr
30 | elif [[ ${dataset} == "office-home" ]] ;
31 | then
32 | cd office-home
33 | rm Art
34 | ln -s "${data_path}/Art" Art
35 | rm Clipart
36 | ln -s "${data_path}/Clipart" Clipart
37 | rm Product
38 | ln -s "${data_path}/Product" Product
39 | rm Real_World
40 | ln -s "${data_path}/Real_World" Real_World
41 | elif [[ ${dataset} == "office-home-rsut" ]] ;
42 | then
43 | cd office-home-rsut
44 | rm Art
45 | ln -s "${data_path}/Art" Art
46 | rm Clipart
47 | ln -s "${data_path}/Clipart" Clipart
48 | rm Product
49 | ln -s "${data_path}/Product" Product
50 | rm Real_World
51 | ln -s "${data_path}/Real_World" Real_World
52 | elif [[ ${dataset} == "visda" ]] ;
53 | then
54 | cd visda-2017
55 | rm train
56 | ln -s "${data_path}/train" train
57 | rm validation
58 | ln -s "${data_path}/validation" validation
59 | fi
60 | cd ..
--------------------------------------------------------------------------------
/fig/PK.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/fig/PK.png
--------------------------------------------------------------------------------
/fig/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/fig/framework.png
--------------------------------------------------------------------------------
/pklib/pksolver.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gurobipy as grb
3 | import random
4 |
5 | class PK_solver():
6 | def __init__(self, N, C, C_ub=[], C_br=[], pk_prior_weight=10.):
7 | self.N = N # number of samples
8 | self.C = C # number of classes
9 | self.C_ub = C_ub # constraints of unary bound
10 | self.C_br = C_br # constraints of binary relationship
11 | self.pk_prior_weight = pk_prior_weight
12 |
13 | # create unary bound constraints
14 | def create_C_ub(self, cls_probs, uconf=0.):
15 | ubs = cls_probs * (1 + uconf)
16 | lbs = cls_probs * (1 - uconf)
17 | ubs[ubs > 1.0] = 1.0
18 | lbs[lbs < 0.0] = 0.0
19 | ubs = (ubs*self.N).tolist()
20 | lbs = (lbs*self.N).tolist()
21 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs))
22 |
23 | # create unary bound constraints with noises
24 | def create_C_ub_noisy(self, cls_probs, uconf=0., noise=0.):
25 | ubs = cls_probs * (1 + uconf)
26 | lbs = cls_probs * (1 - uconf)
27 | bias = (2*np.random.rand(len(cls_probs))-1)*cls_probs*noise
28 | bias -= bias.mean()
29 | ubs += bias
30 | lbs += bias
31 | ubs[ubs > 1.0] = 1.0
32 | lbs[lbs < 0.0] = 0.0
33 | ubs = (ubs*self.N).tolist()
34 | lbs = (lbs*self.N).tolist()
35 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs))
36 |
37 | # create binary relationship constraints
38 | def create_C_br(self, cls_probs, uconf=0.):
39 | idx = np.argsort(-cls_probs)
40 | self.C_br = [(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)]
41 |
42 | # create binary relationship constraints with noises
43 | def create_C_br_noisy(self, cls_probs, uconf=0., noise=0.):
44 | idx = np.argsort(-cls_probs)
45 | C=len(idx)
46 | score = np.arange(C) + (2*np.random.rand(C)-1)*noise + np.random.rand(C)*0.0001
47 | idd = np.argsort(score)
48 | idx = idx[idd]
49 | self.C_br = [(idx[c], idx[c+1], 0) for c in range(self.C-1)]
50 |
51 | # create unary bound constraints from (head) partial classes
52 | def create_C_ub_partial(self, cls_probs, uconf=0., N=10):
53 | ubs = cls_probs * (1 + uconf)
54 | lbs = cls_probs * (1 - uconf)
55 | ubs[ubs > 1.0] = 1.0
56 | lbs[lbs < 0.0] = 0.0
57 | ubs = (ubs*self.N).tolist()
58 | lbs = (lbs*self.N).tolist()
59 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs))[:N]
60 |
61 | # create unary bound constraints from (tail) partial classes
62 | def create_C_ub_partial_reverse(self, cls_probs, uconf=0., N=10):
63 | ubs = cls_probs * (1 + uconf)
64 | lbs = cls_probs * (1 - uconf)
65 | ubs[ubs > 1.0] = 1.0
66 | lbs[lbs < 0.0] = 0.0
67 | ubs = (ubs*self.N).tolist()
68 | lbs = (lbs*self.N).tolist()
69 | self.C_ub = list(zip(list(range(self.C)), lbs, ubs))[-N:]
70 |
71 | # create unary bound constraints from (random) partial classes
72 | def create_C_ub_partial_rand(self, cls_probs, uconf=0., N=10):
73 | ubs = cls_probs * (1 + uconf)
74 | lbs = cls_probs * (1 - uconf)
75 | ubs[ubs > 1.0] = 1.0
76 | lbs[lbs < 0.0] = 0.0
77 | ubs = (ubs*self.N).tolist()
78 | lbs = (lbs*self.N).tolist()
79 | self.C_ub = random.sample(list(zip(list(range(self.C)), lbs, ubs)), k=N)
80 |
81 | # create binary relationship constraints from (head) partial classes
82 | def create_C_br_partial(self, cls_probs, uconf=0., N=10):
83 | idx = np.argsort(-cls_probs)
84 | self.C_br = [(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)][:N]
85 |
86 | # create binary relationship constraints from (tail) partial classes
87 | def create_C_br_partial_reverse(self, cls_probs, uconf=0., N=10):
88 | idx = np.argsort(-cls_probs)
89 | self.C_br = [(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)][-N:]
90 |
91 | # create binary relationship constraints from (random) partial classes
92 | def create_C_br_partial_rand(self, cls_probs, uconf=0., N=10):
93 | idx = np.argsort(-cls_probs)
94 | self.C_br = random.sample([(idx[c], idx[c+1], (1-uconf)*np.round((cls_probs[idx[c]]-cls_probs[idx[c+1]])*self.N)) for c in range(self.C-1)], k=N)
95 |
96 |
97 | # solver with smooth regularization
98 | def solve_soft_knn_cst(self, probs, fix_set=[], fix_labels=[], knn_regs=[]):
99 | # fix_set and fix_labels are samples with given (pseudo) labels that do not require optimization
100 | fix_cls_probs = np.eye(self.C)[fix_labels].sum(0)
101 |
102 | # var_set are samples to refine (pseudo) labels
103 | var_set = list(set(range(self.N)) - set(fix_set))
104 | Nvar = len(var_set)
105 |
106 | # create an optimization model
107 | LP = grb.Model(name="Prior Constraint Problem")
108 | x = {(n, c): LP.addVar(vtype=grb.GRB.BINARY,
109 | name="x_{0}_{1}".format(n, c))
110 | for n in range(Nvar) for c in range(self.C)}
111 |
112 | LP.addConstrs( (grb.quicksum(x[n, c] for c in range(self.C))==1) for n in range(len(var_set)))
113 |
114 | objective = grb.quicksum(x[n, c] * probs[var_set[n], c]
115 | for n in range(Nvar)
116 | for c in range(self.C))
117 |
118 | # add soft constraints of unary bound
119 | xi_ub = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar,
120 | name="xi_ub_{0}_{1}".format(c,k))
121 | for c in range(len(self.C_ub)) for k in range(2)}
122 |
123 | xi_lb = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar,
124 | name="xi_lb_{0}_{1}".format(c,k))
125 | for c in range(len(self.C_ub)) for k in range(2)}
126 |
127 | margin_ub = []
128 | margin_lb = []
129 | for i, (c, lb, ub) in enumerate(self.C_ub):
130 | if ub is not None:
131 | ub = ub - fix_cls_probs[c]
132 | margin_ub.append(grb.quicksum(x[n, c] for n in range(Nvar))-ub)
133 | else:
134 | margin_ub.append(0.)
135 |
136 | if lb is not None:
137 | lb = lb - fix_cls_probs[c]
138 | margin_lb.append( - grb.quicksum(x[n, c] for n in range(Nvar)) + lb)
139 | else:
140 | margin_lb.append(0.)
141 |
142 |
143 | LP.addConstrs(
144 | (xi_ub[i, 1] == margin_ub[i] for i in range(len(self.C_ub))), name="slack_ub_0"
145 | )
146 | LP.addConstrs(
147 | (xi_ub[i, 0] == grb.max_(xi_ub[i, 1], 0) for i in range(len(self.C_ub))), name="slack_ub_1"
148 | )
149 |
150 | LP.addConstrs(
151 | (xi_lb[i, 1] == margin_lb[i] for i in range(len(self.C_ub))), name="slack_lb_0"
152 | )
153 | LP.addConstrs(
154 | (xi_lb[i, 0] == grb.max_(xi_lb[i, 1], 0) for i in range(len(self.C_ub))), name="slack_lb_1"
155 | )
156 |
157 | constraint_ub = grb.quicksum(xi_ub[c, 0] for c in range(len(self.C_ub))) + \
158 | grb.quicksum(xi_lb[c, 0] for c in range(len(self.C_ub)))
159 |
160 | constraint_ub /= (len(self.C_ub) * 2 + 1e-10)
161 |
162 | # add soft constraints of binary relationship
163 | margin_br = []
164 | for (c1, c2, diff) in self.C_br:
165 | diff = diff - fix_cls_probs[c1] + fix_cls_probs[c2]
166 | margin_br.append(
167 | -grb.quicksum(x[n, c1] for n in range(Nvar)) + grb.quicksum(x[n, c2] for n in range(Nvar)) + diff)
168 |
169 | xi_br = {(c, k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-2 * Nvar, ub=2 * Nvar,
170 | name="xi_br_{0}_{1}".format(c, k))
171 | for c in range(len(self.C_br)) for k in range(2)}
172 |
173 | LP.addConstrs(
174 | (xi_br[i, 1] == margin_br[i] for i in range(len(self.C_br))), name="slack_br_0"
175 | )
176 | LP.addConstrs(
177 | (xi_br[i, 0] == grb.max_(xi_br[i, 1], 0) for i in range(len(self.C_br))), name="slack_br_1"
178 | )
179 |
180 | constraint_br = grb.quicksum(xi_br[c, 0] for c in range(len(self.C_br)))
181 | constraint_br /= (len(self.C_br) + 1e-10)
182 |
183 | constraint = constraint_ub + constraint_br
184 |
185 | # add smooth regularization
186 | # currently it does NOT support fixset
187 | if len(knn_regs) > 0:
188 | LP.addConstrs(
189 | (x[knn_regs[i][0], c] == x[knn_regs[i][1][k], c]
190 | for i in range(len(knn_regs))
191 | for k in range(len(knn_regs[i][1]))
192 | for c in range(self.C) ), name="smooth_regularization"
193 | )
194 |
195 | LP.ModelSense = grb.GRB.MAXIMIZE
196 | LP.setObjective(objective - self.pk_prior_weight*constraint*Nvar)
197 |
198 | LP.optimize()
199 |
200 | # get refined (pseudo) labels from optimal solution
201 | var_labels = []
202 | for n in range(Nvar):
203 | for c in range(self.C):
204 | var_labels.append(x[n, c].X)
205 |
206 | var_labels = np.array(var_labels)
207 | var_labels = var_labels.reshape([Nvar, self.C])
208 | var_labels = np.argmax(var_labels, axis=-1)
209 |
210 | labels = np.zeros(self.N).astype(np.int32)
211 | labels[fix_set] = fix_labels
212 | labels[var_set] = var_labels
213 |
214 | return labels
215 |
216 | # solver without smooth regularization
217 | def solve_soft(self, probs, fix_set=[], fix_labels=[]):
218 | fix_cls_probs = np.eye(self.C)[fix_labels].sum(0)
219 |
220 | var_set = list(set(range(self.N)) - set(fix_set))
221 | Nvar = len(var_set)
222 |
223 | LP = grb.Model(name="Prior Constraint Problem")
224 | x = {(n, c): LP.addVar(vtype=grb.GRB.BINARY,
225 | name="x_{0}_{1}".format(n, c))
226 | for n in range(Nvar) for c in range(self.C)}
227 |
228 | LP.addConstrs( (grb.quicksum(x[n, c] for c in range(self.C))==1) for n in range(len(var_set)))
229 |
230 | objective = grb.quicksum(x[n, c] * probs[var_set[n], c]
231 | for n in range(Nvar)
232 | for c in range(self.C))
233 |
234 | # add soft constraints of unary bound
235 | xi_ub = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar,
236 | name="xi_ub_{0}_{1}".format(c,k))
237 | for c in range(len(self.C_ub)) for k in range(2)}
238 |
239 | xi_lb = {(c,k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-Nvar, ub=Nvar,
240 | name="xi_lb_{0}_{1}".format(c,k))
241 | for c in range(len(self.C_ub)) for k in range(2)}
242 |
243 | margin_ub = []
244 | margin_lb = []
245 | for i, (c, lb, ub) in enumerate(self.C_ub):
246 | if ub is not None:
247 | ub = ub - fix_cls_probs[c]
248 | margin_ub.append(grb.quicksum(x[n, c] for n in range(Nvar))-ub)
249 | else:
250 | margin_ub.append(0.)
251 |
252 | if lb is not None:
253 | lb = lb - fix_cls_probs[c]
254 | margin_lb.append( - grb.quicksum(x[n, c] for n in range(Nvar)) + lb)
255 | else:
256 | margin_lb.append(0.)
257 |
258 |
259 | LP.addConstrs(
260 | (xi_ub[i, 1] == margin_ub[i] for i in range(len(self.C_ub))), name="slack_ub_0"
261 | )
262 | LP.addConstrs(
263 | (xi_ub[i, 0] == grb.max_(xi_ub[i, 1], 0) for i in range(len(self.C_ub))), name="slack_ub_1"
264 | )
265 |
266 | LP.addConstrs(
267 | (xi_lb[i, 1] == margin_lb[i] for i in range(len(self.C_ub))), name="slack_lb_0"
268 | )
269 | LP.addConstrs(
270 | (xi_lb[i, 0] == grb.max_(xi_lb[i, 1], 0) for i in range(len(self.C_ub))), name="slack_lb_1"
271 | )
272 |
273 | constraint_ub = grb.quicksum(xi_ub[c,0] for c in range(len(self.C_ub))) + \
274 | grb.quicksum(xi_lb[c,0] for c in range(len(self.C_ub)))
275 |
276 | constraint_ub /= (len(self.C_ub)*2 + 1e-10)
277 |
278 | # add soft constraints of binary relationship
279 | margin_br = []
280 | for (c1, c2, diff) in self.C_br:
281 | diff = diff - fix_cls_probs[c1] + fix_cls_probs[c2]
282 | margin_br.append(-grb.quicksum(x[n, c1] for n in range(Nvar)) + grb.quicksum(x[n, c2] for n in range(Nvar)) + diff)
283 |
284 | xi_br = {(c, k): LP.addVar(vtype=grb.GRB.CONTINUOUS, lb=-2*Nvar, ub=2*Nvar,
285 | name="xi_br_{0}_{1}".format(c, k))
286 | for c in range(len(self.C_br)) for k in range(2)}
287 |
288 | LP.addConstrs(
289 | (xi_br[i, 1] == margin_br[i] for i in range(len(self.C_br))), name="slack_br_0"
290 | )
291 | LP.addConstrs(
292 | (xi_br[i, 0] == grb.max_(xi_br[i, 1], 0) for i in range(len(self.C_br))), name="slack_br_1"
293 | )
294 |
295 | constraint_br = grb.quicksum(xi_br[c, 0] for c in range(len(self.C_br)))
296 | constraint_br /= (len(self.C_br) + 1e-10)
297 |
298 | constraint = constraint_ub + constraint_br
299 |
300 |
301 | LP.ModelSense = grb.GRB.MAXIMIZE
302 | LP.setObjective(objective - self.pk_prior_weight*constraint*Nvar)
303 |
304 | LP.optimize()
305 |
306 | # get refined (pseudo) labels from optimal solution
307 | var_labels = []
308 | for n in range(Nvar):
309 | for c in range(self.C):
310 | var_labels.append(x[n, c].X)
311 |
312 | var_labels = np.array(var_labels)
313 | var_labels = var_labels.reshape([Nvar, self.C])
314 | var_labels = np.argmax(var_labels, axis=-1)
315 |
316 | labels = np.zeros(self.N).astype(np.int32)
317 | labels[fix_set] = fix_labels
318 | labels[var_set] = var_labels
319 |
320 | return labels
321 |
322 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsun/KUDA/beeb839456daf5fb5d263783c79bb6cff38e2375/util/__init__.py
--------------------------------------------------------------------------------
/util/get_time.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | def get_time():
4 | return time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
5 |
6 | if __name__ == '__main__':
7 | print(get_time())
8 |
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import socket
3 | import os
4 | import numpy as np
5 | import torch
6 | import random
7 | import logging
8 | pil_logger = logging.getLogger('PIL')
9 | pil_logger.setLevel(logging.INFO)
10 |
11 | def get_time():
12 | return time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
13 |
14 | def get_hostname():
15 | return socket.gethostname()
16 |
17 | def get_pid():
18 | return os.getpid()
19 |
20 | # set random number generators' seeds
21 | def resetRNGseed(seed):
22 | np.random.seed(seed)
23 | random.seed(seed)
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed)
27 |
28 | import logging
29 | logger_init = False
30 |
31 | def init_logger(_log_file, use_file_logger=True, dir='log/'):
32 | if not os.path.exists(dir):
33 | os.makedirs(dir)
34 | log_file = os.path.join(dir, _log_file + '.log')
35 | #logging.basicConfig(filename=log_file, format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S',level=logging.DEBUG)
36 | logger = logging.getLogger()
37 | for handler in logger.handlers[:]:
38 | logger.removeHandler(handler)
39 |
40 | logger.setLevel('DEBUG')
41 | BASIC_FORMAT = "%(asctime)s:%(levelname)s:%(message)s"
42 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
43 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
44 | chlr = logging.StreamHandler()
45 | chlr.setFormatter(formatter)
46 | logger.addHandler(chlr)
47 | if use_file_logger:
48 | fhlr = logging.FileHandler(log_file)
49 | fhlr.setFormatter(formatter)
50 | logger.addHandler(fhlr)
51 |
52 | global logger_init
53 | logger_init = True
--------------------------------------------------------------------------------