├── README.md ├── data_loader.py ├── datasets └── resnet_feature.py ├── loss.py ├── main.py ├── models.py ├── requirements.txt ├── teaser.jpg └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Simultaneous Semantic Alignment Network for Heterogeneous Domain Adaptation 2 | This is a [pytorch](http://pytorch.org/) implementation of [SSAN](https://arxiv.org/abs/2008.01677). 3 | 4 | ### Paper 5 | 6 | ![](./teaser.jpg) 7 | 8 | [Simultaneous Semantic Alignment Network for Heterogeneous Domain Adaptation ](https://arxiv.org/abs/2008.01677) 9 | 10 | [Shuang Li](http://shuangli.xyz), [Binhui Xie](https://binhuixie.github.io), [Jiashu Wu](https://jiashuwu.github.io), Ying Zhao, [Chi Harold Liu](http://cs.bit.edu.cn/szdw/jsml/js/lc_20180927062826951290/index.htm), [Zhengming Ding](http://allanding.net) 11 | 12 | *ACM International Conference on Multimedia*, 2020 13 | 14 | ### Abstract 15 | Heterogeneous domain adaptation (HDA) transfers knowledge across source and target domains that present heterogeneities e.g., distinct domain distributions and difference in feature type or dimension. Most previous HDA methods tackle this problem through learning a domain-invariant feature subspace to reduce the discrepancy between domains. However, the intrinsic semantic properties contained in data are under-explored in such alignment strategy, which is also indispensable to achieve promising adaptability. In this paper, we propose a Simultaneous Semantic Alignment Network (SSAN) to simultaneously exploit correlations among categories and align the centroids for each category across domains. In particular, we propose an implicit semantic correlation loss to transfer the correlation knowledge of source categorical prediction distributions to target domain. Meanwhile, by leveraging target pseudo-labels, a robust triplet-centroid alignment mechanism is explicitly applied to align feature representations for each category. Notably, a pseudo-label refinement procedure with geometric similarity involved is introduced to enhance the target pseudo-label assignment accuracy. Comprehensive experiments on various HDA tasks across text-to-image, image-to-image and text-to-text successfully validate the superiority of our SSAN against state-of-the-art HDA methods. 16 | 17 | ### Prerequisites 18 | - Python 3.6 19 | - Pytorch 1.3.1 20 | - numpy 21 | - scipy 22 | - matplotlib 23 | - scikit_learn 24 | - CUDA >= 8.0 25 | ### Step-by-step installation 26 | 27 | ```bash 28 | $ conda create -n ssan -y python=3.6 29 | $ conda activate ssan 30 | 31 | # this installs the right pip and dependencies for the fresh python 32 | $ conda install -y ipython pip 33 | 34 | # to install the required python packages, run 35 | $ pip install -r requirements.txt 36 | ``` 37 | 38 | 39 | 40 | ### Getting started 41 | 42 | All datasets can be downloaded [here](https://github.com/BIT-DA/SSAN/releases) and put in /datasets 43 | 44 | 45 | ### Train and Evaluate 46 | ``` 47 | Image-To-Image 48 | $ python main.py --source amazon_surf --target amazon_decaf --cuda 0 --nepoch 3000 --partition 20 --prototype three --layer double --d_common 256 --optimizer mSGD --lr 0.1 --alpha 0.1 --beta 0.004 --gamma 0.1 --combine_pred Cosine --checkpoint_path checkpoint/ --temperature 5.0 49 | ``` 50 | 51 | 52 | 53 | ### Acknowledgements 54 | 55 | Especially thanks to [Yuan Yao](https://www.researchgate.net/profile/Yuan_Yao67) for helping experiments 56 | 57 | 58 | ### Citation 59 | If you find this code useful for your research, please cite our [paper](https://arxiv.org/abs/2008.01677): 60 | ``` 61 | @inproceedings{li2020simultaneous, 62 | title = {Simultaneous Semantic Alignment Network for Heterogeneous Domain Adaptation}, 63 | author = {Li, Shuang and Xie, Binhui and Wu, Jiashu and Zhao, Ying and Liu, Chi Harold and Ding, Zhengming}, 64 | booktitle = {The 28th ACM International Conference on Multimedia (MM'20))}, 65 | pages = {3866--3874}, 66 | publisher = {{ACM}}, 67 | year = {2020} 68 | } 69 | ``` 70 | 71 | ### Contact 72 | 73 | If you have any problem about our code, feel free to contact 74 | - shuangli@bit.edu.cn 75 | - binhuixie@bit.edu.cn 76 | 77 | or describe your problem in Issues. 78 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | import scipy.io as sio 4 | import numpy as np 5 | from sklearn import preprocessing 6 | 7 | IMAGE2IMAGE_PATH = 'datasets/ImageToImageObjectRecognition/' 8 | 9 | # image datasets 10 | DATASETS = { 11 | 'amazon_surf': osp.join(IMAGE2IMAGE_PATH, 'amazon_surf.mat'), 12 | 'amazon_decaf': osp.join(IMAGE2IMAGE_PATH, 'amazon_decaf.mat'), 13 | 'amazon_resnet': osp.join(IMAGE2IMAGE_PATH, 'amazon_resnet.mat'), 14 | 'dslr_surf': osp.join(IMAGE2IMAGE_PATH, 'dslr_surf.mat'), 15 | 'dslr_decaf': osp.join(IMAGE2IMAGE_PATH, 'dslr_decaf.mat'), 16 | 'dslr_resnet': osp.join(IMAGE2IMAGE_PATH, 'dslr_resnet.mat'), 17 | 'caltech_surf': osp.join(IMAGE2IMAGE_PATH, 'caltech_surf.mat'), 18 | 'caltech_decaf': osp.join(IMAGE2IMAGE_PATH, 'caltech_decaf.mat'), 19 | 'caltech_resnet': osp.join(IMAGE2IMAGE_PATH, 'caltech_resnet.mat'), 20 | 'webcam_surf': osp.join(IMAGE2IMAGE_PATH, 'webcam_surf.mat'), 21 | 'webcam_decaf': osp.join(IMAGE2IMAGE_PATH, 'webcam_decaf.mat'), 22 | 'webcam_resnet': osp.join(IMAGE2IMAGE_PATH, 'webcam_resnet.mat'), 23 | } 24 | 25 | 26 | def get_configuration(args): 27 | SOURCE_PATH = DATASETS[args.source.lower()] 28 | TARGET_PATH = DATASETS[args.target.lower()] 29 | 30 | # source and target domain infos 31 | print('========= Source & Target Info =========') 32 | print('Source Domain: ' + SOURCE_PATH) 33 | print('Target Domain: ' + TARGET_PATH) 34 | print('========= Loading Data =========') 35 | source = sio.loadmat(SOURCE_PATH) 36 | target = sio.loadmat(TARGET_PATH) 37 | print('========= Loading Data Completed =========') 38 | print() 39 | print('========= Data Information =========') 40 | 41 | # Amount of labeled target instances for each class 42 | if args.target.lower() == 'spanish20': 43 | labeled_amount = 20 44 | elif args.target.lower() == 'spanish15': 45 | labeled_amount = 15 46 | elif args.target.lower() == 'spanish10': 47 | labeled_amount = 10 48 | elif args.target.lower() == 'spanish5': 49 | labeled_amount = 5 50 | else: 51 | labeled_amount = 3 52 | 53 | xs = source['features'] 54 | xs = preprocessing.normalize(xs, norm='l2') 55 | xs_label = source['labels'] - 1 # Label range: 0 - 9 both inclusive 56 | print('xs.shape = ', xs.shape) 57 | print('xs_label.shape = ', xs_label.shape) 58 | 59 | entire_t = target['features'] 60 | entire_t = preprocessing.normalize(entire_t, norm='l2') 61 | entire_t_label = target['labels'] - 1 62 | 63 | print('xt.shape = ', entire_t.shape) 64 | print('xt_label.shape = ', entire_t_label.shape) 65 | print('xt_label.len = ', len(entire_t_label)) 66 | 67 | assert len(np.unique(xs_label)) == len(np.unique(entire_t_label)) 68 | class_number = len(np.unique(xs_label)) # number of classes 69 | 70 | xl = [] 71 | xl_label = [] 72 | 73 | for cls in range(class_number): 74 | amount = labeled_amount 75 | while amount > 0: 76 | random_index = np.random.randint(0, entire_t.shape[0]) 77 | if entire_t_label[random_index] == cls: 78 | xl.append(entire_t[random_index]) 79 | xl_label.append(entire_t_label[random_index]) 80 | amount -= 1 81 | entire_t = np.delete(entire_t, random_index, 0) 82 | entire_t_label = np.delete(entire_t_label, random_index, 0) 83 | xl = np.array(xl) 84 | xl_label = np.array(xl_label) 85 | xu = entire_t 86 | xu_label = entire_t_label 87 | 88 | ns, ds = xs.shape # ns = number of source instances, ds = dimension of source instances 89 | nl, dt = xl.shape # nl = number of labeled target instances, ds = dimension of all target instances 90 | nu, _ = xu.shape 91 | nt = nl + nu # total amount of target instances 92 | print('ns = ', ns) 93 | print('nl = ', nl) 94 | print('nu = ', nu) 95 | print('ds = ', ds) 96 | print('dt = ', dt) 97 | print('Class_number: ', class_number) 98 | print() 99 | 100 | # Generate dataset objects 101 | source_data = [torch.from_numpy(xs), torch.from_numpy(xs_label)] 102 | labeled_target_data = [torch.from_numpy(xl), torch.from_numpy(xl_label)] 103 | unlabeled_target_data = [torch.from_numpy(xu), torch.from_numpy(xu_label)] 104 | 105 | # Data Allocation In Each Batch 106 | print('Number of Source Instances: ' + str(ns)) 107 | print('Number of Labeled Target Instances: ' + str(nl)) 108 | print('Number of Unlabeled Target Instances: ' + str(nu)) 109 | print() 110 | 111 | # data configurations 112 | configuration = {'ns': ns, 'nl': nl, 'nu': nu, 'nt': nt, 'class_number': class_number, 113 | 'labeled_amount': labeled_amount, 'd_source': ds, 'd_target': dt, 114 | 'source_data': source_data, 'labeled_target_data': labeled_target_data, 115 | 'unlabeled_target_data': unlabeled_target_data} 116 | 117 | print('========= Loading Done =========') 118 | print() 119 | print('========= Training Started =========') 120 | return configuration 121 | -------------------------------------------------------------------------------- /datasets/resnet_feature.py: -------------------------------------------------------------------------------- 1 | from torchvision import models, transforms, datasets 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import os 7 | import scipy.io as sio 8 | 9 | 10 | class ResBase(nn.Module): 11 | r"""Constructs a feature extractor based on ResNet-50 model. 12 | remove the last layer 13 | """ 14 | 15 | def __init__(self): 16 | super(ResBase, self).__init__() 17 | model_res50 = torchvision.models.resnet50(pretrained=True) 18 | self.conv1 = model_res50.conv1 19 | self.bn1 = model_res50.bn1 20 | self.relu = model_res50.relu 21 | self.maxpool = model_res50.maxpool 22 | self.layer1 = model_res50.layer1 23 | self.layer2 = model_res50.layer2 24 | self.layer3 = model_res50.layer3 25 | self.layer4 = model_res50.layer4 26 | self.avgpool = model_res50.avgpool 27 | self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, 28 | self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool) 29 | self.__in_features = 2048 30 | 31 | def forward(self, x): 32 | """ 33 | :param x: the input Tensor as [bs, 3, 224, 224] 34 | :return: 2048-dim feature 35 | """ 36 | feature = self.feature_layers(x) 37 | feature = feature.view(feature.size(0), -1) 38 | return feature 39 | 40 | def output_num(self): 41 | return self.__in_features 42 | 43 | 44 | # Resnet-50 model 45 | resnet50 = ResBase() 46 | if torch.cuda.is_available(): 47 | resnet50 = resnet50.cuda() 48 | IMAGE_PATH = '/data1/TL/data/office_caltech_10/' 49 | DOMAINS = ['amazon', 'caltech', 'dslr', 'webcam'] 50 | data_transforms = transforms.Compose([ 51 | transforms.Resize(256), 52 | transforms.CenterCrop(224), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 55 | 56 | for d in DOMAINS: 57 | print('start:', d) 58 | data_set = datasets.ImageFolder(os.path.join(IMAGE_PATH, d), data_transforms) 59 | dataset_size = len(data_set) 60 | data_loader = DataLoader(data_set, batch_size=128, shuffle=False, num_workers=4) 61 | flag = True 62 | with torch.no_grad(): 63 | for i, data in enumerate(data_loader, 0): 64 | print(i) 65 | inputs, labels = data 66 | if torch.cuda.is_available(): 67 | inputs, labels = inputs.cuda(), labels.cuda() 68 | labels = labels + 1 69 | features = resnet50(inputs) 70 | if flag: 71 | all_features = features 72 | all_labels = labels 73 | flag = False 74 | else: 75 | all_features = torch.cat((all_features, features), 0) 76 | all_labels = torch.cat((all_labels, labels), 0) 77 | save_name = str(d) + '_resnet.mat' 78 | sio.savemat(save_name, {'features': all_features.cpu().numpy(), 'labels': all_labels.long().cpu().numpy()}) 79 | 80 | print('finished', d) 81 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as dist 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | 9 | def classification_loss_func(prediction, true_labels, ce_temperature=1.0): 10 | celoss_criterion = nn.CrossEntropyLoss() 11 | return celoss_criterion(prediction / ce_temperature, true_labels) 12 | 13 | 14 | def explicit_semantic_alignment_loss_func(source_learned_features, l_target_learned_features, 15 | u_target_learned_features, source_labels, l_target_labels, 16 | u_target_pseudo_labels, configuration, prototype): 17 | """ 18 | class-level feature alignment: k-th class features of source, target, source-target, 19 | and calculate MSELOss between each pair 20 | :param prototype: how many prototypes used for general loss 21 | :param source_learned_features: source feature 22 | :param l_target_learned_features: labeled target feature 23 | :param u_target_learned_features: unlabeled target feature 24 | :param source_labels: source groundtruth 25 | :param l_target_labels: label target groundtruth 26 | :param u_target_pseudo_labels: unlabeled target pseudo label 27 | :param configuration: 28 | :return: 29 | """ 30 | class_number = configuration['class_number'] 31 | mu_s = OrderedDict() 32 | mu_t = OrderedDict() 33 | 34 | if prototype == 'two': 35 | for i in range(class_number): 36 | mu_s[i] = [] 37 | mu_t[i] = [] 38 | 39 | assert source_learned_features.shape[0] == len(source_labels) 40 | for i in range(source_learned_features.shape[0]): 41 | mu_s[int(source_labels[i])].append(source_learned_features[i]) 42 | 43 | assert l_target_learned_features.shape[0] == len(l_target_labels) 44 | for i in range(l_target_learned_features.shape[0]): 45 | mu_t[int(l_target_labels[i])].append(l_target_learned_features[i]) 46 | 47 | assert u_target_learned_features.shape[0] == len(u_target_pseudo_labels) 48 | for i in range(u_target_learned_features.shape[0]): 49 | mu_t[int(u_target_pseudo_labels[i])].append(u_target_learned_features[i]) 50 | 51 | error_general = 0 52 | mseloss_critein = nn.MSELoss(size_average=False) 53 | 54 | for i in range(class_number): 55 | mu_s[i] = torch.mean(torch.stack(mu_s[i], 0).float(), 0).float() 56 | 57 | mu_t[i] = torch.mean(torch.stack(mu_t[i], 0).float(), 0).float() 58 | 59 | error_general += mseloss_critein(mu_s[i], mu_t[i]) 60 | 61 | return error_general 62 | 63 | elif prototype == 'three': 64 | mu_st = OrderedDict() 65 | 66 | for i in range(class_number): 67 | mu_s[i] = [] 68 | mu_t[i] = [] 69 | mu_st[i] = [[], []] 70 | 71 | assert source_learned_features.shape[0] == len(source_labels) 72 | for i in range(source_learned_features.shape[0]): 73 | mu_s[int(source_labels[i])].append(source_learned_features[i]) 74 | mu_st[int(source_labels[i])][0].append(source_learned_features[i]) 75 | 76 | assert l_target_learned_features.shape[0] == len(l_target_labels) 77 | for i in range(l_target_learned_features.shape[0]): 78 | mu_t[int(l_target_labels[i])].append(l_target_learned_features[i]) 79 | mu_st[int(l_target_labels[i])][1].append(l_target_learned_features[i]) 80 | 81 | assert u_target_learned_features.shape[0] == len(u_target_pseudo_labels) 82 | for i in range(u_target_learned_features.shape[0]): 83 | mu_t[int(u_target_pseudo_labels[i])].append(u_target_learned_features[i]) 84 | mu_st[int(u_target_pseudo_labels[i])][1].append(u_target_learned_features[i]) 85 | 86 | error_general = 0 87 | mseloss_critein = nn.MSELoss(size_average=False) 88 | 89 | for i in range(class_number): 90 | source_mean = torch.mean(torch.stack(mu_s[i], 0).float(), 0).float() 91 | 92 | target_mean = torch.mean(torch.stack(mu_t[i], 0).float(), 0).float() 93 | 94 | mu_st_numerator = 0 95 | mu_st_numerator += torch.sum(torch.stack(mu_st[i][0], 0).float(), 0).float() 96 | mu_st_numerator += torch.sum(torch.stack(mu_st[i][1], 0).float(), 0).float() 97 | source_target_mean = torch.div(mu_st_numerator, len(mu_st[i][0]) + len(mu_st[i][1])) 98 | 99 | error_general += mseloss_critein(source_mean, target_mean) 100 | error_general += mseloss_critein(source_mean, source_target_mean) 101 | error_general += mseloss_critein(target_mean, source_target_mean) 102 | 103 | return error_general 104 | 105 | 106 | def knowledge_distillation_loss_func(source_predic, source_label, l_target_predic, l_target_label, args): 107 | """ 108 | semantic-level alignment: source prediction, target prediction, source label, target label 109 | q: soft label for class k is the average over the softmax of all activations of source example in class k 110 | p: each labeled target smaple softmax output with temperature (T>1) 111 | :param args: temperature parameter 112 | :param source_predic: source output 113 | :param source_label: 114 | :param l_target_predic: labeled target output 115 | :param l_target_label: labeled target label 116 | :return: implicit semantic-level alignment loss 117 | """ 118 | if args.alpha == 1.0: 119 | return classification_loss_func(l_target_predic, l_target_label), \ 120 | torch.Tensor([0.])[0], torch.Tensor([0.])[0] 121 | 122 | assert source_predic.shape[1] == l_target_predic.shape[1] 123 | class_num = source_predic.shape[1] 124 | k_categories = torch.zeros((class_num, class_num)) 125 | source_softmax = F.softmax(source_predic / args.temperature) 126 | l_target_softmax = F.softmax(l_target_predic) 127 | soft_loss = 0 128 | 129 | for k in range(class_num): 130 | k_source_softmax = source_softmax.index_select(dim=0, index=(source_label == k).nonzero().reshape(-1, )) 131 | k_categories[k] = torch.mean(k_source_softmax, dim=0) 132 | 133 | if torch.cuda.is_available(): 134 | k_categories = k_categories.cuda() 135 | 136 | for k in range(class_num): 137 | k_l_target_softmax = l_target_softmax.index_select(dim=0, index=(l_target_label == k).nonzero().reshape(-1, )) 138 | soft_loss -= torch.mean(torch.sum(k_categories[k] * torch.log(k_l_target_softmax + 1e-5), 1)) 139 | 140 | hard_loss = classification_loss_func(l_target_predic, l_target_label) 141 | loss = (1 - args.alpha) * hard_loss + args.alpha * soft_loss 142 | return loss, (1 - args.alpha) * hard_loss, args.alpha * soft_loss 143 | 144 | 145 | def get_prototype_label(source_learned_features, l_target_learned_features, u_target_learned_features, source_labels, 146 | l_target_labels, configuration, combine_pred, epoch): 147 | """ 148 | get unlabeled target prototype label 149 | :param epoch: training epoch 150 | :param combine_pred: Euclidean, Cosine 151 | :param configuration: dataset configuration 152 | :param source_learned_features: source feature 153 | :param l_target_learned_features: labeled target feature 154 | :param u_target_learned_features: unlabeled target feature 155 | :param source_labels: source labels 156 | :param l_target_labels: labeled target labels 157 | :return: unlabeled target prototype label 158 | """ 159 | def prototype_softmax(features, feature_centers): 160 | assert features.shape[1] == feature_centers.shape[1] 161 | n_samples = features.shape[0] 162 | C, dim = feature_centers.shape 163 | pred = torch.FloatTensor() 164 | for i in range(n_samples): 165 | if combine_pred.find('Euclidean') != -1: 166 | dis = -torch.sum(torch.pow(features[i].expand(C, dim) - feature_centers, 2), dim=1) 167 | elif combine_pred.find('Cosine') != -1: 168 | dis = torch.cosine_similarity(features[i].expand(C, dim), feature_centers) 169 | if not i: 170 | pred = dis.reshape(1, -1) 171 | else: 172 | pred = torch.cat((pred, dis.reshape(1, -1)), dim=0) 173 | return pred 174 | 175 | assert source_learned_features.shape[1] == u_target_learned_features.shape[1] 176 | class_num = configuration['class_number'] 177 | feature_dim = source_learned_features.shape[1] 178 | feature_centers = torch.zeros((class_num, feature_dim)) 179 | for k in range(class_num): 180 | # calculate feature center of each class for source and target 181 | k_source_feature = source_learned_features.index_select(dim=0, 182 | index=(source_labels == k).nonzero().reshape(-1, )) 183 | k_l_target_feature = l_target_learned_features.index_select(dim=0, index=( 184 | l_target_labels == k).nonzero().reshape(-1, )) 185 | feature_centers[k] = torch.mean(torch.cat((k_source_feature, k_l_target_feature), dim=0), dim=0) 186 | 187 | if torch.cuda.is_available(): 188 | feature_centers = feature_centers.cuda() 189 | 190 | # assign 'pseudo label' by Euclidean distance or Cosine similarity between feature and prototype, 191 | # select the most confident samples in each pseudo class, not confident label=-1 192 | prototype_pred = prototype_softmax(u_target_learned_features, feature_centers) 193 | prototype_value, prototype_label = torch.max(prototype_pred.data, 1) 194 | 195 | # add threshold 196 | if combine_pred.find('threshold') != -1: 197 | if combine_pred == 'Euclidean_threshold': 198 | # threshold for Euclidean distance 199 | select_threshold = 0.2 200 | elif combine_pred == 'Cosine_threshold': 201 | # Ref: Progressive Feature Alignment for Unsupervised Domain Adaptation CVPR2019 202 | select_threshold = 1. / (1 + np.exp(-0.8 * (epoch + 1))) - 0.01 203 | # select_threshold = 0.1 204 | prototype_label[(prototype_value < select_threshold).nonzero()] = -1 205 | 206 | return prototype_label 207 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import datetime 5 | import argparse 6 | import warnings 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import data_loader 12 | import numpy as np 13 | import torch.nn as nn 14 | from collections import defaultdict 15 | from models import Prototypical, Discriminator 16 | from loss import classification_loss_func, explicit_semantic_alignment_loss_func, knowledge_distillation_loss_func, \ 17 | get_prototype_label 18 | from utils import write_log_record, seed_everything, make_dirs 19 | 20 | warnings.filterwarnings('ignore') 21 | 22 | parser = argparse.ArgumentParser( 23 | description='Simultaneous Semantic Alignment Network for Heterogeneous Domain Adaptation') 24 | parser.add_argument('--source', type=str, default='amazon_surf', help='Source domain', 25 | choices=['amazon_surf', 'amazon_decaf', 'amazon_resnet', 26 | 'webcam_surf', 'webcam_decaf', 'webcam_resnet', 27 | 'caltech_surf', 'caltech_decaf', 'caltech_resnet']) 28 | parser.add_argument('--target', type=str, default='amazon_decaf', help='Target domain', 29 | choices=['amazon_surf', 'amazon_decaf', 'amazon_resnet', 30 | 'webcam_surf', 'webcam_decaf', 'webcam_resnet', 31 | 'caltech_surf', 'caltech_decaf', 'caltech_resnet', 32 | 'dslr_decaf', 'dslr_resnet']) 33 | parser.add_argument('--cuda', type=str, default='0', help='Cuda index number') 34 | parser.add_argument('--nepoch', type=int, default=3000, help='Epoch amount') 35 | parser.add_argument('--partition', type=int, default=20, help='Number of partition') 36 | parser.add_argument('--prototype', type=str, default='three', choices=['two', 'three'], 37 | help='how many prototypes used for domain and general alignment loss') 38 | parser.add_argument('--layer', type=str, default='double', choices=['single', 'double'], 39 | help='Structure of the projector network, single layer or double layers projector') 40 | parser.add_argument('--d_common', type=int, default=256, help='Dimension of the common representation') 41 | parser.add_argument('--optimizer', type=str, default='mSGD', choices=['SGD', 'mSGD', 'Adam'], help='optimizer options') 42 | parser.add_argument('--lr', type=float, default=0.1, help='Learning rate') 43 | parser.add_argument('--temperature', type=float, default=5.0, help='source softmax temperature') 44 | parser.add_argument('--alpha', type=float, default=0.1, 45 | help='Trade-off parameter in front of L_soft, set to 0.0 to turn it off' 46 | 'Weight the (1 - alpha) * hard CE loss and alpha * soft CE loss') 47 | parser.add_argument('--beta', type=float, default=0.004, help='Trade-off parameter of L_ESA, set to 0 to turn off') 48 | parser.add_argument('--gamma', type=float, default=0.1, help='Trade-off parameter of L_D, set to 0 to turn off') 49 | parser.add_argument('--combine_pred', type=str, default='Cosine', 50 | choices=['Euclidean', 'Cosine', 'Euclidean_threshold', 'Cosine_threshold', 'None'], 51 | help='the way of prototype predictions Euclidean, Cosine, None(not use)') 52 | parser.add_argument('--checkpoint_path', type=str, default='checkpoint', help='All records save path') 53 | parser.add_argument('--seed', type=int, default=2020, help='seed for everything') 54 | 55 | args = parser.parse_args() 56 | args.time_string = datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H-%M-%S') 57 | 58 | if torch.cuda.is_available(): 59 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda 60 | if len(args.cuda) == 1: 61 | torch.cuda.set_device(int(args.cuda)) 62 | 63 | # seed for everything 64 | seed_everything(args) 65 | # make dirs 66 | make_dirs(args) 67 | print(str(args)) 68 | 69 | 70 | def test(model, configuration, srctar): 71 | model.eval() 72 | if srctar == 'source': 73 | loader = configuration['source_data'] 74 | N = configuration['ns'] 75 | elif srctar == 'labeled_target': 76 | loader = configuration['labeled_target_data'] 77 | N = configuration['nl'] 78 | elif srctar == 'unlabeled_target': 79 | loader = configuration['unlabeled_target_data'] 80 | N = configuration['nu'] 81 | else: 82 | raise Exception('Parameter srctar invalid! ') 83 | 84 | with torch.no_grad(): 85 | feature, label = loader[0].float(), loader[1].reshape(-1, ).long() 86 | if torch.cuda.is_available(): 87 | feature, label = feature.cuda(), label.cuda() 88 | classifier_output, _ = model(input_feature=feature) 89 | _, pred = torch.max(classifier_output.data, 1) 90 | n_correct = (pred == label).sum().item() 91 | acc = float(n_correct) / N * 100. 92 | 93 | return acc 94 | 95 | 96 | def train(model, model_d, optimizer, optimizer_d, configuration): 97 | best_acc = -float('inf') 98 | 99 | # training 100 | for epoch in range(args.nepoch): 101 | 102 | start_time = time.time() 103 | model.train() 104 | model_d.train() 105 | optimizer.zero_grad() 106 | optimizer_d.zero_grad() 107 | 108 | # prepare data 109 | source_data = configuration['source_data'] 110 | l_target_data = configuration['labeled_target_data'] 111 | u_target_data = configuration['unlabeled_target_data'] 112 | source_feature, source_label = source_data[0].float(), source_data[1].reshape(-1, ).long() 113 | l_target_feature, l_target_label = l_target_data[0].float(), l_target_data[1].reshape(-1, ).long() 114 | u_target_feature = u_target_data[0].float() 115 | if torch.cuda.is_available(): 116 | source_feature, source_label = source_feature.cuda(), source_label.cuda() 117 | l_target_feature, l_target_label = l_target_feature.cuda(), l_target_label.cuda() 118 | u_target_feature = u_target_feature.cuda() 119 | 120 | # forward propagation 121 | source_output, source_learned_feature = model(input_feature=source_feature) 122 | l_target_output, l_target_learned_feature = model(input_feature=l_target_feature) 123 | u_target_output, u_target_learned_feature = model(input_feature=u_target_feature) 124 | _, u_target_pseudo_label = torch.max(u_target_output, 1) 125 | if args.combine_pred == 'None': 126 | u_target_selected_feature = u_target_learned_feature 127 | u_target_selected_label = u_target_pseudo_label 128 | if epoch % 10 == 0: 129 | n_correct = (u_target_pseudo_label.cpu() == u_target_data[1].reshape(-1, ).long()).sum().item() 130 | acc_nn = float(n_correct) / configuration['nu'] * 100. 131 | print('Pesudo acc: (NN)', acc_nn) 132 | elif args.combine_pred.find('Euclidean') != -1 or args.combine_pred.find('Cosine') != -1: 133 | # get unlabeled data label via prototype prediction & network prediction 134 | u_target_prototype_label = get_prototype_label(source_learned_features=source_learned_feature, 135 | l_target_learned_features=l_target_learned_feature, 136 | u_target_learned_features=u_target_learned_feature, 137 | source_labels=source_label, 138 | l_target_labels=l_target_label, 139 | configuration=configuration, 140 | combine_pred=args.combine_pred, 141 | epoch=epoch) 142 | # select consistent examples 143 | u_target_selected_feature = u_target_learned_feature.index_select(dim=0, index=( 144 | u_target_pseudo_label == u_target_prototype_label).nonzero().reshape(-1, )) 145 | u_target_selected_label = u_target_pseudo_label.index_select(dim=0, index=( 146 | u_target_pseudo_label == u_target_prototype_label).nonzero().reshape(-1, )) 147 | 148 | if epoch % 10 == 0: 149 | print('shared predictions:', len(u_target_selected_label), '/', len(u_target_pseudo_label)) 150 | n_correct = (u_target_prototype_label.cpu() == u_target_data[1].reshape(-1, ).long()).sum().item() 151 | acc_pro = float(n_correct) / configuration['nu'] * 100. 152 | print('Prototype acc: (pro)', acc_pro) 153 | 154 | # ========================source data loss============================ 155 | # labeled source data 156 | # CrossEntropy loss 157 | error_overall = classification_loss_func(source_output, source_label) 158 | if epoch % 10 == 0: 159 | print('Use source CE loss: ', error_overall) 160 | 161 | # ========================alignment loss============================ 162 | # Calculate implicit semantic alignment loss 163 | isa_loss, hard_loss, soft_loss = knowledge_distillation_loss_func(source_output, source_label, 164 | l_target_output, l_target_label, args) 165 | error_overall += isa_loss 166 | if epoch % 10 == 0: 167 | print('Use ISA loss: ', isa_loss, 'hard CE loss: ', hard_loss, 'soft CE loss: ', soft_loss) 168 | 169 | # Calculate global adversarial alignment loss 170 | if args.gamma: 171 | transfer_criterion = nn.BCELoss() 172 | alpha = 2. / (1. + np.exp(-10 * float(epoch / args.nepoch))) - 1 173 | domain_labels = torch.from_numpy( 174 | np.array([[1]] * configuration['ns'] + [[0]] * configuration['nt'])).float() 175 | if torch.cuda.is_available(): 176 | domain_labels = domain_labels.cuda() 177 | discriminator_out = model_d( 178 | torch.cat((source_learned_feature, l_target_learned_feature, u_target_learned_feature), dim=0), alpha) 179 | domain_adv_alignment_loss = transfer_criterion(discriminator_out, domain_labels) 180 | error_overall += args.gamma * domain_adv_alignment_loss 181 | if epoch % 10 == 0: 182 | print('Use domain adversarial loss: ', args.gamma * domain_adv_alignment_loss) 183 | 184 | # Calculate explicit semantic alignment loss 185 | if args.beta: 186 | u_target_selected_label = u_target_selected_label.reshape(-1, ) 187 | 188 | general_alignment_loss = explicit_semantic_alignment_loss_func( 189 | source_learned_features=source_learned_feature, 190 | l_target_learned_features=l_target_learned_feature, 191 | u_target_learned_features=u_target_selected_feature, 192 | source_labels=source_label, 193 | l_target_labels=l_target_label, 194 | u_target_pseudo_labels=u_target_selected_label, 195 | configuration=configuration, 196 | prototype=args.prototype) 197 | error_overall += args.beta * general_alignment_loss 198 | # general_align_list[epoch].append(general_alignment_loss.item()) 199 | if epoch % 10 == 0: 200 | print('Use ESA loss:', args.beta * general_alignment_loss) 201 | 202 | # backward propagation 203 | error_overall.backward() 204 | optimizer.step() 205 | optimizer_d.step() 206 | 207 | # Testing Phase 208 | acc_src = test(model, configuration, 'source') 209 | acc_labeled_tar = test(model, configuration, 'labeled_target') 210 | acc_unlabeled_tar = test(model, configuration, 'unlabeled_target') 211 | end_time = time.time() 212 | print('ACC -> ', end='') 213 | print('Epoch: [{}/{}], {:.1f}s, Src acc: {:.4f}%, LTar acc: {:.4f}%, UTar acc: {:.4f}%'.format( 214 | epoch, args.nepoch, end_time - start_time, acc_src, acc_labeled_tar, acc_unlabeled_tar)) 215 | 216 | if best_acc < acc_unlabeled_tar: 217 | best_acc = acc_unlabeled_tar 218 | best_text = args.source.ljust(10) + '-> ' + args.target.ljust(10) \ 219 | + ' The proposed model for HDA achieves current best accuracy. ' 220 | print(best_text) 221 | if epoch >= 1000: 222 | print('need more epoch training') 223 | 224 | # end for max_epoch 225 | print('Best Test Accuracy: {:.4f}%'.format(best_acc)) 226 | write_log_record(args, configuration, best_acc) 227 | return best_acc 228 | 229 | 230 | if __name__ == '__main__': 231 | result = 0. 232 | for i in range(args.partition): 233 | configuration = data_loader.get_configuration(args) 234 | model = Prototypical(configuration['d_source'], configuration['d_target'], args.d_common, 235 | configuration['class_number'], args.layer) 236 | model_D = Discriminator(args.d_common) 237 | if torch.cuda.is_available(): 238 | model = model.cuda() 239 | model_D = model_D.cuda() 240 | if args.optimizer == 'SGD': 241 | optimizer = optim.SGD(model.parameters(), lr=args.lr) 242 | optimizer_d = optim.SGD(model_D.parameters(), lr=args.lr) 243 | elif args.optimizer == 'mSGD': 244 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, 245 | weight_decay=0.001, nesterov=True) 246 | optimizer_d = optim.SGD(model_D.parameters(), lr=args.lr, momentum=0.9, 247 | weight_decay=0.001, nesterov=True) 248 | elif args.optimizer == 'Adam': 249 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99)) 250 | optimizer_d = optim.Adam(model_D.parameters(), lr=args.lr, betas=(0.9, 0.99)) 251 | 252 | result += train(model, model_D, optimizer, optimizer_d, configuration) 253 | 254 | with open(args.log_path, 'a') as fp: 255 | fp.write('PN_HDA: ' 256 | + '| src = ' + args.source.ljust(4) 257 | + '| tar = ' + args.target.ljust(4) 258 | + '| avg acc = ' + str('%.4f' % (result / args.partition)).ljust(4) 259 | + '\n' 260 | + str(args) 261 | + '\n') 262 | # write to another avg txt 263 | with open(args.avg_path, 'a') as fp: 264 | fp.write('PN_HDA: ' 265 | + '| src = ' + args.source.ljust(4) 266 | + '| tar = ' + args.target.ljust(4) 267 | + '| avg acc = ' + str('%.4f' % (result / args.partition)).ljust(4) 268 | + '\n' 269 | + str(args) 270 | + '\n') 271 | fp.close() 272 | print('Avg acc:', str('%.4f' % (result / args.partition)).ljust(4)) 273 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Function 3 | 4 | 5 | # HDA Feature Projector 6 | class Projector(nn.Module): 7 | def __init__(self, d_input, d_common, layer): 8 | super(Projector, self).__init__() 9 | if layer.lower() == "single": 10 | layer = nn.Linear(d_input, d_common) 11 | leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 12 | 13 | # init weight and bias 14 | nn.init.normal_(layer.weight, std=0.01) 15 | nn.init.normal_(layer.bias, std=0.01) 16 | 17 | projector = nn.Sequential(layer, leaky_relu) 18 | elif layer.lower() == "double": 19 | d_intermediate = int((d_input + d_common) / 2) 20 | layer1 = nn.Linear(d_input, d_intermediate) 21 | leaky_relu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True) 22 | layer2 = nn.Linear(d_intermediate, d_common) 23 | leaky_relu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | 25 | # init weight and bias 26 | nn.init.normal_(layer1.weight, std=0.01) 27 | nn.init.normal_(layer1.bias, std=0.01) 28 | nn.init.normal_(layer2.weight, std=0.01) 29 | nn.init.normal_(layer2.bias, std=0.01) 30 | 31 | projector = nn.Sequential(layer1, leaky_relu1, layer2, leaky_relu2) 32 | else: 33 | raise Exception("Input layer invalid! ") 34 | self.projector = projector 35 | 36 | def forward(self, x): 37 | return nn.functional.normalize(self.projector(x), dim=1, p=2) 38 | 39 | 40 | # Label Classifier 41 | class Classifier(nn.Module): 42 | def __init__(self, d_common, class_number): 43 | super(Classifier, self).__init__() 44 | layer = nn.Linear(d_common, class_number) 45 | leakey_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 46 | 47 | # init weight and bias 48 | nn.init.normal_(layer.weight, std=0.01) 49 | nn.init.normal_(layer.bias, std=0.01) 50 | 51 | self.class_classifier = nn.Sequential(layer, leakey_relu) 52 | 53 | def forward(self, x): 54 | return self.class_classifier(x) 55 | 56 | 57 | class ReverseLayerF(Function): 58 | r"""Gradient Reverse Layer(Unsupervised Domain Adaptation by Backpropagation) 59 | Definition: During the forward propagation, GRL acts as an identity transform. During the back propagation though, 60 | GRL takes the gradient from the subsequent level, multiplies it by -alpha and pass it to the preceding layer. 61 | Args: 62 | x (Tensor): the input tensor 63 | alpha (float): \alpha = \frac{2}{1+\exp^{-\gamma \cdot p}}-1 (\gamma =10) 64 | out (Tensor): the same output tensor as x 65 | """ 66 | 67 | @staticmethod 68 | def forward(ctx, x, alpha): 69 | ctx.alpha = alpha 70 | return x.view_as(x) 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | output = grad_output.neg() * ctx.alpha 75 | return output, None 76 | 77 | 78 | # Domain Discriminator 79 | class Discriminator(nn.Module): 80 | def __init__(self, d_common): 81 | super(Discriminator, self).__init__() 82 | layer = nn.Linear(d_common, 1) 83 | sigmod = nn.Sigmoid() 84 | 85 | # init weight and bias 86 | nn.init.normal_(layer.weight, std=0.01) 87 | nn.init.normal_(layer.bias, std=0.01) 88 | 89 | self.discriminator = nn.Sequential(layer, sigmod) 90 | 91 | def forward(self, x, alpha): 92 | x = ReverseLayerF.apply(x, alpha) 93 | x = self.discriminator(x) 94 | 95 | return x 96 | 97 | 98 | # Prototypical Network 99 | class Prototypical(nn.Module): 100 | def __init__(self, d_source, d_target, d_common, class_number, layer): 101 | super(Prototypical, self).__init__() 102 | self.d_common = d_common 103 | self.d_source = d_source 104 | self.d_target = d_target 105 | self.class_number = class_number 106 | self.layer = layer 107 | self.projector_source = Projector(self.d_source, self.d_common, self.layer) 108 | self.projector_target = Projector(self.d_target, self.d_common, self.layer) 109 | self.classifier = Classifier(self.d_common, self.class_number) 110 | 111 | def forward(self, input_feature): 112 | if input_feature.shape[1] == self.d_source: 113 | feature = self.projector_source(input_feature) 114 | elif input_feature.shape[1] == self.d_target: 115 | feature = self.projector_target(input_feature) 116 | else: 117 | raise Exception("Input data wrong dimension! ") 118 | feature = feature.view(-1, self.d_common) 119 | classifier_output = self.classifier(feature) 120 | return classifier_output, feature 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | matplotlib==3.1.2 3 | scipy==1.3.2 4 | numpy==1.17.4 5 | scikit_learn==0.22.1 6 | -------------------------------------------------------------------------------- /teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/SSAN/2167433dc8a4f2dc05c14744ae6d4514245f4eb1/teaser.jpg -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pickle 4 | import numpy as np 5 | import random 6 | from matplotlib import pyplot as plt 7 | from collections import defaultdict 8 | 9 | 10 | # seed for everything 11 | def seed_everything(args): 12 | random.seed(args.seed) 13 | torch.manual_seed(args.seed) 14 | torch.cuda.manual_seed_all(args.seed) 15 | np.random.seed(args.seed) 16 | os.environ['PYTHONHASHSEED'] = str(args.seed) 17 | 18 | 19 | # Write Log Record 20 | def write_log_record(args, configuration, best_acc): 21 | with open(args.log_path, 'a') as fp: 22 | fp.write('PN_HDA: ' 23 | + '| seed = ' + str(args.seed).ljust(4) 24 | + '| src = ' + args.source.ljust(4) 25 | + '| tar = ' + args.target.ljust(4) 26 | + '| best tar acc = ' + str('%.4f' % best_acc).ljust(4) 27 | + '| nepoch = ' + str(args.nepoch).ljust(4) 28 | + '| layer =' + str(args.layer).ljust(4) 29 | + '| d_common =' + str(args.d_common).ljust(4) 30 | + '| optimizer =' + str(args.optimizer).ljust(4) 31 | + '| lr = ' + str(args.lr).ljust(4) 32 | + '| temperature =' + str(args.temperature).ljust(4) 33 | + '| alpha =' + str(args.alpha).ljust(4) 34 | + '| beta = ' + str(args.beta).ljust(4) 35 | + '| gamma = ' + str(args.gamma).ljust(4) 36 | + '| time = ' + args.time_string 37 | + '| checkpoint_path = ' + str(args.checkpoint_path) 38 | + '\n') 39 | fp.close() 40 | 41 | 42 | # Command Line Argument Bool Helper 43 | def bool_string(input_string): 44 | if input_string.lower() not in ['true', 'false']: 45 | raise ValueError('Bool String Input Invalid! ') 46 | return input_string.lower() == 'true' 47 | 48 | 49 | # make dirs for model_path, result_path, log_path, diagram_path 50 | def make_dirs(args): 51 | save_name = '_'.join([args.source.lower(), args.target.lower()]) 52 | log_path = os.path.join(args.checkpoint_path, 'logs') 53 | if not os.path.exists(args.checkpoint_path): 54 | os.makedirs(log_path) 55 | print('Makedir: ' + str(log_path)) 56 | args.log_path = os.path.join(log_path, save_name + '.txt') 57 | args.avg_path = os.path.join(log_path, save_name + '_avg.txt') 58 | --------------------------------------------------------------------------------