├── README.md ├── loaders ├── __init__.py └── data_list.py ├── main.py ├── metric ├── __init__.py └── mmd.py ├── model ├── __init__.py ├── basenet.py └── resnet.py ├── perturb.py ├── requirements.txt ├── test.py └── utils ├── __init__.py ├── lr_schedule.py ├── return_dataset.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [Attract, Perturb, and Explore: Learning a Feature Alignment Network for Semi-supervised Domain Adaptation (ECCV 2020)](https://arxiv.org/pdf/2007.09375.pdf) 2 | 3 | ### Acknowledgment 4 | 5 | The implementation is built on the pytorch implementation of [SSDA_MME](https://github.com/VisionLearningGroup/SSDA_MME) and we refer a specific module in [DTA](https://github.com/postBG/DTA.pytorch). 6 | 7 | ### Prerequisites 8 | 9 | * CUDA 10.0 or 10.1 10 | * Python 3.7 (or 3.6) 11 | * Pytorch 1.0.1 12 | ``` 13 | conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=10.0 -c pytorch 14 | ``` 15 | * Pillow, numpy, tqdm 16 | * You can easily install dependencies through 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ### Dataset Structure 22 | You can download the datasets by following the instructions in [SSDA_MME](https://github.com/VisionLearningGroup/SSDA_MME). 23 | ``` 24 | data--- 25 | | 26 | multi--- 27 | | | 28 | | Real 29 | | Clipart 30 | | Product 31 | | Real 32 | office_home--- 33 | | | 34 | | Art 35 | | Clipart 36 | | Product 37 | | Real 38 | office--- 39 | | | 40 | | amazon 41 | | dslr 42 | | webcam 43 | txt--- 44 | | 45 | multi--- 46 | | | 47 | | labeled_source_images_real.txt 48 | | unlabeled_target_images_real_3.txt 49 | | labeled_target_images_real_3.txt 50 | | unlabeled_source_images_sketch.txt 51 | | ... 52 | office--- 53 | | | 54 | | labeled_source_images_amazon.txt 55 | | unlabeled_target_images_amazon_3.txt 56 | | labeled_target_images_amazon_3.txt 57 | | unlabeled_source_images_webcam.txt 58 | | ... 59 | office_home--- 60 | | 61 | ... 62 | ``` 63 | 64 | ### Example 65 | #### Train 66 | * DomainNet (clipart, painting, real, sketch) 67 | ``` 68 | python main.py --dataset multi --source real --target sketch --save_interval 5000 --steps 70000 --net resnet34 --num 3 --save_check 69 | ``` 70 | * Office-home (Art, Clipart, Product, Real) 71 | * Office (amazon, dslr, webcam) 72 | 73 | ### Test 74 | * DomainNet (clipart, painting, real, sketch) 75 | ``` 76 | python test.py --dataset multi --source real --target sketch --steps 70000 77 | ``` 78 | ### Checkpoint samples 79 | * (DomainNet) Real to Sketch [BaseNet](https://drive.google.com/file/d/1mwG1ClXzsyC3Pvq7WnlJfvtVwZdlQLxy/view?usp=sharing) / 80 | [Classifier](https://drive.google.com/file/d/1cO8YEaFWykRw7Pzw-xJcWx3ERioUBp_L/view?usp=sharing) 81 | ### Additional Splits 82 | We provide 5, 10, 20-shot splits for four domains (clipart, painting, real, sketch) of the DomainNet dataset. 83 | * (DomainNet) [splits](https://drive.google.com/file/d/1PhNe8-CmKJq3zCdl0MM8a4tcEbwMbjb0/view?usp=sharing) 84 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TKKim93/APE/28b0cb8353b06b818b9db00209c403876e84dfd2/loaders/__init__.py -------------------------------------------------------------------------------- /loaders/data_list.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import os.path 4 | from PIL import Image 5 | 6 | 7 | def pil_loader(path): 8 | with open(path, 'rb') as f: 9 | img = Image.open(f) 10 | return img.convert('RGB') 11 | 12 | 13 | def make_dataset_fromlist(image_list): 14 | with open(image_list) as f: 15 | image_index = [x.split(' ')[0] for x in f.readlines()] 16 | with open(image_list) as f: 17 | label_list = [] 18 | selected_list = [] 19 | for ind, x in enumerate(f.readlines()): 20 | label = x.split(' ')[1].strip() 21 | label_list.append(int(label)) 22 | selected_list.append(ind) 23 | image_index = np.array(image_index) 24 | label_list = np.array(label_list) 25 | image_index = image_index[selected_list] 26 | return image_index, label_list 27 | 28 | 29 | def return_classlist(image_list): 30 | with open(image_list) as f: 31 | label_list = [] 32 | for ind, x in enumerate(f.readlines()): 33 | label = x.split(' ')[0].split('/')[-2] 34 | if label not in label_list: 35 | label_list.append(str(label)) 36 | return label_list 37 | 38 | 39 | class Imagelists_VISDA(object): 40 | def __init__(self, image_list, root="./data/multi/", 41 | transform=None, target_transform=None, test=False): 42 | imgs, labels = make_dataset_fromlist(image_list) 43 | self.imgs = imgs 44 | self.labels = labels 45 | self.transform = transform 46 | self.target_transform = target_transform 47 | self.loader = pil_loader 48 | self.root = root 49 | self.test = test 50 | 51 | def __getitem__(self, index): 52 | """ 53 | Args: 54 | index (int): Index 55 | Returns: 56 | tuple: (image, target) where target is 57 | class_index of the target class. 58 | """ 59 | path = os.path.join(self.root, self.imgs[index]) 60 | target = self.labels[index] 61 | img = self.loader(path) 62 | if self.transform is not None: 63 | img = self.transform(img) 64 | if self.target_transform is not None: 65 | target = self.target_transform(target) 66 | if not self.test: 67 | return img, target 68 | else: 69 | return img, target, self.imgs[index] 70 | 71 | def __len__(self): 72 | return len(self.imgs) 73 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | from model.resnet import resnet34 10 | from model.basenet import AlexNetBase, Predictor_latent, Predictor_deep_latent, grad_reverse 11 | from utils.lr_schedule import inv_lr_scheduler 12 | from utils.return_dataset import return_dataset 13 | import torch.nn.functional as F 14 | import metric.mmd as mmd 15 | from perturb import PerturbationGenerator 16 | from utils.utils import weights_init, group_step 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='SSDA Classification') 20 | parser.add_argument('--steps', type=int, default=50000, metavar='N', 21 | help='maximum number of iterations ' 22 | 'to train (default: 50000)') 23 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 24 | help='learning rate (default: 0.001)') 25 | parser.add_argument('--multi', type=float, default=0.1, metavar='MLT', 26 | help='learning rate multiplication') 27 | parser.add_argument('--T', type=float, default=0.05, metavar='T', 28 | help='temperature (default: 0.05)') 29 | parser.add_argument('--save_check', action='store_true', default=False, 30 | help='save checkpoint or not') 31 | parser.add_argument('--checkpath', type=str, default='./checkpoints', 32 | help='dir to save checkpoint') 33 | parser.add_argument('--save_interval', type=int, default=5000, metavar='N', 34 | help='how many batches to wait before saving a model') 35 | parser.add_argument('--net', type=str, default='alexnet', 36 | help='which network to use') 37 | parser.add_argument('--source', type=str, default='real', 38 | help='source domain') 39 | parser.add_argument('--target', type=str, default='sketch', 40 | help='target domain') 41 | parser.add_argument('--dataset', type=str, default='multi', 42 | choices=['multi', 'office', 'office_home'], 43 | help='the name of dataset') 44 | parser.add_argument('--num', type=int, default=3, 45 | help='number of labeled examples in the target') 46 | parser.add_argument('--thr', type=float, default=0.5, 47 | help='threshold for exploration scheme') 48 | 49 | args = parser.parse_args() 50 | source_loader, target_loader, target_loader_unl, target_loader_val, \ 51 | target_loader_test, class_list = return_dataset(args) 52 | use_gpu = torch.cuda.is_available() 53 | 54 | 55 | if args.net == 'resnet34': 56 | G = resnet34() 57 | inc = 512 58 | elif args.net == "alexnet": 59 | G = AlexNetBase() 60 | inc = 4096 61 | else: 62 | raise ValueError('Model cannot be recognized.') 63 | 64 | params = [] 65 | for key, value in dict(G.named_parameters()).items(): 66 | if value.requires_grad: 67 | if 'classifier' not in key: 68 | params += [{'params': [value], 'lr': args.multi, 69 | 'weight_decay': 0.0005}] 70 | else: 71 | params += [{'params': [value], 'lr': args.multi * 10, 72 | 'weight_decay': 0.0005}] 73 | 74 | if "resnet" in args.net: 75 | F1 = Predictor_deep_latent(num_class=len(class_list), inc=inc) 76 | else: 77 | F1 = Predictor_latent(num_class=len(class_list), inc=inc, temp=args.T) 78 | weights_init(F1) 79 | lr = args.lr 80 | G = torch.nn.DataParallel(G).cuda() 81 | F1 = torch.nn.DataParallel(F1).cuda() 82 | 83 | 84 | if os.path.exists(args.checkpath) == False: 85 | os.mkdir(args.checkpath) 86 | 87 | 88 | def train(): 89 | G.train() 90 | F1.train() 91 | optimizer_g = optim.SGD(params, momentum=0.9, 92 | weight_decay=0.0005, nesterov=True) 93 | optimizer_f = optim.SGD(list(F1.parameters()), 94 | lr=1.0, momentum=0.9, weight_decay=0.0005, nesterov=True) 95 | param_lr_g = [] 96 | for param_group in optimizer_g.param_groups: 97 | param_lr_g.append(param_group["lr"]) 98 | param_lr_f = [] 99 | for param_group in optimizer_f.param_groups: 100 | param_lr_f.append(param_group["lr"]) 101 | 102 | 103 | ################################################################################################################ 104 | ################################################# train model ################################################## 105 | ################################################################################################################ 106 | 107 | def zero_grad_all(): 108 | optimizer_g.zero_grad() 109 | optimizer_f.zero_grad() 110 | 111 | class AbstractConsistencyLoss(nn.Module): 112 | def __init__(self, reduction='mean'): 113 | super().__init__() 114 | self.reduction = reduction 115 | 116 | def forward(self, logits1, logits2): 117 | raise NotImplementedError 118 | 119 | class KLDivLossWithLogits(AbstractConsistencyLoss): 120 | def __init__(self, reduction='mean'): 121 | super().__init__(reduction) 122 | self.kl_div_loss = nn.KLDivLoss(reduction=reduction) 123 | 124 | def forward(self, logits1, logits2): 125 | return self.kl_div_loss(F.log_softmax(logits1, dim=1), F.softmax(logits2, dim=1)) 126 | 127 | class EntropyLoss(nn.Module): 128 | def __init__(self, reduction='mean'): 129 | super().__init__() 130 | self.reduction = reduction 131 | 132 | def forward(self, logits): 133 | p = F.softmax(logits, dim=1) 134 | elementwise_entropy = -p * F.log_softmax(logits, dim=1) 135 | if self.reduction == 'none': 136 | return elementwise_entropy 137 | 138 | sum_entropy = torch.sum(elementwise_entropy, dim=1) 139 | if self.reduction == 'sum': 140 | return sum_entropy 141 | 142 | return torch.mean(sum_entropy) 143 | 144 | 145 | P = PerturbationGenerator(G, F1, xi=1, eps=25, ip=1) 146 | criterion = nn.CrossEntropyLoss().cuda() 147 | criterion_reduce = nn.CrossEntropyLoss(reduce=False).cuda() 148 | target_consistency_criterion = KLDivLossWithLogits(reduction='mean').cuda() 149 | criterion_entropy = EntropyLoss() 150 | 151 | all_step = args.steps 152 | data_iter_s = iter(source_loader) 153 | data_iter_t = iter(target_loader) 154 | data_iter_t_unl = iter(target_loader_unl) 155 | len_train_source = len(source_loader) 156 | len_train_target = len(target_loader) 157 | len_train_target_semi = len(target_loader_unl) 158 | best_acc = 0 159 | counter = 0 160 | if args.net == 'resnet34': 161 | thr = 0.5 162 | else: 163 | thr = 0.3 164 | 165 | for step in range(all_step): 166 | optimizer_g = inv_lr_scheduler(param_lr_g, optimizer_g, step, 167 | init_lr=args.lr) 168 | optimizer_f = inv_lr_scheduler(param_lr_f, optimizer_f, step, 169 | init_lr=args.lr) 170 | lr = optimizer_f.param_groups[0]['lr'] 171 | 172 | if step % len_train_target == 0: 173 | data_iter_t = iter(target_loader) 174 | if step % len_train_target_semi == 0: 175 | data_iter_t_unl = iter(target_loader_unl) 176 | if step % len_train_source == 0: 177 | data_iter_s = iter(source_loader) 178 | 179 | 180 | data_t = next(data_iter_t) 181 | data_t_unl = next(data_iter_t_unl) 182 | data_s = next(data_iter_s) 183 | 184 | im_data_s = Variable(data_s[0].cuda()) 185 | gt_labels_s = Variable(data_s[1].cuda()) 186 | im_data_t = Variable(data_t[0].cuda()) 187 | gt_labels_t = Variable(data_t[1].cuda()) 188 | im_data_tu = Variable(data_t_unl[0].cuda()) 189 | gt_labels_tu = Variable(data_t_unl[1].cuda()) 190 | gt_labels = torch.cat((gt_labels_s, gt_labels_t), 0) 191 | gt_dom_s = Variable(torch.zeros(im_data_s.size(0)).cuda().long()) 192 | gt_dom_t = Variable(torch.ones(im_data_t.size(0)).cuda().long()) 193 | gt_dom = torch.cat((gt_dom_s, gt_dom_t)) 194 | zero_grad_all() 195 | 196 | ################################################################################################################ 197 | ################################################# train model ################################################## 198 | ################################################################################################################ 199 | data = torch.cat((im_data_s, im_data_t), 0) 200 | target = torch.cat((gt_labels_s, gt_labels_t), 0) 201 | sigma = [1, 2, 5, 10] 202 | 203 | output = G(data) 204 | output_tu = G(im_data_tu) 205 | latent_F1, out1 = F1(output) 206 | latent_F1_tu, out_F1_tu = F1(output_tu) 207 | 208 | # supervision loss 209 | loss = criterion(out1, target) 210 | 211 | # attraction scheme 212 | loss_msda = 10 * mmd.mix_rbf_mmd2(latent_F1, latent_F1_tu, sigma) 213 | 214 | # exploration scheme 215 | pred = out_F1_tu.data.max(1)[1].detach() 216 | ent = - torch.sum(F.softmax(out_F1_tu, 1) * (torch.log(F.softmax(out_F1_tu, 1) + 1e-5)), 1) 217 | mask_reliable = (ent < thr).float().detach() 218 | loss_cls_F1 = (mask_reliable * criterion_reduce(out_F1_tu, pred)).sum(0) / (1e-5 + mask_reliable.sum()) 219 | 220 | (loss + loss_cls_F1 + loss_msda).backward(retain_graph=False) 221 | group_step([optimizer_g, optimizer_f]) 222 | zero_grad_all() 223 | if step % 20 == 0: 224 | print('step %d' % step, 'loss_cls: {:.4f}'.format(loss.cpu().data), ' | ', 'loss_Attract: {:.4f}'.format(loss_msda.cpu().data), ' | ', \ 225 | 'loss_Explore: {:.4f}'.format(loss_cls_F1.cpu().item()), end=' | ') 226 | 227 | # perturbation scheme 228 | bs = gt_labels_s.size(0) 229 | target_data = torch.cat((im_data_t, im_data_tu), 0) 230 | perturb, clean_vat_logits = P(target_data) 231 | perturb_inputs = target_data + perturb 232 | perturb_inputs = torch.cat(perturb_inputs.split(bs), 0) 233 | perturb_features = G(perturb_inputs) 234 | perturb_logits = F1(perturb_features)[0] 235 | target_vat_loss2 = 10 * target_consistency_criterion(perturb_logits, clean_vat_logits) 236 | 237 | target_vat_loss2.backward() 238 | group_step([optimizer_g, optimizer_f]) 239 | zero_grad_all() 240 | 241 | 242 | if step % 20 == 0: 243 | print('loss_Perturb: {:.4f}'.format(target_vat_loss2.cpu().data)) 244 | G.zero_grad() 245 | F1.zero_grad() 246 | zero_grad_all() 247 | 248 | if step % args.save_interval == 0 and step > 0: 249 | loss_test, acc_test = test(G, F1, target_loader_test) 250 | loss_val, acc_val = test(G, F1, target_loader_val) 251 | G.train() 252 | F1.train() 253 | 254 | if args.save_check: 255 | print('saving model') 256 | torch.save(G.state_dict(), 257 | os.path.join(args.checkpath, 258 | "G_{}_{}_" 259 | "to_{}_step_{}.pth.tar". 260 | format(args.dataset, args.source, 261 | args.target, step))) 262 | torch.save(F1.state_dict(), 263 | os.path.join(args.checkpath, 264 | "F1_{}_{}_" 265 | "to_{}_step_{}.pth.tar". 266 | format(args.dataset, args.source, 267 | args.target, step))) 268 | 269 | 270 | def test(base, classifier, loader): 271 | base.eval() 272 | classifier.eval() 273 | test_loss = 0 274 | correct = 0 275 | size = 0 276 | num_class = len(class_list) 277 | criterion = nn.CrossEntropyLoss().cuda() 278 | confusion_matrix = torch.zeros(num_class, num_class) 279 | with torch.no_grad(): 280 | for batch_idx, data_t in enumerate(loader): 281 | im_data_t = Variable(data_t[0].cuda()) 282 | gt_labels_t = Variable(data_t[1].cuda()) 283 | feat = base(im_data_t) 284 | _, output1 = classifier(feat) 285 | size += im_data_t.size(0) 286 | pred1 = output1.data.max(1)[1] 287 | for t, p in zip(gt_labels_t.view(-1), pred1.view(-1)): 288 | confusion_matrix[t.long(), p.long()] += 1 289 | correct += pred1.eq(gt_labels_t.data).cpu().sum() 290 | test_loss += criterion(output1, gt_labels_t) / len(loader) 291 | print('Test set: Average loss: {:.4f}, ' 292 | 'Accuracy: {}/{} F1 ({:.0f}%)'. 293 | format(test_loss, correct, size, 294 | 100. * correct / size)) 295 | return test_loss.data, 100. * float(correct) / size 296 | 297 | 298 | train() 299 | 300 | 301 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TKKim93/APE/28b0cb8353b06b818b9db00209c403876e84dfd2/metric/__init__.py -------------------------------------------------------------------------------- /metric/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | min_var_est = 1e-8 4 | 5 | def _mix_rbf_kernel(X, Y, sigma_list): 6 | assert(X.size(0) == Y.size(0)) 7 | m = X.size(0) 8 | 9 | Z = torch.cat((X, Y), 0) 10 | ZZT = torch.mm(Z, Z.t()) 11 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 12 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 13 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 14 | 15 | K = 0.0 16 | for sigma in sigma_list: 17 | gamma = 1.0 / (2 * sigma**2) 18 | K += torch.exp(-gamma * exponent) 19 | 20 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 21 | 22 | 23 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 24 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 25 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 26 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 27 | 28 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 29 | m = K_XX.size(0) # assume X, Y are same shape 30 | 31 | # Get the various sums of kernels that we'll use 32 | # Kts drop the diagonal, but we don't need to compute them explicitly 33 | if const_diagonal is not False: 34 | diag_X = diag_Y = const_diagonal 35 | sum_diag_X = sum_diag_Y = m * const_diagonal 36 | else: 37 | diag_X = torch.diag(K_XX) # (m,) 38 | diag_Y = torch.diag(K_YY) # (m,) 39 | sum_diag_X = torch.sum(diag_X) 40 | sum_diag_Y = torch.sum(diag_Y) 41 | 42 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 43 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 44 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 45 | 46 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 47 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 48 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 49 | 50 | if biased: 51 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 52 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 53 | - 2.0 * K_XY_sum / (m * m)) 54 | else: 55 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 56 | + Kt_YY_sum / (m * (m - 1)) 57 | - 2.0 * K_XY_sum / (m * m)) 58 | 59 | return mmd2 60 | 61 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/basenet.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | 7 | 8 | class GradReverse(Function): 9 | def __init__(self, lambd): 10 | self.lambd = lambd 11 | 12 | def forward(self, x): 13 | return x.view_as(x) 14 | 15 | def backward(self, grad_output): 16 | return (grad_output * -self.lambd) 17 | 18 | 19 | def grad_reverse(x, lambd=1.0): 20 | return GradReverse(lambd)(x) 21 | 22 | 23 | def l2_norm(input): 24 | input_size = input.size() 25 | buffer = torch.pow(input, 2) 26 | 27 | normp = torch.sum(buffer, 1).add_(1e-10) 28 | norm = torch.sqrt(normp) 29 | 30 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 31 | 32 | output = _output.view(input_size) 33 | 34 | return output 35 | 36 | 37 | class AlexNetBase(nn.Module): 38 | def __init__(self, pret=True): 39 | super(AlexNetBase, self).__init__() 40 | model_alexnet = models.alexnet(pretrained=pret) 41 | self.features = nn.Sequential(*list(model_alexnet. 42 | features._modules.values())[:]) 43 | self.classifier = nn.Sequential() 44 | for i in range(6): 45 | self.classifier.add_module("classifier" + str(i), 46 | model_alexnet.classifier[i]) 47 | self.__in_features = model_alexnet.classifier[6].in_features 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = x.view(x.size(0), 256 * 6 * 6) 52 | x = self.classifier(x) 53 | return x 54 | 55 | def output_num(self): 56 | return self.__in_features 57 | 58 | 59 | class Predictor_latent(nn.Module): 60 | def __init__(self, num_class=64, inc=4096, temp=0.05): 61 | super(Predictor_latent, self).__init__() 62 | self.fc = nn.Linear(inc, num_class, bias=False) 63 | self.num_class = num_class 64 | self.temp = temp 65 | 66 | def forward(self, x, reverse=False, eta=0.1): 67 | # x = self.bn(x) 68 | if reverse: 69 | x = grad_reverse(x, eta) 70 | x = F.normalize(x) 71 | x_out = self.fc(x) / self.temp 72 | return x, x_out 73 | 74 | 75 | class Predictor_deep_latent(nn.Module): 76 | def __init__(self, num_class=64, inc=4096, temp=0.05): 77 | super(Predictor_deep_latent, self).__init__() 78 | self.fc1 = nn.Linear(inc, 512) 79 | self.fc2 = nn.Linear(512, num_class, bias=False) 80 | self.num_class = num_class 81 | self.temp = temp 82 | 83 | def forward(self, x, reverse=False, eta=0.1): 84 | x = self.fc1(x) 85 | if reverse: 86 | x = grad_reverse(x, eta) 87 | x = F.normalize(x) 88 | x_out = self.fc2(x) / self.temp 89 | return x, x_out 90 | 91 | 92 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from torch.autograd import Function 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | model_urls = { 15 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 18 | 'resnet101': 19 | 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 21 | 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 22 | } 23 | 24 | 25 | def init_weights(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Conv2d') != -1 or \ 28 | classname.find('ConvTranspose2d') != -1: 29 | nn.init.kaiming_uniform_(m.weight) 30 | nn.init.zeros_(m.bias) 31 | elif classname.find('BatchNorm') != -1: 32 | nn.init.normal_(m.weight, 1.0, 0.02) 33 | nn.init.zeros_(m.bias) 34 | elif classname.find('Linear') != -1: 35 | nn.init.xavier_normal_(m.weight) 36 | 37 | 38 | def conv3x3(in_planes, out_planes, stride=1): 39 | "3x3 convolution with padding" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 41 | padding=1, bias=False) 42 | 43 | 44 | class GradReverse(Function): 45 | def __init__(self, lambd): 46 | self.lambd = lambd 47 | 48 | def forward(self, x): 49 | return x.view_as(x) 50 | 51 | def backward(self, grad_output): 52 | return (grad_output * -self.lambd) 53 | 54 | 55 | def grad_reverse(x, lambd=1.0): 56 | return GradReverse(lambd)(x) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None, nobn=False): 63 | super(BasicBlock, self).__init__() 64 | self.conv1 = conv3x3(inplanes, planes, stride) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = conv3x3(planes, planes) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.downsample = downsample 70 | self.stride = stride 71 | self.nobn = nobn 72 | 73 | def forward(self, x, source=True): 74 | 75 | residual = x 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | 91 | class ScaleLayer(nn.Module): 92 | def __init__(self, init_value=1e-3): 93 | super(ScaleLayer, self).__init__() 94 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 95 | 96 | def forward(self, input): 97 | print(self.scale) 98 | return input * self.scale 99 | 100 | 101 | class Bottleneck(nn.Module): 102 | expansion = 4 103 | 104 | def __init__(self, inplanes, planes, stride=1, downsample=None, nobn=False): 105 | super(Bottleneck, self).__init__() 106 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, 107 | stride=stride, bias=False) 108 | self.bn1 = nn.BatchNorm2d(planes) 109 | 110 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 111 | stride=1, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(planes) 113 | 114 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(planes * 4) 116 | 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | 120 | self.stride = stride 121 | self.nobn = nobn 122 | 123 | def forward(self, x): 124 | residual = x 125 | out = self.conv1(x) 126 | out = self.bn1(out) 127 | out = self.relu(out) 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | out = self.conv3(out) 132 | out = self.bn3(out) 133 | if self.downsample is not None: 134 | residual = self.downsample(x) 135 | out += residual 136 | out = self.relu(out) 137 | 138 | return out 139 | 140 | 141 | class ResNet(nn.Module): 142 | def __init__(self, block, layers, num_classes=1000): 143 | self.inplanes = 64 144 | super(ResNet, self).__init__() 145 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = nn.BatchNorm2d(64) 148 | self.in1 = nn.InstanceNorm2d(64) 149 | self.in2 = nn.InstanceNorm2d(128) 150 | self.relu = nn.ReLU(inplace=True) 151 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, 152 | padding=0, ceil_mode=True) 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 155 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 156 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 157 | self.avgpool = nn.AvgPool2d(7) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def stash_grad(self, grad_dict): 169 | for k, v in self.named_parameters(): 170 | if k in grad_dict: 171 | grad_dict[k] += v.grad.clone() 172 | else: 173 | grad_dict[k] = v.grad.clone() 174 | self.zero_grad() 175 | return grad_dict 176 | 177 | def restore_grad(self, grad_dict): 178 | for k, v in self.named_parameters(): 179 | grad = grad_dict[k] if k in grad_dict else torch.zeros_like(v.grad) 180 | 181 | if v.grad is None: 182 | v.grad = grad 183 | else: 184 | v.grad += grad 185 | 186 | def _make_layer(self, block, planes, blocks, stride=1, nobn=False): 187 | downsample = None 188 | if stride != 1 or self.inplanes != planes * block.expansion: 189 | downsample = nn.Sequential( 190 | nn.Conv2d(self.inplanes, planes * block.expansion, 191 | kernel_size=1, stride=stride, bias=False), 192 | nn.BatchNorm2d(planes * block.expansion), 193 | ) 194 | 195 | layers = [] 196 | layers.append(block(self.inplanes, planes, stride, downsample)) 197 | self.inplanes = planes * block.expansion 198 | 199 | for i in range(1, blocks): 200 | layers.append(block(self.inplanes, planes, nobn=nobn)) 201 | return nn.Sequential(*layers) 202 | 203 | def forward(self, x): 204 | x = self.conv1(x) 205 | x = self.bn1(x) 206 | x = self.relu(x) 207 | x = self.maxpool(x) 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | x = self.avgpool(x) 213 | x = x.view(x.size(0), -1) 214 | return x 215 | 216 | 217 | def resnet18(pretrained=True): 218 | """Constructs a ResNet-18 model. 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | """ 222 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 223 | if pretrained: 224 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 225 | model_dict = model.state_dict() 226 | # 1. filter out unnecessary keys 227 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 228 | if k in model_dict} 229 | # 2. overwrite entries in the existing state dict 230 | model_dict.update(pretrained_dict) 231 | # 3. load the new state dict 232 | model.load_state_dict(model_dict) 233 | return model 234 | 235 | 236 | def resnet34(pretrained=True): 237 | """Constructs a ResNet-34 model. 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 244 | return model 245 | 246 | 247 | def resnet50(pretrained=True): 248 | """Constructs a ResNet-50 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 253 | if pretrained: 254 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 255 | model_dict = model.state_dict() 256 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 257 | if k in model_dict} 258 | model_dict.update(pretrained_dict) 259 | model.load_state_dict(model_dict) 260 | return model 261 | 262 | 263 | def resnet101(pretrained=False): 264 | """Constructs a ResNet-101 model. 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 269 | if pretrained: 270 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 271 | return model 272 | 273 | 274 | def resnet152(pretrained=False): 275 | """Constructs a ResNet-152 model. 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | """ 279 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 280 | if pretrained: 281 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 282 | return model 283 | -------------------------------------------------------------------------------- /perturb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import contextlib 6 | 7 | 8 | def set_requires_grad(model, requires_grad): 9 | for param in model.parameters(): 10 | param.requires_grad = requires_grad 11 | 12 | @contextlib.contextmanager 13 | def disable_tracking_bn_stats(model): 14 | def switch_attr(m): 15 | if hasattr(m, 'track_running_stats'): 16 | m.track_running_stats ^= True 17 | 18 | model.apply(switch_attr) 19 | yield 20 | model.apply(switch_attr) 21 | 22 | def l2_normalize(d): 23 | d_reshaped = d.view(d.size(0), -1, *(1 for _ in range(d.dim() - 2))) 24 | d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8 25 | return d 26 | 27 | class AbstractConsistencyLoss(nn.Module): 28 | def __init__(self, reduction='mean'): 29 | super().__init__() 30 | self.reduction = reduction 31 | 32 | def forward(self, logits1, logits2): 33 | raise NotImplementedError 34 | 35 | class KLDivLossWithLogits(AbstractConsistencyLoss): 36 | def __init__(self, reduction='mean'): 37 | super().__init__(reduction) 38 | self.kl_div_loss = nn.KLDivLoss(reduction=reduction) 39 | 40 | def forward(self, logits1, logits2): 41 | return self.kl_div_loss(F.log_softmax(logits1, dim=1), F.softmax(logits2, dim=1)) 42 | 43 | 44 | class PerturbationGenerator(nn.Module): 45 | 46 | def __init__(self, feature_extractor, classifier, xi=1e-6, eps=3.5, ip=1): 47 | super().__init__() 48 | self.feature_extractor = feature_extractor 49 | self.classifier = classifier 50 | self.xi = xi 51 | self.eps = eps 52 | self.ip = ip 53 | self.kl_div = KLDivLossWithLogits() 54 | 55 | def forward(self, inputs): 56 | with disable_tracking_bn_stats(self.feature_extractor): 57 | with disable_tracking_bn_stats(self.classifier): 58 | features = self.feature_extractor(inputs) 59 | logits = self.classifier(features)[1].detach() 60 | 61 | # prepare random unit tensor 62 | d = l2_normalize(torch.randn_like(inputs).to(inputs.device)) 63 | 64 | # calc adversarial direction 65 | x_hat = inputs 66 | x_hat = x_hat + self.xi * d 67 | x_hat.requires_grad = True 68 | features_hat = self.feature_extractor(x_hat) 69 | logits_hat = self.classifier(features_hat, reverse=True, eta=1)[1] 70 | prob_hat = F.softmax(logits_hat, 1) 71 | adv_distance = (prob_hat * torch.log(1e-4 + prob_hat)).sum(1).mean() 72 | adv_distance.backward() 73 | d = l2_normalize(x_hat.grad) 74 | self.feature_extractor.zero_grad() 75 | self.classifier.zero_grad() 76 | r_adv = d * self.eps 77 | return r_adv.detach(), features 78 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | numpy==1.14.3 3 | torchvision==0.2.1 4 | tqdm==4.26.0 5 | Pillow==6.0.0 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | from model.resnet import resnet34 10 | from model.basenet import AlexNetBase, Predictor_latent, Predictor_deep_latent, grad_reverse 11 | from utils.lr_schedule import inv_lr_scheduler 12 | from utils.return_dataset import return_dataset, return_dataset_test 13 | import torch.nn.functional as F 14 | import metric.mmd as mmd 15 | from perturb import PerturbationGenerator 16 | from utils.utils import weights_init, group_step 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='SSDA Classification') 20 | parser.add_argument('--steps', type=int, default=50000, metavar='N', 21 | help='maximum number of iterations ' 22 | 'to train (default: 50000)') 23 | parser.add_argument('--multi', type=float, default=0.1, metavar='MLT', 24 | help='learning rate multiplication') 25 | parser.add_argument('--T', type=float, default=0.05, metavar='T', 26 | help='temperature (default: 0.05)') 27 | parser.add_argument('--checkpath', type=str, default='./checkpoints', 28 | help='dir to save checkpoint') 29 | parser.add_argument('--net', type=str, default='resnet34', 30 | help='which network to use') 31 | parser.add_argument('--source', type=str, default='real', 32 | help='source domain') 33 | parser.add_argument('--target', type=str, default='sketch', 34 | help='target domain') 35 | parser.add_argument('--dataset', type=str, default='multi', 36 | choices=['multi', 'office', 'office_home'], 37 | help='the name of dataset') 38 | parser.add_argument('--num', type=int, default=3, 39 | help='number of labeled examples in the target') 40 | 41 | args = parser.parse_args() 42 | target_loader_test, class_list = return_dataset_test(args) 43 | use_gpu = torch.cuda.is_available() 44 | 45 | if args.net == 'resnet34': 46 | G = resnet34() 47 | inc = 512 48 | elif args.net == "alexnet": 49 | G = AlexNetBase() 50 | inc = 4096 51 | else: 52 | raise ValueError('Model cannot be recognized.') 53 | 54 | 55 | if "resnet" in args.net: 56 | F1 = Predictor_deep_latent(num_class=len(class_list), inc=inc) 57 | else: 58 | F1 = Predictor_latent(num_class=len(class_list), inc=inc, temp=args.T) 59 | G = torch.nn.DataParallel(G).cuda() 60 | F1 = torch.nn.DataParallel(F1).cuda() 61 | 62 | G_dict = os.path.join(args.checkpath, "G_{}_{}_to_{}_step_{}.pth.tar".format(args.dataset, args.source, args.target, args.steps)) 63 | pretrained_dict = torch.load(G_dict) 64 | model_dict = G.state_dict() 65 | model_dict.update(pretrained_dict) 66 | G.load_state_dict(model_dict) 67 | 68 | F_dict = os.path.join(args.checkpath, "F1_{}_{}_to_{}_step_{}.pth.tar".format(args.dataset, args.source, args.target, args.steps)) 69 | pretrained_dict = torch.load(F_dict) 70 | model_dict = F1.state_dict() 71 | model_dict.update(pretrained_dict) 72 | F1.load_state_dict(model_dict) 73 | 74 | 75 | 76 | def test(base, classifier, loader): 77 | base.eval() 78 | classifier.eval() 79 | test_loss = 0 80 | correct = 0 81 | size = 0 82 | num_class = len(class_list) 83 | criterion = nn.CrossEntropyLoss().cuda() 84 | confusion_matrix = torch.zeros(num_class, num_class) 85 | with torch.no_grad(): 86 | for batch_idx, data_t in enumerate(loader): 87 | im_data_t = Variable(data_t[0].cuda()) 88 | gt_labels_t = Variable(data_t[1].cuda()) 89 | feat = base(im_data_t) 90 | _, output1 = classifier(feat) 91 | size += im_data_t.size(0) 92 | pred1 = output1.data.max(1)[1] 93 | for t, p in zip(gt_labels_t.view(-1), pred1.view(-1)): 94 | confusion_matrix[t.long(), p.long()] += 1 95 | correct += pred1.eq(gt_labels_t.data).cpu().sum() 96 | test_loss += criterion(output1, gt_labels_t) / len(loader) 97 | print('Test set: Average loss: {:.4f}, ' 98 | 'Accuracy: {}/{} F1 ({:.0f}%)'. 99 | format(test_loss, correct, size, 100 | 100. * correct / size)) 101 | return test_loss.data, 100. * float(correct) / size 102 | 103 | loss_test, acc_test = test(G, F1, target_loader_test) 104 | 105 | 106 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/lr_schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=0.0001, 5 | power=0.75, init_lr=0.001): 6 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 7 | lr = init_lr * (1 + gamma * iter_num) ** (- power) 8 | i = 0 9 | for param_group in optimizer.param_groups: 10 | param_group['lr'] = lr * param_lr[i] 11 | i += 1 12 | return optimizer 13 | 14 | 15 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 16 | return np.float(2.0 * (high - low) / 17 | (1.0 + np.exp(- alpha * iter_num / max_iter)) - 18 | (high - low) + low) 19 | -------------------------------------------------------------------------------- /utils/return_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import transforms 4 | from loaders.data_list import Imagelists_VISDA, return_classlist 5 | 6 | 7 | class ResizeImage(): 8 | def __init__(self, size): 9 | if isinstance(size, int): 10 | self.size = (int(size), int(size)) 11 | else: 12 | self.size = size 13 | 14 | def __call__(self, img): 15 | th, tw = self.size 16 | return img.resize((th, tw)) 17 | 18 | 19 | def return_dataset(args): 20 | base_path = './data/txt/%s' % args.dataset 21 | root = './data/%s' % args.dataset 22 | image_set_file_s = \ 23 | os.path.join(base_path, 24 | 'labeled_source_images_' + 25 | args.source + '.txt') 26 | image_set_file_t = \ 27 | os.path.join(base_path, 28 | 'labeled_target_images_' + 29 | args.target + '_%d.txt' % (args.num)) 30 | image_set_file_t_val = \ 31 | os.path.join(base_path, 32 | 'validation_target_images_' + 33 | args.target + '_3.txt') 34 | image_set_file_unl = \ 35 | os.path.join(base_path, 36 | 'unlabeled_target_images_' + 37 | args.target + '_%d.txt' % (args.num)) 38 | 39 | if args.net == 'alexnet': 40 | crop_size = 227 41 | else: 42 | crop_size = 224 43 | data_transforms = { 44 | 'train': transforms.Compose([ 45 | ResizeImage(256), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.RandomCrop(crop_size), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 50 | ]), 51 | 'val': transforms.Compose([ 52 | ResizeImage(256), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.RandomCrop(crop_size), 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 57 | ]), 58 | 'test': transforms.Compose([ 59 | ResizeImage(256), 60 | transforms.CenterCrop(crop_size), 61 | transforms.ToTensor(), 62 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 63 | ]), 64 | } 65 | source_dataset = Imagelists_VISDA(image_set_file_s, root=root, 66 | transform=data_transforms['train']) 67 | target_dataset = Imagelists_VISDA(image_set_file_t, root=root, 68 | transform=data_transforms['val']) 69 | target_dataset_val = Imagelists_VISDA(image_set_file_t_val, root=root, 70 | transform=data_transforms['val']) 71 | target_dataset_unl = Imagelists_VISDA(image_set_file_unl, root=root, 72 | transform=data_transforms['val']) 73 | target_dataset_test = Imagelists_VISDA(image_set_file_unl, root=root, 74 | transform=data_transforms['test']) 75 | class_list = return_classlist(image_set_file_s) 76 | print("%d classes in this dataset" % len(class_list)) 77 | if args.net == 'alexnet': 78 | bs = 24 79 | else: 80 | bs = 16 81 | source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs, 82 | num_workers=2, shuffle=True, 83 | drop_last=True) 84 | target_loader = \ 85 | torch.utils.data.DataLoader(target_dataset, 86 | batch_size=min(bs, len(target_dataset)), 87 | num_workers=2, 88 | shuffle=True, drop_last=True) 89 | target_loader_val = \ 90 | torch.utils.data.DataLoader(target_dataset_val, 91 | batch_size=min(bs, 92 | len(target_dataset_val)), 93 | num_workers=2, 94 | shuffle=True, drop_last=True) 95 | target_loader_unl = \ 96 | torch.utils.data.DataLoader(target_dataset_unl, 97 | batch_size=bs * 2, num_workers=3, 98 | shuffle=True, drop_last=True) 99 | target_loader_test = \ 100 | torch.utils.data.DataLoader(target_dataset_test, 101 | batch_size=bs * 2, num_workers=2, 102 | shuffle=True, drop_last=True) 103 | return source_loader, target_loader, target_loader_unl, \ 104 | target_loader_val, target_loader_test, class_list 105 | 106 | def return_dataset_test(args): 107 | base_path = './data/txt/%s' % args.dataset 108 | root = './data/%s/' % args.dataset 109 | image_set_file_s = \ 110 | os.path.join(base_path, 111 | 'labeled_source_images_' + 112 | args.source + '.txt') 113 | image_set_file_test = os.path.join(base_path, 114 | 'unlabeled_target_images_' + 115 | args.target + '_%d.txt' % (args.num)) 116 | if args.net == 'alexnet': 117 | crop_size = 227 118 | else: 119 | crop_size = 224 120 | 121 | data_transforms = { 122 | 'test': transforms.Compose([ 123 | ResizeImage(256), 124 | transforms.CenterCrop(crop_size), 125 | transforms.ToTensor(), 126 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 127 | ]), 128 | } 129 | target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root, 130 | transform=data_transforms['test'], 131 | test=True) 132 | class_list = return_classlist(image_set_file_s) 133 | print("%d classes in this dataset" % len(class_list)) 134 | if args.net == 'alexnet': 135 | bs = 24 136 | else: 137 | bs = 16 138 | target_loader_unl = \ 139 | torch.utils.data.DataLoader(target_dataset_unl, 140 | batch_size=bs * 2, num_workers=3, 141 | shuffle=False, drop_last=False) 142 | return target_loader_unl, class_list 143 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import shutil 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | m.weight.data.normal_(0.0, 0.1) 11 | elif classname.find('Linear') != -1: 12 | nn.init.xavier_normal_(m.weight) 13 | nn.init.zeros_(m.bias) 14 | elif classname.find('BatchNorm') != -1: 15 | m.weight.data.normal_(1.0, 0.1) 16 | m.bias.data.fill_(0) 17 | 18 | 19 | def save_checkpoint(state, is_best, checkpoint='checkpoint', 20 | filename='checkpoint.pth.tar'): 21 | filepath = os.path.join(checkpoint, filename) 22 | torch.save(state, filepath) 23 | if is_best: 24 | shutil.copyfile(filepath, os.path.join(checkpoint, 25 | 'model_best.pth.tar')) 26 | 27 | def group_step(step_list): 28 | for i in range(len(step_list)): 29 | step_list[i].step() --------------------------------------------------------------------------------