├── fig ├── FigTNSE.png ├── FigFramework.jpg └── FigToyexample.png ├── DataSets ├── VisDA-17 │ └── VisDA-17-example.png ├── office31 │ └── office31-example.png ├── OfficeHome │ └── OfficeHome-example.png └── image_CLEF │ └── image_CLEF-example.png ├── loss_funcs ├── __init__.py ├── coral.py ├── mmd.py ├── adv.py └── lmmd.py ├── DANN ├── readme.md ├── DANN.yaml └── DANN.sh ├── DeepCoral ├── README.md ├── DeepCoral.yaml └── DeepCoral.sh ├── DAN ├── DAN.yaml ├── README.md └── DAN.sh ├── DSAN ├── DSAN.yaml ├── README.md └── DSAN.sh ├── utils.py ├── transfer_losses.py ├── data_loader.py ├── backbones.py ├── README.md ├── models.py ├── metrics.py └── main.py /fig/FigTNSE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/fig/FigTNSE.png -------------------------------------------------------------------------------- /fig/FigFramework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/fig/FigFramework.jpg -------------------------------------------------------------------------------- /fig/FigToyexample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/fig/FigToyexample.png -------------------------------------------------------------------------------- /DataSets/VisDA-17/VisDA-17-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/DataSets/VisDA-17/VisDA-17-example.png -------------------------------------------------------------------------------- /DataSets/office31/office31-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/DataSets/office31/office31-example.png -------------------------------------------------------------------------------- /DataSets/OfficeHome/OfficeHome-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/DataSets/OfficeHome/OfficeHome-example.png -------------------------------------------------------------------------------- /DataSets/image_CLEF/image_CLEF-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cloudfly-Z/ACAN/HEAD/DataSets/image_CLEF/image_CLEF-example.png -------------------------------------------------------------------------------- /loss_funcs/__init__.py: -------------------------------------------------------------------------------- 1 | from loss_funcs.adv import * 2 | from loss_funcs.mmd import * 3 | from loss_funcs.lmmd import * 4 | from loss_funcs.coral import * -------------------------------------------------------------------------------- /DANN/readme.md: -------------------------------------------------------------------------------- 1 | # Domain adversarial neural network (DANN/RevGrad) 2 | 3 | This is a Pytorch implementation of Unsupervised domain adaptation by backpropagation (also know as *DANN* or *RevGrad*). 4 | 5 | 6 | **Reference** 7 | 8 | Ganin Y, Lempitsky V. Unsupervised domain adaptation by backpropagation. ICML 2015. 9 | -------------------------------------------------------------------------------- /DeepCoral/README.md: -------------------------------------------------------------------------------- 1 | # Deep Coral 2 | 3 | A PyTorch implementation of '[Deep CORAL Correlation Alignment for Deep Domain Adaptation](https://arxiv.org/pdf/1607.01719.pdf)'. 4 | The contributions of this paper are summarized as fol- 5 | lows. 6 | * They extend CORAL to incorporate it directly into deep networks by constructing a differentiable loss function that minimizes the difference between source and target correlations–the CORAL loss. 7 | * Compared to CORAL, Deep CORAL approach learns a non-linear transformation that is more powerful and also works seamlessly with deep CNNs. -------------------------------------------------------------------------------- /loss_funcs/coral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def CORAL(source, target, **kwargs): 5 | d = source.data.shape[1] 6 | ns, nt = source.data.shape[0], target.data.shape[0] 7 | # source covariance 8 | xm = torch.mean(source, 0, keepdim=True) - source 9 | xc = xm.t() @ xm / (ns - 1) 10 | 11 | # target covariance 12 | xmt = torch.mean(target, 0, keepdim=True) - target 13 | xct = xmt.t() @ xmt / (nt - 1) 14 | 15 | # frobenius norm between source and target 16 | loss = torch.mul((xc - xct), (xc - xct)) 17 | loss = torch.sum(loss) / (4*d*d) 18 | return loss 19 | -------------------------------------------------------------------------------- /DAN/DAN.yaml: -------------------------------------------------------------------------------- 1 | # Backbone 2 | backbone: resnet50 3 | 4 | # Transfer loss related 5 | transfer_loss_weight: 1 6 | transfer_loss: mmd 7 | 8 | # Metric related 9 | metric: arc_margin # arc_margin/add_margin/sphere 10 | metric_s: 20 11 | metric_m: 0.5 12 | metric_loss_weight: 0.1 13 | 14 | # Entropy related 15 | entropy_loss_weight: 0.0005 16 | 17 | # Optimizer related 18 | lr: 0.01 19 | weight_decay: 0.001 20 | lr_scheduler: True 21 | lr_gamma: 10 22 | lr_decay: 0.75 23 | momentum: 0.9 24 | 25 | # Training related 26 | n_iter_per_epoch: 200 27 | n_epoch: 20 28 | 29 | # Others 30 | seed: 1 31 | num_workers: 3 32 | -------------------------------------------------------------------------------- /DSAN/DSAN.yaml: -------------------------------------------------------------------------------- 1 | # Backbone 2 | backbone: resnet50 3 | 4 | # Transfer related 5 | transfer_loss_weight: 0.5 6 | transfer_loss: lmmd 7 | 8 | # Metric related 9 | metric: arc_margin # arc_margin/add_margin/sphere 10 | metric_s: 8 11 | metric_m: 0.5 12 | metric_loss_weight: 0.1 13 | 14 | # Entropy related 15 | entropy_loss_weight: 0.0005 16 | 17 | # Optimizer related 18 | lr: 0.01 19 | weight_decay: 5e-4 20 | momentum: 0.9 21 | lr_scheduler: True 22 | lr_gamma: 0.0003 23 | lr_decay: 0.75 24 | 25 | # Training related 26 | n_iter_per_epoch: 100 27 | n_epoch: 20 28 | 29 | # Others 30 | seed: 1 31 | num_workers: 3 32 | -------------------------------------------------------------------------------- /DANN/DANN.yaml: -------------------------------------------------------------------------------- 1 | # Backbone 2 | backbone: resnet50 3 | 4 | # Transfer related 5 | transfer_loss_weight: 1.0 6 | transfer_loss: adv 7 | 8 | # Metric related 9 | metric: arc_margin # arc_margin/add_margin/sphere 10 | metric_s: 20 11 | metric_m: 0.5 12 | metric_loss_weight: 0.5 13 | 14 | # Entropy related 15 | entropy_loss_weight: 0.075 16 | 17 | # Optimizer related 18 | lr: 0.0065 19 | weight_decay: 0.001 # 5e-4 20 | momentum: 0.9 21 | lr_scheduler: True 22 | lr_gamma: 10.0 23 | lr_decay: 0.75 24 | 25 | # Training related 26 | n_iter_per_epoch: 100 27 | n_epoch: 20 28 | 29 | # Others 30 | seed: 1 31 | num_workers: 3 32 | -------------------------------------------------------------------------------- /DeepCoral/DeepCoral.yaml: -------------------------------------------------------------------------------- 1 | # Backbone 2 | backbone: resnet50 3 | 4 | # Transfer related 5 | transfer_loss_weight: 1.0 6 | transfer_loss: coral 7 | 8 | # Metric related 9 | metric: arc_margin # arc_margin/add_margin/sphere 10 | metric_s: 20 11 | metric_m: 0.5 12 | metric_loss_weight: 0.5 13 | 14 | # Entropy related 15 | entropy_loss_weight: 0.05 16 | 17 | # Optimizer related 18 | lr: 0.01 19 | weight_decay: 0.001 20 | lr_scheduler: True 21 | lr_gamma: 10 22 | lr_decay: 0.75 23 | momentum: 0.9 24 | 25 | # Training related 26 | n_iter_per_epoch: 100 27 | n_epoch: 20 28 | 29 | # Others 30 | seed: 1 31 | num_workers: 3 32 | -------------------------------------------------------------------------------- /DAN/README.md: -------------------------------------------------------------------------------- 1 | # DAN 2 | A PyTorch implementation of '[Learning Transferable Features with Deep Adaptation Networks](http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-adaptation-networks-icml15.pdf)'. 3 | The contributions of this paper are summarized as follows. 4 | * They propose a novel deep neural network architecture for domain adaptation, in which all the layers corresponding to task-specific features are adapted in a layerwise manner, hence benefiting from “deep adaptation.” 5 | * They explore multiple kernels for adapting deep representations, which substantially enhances adaptation effectiveness compared to single kernel methods. Our model can yield unbiased deep features with statistical guarantees. -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | def str2bool(v): 20 | if isinstance(v, bool): 21 | return v 22 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 23 | return True 24 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 25 | return False 26 | else: 27 | raise ValueError('Boolean value expected.') -------------------------------------------------------------------------------- /transfer_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from loss_funcs import * 4 | 5 | class TransferLoss(nn.Module): 6 | def __init__(self, loss_type, **kwargs): 7 | super(TransferLoss, self).__init__() 8 | self.loss_type = loss_type 9 | if loss_type == "adv": 10 | self.loss_func = AdversarialLoss(**kwargs) 11 | elif loss_type == "mmd": 12 | self.loss_func = MMDLoss(**kwargs) 13 | elif loss_type == "lmmd": 14 | self.loss_func = LMMDLoss(**kwargs) 15 | elif loss_type == "coral": 16 | self.loss_func = CORAL 17 | else: 18 | print("WARNING: No valid transfer loss function is used.") 19 | self.loss_func = lambda x, y: 0 # return 0 20 | 21 | def forward(self, source, target, **kwargs): 22 | return self.loss_func(source, target, **kwargs) -------------------------------------------------------------------------------- /DSAN/README.md: -------------------------------------------------------------------------------- 1 | # DSAN 2 | 3 | A PyTorch implementation of 'Deep Subdomain Adaptation Network for Image Classification' which has published on IEEE Transactions on Neural Networks and Learning Systems. 4 | The contributions of this paper are summarized as follows. 5 | * They propose a novel deep neural network architecture for Subdomain Adaptation, which can extend the ability of deep adaptation networks by capturing the fine-grained information for each category. 6 | * They show that DSAN which is a non-adversarial method can achieve the remarkable results. In addition, their DSAN is very simple and easy to implement. 7 | 8 | 9 | ## Reference 10 | 11 | ``` 12 | Zhu Y, Zhuang F, Wang J, et al. Deep Subdomain Adaptation Network for Image Classification[J]. IEEE Transactions on Neural Networks and Learning Systems, 2020. 13 | ``` 14 | 15 | or in bibtex style: 16 | 17 | ``` 18 | @article{zhu2020deep, 19 | title={Deep Subdomain Adaptation Network for Image Classification}, 20 | author={Zhu, Yongchun and Zhuang, Fuzhen and Wang, Jindong and Ke, Guolin and Chen, Jingwu and Bian, Jiang and Xiong, Hui and He, Qing}, 21 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 22 | year={2020}, 23 | publisher={IEEE} 24 | } 25 | ``` -------------------------------------------------------------------------------- /loss_funcs/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MMDLoss(nn.Module): 5 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs): 6 | super(MMDLoss, self).__init__() 7 | self.kernel_num = kernel_num 8 | self.kernel_mul = kernel_mul 9 | self.fix_sigma = None 10 | self.kernel_type = kernel_type 11 | 12 | def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma): 13 | n_samples = int(source.size()[0]) + int(target.size()[0]) 14 | total = torch.cat([source, target], dim=0) 15 | total0 = total.unsqueeze(0).expand( 16 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 17 | total1 = total.unsqueeze(1).expand( 18 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 19 | L2_distance = ((total0-total1)**2).sum(2) 20 | if fix_sigma: 21 | bandwidth = fix_sigma 22 | else: 23 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 24 | bandwidth /= kernel_mul ** (kernel_num // 2) 25 | bandwidth_list = [bandwidth * (kernel_mul**i) 26 | for i in range(kernel_num)] 27 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 28 | for bandwidth_temp in bandwidth_list] 29 | return sum(kernel_val) 30 | 31 | def linear_mmd2(self, f_of_X, f_of_Y): 32 | loss = 0.0 33 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0) 34 | loss = delta.dot(delta.T) 35 | return loss 36 | 37 | def forward(self, source, target): 38 | if self.kernel_type == 'linear': 39 | return self.linear_mmd2(source, target) 40 | elif self.kernel_type == 'rbf': 41 | batch_size = int(source.size()[0]) 42 | kernels = self.guassian_kernel( 43 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 44 | XX = torch.mean(kernels[:batch_size, :batch_size]) 45 | YY = torch.mean(kernels[batch_size:, batch_size:]) 46 | XY = torch.mean(kernels[:batch_size, batch_size:]) 47 | YX = torch.mean(kernels[batch_size:, :batch_size]) 48 | loss = torch.mean(XX + YY - XY - YX) 49 | return loss 50 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | import torch 3 | 4 | def load_data(data_folder, batch_size, train, num_workers=0, **kwargs): 5 | transform = { 6 | 'train': transforms.Compose( 7 | [transforms.Resize([256, 256]), 8 | transforms.RandomCrop(224), 9 | transforms.RandomHorizontalFlip(), 10 | transforms.ToTensor(), 11 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 12 | std=[0.229, 0.224, 0.225])]), 13 | 'test': transforms.Compose( 14 | [transforms.Resize([224, 224]), 15 | transforms.ToTensor(), 16 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 17 | std=[0.229, 0.224, 0.225])]) 18 | } 19 | data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test']) 20 | data_loader = get_data_loader(data, batch_size=batch_size, 21 | shuffle=True if train else False, 22 | num_workers=num_workers, **kwargs, drop_last=True if train else False) 23 | n_class = len(data.classes) 24 | return data_loader, n_class 25 | 26 | 27 | def get_data_loader(dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, infinite_data_loader=False, **kwargs): 28 | if not infinite_data_loader: 29 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, num_workers=num_workers, **kwargs) 30 | else: 31 | return InfiniteDataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, num_workers=num_workers, **kwargs) 32 | 33 | class _InfiniteSampler(torch.utils.data.Sampler): 34 | """Wraps another Sampler to yield an infinite stream.""" 35 | def __init__(self, sampler): 36 | self.sampler = sampler 37 | 38 | def __iter__(self): 39 | while True: 40 | for batch in self.sampler: 41 | yield batch 42 | 43 | class InfiniteDataLoader: 44 | def __init__(self, dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, weights=None, **kwargs): 45 | if weights is not None: 46 | sampler = torch.utils.data.WeightedRandomSampler(weights, 47 | replacement=False, 48 | num_samples=batch_size) 49 | else: 50 | sampler = torch.utils.data.RandomSampler(dataset, 51 | replacement=False) 52 | 53 | batch_sampler = torch.utils.data.BatchSampler( 54 | sampler, 55 | batch_size=batch_size, 56 | drop_last=drop_last) 57 | 58 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 59 | dataset, 60 | num_workers=num_workers, 61 | batch_sampler=_InfiniteSampler(batch_sampler) 62 | )) 63 | 64 | def __iter__(self): 65 | while True: 66 | yield next(self._infinite_iterator) 67 | 68 | def __len__(self): 69 | return 0 # Always return 0 -------------------------------------------------------------------------------- /backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | resnet_dict = { 5 | "resnet18": models.resnet18, 6 | "resnet34": models.resnet34, 7 | "resnet50": models.resnet50, 8 | "resnet101": models.resnet101, 9 | "resnet152": models.resnet152, 10 | } 11 | 12 | def get_backbone(name): 13 | if "resnet" in name.lower(): 14 | return ResNetBackbone(name) 15 | elif "alexnet" == name.lower(): 16 | return AlexNetBackbone() 17 | elif "dann" == name.lower(): 18 | return DaNNBackbone() 19 | 20 | class DaNNBackbone(nn.Module): 21 | def __init__(self, n_input=224*224*3, n_hidden=256): 22 | super(DaNNBackbone, self).__init__() 23 | self.layer_input = nn.Linear(n_input, n_hidden) 24 | self.dropout = nn.Dropout(p=0.5) 25 | self.relu = nn.ReLU() 26 | self._feature_dim = n_hidden 27 | 28 | def forward(self, x): 29 | x = x.view(x.size(0), -1) 30 | x = self.layer_input(x) 31 | x = self.dropout(x) 32 | x = self.relu(x) 33 | return x 34 | 35 | def output_num(self): 36 | return self._feature_dim 37 | 38 | # convnet without the last layer 39 | class AlexNetBackbone(nn.Module): 40 | def __init__(self): 41 | super(AlexNetBackbone, self).__init__() 42 | model_alexnet = models.alexnet(pretrained=True) 43 | self.features = model_alexnet.features 44 | self.classifier = nn.Sequential() 45 | for i in range(6): 46 | self.classifier.add_module( 47 | "classifier"+str(i), model_alexnet.classifier[i]) 48 | self._feature_dim = model_alexnet.classifier[6].in_features 49 | 50 | def forward(self, x): 51 | x = self.features(x) 52 | x = x.view(x.size(0), 256*6*6) 53 | x = self.classifier(x) 54 | return x 55 | 56 | def output_num(self): 57 | return self._feature_dim 58 | 59 | class ResNetBackbone(nn.Module): 60 | def __init__(self, network_type): 61 | super(ResNetBackbone, self).__init__() 62 | resnet = resnet_dict[network_type](pretrained=True) 63 | self.conv1 = resnet.conv1 64 | self.bn1 = resnet.bn1 65 | self.relu = resnet.relu 66 | self.maxpool = resnet.maxpool 67 | self.layer1 = resnet.layer1 68 | self.layer2 = resnet.layer2 69 | self.layer3 = resnet.layer3 70 | self.layer4 = resnet.layer4 71 | self.avgpool = resnet.avgpool 72 | self._feature_dim = resnet.fc.in_features 73 | del resnet 74 | 75 | def forward(self, x): 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | x = self.maxpool(x) 80 | x = self.layer1(x) 81 | x = self.layer2(x) 82 | x = self.layer3(x) 83 | x = self.layer4(x) 84 | x = self.avgpool(x) 85 | x = x.view(x.size(0), -1) #卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),即将(channels,x,y)拉直,然后就可以和fc层连接了 86 | return x 87 | 88 | def output_num(self): 89 | return self._feature_dim -------------------------------------------------------------------------------- /loss_funcs/adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class LambdaSheduler(nn.Module): 8 | def __init__(self, gamma=10.0, max_iter=1000, **kwargs): 9 | super(LambdaSheduler, self).__init__() 10 | self.gamma = gamma 11 | self.max_iter = max_iter 12 | self.curr_iter = 0 13 | 14 | def lamb(self): 15 | p = self.curr_iter / self.max_iter 16 | lamb = 2. / (1. + np.exp(-self.gamma * p)) - 1 17 | return lamb 18 | 19 | def step(self): 20 | self.curr_iter = min(self.curr_iter + 1, self.max_iter) 21 | 22 | class AdversarialLoss(nn.Module): 23 | ''' 24 | Acknowledgement: The adversarial loss implementation is inspired by http://transfer.thuml.ai/ 25 | ''' 26 | def __init__(self, gamma=10.0, max_iter=1000, use_lambda_scheduler=True, **kwargs): 27 | super(AdversarialLoss, self).__init__() 28 | self.domain_classifier = Discriminator() 29 | self.use_lambda_scheduler = use_lambda_scheduler 30 | if self.use_lambda_scheduler: 31 | self.lambda_scheduler = LambdaSheduler(gamma, max_iter) 32 | 33 | def forward(self, source, target): 34 | lamb = 1.0 35 | if self.use_lambda_scheduler: 36 | lamb = self.lambda_scheduler.lamb() 37 | self.lambda_scheduler.step() 38 | source_loss = self.get_adversarial_result(source, True, lamb) 39 | target_loss = self.get_adversarial_result(target, False, lamb) 40 | adv_loss = 0.5 * (source_loss + target_loss) * lamb 41 | return adv_loss 42 | 43 | def get_adversarial_result(self, x, source=True, lamb=1.0): 44 | x = ReverseLayerF.apply(x, lamb) 45 | domain_pred = self.domain_classifier(x) 46 | device = domain_pred.device 47 | if source: 48 | domain_label = torch.ones(len(x), 1).long() 49 | else: 50 | domain_label = torch.zeros(len(x), 1).long() 51 | loss_fn = nn.BCELoss() 52 | loss_adv = loss_fn(domain_pred, domain_label.float().to(device)) 53 | return loss_adv 54 | 55 | 56 | class ReverseLayerF(Function): 57 | @staticmethod 58 | def forward(ctx, x, alpha): 59 | ctx.alpha = alpha 60 | return x.view_as(x) 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | output = grad_output.neg() * ctx.alpha 65 | return output, None 66 | 67 | class Discriminator(nn.Module): 68 | def __init__(self, input_dim=256, hidden_dim=256): 69 | super(Discriminator, self).__init__() 70 | self.input_dim = input_dim 71 | self.hidden_dim = hidden_dim 72 | layers = [ 73 | nn.Linear(input_dim, hidden_dim), 74 | nn.BatchNorm1d(hidden_dim), 75 | nn.ReLU(), 76 | nn.Linear(hidden_dim, hidden_dim), 77 | nn.BatchNorm1d(hidden_dim), 78 | nn.ReLU(), 79 | nn.Linear(hidden_dim, 1), 80 | nn.Sigmoid() 81 | ] 82 | self.layers = torch.nn.Sequential(*layers) 83 | 84 | def forward(self, x): 85 | return self.layers(x) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACAN: A plug-and-play Adaptive Center-Aligned Network for unsupervised domain adaptation 2 | Official implementation of [ACAN](https://doi.org/10.1016/j.engappai.2024.109132) (EAAI 2024). 3 | 4 | Abstract 5 | --- 6 | Domain adaptation is an important topic due to its capability in transferring knowledge from source domain to target domain. However, many existing domain adaptation methods primarily concentrate on aligning the data distributions between the source and target domains, often neglecting discriminative feature learning. As a result, target samples with low confidence are embedded near the decision boundary, where they are susceptible to being misclassified, resulting in negative transfer. To address this problem, a novel Adaptive Center-Aligned Network dubbed ACAN is proposed for unsupervised domain adaptation in this work. The main innovations of ACAN are fourfold. Firstly, it is a plug-and-play module and can be easily incorporated into any domain alignment methods without increasing the model complexity and computational burden. Secondly, in contrast to conventional softmax plus cross-entropy loss, angular margin loss is called to enhance the discrimination power for classifier. Thirdly, entropy regularization is exploited to highlight the probability of potential related class, which renders our learned feature representation far away from the decision boundary. Fourthly, to improve the discriminative capacity of model to the target domain, we propose to align the target domain samples to the corresponding class center via pseudo labels. Incorporating ACAN, the performance of baseline domain alignment methods is significantly improved. Extensive ablation and comparison experiments on four widely adopted databases demonstrate the effectiveness of our ACAN. 7 | 8 | Motivation 9 | --- 10 | ![Motivation](/fig/FigToyexample.png "Toy Example") 11 | 12 | Network Architecture 13 | --- 14 | ![Framework](/fig/FigFramework.jpg "Network Architecture") 15 | 16 | Visualization 17 | --- 18 | ![TSNE](/fig/FigTNSE.png "Visualization") 19 | 20 | Usage 21 | --- 22 | 1. Data 23 | 24 | Office-31, Office-Home, ImageCLEF-DA, VisDA-2017 datasets are available at [Datasets download](https://github.com/jindongwang/transferlearning/tree/master/data). 25 | 26 | Place the downloaded dataset in the `ACAN/DataSets/` 27 | 28 | 2. Dependencies 29 | ``` 30 | CUDA Version: 12.0 31 | python==3.8.16 32 | torch==2.0.1+cu118 33 | torchvision==0.15.2+cu118 34 | torchaudio==2.0.2+cu118 35 | numpy==1.2.14 36 | configargparse==1.7 37 | pyyaml==6.0.1 38 | ``` 39 | 40 | 3. Run shell script files such as: 41 | ``` 42 | bash DSAN/DSAN.sh 43 | ``` 44 | 45 | Acknowledgement 46 | --- 47 | Our code is based on the project [Everything about Transfer Learning and Domain Adapation](https://github.com/jindongwang/transferlearning). 48 | 49 | Contact 50 | --- 51 | If you have any questions, please create an issue on this repository or contact at 1923351867@qq.com. 52 | 53 | Citation 54 | --- 55 | If you think our code or paper are helpful to you, please cite us! 56 | ``` 57 | @article{ZHANG2024109132, 58 | title = {ACAN: A plug-and-play Adaptive Center-Aligned Network for unsupervised domain adaptation}, 59 | journal = {Engineering Applications of Artificial Intelligence}, 60 | volume = {137}, 61 | pages = {109132}, 62 | year = {2024}, 63 | issn = {0952-1976}, 64 | doi = {https://doi.org/10.1016/j.engappai.2024.109132}, 65 | author = {Yunfei Zhang and Jun Zhang and Tonglu Li and Feixue Shao and Xuetao Ma and Yongfei Wu and Shu Feng and Daoxiang Zhou} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /DAN/DAN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPU_ID=0 3 | data_dir=./DataSets/office31 4 | # Office31 5 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain dslr --tgt_domain amazon | tee DAN_D2A.log 6 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain dslr --tgt_domain webcam | tee DAN_D2W.log 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain dslr | tee DAN_A2D.log 9 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain webcam | tee DAN_A2W.log 10 | 11 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain webcam --tgt_domain amazon | tee DAN_W2A.log 12 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain webcam --tgt_domain dslr | tee DAN_W2D.log 13 | 14 | 15 | 16 | data_dir=./DataSets/OfficeHome 17 | # Office-Home 18 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Clipart | tee DAN_A2C.log 19 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Real_World | tee DAN_A2R.log 20 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Product | tee DAN_A2P.log 21 | 22 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Art | tee DAN_C2A.log 23 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Real_World | tee DAN_C2R.log 24 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Product | tee DAN_C2P.log 25 | 26 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Art | tee DAN_P2A.log 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Real_World | tee DAN_P2R.log 28 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Clipart | tee DAN_P2C.log 29 | 30 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Art | tee DAN_R2A.log 31 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Product | tee DAN_R2P.log 32 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Clipart | tee DAN_R2C.log 33 | 34 | data_dir=./DataSets/image_CLEF 35 | # image_CLEF 36 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain c --tgt_domain i_tar | tee DAN_C2I.log 37 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain c --tgt_domain p_tar | tee DAN_C2P.log 38 | 39 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain i --tgt_domain c_tar | tee DAN_I2C.log 40 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain i --tgt_domain p_tar | tee DAN_I2P.log 41 | 42 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain p --tgt_domain c_tar | tee DAN_P2C.log 43 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain p --tgt_domain i_tar | tee DAN_P2I.log 44 | 45 | 46 | data_dir=./DataSets/VisDA-17 47 | # VisDA-17 48 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DAN/DAN.yaml --data_dir $data_dir --src_domain train --tgt_domain validation | tee DAN_S2R.log -------------------------------------------------------------------------------- /DANN/DANN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPU_ID=0 3 | data_dir=./DataSets/office31 4 | # Office31 5 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain dslr --tgt_domain amazon | tee DANN_D2A.log 6 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain dslr --tgt_domain webcam | tee DANN_D2W.log 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain dslr | tee DANN_A2D.log 9 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain webcam | tee DANN_A2W.log 10 | 11 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain webcam --tgt_domain amazon | tee DANN_W2A.log 12 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain webcam --tgt_domain dslr | tee DANN_W2D.log 13 | 14 | 15 | data_dir=./DataSets/OfficeHome 16 | # Office-Home 17 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Clipart | tee DANN_A2C.log 18 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Real_World | tee DANN_A2R.log 19 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Product | tee DANN_A2P.log 20 | 21 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Art | tee DANN_C2A.log 22 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Real_World | tee DANN_C2R.log 23 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Product | tee DANN_C2P.log 24 | 25 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Art | tee DANN_P2A.log 26 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Real_World | tee DANN_P2R.log 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Clipart | tee DANN_P2C.log 28 | 29 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Art | tee DANN_R2A.log 30 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Product | tee DANN_R2P.log 31 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Clipart | tee DANN_R2C.log 32 | 33 | data_dir=./DataSets/image_CLEF 34 | # image_CLEF 35 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain c --tgt_domain i_tar | tee DANN_C2I.log 36 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain c --tgt_domain p_tar | tee DANN_C2P.log 37 | 38 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain i --tgt_domain c_tar | tee DANN_I2C.log 39 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain i --tgt_domain p_tar | tee DANN_I2P.log 40 | 41 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain p --tgt_domain c_tar | tee DANN_P2C.log 42 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain p --tgt_domain i_tar | tee DANN_P2I.log 43 | 44 | 45 | data_dir=./DataSets/VisDA-17 46 | # VisDA-17 47 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DANN/DANN.yaml --data_dir $data_dir --src_domain train --tgt_domain validation | tee DANN_S2R.log -------------------------------------------------------------------------------- /DSAN/DSAN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPU_ID=0 3 | data_dir=./DataSets/office31 4 | # Office31 5 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain dslr --tgt_domain amazon | tee DSAN_D2A.log 6 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain dslr --tgt_domain webcam | tee DSAN_D2W.log 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain dslr | tee DSAN_A2D.log 9 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain webcam | tee DSAN_A2W.log 10 | 11 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain webcam --tgt_domain amazon | tee DSAN_W2A.log 12 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain webcam --tgt_domain dslr | tee DSAN_W2D.log 13 | 14 | 15 | data_dir=./DataSets/OfficeHome 16 | # Office-Home 17 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Clipart | tee DSAN_A2C.log 18 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Real_World | tee DSAN_A2R.log 19 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Art --tgt_domain Product | tee DSAN_A2P.log 20 | 21 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Art | tee DSAN_C2A.log 22 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Real_World | tee DSAN_C2R.log 23 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Product | tee DSAN_C2P.log 24 | 25 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Art | tee DSAN_P2A.log 26 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Real_World | tee DSAN_P2R.log 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Product --tgt_domain Clipart | tee DSAN_P2C.log 28 | 29 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Art | tee DSAN_R2A.log 30 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Product | tee DSAN_R2P.log 31 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Clipart | tee DSAN_R2C.log 32 | 33 | data_dir=./DataSets/image_CLEF 34 | # image_CLEF 35 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain c --tgt_domain i_tar | tee DSAN_C2I.log 36 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain c --tgt_domain p_tar | tee DSAN_C2P.log 37 | 38 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain i --tgt_domain c_tar | tee DSAN_I2C.log 39 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain i --tgt_domain p_tar | tee DSAN_I2P.log 40 | 41 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain p --tgt_domain c_tar | tee DSAN_P2C.log 42 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain p --tgt_domain i_tar | tee DSAN_P2I.log 43 | 44 | 45 | data_dir=./DataSets/VisDA-17 46 | # VisDA-17 47 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DSAN/DSAN.yaml --data_dir $data_dir --src_domain train --tgt_domain validation | tee DSAN_S2R.log -------------------------------------------------------------------------------- /loss_funcs/lmmd.py: -------------------------------------------------------------------------------- 1 | from loss_funcs.mmd import MMDLoss 2 | from loss_funcs.adv import LambdaSheduler 3 | import torch 4 | import numpy as np 5 | 6 | class LMMDLoss(MMDLoss, LambdaSheduler): 7 | def __init__(self, num_class, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, 8 | gamma=1.0, max_iter=1000, **kwargs): 9 | ''' 10 | Local MMD 11 | ''' 12 | super(LMMDLoss, self).__init__(kernel_type, kernel_mul, kernel_num, fix_sigma, **kwargs) 13 | super(MMDLoss, self).__init__(gamma, max_iter, **kwargs) 14 | self.num_class = num_class 15 | 16 | def forward(self, source, target, source_label, target_logits): 17 | if self.kernel_type == 'linear': 18 | raise NotImplementedError("Linear kernel is not supported yet.") 19 | 20 | elif self.kernel_type == 'rbf': 21 | batch_size = source.size()[0] 22 | weight_ss, weight_tt, weight_st = self.cal_weight(source_label, target_logits) 23 | weight_ss = torch.from_numpy(weight_ss).cuda() # B, B 24 | weight_tt = torch.from_numpy(weight_tt).cuda() 25 | weight_st = torch.from_numpy(weight_st).cuda() 26 | 27 | kernels = self.guassian_kernel(source, target, 28 | kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 29 | loss = torch.Tensor([0]).cuda() 30 | if torch.sum(torch.isnan(sum(kernels))): 31 | return loss 32 | SS = kernels[:batch_size, :batch_size] 33 | TT = kernels[batch_size:, batch_size:] 34 | ST = kernels[:batch_size, batch_size:] 35 | 36 | loss += torch.sum( weight_ss * SS + weight_tt * TT - 2 * weight_st * ST ) 37 | # Dynamic weighting 38 | lamb = self.lamb() 39 | self.step() 40 | loss = loss * lamb 41 | return loss 42 | 43 | def cal_weight(self, source_label, target_logits): 44 | batch_size = source_label.size()[0] 45 | source_label = source_label.cpu().data.numpy() 46 | source_label_onehot = np.eye(self.num_class)[source_label] # one hot 47 | 48 | source_label_sum = np.sum(source_label_onehot, axis=0).reshape(1, self.num_class) 49 | source_label_sum[source_label_sum == 0] = 100 50 | source_label_onehot = source_label_onehot / source_label_sum # label ratio 51 | 52 | # Pseudo label 53 | target_label = target_logits.cpu().data.max(1)[1].numpy() 54 | 55 | target_logits = target_logits.cpu().data.numpy() 56 | target_logits_sum = np.sum(target_logits, axis=0).reshape(1, self.num_class) 57 | target_logits_sum[target_logits_sum == 0] = 100 58 | target_logits = target_logits / target_logits_sum 59 | 60 | weight_ss = np.zeros((batch_size, batch_size)) 61 | weight_tt = np.zeros((batch_size, batch_size)) 62 | weight_st = np.zeros((batch_size, batch_size)) 63 | 64 | set_s = set(source_label) 65 | set_t = set(target_label) 66 | count = 0 67 | for i in range(self.num_class): # (B, C) 68 | if i in set_s and i in set_t: 69 | s_tvec = source_label_onehot[:, i].reshape(batch_size, -1) # (B, 1) 70 | t_tvec = target_logits[:, i].reshape(batch_size, -1) # (B, 1) 71 | 72 | ss = np.dot(s_tvec, s_tvec.T) # (B, B) 73 | weight_ss = weight_ss + ss 74 | tt = np.dot(t_tvec, t_tvec.T) 75 | weight_tt = weight_tt + tt 76 | st = np.dot(s_tvec, t_tvec.T) 77 | weight_st = weight_st + st 78 | count += 1 79 | 80 | length = count 81 | if length != 0: 82 | weight_ss = weight_ss / length 83 | weight_tt = weight_tt / length 84 | weight_st = weight_st / length 85 | else: 86 | weight_ss = np.array([0]) 87 | weight_tt = np.array([0]) 88 | weight_st = np.array([0]) 89 | return weight_ss.astype('float32'), weight_tt.astype('float32'), weight_st.astype('float32') 90 | 91 | 92 | -------------------------------------------------------------------------------- /DeepCoral/DeepCoral.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPU_ID=0 3 | data_dir=./DataSets/office31 4 | # Office31 5 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain dslr --tgt_domain amazon | tee DeepCoral_D2A.log 6 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain dslr --tgt_domain webcam | tee DeepCoral_D2W.log 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain amazon --tgt_domain dslr | tee DeepCoral_A2D.log 9 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain amazon --tgt_domain webcam | tee DeepCoral_A2W.log 10 | 11 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain webcam --tgt_domain amazon | tee DeepCoral_W2A.log 12 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain webcam --tgt_domain dslr | tee DeepCoral_W2D.log 13 | 14 | 15 | data_dir=./DataSets/OfficeHome 16 | # Office-Home 17 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Art --tgt_domain Clipart | tee DeepCoral_A2C.log 18 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Art --tgt_domain Real_World | tee DeepCoral_A2R.log 19 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Art --tgt_domain Product | tee DeepCoral_A2P.log 20 | 21 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Art | tee DeepCoral_C2A.log 22 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Real_World | tee DeepCoral_C2R.log 23 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Clipart --tgt_domain Product | tee DeepCoral_C2P.log 24 | 25 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Product --tgt_domain Art | tee DeepCoral_P2A.log 26 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Product --tgt_domain Real_World | tee DeepCoral_P2R.log 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Product --tgt_domain Clipart | tee DeepCoral_P2C.log 28 | 29 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Art | tee DeepCoral_R2A.log 30 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Product | tee DeepCoral_R2P.log 31 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain Real_World --tgt_domain Clipart | tee DeepCoral_R2C.log 32 | 33 | data_dir=./DataSets/image_CLEF 34 | # image_CLEF 35 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain c --tgt_domain i_tar | tee DeepCoral_C2I.log 36 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain c --tgt_domain p_tar | tee DeepCoral_C2P.log 37 | 38 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain i --tgt_domain c_tar | tee DeepCoral_I2C.log 39 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain i --tgt_domain p_tar | tee DeepCoral_I2P.log 40 | 41 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain p --tgt_domain c_tar | tee DeepCoral_P2C.log 42 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain p --tgt_domain i_tar | tee DeepCoral_P2I.log 43 | 44 | 45 | data_dir=./DataSets/VisDA-17 46 | # VisDA-17 47 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py --config DeepCoral/DeepCoral.yaml --data_dir $data_dir --src_domain train --tgt_domain validation | tee DeepCoral_S2R.log -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transfer_losses import TransferLoss 4 | import backbones 5 | from metrics import * 6 | import copy 7 | 8 | 9 | class TransferNet(nn.Module): 10 | def __init__(self, num_class, base_net='resnet50', transfer_loss='mmd', use_bottleneck=True, bottleneck_width=256, metric='arc_margin', metric_s=20, metric_m=0.5, max_iter=1000, **kwargs): 11 | super(TransferNet, self).__init__() 12 | self.num_class = num_class 13 | self.base_network = backbones.get_backbone(base_net) 14 | self.use_bottleneck = use_bottleneck 15 | self.transfer_loss = transfer_loss 16 | self.metric = metric 17 | self.s = metric_s 18 | self.m = metric_m 19 | if self.use_bottleneck: 20 | bottleneck_list = [ 21 | nn.Linear(self.base_network.output_num(), bottleneck_width), 22 | nn.ReLU() 23 | ] 24 | self.bottleneck_layer = nn.Sequential(*bottleneck_list) 25 | feature_dim = bottleneck_width 26 | else: 27 | feature_dim = self.base_network.output_num() 28 | 29 | if self.metric == 'add_margin': 30 | self.classifier_layer = AddMarginProduct(feature_dim, self.num_class, self.s, self.m) 31 | elif self.metric == 'arc_margin': 32 | self.classifier_layer = ArcMarginProduct(feature_dim, self.num_class, self.s, self.m, easy_margin=False) 33 | elif self.metric == 'sphere': 34 | self.classifier_layer = SphereProduct(feature_dim, self.num_class, self.s, self.m) 35 | else: 36 | self.classifier_layer = nn.Linear(feature_dim, self.num_classes) 37 | transfer_loss_args = { 38 | "loss_type": self.transfer_loss, 39 | "max_iter": max_iter, 40 | "num_class": num_class 41 | } 42 | self.adapt_loss = TransferLoss(**transfer_loss_args) 43 | self.criterion = torch.nn.CrossEntropyLoss() 44 | 45 | def forward(self, source, target, source_label): 46 | source = self.base_network(source) 47 | target = self.base_network(target) 48 | if self.use_bottleneck: 49 | source = self.bottleneck_layer(source) 50 | target = self.bottleneck_layer(target) 51 | 52 | # source arc classification (L_AM) 53 | source_clf = self.classifier_layer(source,source_label) 54 | clf_loss = self.criterion(source_clf, source_label) 55 | 56 | # spource target transfer (L_D) 57 | kwargs = {} 58 | if self.transfer_loss == "lmmd": 59 | kwargs['source_label'] = source_label 60 | target_clf = self.classifier_layer(target,None) 61 | kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1) 62 | transfer_loss = self.adapt_loss(source, target, **kwargs) 63 | 64 | 65 | # target center alignment loss (L_CA) 66 | loss_metric = torch.tensor(0.0,requires_grad=True) 67 | target_norm = F.normalize(target) 68 | target_clf = self.classifier_layer(target,None) 69 | target_logits = torch.nn.functional.softmax(target_clf, dim=1) 70 | num = 0 71 | 72 | for i in range(64): 73 | if max(target_logits[i]) > 0.9: #threshold T=0.9 74 | index = torch.argmax(target_logits[i]) 75 | source_w_norm = F.normalize(self.classifier_layer.weight) 76 | source_w_norm_index = source_w_norm[index] 77 | metric_loss_i = (F.linear(target_norm[i], source_w_norm_index)) 78 | loss_metric = loss_metric + metric_loss_i 79 | num =num + 1 80 | 81 | metric_loss = 1-(loss_metric / (num+torch.tensor(1e-6))) 82 | 83 | # target entropy loss (L_EM) 84 | weight_copy = copy.deepcopy(self.classifier_layer.weight) 85 | weight_copy.requires_grad = False 86 | target_clf_copy = F.linear(target, weight_copy) 87 | target_logits_copy = torch.nn.functional.softmax(target_clf_copy, dim=1) 88 | 89 | entropy_loss = -torch.mean((target_logits_copy * torch.log(target_logits_copy + 1e-6)).sum(dim=1)) 90 | 91 | return clf_loss, transfer_loss, metric_loss, entropy_loss 92 | 93 | def get_parameters(self, initial_lr=1.0): 94 | params = [ 95 | {'params': self.base_network.parameters(), 'lr': 0.1 * initial_lr}, 96 | {'params': self.classifier_layer.parameters(), 'lr': 1.0 * initial_lr}, 97 | ] 98 | if self.use_bottleneck: 99 | params.append( 100 | {'params': self.bottleneck_layer.parameters(), 'lr': 1.0 * initial_lr} 101 | ) 102 | # Loss-dependent 103 | if self.transfer_loss == "adv": 104 | params.append( 105 | {'params': self.adapt_loss.loss_func.domain_classifier.parameters(), 'lr': 1.0 * initial_lr} 106 | ) 107 | return params 108 | 109 | def predict(self, x): 110 | features = self.base_network(x) 111 | x = self.bottleneck_layer(features) 112 | clf = self.classifier_layer(x,None) 113 | return clf -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import Parameter 7 | import math 8 | 9 | 10 | class ArcMarginProduct(nn.Module): 11 | r"""Implement of large margin arc distance: : 12 | Args: 13 | in_features: size of each input sample 14 | out_features: size of each output sample 15 | s: norm of input feature 16 | m: margin 17 | 18 | cos(theta + m) 19 | """ 20 | def __init__(self, in_features, out_features, s=20.0, m=0.50, easy_margin=False): 21 | super(ArcMarginProduct, self).__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.s = s 25 | self.m = m 26 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) # n×d 27 | nn.init.xavier_uniform_(self.weight) 28 | 29 | self.easy_margin = easy_margin 30 | self.cos_m = math.cos(m) 31 | self.sin_m = math.sin(m) 32 | self.th = math.cos(math.pi - m) 33 | self.mm = math.sin(math.pi - m) * m 34 | 35 | def forward(self, input, label): 36 | 37 | # --------------------------- cos(theta) & phi(theta) --------------------------- 38 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 39 | 40 | if label == None: 41 | return cosine*self.s 42 | 43 | else: 44 | # 由 cosθ 计算相应的 sinθ 45 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 46 | # 展开计算 phi=cos(θ+m) = cosθ*cosm - sinθ*sinm, 其中包含了 Target Logit (cos(θyi+ m)) (由于输入特征 xi 的非真实类也参与了计算, 最后计算新 Logit 时需使用 One-Hot 区别) 47 | phi = cosine * self.cos_m - sine * self.sin_m 48 | if self.easy_margin: # 是否松弛约束?? 49 | phi = torch.where(cosine > 0, phi, cosine) 50 | else: 51 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 52 | # --------------------------- convert label to one-hot --------------------------- 53 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 54 | one_hot = torch.zeros(cosine.size(), device='cuda') # requires_grad=False 55 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 56 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 57 | # - 只有输入特征 xi 对应的真实类别 yi (one_hot=1) 采用新 Target Logit cos(θ_yi + m) 58 | # - 其余并不对应输入特征 xi 的真实类别的类 (one_hot=0) 则仍保持原 Logit cosθ_j 59 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 60 | output *= self.s 61 | 62 | return output 63 | 64 | 65 | class AddMarginProduct(nn.Module): 66 | r"""Implement of large margin cosine distance: : 67 | Args: 68 | in_features: size of each input sample 69 | out_features: size of each output sample 70 | s: norm of input feature 71 | m: margin 72 | cos(theta) - m 73 | """ 74 | 75 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 76 | super(AddMarginProduct, self).__init__() 77 | self.in_features = in_features 78 | self.out_features = out_features 79 | self.s = s 80 | self.m = m 81 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 82 | nn.init.xavier_uniform_(self.weight) 83 | 84 | def forward(self, input, label): 85 | # --------------------------- cos(theta) & phi(theta) --------------------------- 86 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 87 | 88 | if label == None: 89 | return cosine*self.s 90 | 91 | else: 92 | phi = cosine - self.m 93 | # --------------------------- convert label to one-hot --------------------------- 94 | one_hot = torch.zeros(cosine.size(), device='cuda') 95 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 96 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 97 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 98 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 99 | output *= self.s 100 | 101 | return output 102 | 103 | def __repr__(self): 104 | return self.__class__.__name__ + '(' \ 105 | + 'in_features=' + str(self.in_features) \ 106 | + ', out_features=' + str(self.out_features) \ 107 | + ', s=' + str(self.s) \ 108 | + ', m=' + str(self.m) + ')' 109 | 110 | 111 | class SphereProduct(nn.Module): 112 | r"""Implement of large margin cosine distance: : 113 | Args: 114 | in_features: size of each input sample 115 | out_features: size of each output sample 116 | m: margin 117 | cos(m*theta) 118 | """ 119 | def __init__(self, in_features, out_features, s=20,m=4): 120 | super(SphereProduct, self).__init__() 121 | self.in_features = in_features 122 | self.out_features = out_features 123 | self.m = m 124 | self.s = s 125 | self.base = 1000.0 126 | self.gamma = 0.12 127 | self.power = 1 128 | self.LambdaMin = 5.0 129 | self.iter = 0 130 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 131 | nn.init.xavier_uniform(self.weight) 132 | 133 | # duplication formula 134 | self.mlambda = [ 135 | lambda x: x ** 0, 136 | lambda x: x ** 1, 137 | lambda x: 2 * x ** 2 - 1, 138 | lambda x: 4 * x ** 3 - 3 * x, 139 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 140 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 141 | ] 142 | 143 | def forward(self, input, label): 144 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power)) 145 | self.iter += 1 146 | self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) 147 | 148 | # --------------------------- cos(theta) & phi(theta) --------------------------- 149 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) 150 | 151 | if label ==None: 152 | return cos_theta*self.s 153 | 154 | else: 155 | cos_theta = cos_theta.clamp(-1, 1) 156 | cos_m_theta = self.mlambda[self.m](cos_theta) 157 | theta = cos_theta.data.acos() 158 | k = (self.m * theta / 3.14159265).floor() 159 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k 160 | NormOfFeature = torch.norm(input, 2, 1) 161 | 162 | # --------------------------- convert label to one-hot --------------------------- 163 | one_hot = torch.zeros(cos_theta.size()) 164 | one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot 165 | one_hot.scatter_(1, label.view(-1, 1), 1) 166 | 167 | # --------------------------- Calculate output --------------------------- 168 | output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta 169 | output *= NormOfFeature.view(-1, 1) 170 | 171 | return output*self.s 172 | 173 | def __repr__(self): 174 | return self.__class__.__name__ + '(' \ 175 | + 'in_features=' + str(self.in_features) \ 176 | + ', out_features=' + str(self.out_features) \ 177 | + ', m=' + str(self.m) + ')' 178 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import data_loader 3 | import os 4 | import torch 5 | import models 6 | import utils 7 | from utils import str2bool 8 | import numpy as np 9 | import random 10 | 11 | def get_parser(): 12 | """Get default arguments.""" 13 | parser = configargparse.ArgumentParser( 14 | description="Transfer learning config parser", 15 | config_file_parser_class=configargparse.YAMLConfigFileParser, 16 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 17 | ) 18 | # general configuration 19 | parser.add("--config", is_config_file=True, help="config file path") 20 | parser.add("--seed", type=int, default=0) 21 | parser.add_argument('--num_workers', type=int, default=0) 22 | 23 | 24 | # network related 25 | parser.add_argument('--backbone', type=str, default='resnet50') 26 | parser.add_argument('--use_bottleneck', type=str2bool, default=True) 27 | 28 | # data loading related 29 | parser.add_argument('--data_dir', type=str, required=True) 30 | parser.add_argument('--src_domain', type=str, required=True) 31 | parser.add_argument('--tgt_domain', type=str, required=True) 32 | 33 | # training related 34 | parser.add_argument('--batch_size', type=int, default=64) 35 | parser.add_argument('--n_epoch', type=int, default=100) 36 | parser.add_argument('--early_stop', type=int, default=0, help="Early stopping") 37 | parser.add_argument('--epoch_based_training', type=str2bool, default=False, help="Epoch-based training / Iteration-based training") 38 | parser.add_argument("--n_iter_per_epoch", type=int, default=20, help="Used in Iteration-based training") 39 | 40 | # optimizer related 41 | parser.add_argument('--lr', type=float, default=1e-3) 42 | parser.add_argument('--momentum', type=float, default=0.9) 43 | parser.add_argument('--weight_decay', type=float, default=5e-4) 44 | 45 | # learning rate scheduler related 46 | parser.add_argument('--lr_gamma', type=float, default=0.0003) 47 | parser.add_argument('--lr_decay', type=float, default=0.75) 48 | parser.add_argument('--lr_scheduler', type=str2bool, default=True) 49 | 50 | # transfer related 51 | parser.add_argument('--transfer_loss_weight', type=float, default=1) 52 | parser.add_argument('--transfer_loss', type=str, default='mmd') 53 | 54 | # metric related 55 | parser.add_argument('--metric', type=str, default='arc_margin') 56 | parser.add_argument('--metric_s', type=float, default=20) 57 | parser.add_argument('--metric_m', type=float, default=0.5) 58 | parser.add_argument('--metric_loss_weight', type=float, default=1) 59 | 60 | # entropy related 61 | parser.add_argument('--entropy_loss_weight', type=float, default=1) 62 | return parser 63 | 64 | def set_random_seed(seed=0): 65 | # seed setting 66 | random.seed(seed) 67 | np.random.seed(seed) 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed(seed) 70 | torch.backends.cudnn.deterministic = True 71 | torch.backends.cudnn.benchmark = False 72 | 73 | def load_data(args): 74 | ''' 75 | src_domain, tgt_domain data to load 76 | ''' 77 | folder_src = os.path.join(args.data_dir, args.src_domain) 78 | folder_tgt = os.path.join(args.data_dir, args.tgt_domain) 79 | source_loader, n_class = data_loader.load_data( 80 | folder_src, args.batch_size, infinite_data_loader=not args.epoch_based_training, train=True, num_workers=args.num_workers) 81 | target_train_loader, _ = data_loader.load_data( 82 | folder_tgt, args.batch_size, infinite_data_loader=not args.epoch_based_training, train=True, num_workers=args.num_workers) 83 | target_test_loader, _ = data_loader.load_data( 84 | folder_tgt, args.batch_size, infinite_data_loader=False, train=False, num_workers=args.num_workers) 85 | return source_loader, target_train_loader, target_test_loader, n_class 86 | 87 | def get_model(args): 88 | model = models.TransferNet( 89 | args.n_class, transfer_loss=args.transfer_loss, base_net=args.backbone, max_iter=args.max_iter, use_bottleneck=args.use_bottleneck,metric=args.metric).to(args.device) 90 | return model 91 | 92 | def get_optimizer(model, args): 93 | initial_lr = args.lr if not args.lr_scheduler else 1.0 94 | params = model.get_parameters(initial_lr=initial_lr) 95 | optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) 96 | return optimizer 97 | 98 | def get_scheduler(optimizer, args): 99 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)/args.max_iter) ** (-args.lr_decay)) #公式见CDAN https://arxiv.org/abs/1705.10667 100 | return scheduler 101 | 102 | def test(model, target_test_loader, args): 103 | model.eval() 104 | test_loss = utils.AverageMeter() 105 | correct = 0 106 | criterion = torch.nn.CrossEntropyLoss() 107 | len_target_dataset = len(target_test_loader.dataset) 108 | with torch.no_grad(): 109 | for data, target in target_test_loader: 110 | data, target = data.to(args.device), target.to(args.device) 111 | s_output = model.predict(data) 112 | loss = criterion(s_output, target) 113 | test_loss.update(loss.item()) 114 | pred = torch.max(s_output, 1)[1] 115 | correct += torch.sum(pred == target) 116 | acc = 100. * correct / len_target_dataset 117 | return acc, test_loss.avg 118 | 119 | def train(source_loader, target_train_loader, target_test_loader, model, optimizer, lr_scheduler, args): 120 | len_source_loader = len(source_loader) 121 | len_target_loader = len(target_train_loader) 122 | n_batch = min(len_source_loader, len_target_loader) 123 | if n_batch == 0: 124 | n_batch = args.n_iter_per_epoch 125 | 126 | iter_source, iter_target = iter(source_loader), iter(target_train_loader) 127 | 128 | best_acc = 0 129 | stop = 0 130 | log = [] 131 | 132 | epoch_list = [] 133 | epoch_pbacc_list = [] 134 | for e in range(1, args.n_epoch+1): 135 | model.train() #每个epoch训练前都要声明 136 | train_loss_clf = utils.AverageMeter() 137 | train_loss_transfer = utils.AverageMeter() 138 | train_loss_metric = utils.AverageMeter() 139 | train_loss_entropy = utils.AverageMeter() 140 | train_loss_total = utils.AverageMeter() 141 | 142 | if max(len_target_loader, len_source_loader) != 0: 143 | iter_source, iter_target = iter(source_loader), iter(target_train_loader) 144 | 145 | criterion = torch.nn.CrossEntropyLoss() 146 | 147 | for i in range(n_batch): 148 | model.train() 149 | data_source, label_source = next(iter_source) # .next() 150 | data_target, _ = next(iter_target) # .next() 151 | data_source, label_source = data_source.to(args.device), label_source.to(args.device) 152 | data_target = data_target.to(args.device) 153 | 154 | clf_loss, transfer_loss, metric_loss ,entropy_loss = model(data_source, data_target, label_source) 155 | 156 | loss = clf_loss + args.transfer_loss_weight * transfer_loss + args.metric_loss_weight * metric_loss + args.entropy_loss_weight * entropy_loss 157 | 158 | optimizer.zero_grad() 159 | loss.backward() 160 | optimizer.step() 161 | if lr_scheduler: 162 | lr_scheduler.step() 163 | 164 | train_loss_clf.update(clf_loss.item()) 165 | train_loss_transfer.update(transfer_loss.item()) 166 | train_loss_metric.update(metric_loss.item()) 167 | train_loss_entropy.update(entropy_loss.item()) 168 | train_loss_total.update(loss.item()) 169 | 170 | log.append([train_loss_clf.avg, train_loss_transfer.avg, train_loss_metric.avg, train_loss_entropy.avg, train_loss_total.avg]) 171 | 172 | info = 'Epoch: [{:2d}/{}], cls_loss: {:.4f}, transfer_loss: {:.4f}, metric_loss:{:.4f},entropy_loss:{:.4f},total_Loss: {:.4f}'.format( 173 | e, args.n_epoch, train_loss_clf.avg, train_loss_transfer.avg, train_loss_metric.avg, train_loss_entropy.avg,train_loss_total.avg) 174 | 175 | # Test 176 | stop += 1 177 | test_acc, test_loss = test(model, target_test_loader, args) 178 | info += ', test_loss {:4f}, test_acc: {:.4f}'.format(test_loss, test_acc) 179 | np_log = np.array(log, dtype=float) 180 | np.savetxt('train_log.csv', np_log, delimiter=',', fmt='%.6f') 181 | if best_acc < test_acc: 182 | best_acc = test_acc 183 | stop = 0 184 | if args.early_stop > 0 and stop >= args.early_stop: 185 | print(info) 186 | break 187 | print(info) 188 | print('Transfer result: {:.4f}'.format(best_acc)) 189 | 190 | 191 | def main(): 192 | parser = get_parser() 193 | args = parser.parse_args() 194 | setattr(args, "device", torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 195 | print(args) 196 | set_random_seed(args.seed) 197 | source_loader, target_train_loader, target_test_loader, n_class = load_data(args) 198 | setattr(args, "n_class", n_class) 199 | if args.epoch_based_training: 200 | setattr(args, "max_iter", args.n_epoch * min(len(source_loader), len(target_train_loader))) 201 | else: 202 | setattr(args, "max_iter", args.n_epoch * args.n_iter_per_epoch) 203 | model = get_model(args) 204 | optimizer = get_optimizer(model, args) 205 | 206 | if args.lr_scheduler: 207 | scheduler = get_scheduler(optimizer, args) 208 | else: 209 | scheduler = None 210 | train(source_loader, target_train_loader, target_test_loader, model, optimizer, scheduler, args) 211 | if __name__ == "__main__": 212 | main() 213 | --------------------------------------------------------------------------------