├── gen_list.py ├── LICENSE.md ├── README.md ├── loss.py ├── data_list.py ├── network.py ├── train_source.py ├── distill.py └── adapt_multi.py /gen_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dataset = 'office-home' 4 | 5 | if dataset == 'office': 6 | domains = ['amazon', 'dslr', 'webcam'] 7 | elif dataset == 'office-caltech': 8 | domains = ['amazon', 'dslr', 'webcam', 'caltech'] 9 | elif dataset == 'office-home': 10 | domains = ['Art', 'Clipart', 'Product', 'Real_World'] 11 | else: 12 | print('No such dataset exists!') 13 | 14 | for domain in domains: 15 | log = open(dataset+'/'+domain+'_list.txt','w') 16 | directory = os.path.join(dataset, os.path.join(domain,'images')) 17 | classes = [x[0] for x in os.walk(directory)] 18 | classes = classes[1:] 19 | classes.sort() 20 | for idx,f in enumerate(classes): 21 | files = os.listdir(f) 22 | for file in files: 23 | s = os.path.abspath(os.path.join(f,file)) + ' ' + str(idx) + '\n' 24 | log.write(s) 25 | log.close() -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dripta S. Raychaudhuri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DECISION 2 | Unsupervised Multi-source Domain Adaptation Without Access to Source Data (CVPR '21 Oral) 3 | 4 | ### Overview 5 | This repository is a PyTorch implementation of the paper [Unsupervised Multi-source Domain Adaptation Without Access to Source Data](https://arxiv.org/pdf/2104.01845.pdf) published at [CVPR 2021](http://cvpr2021.thecvf.com/). This code is based on the [SHOT](https://github.com/tim-learn/SHOT) repository. 6 | 7 | ### Dependencies 8 | Create a conda environment with `environment.yml`. 9 | 10 | ### Dataset 11 | - Manually download the datasets [Office](https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view), [Office-Home](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view), [Office-Caltech](http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar) from the official websites. 12 | - Move `gen_list.py` inside data directory. 13 | - Generate '.txt' file for each dataset using `gen_list.py` (change dataset argument in the file accordingly). 14 | 15 | ### Training 16 | - Train source models (shown here for Office with source A) 17 | ``` 18 | python train_source.py --dset office --s 0 --max_epoch 100 --trte val --gpu_id 0 --output ckps/source/ 19 | ``` 20 | - Adapt to target (shown here for Office with target D) 21 | ``` 22 | python adapt_multi.py --dset office --t 1 --max_epoch 15 --gpu_id 0 --output_src ckps/source/ --output ckps/adapt 23 | ``` 24 | - Distill to single target model (shown here for Office with target D) 25 | ``` 26 | python distill.py --dset office --t 1 --max_epoch 15 --gpu_id 0 --output_src ckps/adapt --output ckps/dist 27 | ``` 28 | 29 | ### Citation 30 | If you use this code in your research please consider citing 31 | ``` 32 | @article{ahmed2021unsupervised, 33 | title={Unsupervised Multi-source Domain Adaptation Without Access to Source Data}, 34 | author={Ahmed, Sk Miraj and Raychaudhuri, Dripta S and Paul, Sujoy and Oymak, Samet and Roy-Chowdhury, Amit K}, 35 | journal={arXiv preprint arXiv:2104.01845}, 36 | year={2021} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | class CrossEntropyLabelSmooth(nn.Module): 17 | """Cross entropy loss with label smoothing regularizer. 18 | Reference: 19 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 20 | Equation: y = (1 - epsilon) * y + epsilon / K. 21 | Args: 22 | num_classes (int): number of classes. 23 | epsilon (float): weight. 24 | """ 25 | 26 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 27 | super(CrossEntropyLabelSmooth, self).__init__() 28 | self.num_classes = num_classes 29 | self.epsilon = epsilon 30 | self.use_gpu = use_gpu 31 | self.reduction = reduction 32 | self.logsoftmax = nn.LogSoftmax(dim=1) 33 | 34 | def forward(self, inputs, targets): 35 | """ 36 | Args: 37 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 38 | targets: ground truth labels with shape (num_classes) 39 | """ 40 | log_probs = self.logsoftmax(inputs) 41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 42 | if self.use_gpu: targets = targets.cuda() 43 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 44 | loss = (- targets * log_probs).sum(dim=1) 45 | if self.reduction: 46 | return loss.mean() 47 | else: 48 | return loss 49 | return loss 50 | 51 | 52 | class softCrossEntropy(nn.Module): 53 | def __init__(self): 54 | super(softCrossEntropy, self).__init__() 55 | return 56 | 57 | def forward(self, inputs, target): 58 | """ 59 | :param inputs: predictions 60 | :param target: target labels 61 | :return: loss 62 | """ 63 | log_likelihood = - F.log_softmax(inputs, dim=1) 64 | sample_num, class_num = target.shape 65 | loss = torch.sum(torch.mul(log_likelihood, target))/sample_num 66 | 67 | return loss -------------------------------------------------------------------------------- /data_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import os 7 | import os.path 8 | import cv2 9 | import torchvision 10 | 11 | def make_dataset(image_list, labels): 12 | if labels: 13 | len_ = len(image_list) 14 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 15 | else: 16 | if len(image_list[0].split()) > 2: 17 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 18 | else: 19 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 20 | return images 21 | 22 | 23 | def rgb_loader(path): 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | 28 | def l_loader(path): 29 | with open(path, 'rb') as f: 30 | with Image.open(f) as img: 31 | return img.convert('L') 32 | 33 | class ImageList(Dataset): 34 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 35 | imgs = make_dataset(image_list, labels) 36 | if len(imgs) == 0: 37 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 38 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 39 | 40 | self.imgs = imgs 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | if mode == 'RGB': 44 | self.loader = rgb_loader 45 | elif mode == 'L': 46 | self.loader = l_loader 47 | 48 | def __getitem__(self, index): 49 | path, target = self.imgs[index] 50 | img = self.loader(path) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | if self.target_transform is not None: 54 | target = self.target_transform(target) 55 | 56 | return img, target 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | 61 | class ImageList_idx(Dataset): 62 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 63 | imgs = make_dataset(image_list, labels) 64 | if len(imgs) == 0: 65 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 67 | 68 | self.imgs = imgs 69 | self.transform = transform 70 | self.target_transform = target_transform 71 | if mode == 'RGB': 72 | self.loader = rgb_loader 73 | elif mode == 'L': 74 | self.loader = l_loader 75 | 76 | def __getitem__(self, index): 77 | path, target = self.imgs[index] 78 | img = self.loader(path) 79 | if self.transform is not None: 80 | img = self.transform(img) 81 | if self.target_transform is not None: 82 | target = self.target_transform(target) 83 | 84 | return img, target, index 85 | 86 | def __len__(self): 87 | return len(self.imgs) -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.utils.weight_norm as weightNorm 9 | from collections import OrderedDict 10 | 11 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 12 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 13 | 14 | def init_weights(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 17 | nn.init.kaiming_uniform_(m.weight) 18 | nn.init.zeros_(m.bias) 19 | elif classname.find('BatchNorm') != -1: 20 | nn.init.normal_(m.weight, 1.0, 0.02) 21 | nn.init.zeros_(m.bias) 22 | elif classname.find('Linear') != -1: 23 | nn.init.xavier_normal_(m.weight) 24 | nn.init.zeros_(m.bias) 25 | 26 | vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19, 27 | "vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn} 28 | class VGGBase(nn.Module): 29 | def __init__(self, vgg_name): 30 | super(VGGBase, self).__init__() 31 | model_vgg = vgg_dict[vgg_name](pretrained=True) 32 | self.features = model_vgg.features 33 | self.classifier = nn.Sequential() 34 | for i in range(6): 35 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 36 | self.in_features = model_vgg.classifier[6].in_features 37 | 38 | def forward(self, x): 39 | x = self.features(x) 40 | x = x.view(x.size(0), -1) 41 | x = self.classifier(x) 42 | return x 43 | 44 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50, 45 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d} 46 | 47 | class ResBase(nn.Module): 48 | def __init__(self, res_name): 49 | super(ResBase, self).__init__() 50 | model_resnet = res_dict[res_name](pretrained=True) 51 | self.conv1 = model_resnet.conv1 52 | self.bn1 = model_resnet.bn1 53 | self.relu = model_resnet.relu 54 | self.maxpool = model_resnet.maxpool 55 | self.layer1 = model_resnet.layer1 56 | self.layer2 = model_resnet.layer2 57 | self.layer3 = model_resnet.layer3 58 | self.layer4 = model_resnet.layer4 59 | self.avgpool = model_resnet.avgpool 60 | self.in_features = model_resnet.fc.in_features 61 | 62 | def forward(self, x): 63 | x = self.conv1(x) 64 | x = self.bn1(x) 65 | x = self.relu(x) 66 | x = self.maxpool(x) 67 | x = self.layer1(x) 68 | x = self.layer2(x) 69 | x = self.layer3(x) 70 | x = self.layer4(x) 71 | x = self.avgpool(x) 72 | x = x.view(x.size(0), -1) 73 | return x 74 | 75 | class feat_bottleneck(nn.Module): 76 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 77 | super(feat_bottleneck, self).__init__() 78 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.dropout = nn.Dropout(p=0.5) 81 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 82 | self.bottleneck.apply(init_weights) 83 | self.type = type 84 | 85 | def forward(self, x): 86 | x = self.bottleneck(x) 87 | if self.type == "bn": 88 | x = self.bn(x) 89 | return x 90 | 91 | class feat_classifier(nn.Module): 92 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 93 | super(feat_classifier, self).__init__() 94 | self.type = type 95 | if type == 'wn': 96 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 97 | self.fc.apply(init_weights) 98 | else: 99 | self.fc = nn.Linear(bottleneck_dim, class_num) 100 | self.fc.apply(init_weights) 101 | 102 | def forward(self, x): 103 | x = self.fc(x) 104 | return x 105 | 106 | class feat_classifier_two(nn.Module): 107 | def __init__(self, class_num, input_dim, bottleneck_dim=256): 108 | super(feat_classifier_two, self).__init__() 109 | self.type = type 110 | self.fc0 = nn.Linear(input_dim, bottleneck_dim) 111 | self.fc0.apply(init_weights) 112 | self.fc1 = nn.Linear(bottleneck_dim, class_num) 113 | self.fc1.apply(init_weights) 114 | 115 | def forward(self, x): 116 | x = self.fc0(x) 117 | x = self.fc1(x) 118 | return x 119 | 120 | class Res50(nn.Module): 121 | def __init__(self): 122 | super(Res50, self).__init__() 123 | model_resnet = models.resnet50(pretrained=True) 124 | self.conv1 = model_resnet.conv1 125 | self.bn1 = model_resnet.bn1 126 | self.relu = model_resnet.relu 127 | self.maxpool = model_resnet.maxpool 128 | self.layer1 = model_resnet.layer1 129 | self.layer2 = model_resnet.layer2 130 | self.layer3 = model_resnet.layer3 131 | self.layer4 = model_resnet.layer4 132 | self.avgpool = model_resnet.avgpool 133 | self.in_features = model_resnet.fc.in_features 134 | self.fc = model_resnet.fc 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | y = self.fc(x) 148 | return x, y 149 | 150 | 151 | class scalar(nn.Module): 152 | def __init__(self, init_weights): 153 | super(scalar, self).__init__() 154 | self.w = nn.Parameter(torch.tensor(1.)*init_weights) 155 | 156 | def forward(self,x): 157 | x = self.w*torch.ones((x.shape[0]),1).cuda() 158 | x = torch.sigmoid(x) 159 | return x -------------------------------------------------------------------------------- /train_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from loss import CrossEntropyLabelSmooth 16 | from scipy.spatial.distance import cdist 17 | from sklearn.metrics import confusion_matrix 18 | from sklearn.cluster import KMeans 19 | 20 | def op_copy(optimizer): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr0'] = param_group['lr'] 23 | return optimizer 24 | 25 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 26 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = param_group['lr0'] * decay 29 | param_group['weight_decay'] = 1e-3 30 | param_group['momentum'] = 0.9 31 | param_group['nesterov'] = True 32 | return optimizer 33 | 34 | def image_train(resize_size=256, crop_size=224, alexnet=False): 35 | if not alexnet: 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | else: 39 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 40 | return transforms.Compose([ 41 | transforms.Resize((resize_size, resize_size)), 42 | transforms.RandomCrop(crop_size), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize 46 | ]) 47 | 48 | def image_test(resize_size=256, crop_size=224, alexnet=False): 49 | if not alexnet: 50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | else: 53 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 54 | return transforms.Compose([ 55 | transforms.Resize((resize_size, resize_size)), 56 | transforms.CenterCrop(crop_size), 57 | transforms.ToTensor(), 58 | normalize 59 | ]) 60 | 61 | def data_load(args): 62 | ## prepare data 63 | dsets = {} 64 | dset_loaders = {} 65 | train_bs = args.batch_size 66 | txt_src = open(args.s_dset_path).readlines() 67 | txt_test = open(args.test_dset_path).readlines() 68 | 69 | if args.trte == "val": 70 | dsize = len(txt_src) 71 | tr_size = int(0.9*dsize) 72 | # print(dsize, tr_size, dsize - tr_size) 73 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 74 | else: 75 | dsize = len(txt_src) 76 | tr_size = int(0.9*dsize) 77 | _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 78 | tr_txt = txt_src 79 | 80 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train()) 81 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 82 | dsets["source_te"] = ImageList(te_txt, transform=image_test()) 83 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 84 | dsets["test"] = ImageList(txt_test, transform=image_test()) 85 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False) 86 | 87 | return dset_loaders 88 | 89 | def cal_acc(loader, netF, netB, netC, flag=False): 90 | start_test = True 91 | with torch.no_grad(): 92 | iter_test = iter(loader) 93 | for i in range(len(loader)): 94 | data = iter_test.next() 95 | inputs = data[0] 96 | labels = data[1] 97 | inputs = inputs.cuda() 98 | outputs = netC(netB(netF(inputs))) 99 | if start_test: 100 | all_output = outputs.float().cpu() 101 | all_label = labels.float() 102 | start_test = False 103 | else: 104 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 105 | all_label = torch.cat((all_label, labels.float()), 0) 106 | 107 | all_output = nn.Softmax(dim=1)(all_output) 108 | _, predict = torch.max(all_output, 1) 109 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 110 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 111 | 112 | if flag: 113 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 114 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 115 | aacc = acc.mean() 116 | aa = [str(np.round(i, 2)) for i in acc] 117 | acc = ' '.join(aa) 118 | return aacc, acc 119 | else: 120 | return accuracy*100, mean_ent 121 | 122 | def train_source(args): 123 | dset_loaders = data_load(args) 124 | ## set base network 125 | if args.net[0:3] == 'res': 126 | netF = network.ResBase(res_name=args.net).cuda() 127 | elif args.net[0:3] == 'vgg': 128 | netF = network.VGGBase(vgg_name=args.net).cuda() 129 | 130 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 131 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 132 | 133 | param_group = [] 134 | learning_rate = args.lr 135 | for k, v in netF.named_parameters(): 136 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 137 | for k, v in netB.named_parameters(): 138 | param_group += [{'params': v, 'lr': learning_rate}] 139 | for k, v in netC.named_parameters(): 140 | param_group += [{'params': v, 'lr': learning_rate}] 141 | optimizer = optim.SGD(param_group) 142 | optimizer = op_copy(optimizer) 143 | 144 | acc_init = 0 145 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 146 | interval_iter = max_iter // 10 147 | iter_num = 0 148 | 149 | netF.train() 150 | netB.train() 151 | netC.train() 152 | 153 | while iter_num < max_iter: 154 | try: 155 | inputs_source, labels_source = iter_source.next() 156 | except: 157 | iter_source = iter(dset_loaders["source_tr"]) 158 | inputs_source, labels_source = iter_source.next() 159 | 160 | if inputs_source.size(0) == 1: 161 | continue 162 | 163 | iter_num += 1 164 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 165 | 166 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 167 | outputs_source = netC(netB(netF(inputs_source))) 168 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 169 | 170 | optimizer.zero_grad() 171 | classifier_loss.backward() 172 | optimizer.step() 173 | 174 | if iter_num % interval_iter == 0 or iter_num == max_iter: 175 | netF.eval() 176 | netB.eval() 177 | netC.eval() 178 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False) 179 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) 180 | args.out_file.write(log_str + '\n') 181 | args.out_file.flush() 182 | print(log_str+'\n') 183 | 184 | if acc_s_te >= acc_init: 185 | acc_init = acc_s_te 186 | best_netF = netF.state_dict() 187 | best_netB = netB.state_dict() 188 | best_netC = netC.state_dict() 189 | 190 | netF.train() 191 | netB.train() 192 | netC.train() 193 | 194 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 195 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 196 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 197 | 198 | return netF, netB, netC 199 | 200 | def test_target(args): 201 | dset_loaders = data_load(args) 202 | ## set base network 203 | if args.net[0:3] == 'res': 204 | netF = network.ResBase(res_name=args.net).cuda() 205 | elif args.net[0:3] == 'vgg': 206 | netF = network.VGGBase(vgg_name=args.net).cuda() 207 | 208 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 209 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 210 | 211 | args.modelpath = args.output_dir_src + '/source_F.pt' 212 | netF.load_state_dict(torch.load(args.modelpath)) 213 | args.modelpath = args.output_dir_src + '/source_B.pt' 214 | netB.load_state_dict(torch.load(args.modelpath)) 215 | args.modelpath = args.output_dir_src + '/source_C.pt' 216 | netC.load_state_dict(torch.load(args.modelpath)) 217 | netF.eval() 218 | netB.eval() 219 | netC.eval() 220 | 221 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 222 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) 223 | 224 | args.out_file.write(log_str) 225 | args.out_file.flush() 226 | print(log_str) 227 | 228 | def print_args(args): 229 | s = "==========================================\n" 230 | for arg, content in args.__dict__.items(): 231 | s += "{}:{}\n".format(arg, content) 232 | return s 233 | 234 | if __name__ == "__main__": 235 | parser = argparse.ArgumentParser(description='SHOT') 236 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 237 | parser.add_argument('--s', type=int, default=0, help="source") 238 | parser.add_argument('--t', type=int, default=1, help="target") 239 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 240 | parser.add_argument('--batch_size', type=int, default=32, help="batch_size") 241 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 242 | parser.add_argument('--dset', type=str, default='office-home', choices=['office', 'office-home', 'office-caltech']) 243 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 244 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 245 | parser.add_argument('--seed', type=int, default=2021, help="random seed") 246 | parser.add_argument('--bottleneck', type=int, default=256) 247 | parser.add_argument('--epsilon', type=float, default=1e-5) 248 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 249 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 250 | parser.add_argument('--smooth', type=float, default=0.1) 251 | parser.add_argument('--output', type=str, default='ckps/source') 252 | parser.add_argument('--trte', type=str, default='val', choices=['full', 'val']) 253 | args = parser.parse_args() 254 | 255 | if args.dset == 'office-home': 256 | names = ['Art', 'Clipart', 'Product', 'Real_World'] 257 | args.class_num = 65 258 | if args.dset == 'office': 259 | names = ['amazon', 'dslr', 'webcam'] 260 | args.class_num = 31 261 | if args.dset == 'office-caltech': 262 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 263 | args.class_num = 10 264 | 265 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 266 | SEED = args.seed 267 | torch.manual_seed(SEED) 268 | torch.cuda.manual_seed(SEED) 269 | np.random.seed(SEED) 270 | random.seed(SEED) 271 | # torch.backends.cudnn.deterministic = True 272 | 273 | folder = './data/' 274 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 275 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 276 | 277 | args.output_dir_src = osp.join(args.output, args.dset, names[args.s][0].upper()) 278 | args.name_src = names[args.s][0].upper() 279 | if not osp.exists(args.output_dir_src): 280 | os.system('mkdir -p ' + args.output_dir_src) 281 | if not osp.exists(args.output_dir_src): 282 | os.mkdir(args.output_dir_src) 283 | 284 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 285 | args.out_file.write(print_args(args)+'\n') 286 | args.out_file.flush() 287 | 288 | train_source(args) 289 | 290 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 291 | for i in range(len(names)): 292 | if i == args.s: 293 | continue 294 | args.t = i 295 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 296 | 297 | folder = 'data/' 298 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 299 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 300 | 301 | test_target(args) 302 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | from loss import softCrossEntropy 18 | 19 | 20 | def op_copy(optimizer): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr0'] = param_group['lr'] 23 | return optimizer 24 | 25 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 26 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = param_group['lr0'] * decay 29 | param_group['weight_decay'] = 1e-3 30 | param_group['momentum'] = 0.9 31 | param_group['nesterov'] = True 32 | return optimizer 33 | 34 | def get_labels(inputs, netF_list, netB_list, netC_list, netG_list): 35 | with torch.no_grad(): 36 | inputs = inputs.cuda() 37 | outputs_all = torch.zeros(len(args.src), inputs.shape[0], args.class_num) 38 | weights_all = torch.ones(inputs.shape[0], len(args.src)) 39 | outputs_all_w = torch.zeros(inputs.shape[0], args.class_num) 40 | 41 | for i in range(len(args.src)): 42 | features = netB_list[i](netF_list[i](inputs)) 43 | outputs = netC_list[i](features) 44 | weights = netG_list[i](features) 45 | outputs_all[i] = outputs 46 | weights_all[:, i] = weights.squeeze() 47 | 48 | z = torch.sum(weights_all, dim=1) 49 | z = z + 1e-16 50 | 51 | weights_all = torch.transpose(torch.transpose(weights_all,0,1)/z,0,1) 52 | # print(weights_all.mean(dim=0)) 53 | outputs_all = torch.transpose(outputs_all, 0, 1) 54 | for i in range(inputs.shape[0]): 55 | outputs_all_w[i] = torch.matmul(torch.transpose(outputs_all[i],0,1), weights_all[i]) 56 | 57 | all_output = outputs_all_w.float().cpu() 58 | 59 | _, predict = torch.max(all_output, 1) 60 | 61 | return predict, all_output 62 | 63 | def data_load(args): 64 | ## prepare data 65 | dsets = {} 66 | dset_loaders = {} 67 | train_bs = args.batch_size 68 | txt_tar = open(args.t_dset_path).readlines() 69 | txt_test = open(args.test_dset_path).readlines() 70 | 71 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 72 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 73 | dsets['target_'] = ImageList_idx(txt_tar, transform=image_train()) 74 | dset_loaders['target_'] = DataLoader(dsets['target_'], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 75 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 76 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 77 | return dset_loaders 78 | 79 | def image_train(resize_size=256, crop_size=224, alexnet=False): 80 | if not alexnet: 81 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 82 | std=[0.229, 0.224, 0.225]) 83 | else: 84 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 85 | return transforms.Compose([ 86 | transforms.Resize((resize_size, resize_size)), 87 | transforms.RandomCrop(crop_size), 88 | transforms.RandomHorizontalFlip(), 89 | transforms.ToTensor(), 90 | normalize 91 | ]) 92 | 93 | def image_test(resize_size=256, crop_size=224, alexnet=False): 94 | if not alexnet: 95 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 96 | std=[0.229, 0.224, 0.225]) 97 | else: 98 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 99 | return transforms.Compose([ 100 | transforms.Resize((resize_size, resize_size)), 101 | transforms.CenterCrop(crop_size), 102 | transforms.ToTensor(), 103 | normalize 104 | ]) 105 | 106 | def cal_acc(loader, netF, netB, netC, flag=False): 107 | start_test = True 108 | with torch.no_grad(): 109 | iter_test = iter(loader) 110 | for i in range(len(loader)): 111 | data = iter_test.next() 112 | inputs = data[0] 113 | labels = data[1] 114 | inputs = inputs.cuda() 115 | outputs = netC(netB(netF(inputs))) 116 | if start_test: 117 | all_output = outputs.float().cpu() 118 | all_label = labels.float() 119 | start_test = False 120 | else: 121 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 122 | all_label = torch.cat((all_label, labels.float()), 0) 123 | _, predict = torch.max(all_output, 1) 124 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 125 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 126 | 127 | if flag: 128 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 129 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 130 | aacc = acc.mean() 131 | aa = [str(np.round(i, 2)) for i in acc] 132 | acc = ' '.join(aa) 133 | return aacc, acc 134 | else: 135 | return accuracy*100, mean_ent 136 | 137 | def train_distill(args): 138 | dset_loaders = data_load(args) 139 | # load sources 140 | if args.net[0:3] == 'res': 141 | netF_list = [network.ResBase(res_name=args.net).cuda() for i in range(len(args.src))] 142 | netF = network.ResBase(res_name=args.net).cuda() 143 | elif args.net[0:3] == 'vgg': 144 | netF_list = [network.VGGBase(vgg_name=args.net).cuda() for i in range(len(args.src))] 145 | netF = network.VGGBase(res_name=args.net).cuda() 146 | 147 | netB_list = [network.feat_bottleneck(type=args.classifier, feature_dim=netF_list[i].in_features, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] 148 | netC_list = [network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] 149 | netG_list = [network.scalar(1).cuda() for i in range(len(args.src))] 150 | 151 | for i in range(len(args.src)): 152 | modelpath = args.output_dir_src + '/target_F_'+str(i)+'_par_0.3.pt' 153 | netF_list[i].load_state_dict(torch.load(modelpath)) 154 | netF_list[i].eval() 155 | netF_list[i].cuda() 156 | for k, v in netF_list[i].named_parameters(): 157 | v.requires_grad = False 158 | 159 | modelpath = args.output_dir_src + '/target_B_'+str(i)+'_par_0.3.pt' 160 | netB_list[i].load_state_dict(torch.load(modelpath)) 161 | netB_list[i].eval() 162 | netB_list[i].cuda() 163 | for k, v in netB_list[i].named_parameters(): 164 | v.requires_grad = False 165 | 166 | modelpath = args.output_dir_src + '/target_C_'+str(i)+'_par_0.3.pt' 167 | netC_list[i].load_state_dict(torch.load(modelpath)) 168 | netC_list[i].eval() 169 | netC_list[i].cuda() 170 | for k, v in netC_list[i].named_parameters(): 171 | v.requires_grad = False 172 | 173 | modelpath = args.output_dir_src + '/target_G_'+str(i)+'_par_0.3.pt' 174 | netG_list[i].load_state_dict(torch.load(modelpath)) 175 | netG_list[i].eval() 176 | netG_list[i].cuda() 177 | for k, v in netG_list[i].named_parameters(): 178 | v.requires_grad = False 179 | 180 | # create student 181 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 182 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 183 | 184 | param_group = [] 185 | learning_rate = args.lr 186 | for k, v in netF.named_parameters(): 187 | param_group += [{'params': v, 'lr': learning_rate}] 188 | for k, v in netB.named_parameters(): 189 | param_group += [{'params': v, 'lr': learning_rate}] 190 | for k, v in netC.named_parameters(): 191 | param_group += [{'params': v, 'lr': learning_rate}] 192 | optimizer = optim.SGD(param_group) 193 | optimizer = op_copy(optimizer) 194 | 195 | acc_init = 0 196 | max_iter = args.max_epoch * len(dset_loaders["target"]) 197 | interval_iter = max_iter // 10 198 | iter_num = 0 199 | 200 | netF.train() 201 | netB.train() 202 | netC.train() 203 | 204 | while iter_num < max_iter: 205 | try: 206 | inputs = iter_source.next() 207 | except: 208 | iter_source = iter(dset_loaders["target"]) 209 | inputs = iter_source.next() 210 | 211 | inputs = inputs[0] 212 | if inputs.size(0) == 1: 213 | continue 214 | 215 | iter_num += 1 216 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 217 | 218 | labels, logits = get_labels(inputs, netF_list, netB_list, netC_list, netG_list) 219 | 220 | inputs, labels, logits = inputs.cuda(), labels.cuda(), logits.cuda() 221 | labels, logits = labels.detach(), logits.detach() 222 | outputs = netC(netB(netF(inputs))) 223 | classifier_loss = nn.CrossEntropyLoss()(outputs, labels) 224 | 225 | optimizer.zero_grad() 226 | classifier_loss.backward() 227 | optimizer.step() 228 | 229 | if iter_num % interval_iter == 0 or iter_num == max_iter: 230 | netF.eval() 231 | netB.eval() 232 | netC.eval() 233 | acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 234 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.tgt, iter_num, max_iter, acc_s_te) 235 | # args.out_file.write(log_str + '\n') 236 | # args.out_file.flush() 237 | print(log_str+'\n') 238 | 239 | if acc_s_te >= acc_init: 240 | acc_init = acc_s_te 241 | best_netF = netF.state_dict() 242 | best_netB = netB.state_dict() 243 | best_netC = netC.state_dict() 244 | 245 | netF.train() 246 | netB.train() 247 | netC.train() 248 | 249 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 250 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 251 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 252 | 253 | if __name__ == "__main__": 254 | parser = argparse.ArgumentParser(description='SHOT') 255 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 256 | parser.add_argument('--t', type=int, default=0, help="target") ## Choose which domain to set as target {0 to len(names)-1} 257 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 258 | parser.add_argument('--interval', type=int, default=15) 259 | parser.add_argument('--batch_size', type=int, default=32, help="batch_size") 260 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 261 | parser.add_argument('--dset', type=str, default='office-home', choices=['office', 'office-home', 'office-caltech']) 262 | parser.add_argument('--lr', type=float, default=1*1e-2, help="learning rate") 263 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, res101") 264 | parser.add_argument('--temp', type=float, default=1.0) 265 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 266 | 267 | parser.add_argument('--gent', type=bool, default=True) 268 | parser.add_argument('--ent', type=bool, default=True) 269 | parser.add_argument('--threshold', type=int, default=0) 270 | parser.add_argument('--cls_par', type=float, default=0.3) 271 | parser.add_argument('--ent_par', type=float, default=1.0) 272 | parser.add_argument('--lr_decay1', type=float, default=0.1) 273 | parser.add_argument('--lr_decay2', type=float, default=1.0) 274 | 275 | parser.add_argument('--bottleneck', type=int, default=256) 276 | parser.add_argument('--epsilon', type=float, default=1e-5) 277 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 278 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 279 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 280 | parser.add_argument('--output', type=str, default='san') 281 | parser.add_argument('--output_src', type=str, default='ckps/adapt') 282 | parser.add_argument('--issave', type=bool, default=True) 283 | args = parser.parse_args() 284 | 285 | if args.dset == 'office-home': 286 | names = ['Art', 'Clipart', 'Product', 'Real_World'] 287 | args.class_num = 65 288 | if args.dset == 'office': 289 | names = ['amazon', 'dslr' , 'webcam'] 290 | args.class_num = 31 291 | if args.dset == 'office-caltech': 292 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 293 | args.class_num = 10 294 | 295 | args.src = [] 296 | for i in range(len(names)): 297 | if i == args.t: 298 | continue 299 | else: 300 | args.src.append(names[i]) 301 | args.tgt = names[args.t] 302 | 303 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 304 | SEED = args.seed 305 | torch.manual_seed(SEED) 306 | torch.cuda.manual_seed(SEED) 307 | np.random.seed(SEED) 308 | random.seed(SEED) 309 | 310 | for i in range(len(names)): 311 | if i != args.t: 312 | continue 313 | folder = 'data/' 314 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 315 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 316 | 317 | 318 | args.output_dir_src = osp.join(args.output_src, args.dset, names[args.t][0].upper()) 319 | print(args.output_dir_src) 320 | args.output_dir = osp.join(args.output, 'adapt_distill', args.dset, names[args.t][0].upper()) 321 | 322 | if not osp.exists(args.output_dir): 323 | os.system('mkdir -p ' + args.output_dir) 324 | if not osp.exists(args.output_dir): 325 | os.mkdir(args.output_dir) 326 | 327 | args.savename = 'distill_' + str(args.cls_par) 328 | 329 | train_distill(args) -------------------------------------------------------------------------------- /adapt_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | 18 | def op_copy(optimizer): 19 | for param_group in optimizer.param_groups: 20 | param_group['lr0'] = param_group['lr'] 21 | return optimizer 22 | 23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 24 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = param_group['lr0'] * decay 27 | param_group['weight_decay'] = 1e-3 28 | param_group['momentum'] = 0.9 29 | param_group['nesterov'] = True 30 | return optimizer 31 | 32 | def image_train(resize_size=256, crop_size=224, alexnet=False): 33 | if not alexnet: 34 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]) 36 | else: 37 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | if not alexnet: 48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | else: 51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 52 | return transforms.Compose([ 53 | transforms.Resize((resize_size, resize_size)), 54 | transforms.CenterCrop(crop_size), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | def data_load(args): 60 | ## prepare data 61 | dsets = {} 62 | dset_loaders = {} 63 | train_bs = args.batch_size 64 | txt_tar = open(args.t_dset_path).readlines() 65 | txt_test = open(args.test_dset_path).readlines() 66 | 67 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 68 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 69 | dsets['target_'] = ImageList_idx(txt_tar, transform=image_train()) 70 | dset_loaders['target_'] = DataLoader(dsets['target_'], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 71 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 72 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 73 | 74 | return dset_loaders 75 | 76 | def train_target(args): 77 | dset_loaders = data_load(args) 78 | ## set base network 79 | if args.net[0:3] == 'res': 80 | netF_list = [network.ResBase(res_name=args.net).cuda() for i in range(len(args.src))] 81 | elif args.net[0:3] == 'vgg': 82 | netF_list = [network.VGGBase(vgg_name=args.net).cuda() for i in range(len(args.src))] 83 | 84 | w = 2*torch.rand((len(args.src),))-1 85 | print(w) 86 | 87 | netB_list = [network.feat_bottleneck(type=args.classifier, feature_dim=netF_list[i].in_features, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] 88 | netC_list = [network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] 89 | netG_list = [network.scalar(w[i]).cuda() for i in range(len(args.src))] 90 | 91 | param_group = [] 92 | for i in range(len(args.src)): 93 | modelpath = args.output_dir_src[i] + '/source_F.pt' 94 | print(modelpath) 95 | netF_list[i].load_state_dict(torch.load(modelpath)) 96 | netF_list[i].eval() 97 | for k, v in netF_list[i].named_parameters(): 98 | param_group += [{'params':v, 'lr':args.lr * args.lr_decay1}] 99 | 100 | modelpath = args.output_dir_src[i] + '/source_B.pt' 101 | print(modelpath) 102 | netB_list[i].load_state_dict(torch.load(modelpath)) 103 | netB_list[i].eval() 104 | for k, v in netB_list[i].named_parameters(): 105 | param_group += [{'params':v, 'lr':args.lr * args.lr_decay2}] 106 | 107 | modelpath = args.output_dir_src[i] + '/source_C.pt' 108 | print(modelpath) 109 | netC_list[i].load_state_dict(torch.load(modelpath)) 110 | netC_list[i].eval() 111 | for k, v in netC_list[i].named_parameters(): 112 | v.requires_grad = False 113 | 114 | for k, v in netG_list[i].named_parameters(): 115 | param_group += [{'params':v, 'lr':args.lr}] 116 | 117 | optimizer = optim.SGD(param_group) 118 | optimizer = op_copy(optimizer) 119 | 120 | max_iter = args.max_epoch * len(dset_loaders["target"]) 121 | interval_iter = max_iter // args.interval 122 | iter_num = 0 123 | 124 | c = 0 125 | 126 | while iter_num < max_iter: 127 | try: 128 | inputs_test, _, tar_idx = iter_test.next() 129 | except: 130 | iter_test = iter(dset_loaders["target"]) 131 | inputs_test, _, tar_idx = iter_test.next() 132 | 133 | if inputs_test.size(0) == 1: 134 | continue 135 | 136 | if iter_num % interval_iter == 0 and args.cls_par > 0: 137 | initc = [] 138 | all_feas = [] 139 | for i in range(len(args.src)): 140 | netF_list[i].eval() 141 | netB_list[i].eval() 142 | temp1, temp2 = obtain_label(dset_loaders['target_'], netF_list[i], netB_list[i], netC_list[i], args) 143 | temp1 = torch.from_numpy(temp1).cuda() 144 | temp2 = torch.from_numpy(temp2).cuda() 145 | initc.append(temp1) 146 | all_feas.append(temp2) 147 | netF_list[i].train() 148 | netB_list[i].train() 149 | 150 | inputs_test = inputs_test.cuda() 151 | 152 | iter_num += 1 153 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 154 | 155 | outputs_all = torch.zeros(len(args.src), inputs_test.shape[0], args.class_num) 156 | weights_all = torch.ones(inputs_test.shape[0], len(args.src)) 157 | outputs_all_w = torch.zeros(inputs_test.shape[0], args.class_num) 158 | init_ent = torch.zeros(1,len(args.src)) 159 | 160 | for i in range(len(args.src)): 161 | features_test = netB_list[i](netF_list[i](inputs_test)) 162 | outputs_test = netC_list[i](features_test) 163 | softmax_ = nn.Softmax(dim=1)(outputs_test) 164 | ent_loss = torch.mean(loss.Entropy(softmax_)) 165 | init_ent[:,i] = ent_loss 166 | weights_test = netG_list[i](features_test) 167 | outputs_all[i] = outputs_test 168 | weights_all[:, i] = weights_test.squeeze() 169 | 170 | z = torch.sum(weights_all, dim=1) 171 | z = z + 1e-16 172 | 173 | weights_all = torch.transpose(torch.transpose(weights_all,0,1)/z,0,1) 174 | outputs_all = torch.transpose(outputs_all, 0, 1) 175 | 176 | z_ = torch.sum(weights_all, dim=0) 177 | 178 | z_2 = torch.sum(weights_all) 179 | z_ = z_/z_2 180 | 181 | for i in range(inputs_test.shape[0]): 182 | outputs_all_w[i] = torch.matmul(torch.transpose(outputs_all[i],0,1), weights_all[i]) 183 | 184 | if args.cls_par > 0: 185 | initc_ = torch.zeros(initc[0].size()).cuda() 186 | temp = all_feas[0] 187 | all_feas_ = torch.zeros(temp[tar_idx, :].size()).cuda() 188 | for i in range(len(args.src)): 189 | initc_ = initc_ + z_[i] * initc[i].float() 190 | src_fea = all_feas[i] 191 | all_feas_ = all_feas_ + z_[i] * src_fea[tar_idx, :] 192 | dd = torch.cdist(all_feas_.float(), initc_.float(), p=2) 193 | pred_label = dd.argmin(dim=1) 194 | pred_label = pred_label.int() 195 | pred = pred_label.long() 196 | classifier_loss = args.cls_par * nn.CrossEntropyLoss()(outputs_all_w, pred.cpu()) 197 | else: 198 | classifier_loss = torch.tensor(0.0) 199 | 200 | if args.ent: 201 | softmax_out = nn.Softmax(dim=1)(outputs_all_w) 202 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 203 | if args.gent: 204 | msoftmax = softmax_out.mean(dim=0) 205 | entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) 206 | 207 | im_loss = entropy_loss * args.ent_par 208 | classifier_loss += im_loss 209 | 210 | optimizer.zero_grad() 211 | classifier_loss.backward() 212 | optimizer.step() 213 | 214 | if iter_num % interval_iter == 0 or iter_num == max_iter: 215 | for i in range(len(args.src)): 216 | netF_list[i].eval() 217 | netB_list[i].eval() 218 | acc, _ = cal_acc_multi(dset_loaders['test'], netF_list, netB_list, netC_list, netG_list, args) 219 | log_str = 'Iter:{}/{}; Accuracy = {:.2f}%'.format(iter_num, max_iter, acc) 220 | print(log_str+'\n') 221 | for i in range(len(args.src)): 222 | torch.save(netF_list[i].state_dict(), osp.join(args.output_dir, "target_F_" + str(i) + "_" + args.savename + ".pt")) 223 | torch.save(netB_list[i].state_dict(), osp.join(args.output_dir, "target_B_" + str(i) + "_" + args.savename + ".pt")) 224 | torch.save(netC_list[i].state_dict(), osp.join(args.output_dir, "target_C_" + str(i) + "_" + args.savename + ".pt")) 225 | torch.save(netG_list[i].state_dict(), osp.join(args.output_dir, "target_G_" + str(i) + "_" + args.savename + ".pt")) 226 | 227 | def obtain_label(loader, netF, netB, netC, args): 228 | start_test = True 229 | with torch.no_grad(): 230 | iter_test = iter(loader) 231 | for _ in range(len(loader)): 232 | data = iter_test.next() 233 | inputs = data[0] 234 | labels = data[1] 235 | inputs = inputs.cuda() 236 | feas = netB(netF(inputs.float())) 237 | outputs = netC(feas) 238 | if start_test: 239 | all_fea = feas.float().cpu() 240 | all_output = outputs.float().cpu() 241 | all_label = labels.float() 242 | start_test = False 243 | else: 244 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 245 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 246 | all_label = torch.cat((all_label, labels.float()), 0) 247 | all_output = nn.Softmax(dim=1)(all_output) 248 | _, predict = torch.max(all_output, 1) 249 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 250 | 251 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 252 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 253 | all_fea = all_fea.float().cpu().numpy() 254 | 255 | K = all_output.size(1) 256 | aff = all_output.float().cpu().numpy() 257 | initc = aff.transpose().dot(all_fea) 258 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 259 | 260 | dd = cdist(all_fea, initc, 'cosine') 261 | pred_label = dd.argmin(axis=1) 262 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 263 | 264 | for round in range(1): 265 | aff = np.eye(K)[pred_label] 266 | initc = aff.transpose().dot(all_fea) 267 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 268 | dd = cdist(all_fea, initc, 'cosine') 269 | pred_label = dd.argmin(axis=1) 270 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 271 | 272 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100) 273 | print(log_str+'\n') 274 | #return pred_label.astype('int') 275 | return initc,all_fea 276 | 277 | 278 | def cal_acc_multi(loader, netF_list, netB_list, netC_list, netG_list, args): 279 | start_test = True 280 | with torch.no_grad(): 281 | iter_test = iter(loader) 282 | for _ in range(len(loader)): 283 | data = iter_test.next() 284 | inputs = data[0] 285 | labels = data[1] 286 | inputs = inputs.cuda() 287 | outputs_all = torch.zeros(len(args.src), inputs.shape[0], args.class_num) 288 | weights_all = torch.ones(inputs.shape[0], len(args.src)) 289 | outputs_all_w = torch.zeros(inputs.shape[0], args.class_num) 290 | 291 | for i in range(len(args.src)): 292 | features = netB_list[i](netF_list[i](inputs)) 293 | outputs = netC_list[i](features) 294 | weights = netG_list[i](features) 295 | outputs_all[i] = outputs 296 | weights_all[:, i] = weights.squeeze() 297 | 298 | z = torch.sum(weights_all, dim=1) 299 | z = z + 1e-16 300 | 301 | weights_all = torch.transpose(torch.transpose(weights_all,0,1)/z,0,1) 302 | print(weights_all.mean(dim=0)) 303 | outputs_all = torch.transpose(outputs_all, 0, 1) 304 | 305 | for i in range(inputs.shape[0]): 306 | outputs_all_w[i] = torch.matmul(torch.transpose(outputs_all[i],0,1), weights_all[i]) 307 | 308 | if start_test: 309 | all_output = outputs_all_w.float().cpu() 310 | all_label = labels.float() 311 | start_test = False 312 | else: 313 | all_output = torch.cat((all_output, outputs_all_w.float().cpu()), 0) 314 | all_label = torch.cat((all_label, labels.float()), 0) 315 | _, predict = torch.max(all_output, 1) 316 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 317 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 318 | return accuracy*100, mean_ent 319 | 320 | def print_args(args): 321 | s = "==========================================\n" 322 | for arg, content in args.__dict__.items(): 323 | s += "{}:{}\n".format(arg, content) 324 | return s 325 | 326 | if __name__ == "__main__": 327 | parser = argparse.ArgumentParser(description='SHOT') 328 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 329 | parser.add_argument('--t', type=int, default=0, help="target") ## Choose which domain to set as target {0 to len(names)-1} 330 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 331 | parser.add_argument('--interval', type=int, default=15) 332 | parser.add_argument('--batch_size', type=int, default=32, help="batch_size") 333 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 334 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['office', 'office-home', 'office-caltech']) 335 | parser.add_argument('--lr', type=float, default=1*1e-2, help="learning rate") 336 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, res101") 337 | parser.add_argument('--seed', type=int, default=2021, help="random seed") 338 | 339 | parser.add_argument('--gent', type=bool, default=True) 340 | parser.add_argument('--ent', type=bool, default=True) 341 | parser.add_argument('--threshold', type=int, default=0) 342 | parser.add_argument('--cls_par', type=float, default=0.3) 343 | parser.add_argument('--ent_par', type=float, default=1.0) 344 | parser.add_argument('--lr_decay1', type=float, default=0.1) 345 | parser.add_argument('--lr_decay2', type=float, default=1.0) 346 | 347 | parser.add_argument('--bottleneck', type=int, default=256) 348 | parser.add_argument('--epsilon', type=float, default=1e-5) 349 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 350 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 351 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 352 | parser.add_argument('--output', type=str, default='ckps/adapt_ours') 353 | parser.add_argument('--output_src', type=str, default='ckps/source') 354 | args = parser.parse_args() 355 | 356 | if args.dset == 'office-home': 357 | names = ['Art', 'Clipart', 'Product', 'Real_World'] 358 | args.class_num = 65 359 | if args.dset == 'office': 360 | names = ['amazon', 'dslr' , 'webcam'] 361 | args.class_num = 31 362 | if args.dset == 'office-caltech': 363 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 364 | args.class_num = 10 365 | 366 | args.src = [] 367 | for i in range(len(names)): 368 | if i == args.t: 369 | continue 370 | else: 371 | args.src.append(names[i]) 372 | 373 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 374 | SEED = args.seed 375 | torch.manual_seed(SEED) 376 | torch.cuda.manual_seed(SEED) 377 | np.random.seed(SEED) 378 | random.seed(SEED) 379 | 380 | for i in range(len(names)): 381 | if i != args.t: 382 | continue 383 | folder = './data/' 384 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 385 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 386 | print(args.t_dset_path) 387 | 388 | args.output_dir_src = [] 389 | for i in range(len(args.src)): 390 | args.output_dir_src.append(osp.join(args.output_src, args.dset, args.src[i][0].upper())) 391 | print(args.output_dir_src) 392 | args.output_dir = osp.join(args.output, args.dset, names[args.t][0].upper()) 393 | 394 | if not osp.exists(args.output_dir): 395 | os.system('mkdir -p ' + args.output_dir) 396 | if not osp.exists(args.output_dir): 397 | os.mkdir(args.output_dir) 398 | 399 | args.savename = 'par_' + str(args.cls_par) 400 | 401 | train_target(args) 402 | 403 | --------------------------------------------------------------------------------