├── README.md ├── data └── .gitignore ├── main.py ├── models.py ├── params.py ├── preprocess.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Pytoch_DAN 2 | > This is a simple implementation of [Learning Transferable Features with Deep Adaptation 3 | > Networks][1] with pytorch. This paper introduced a simple and effective method for 4 | > accomplishing domian adaptation with MMD loss. According to this paper, 5 | > multi-layer features are adapted with MMD loss. In this paper, model is based 6 | > on AlexNet and tested on several datasets, while this work just utilizes 7 | > LeNet and tests on MNIST and MNIST_M datasets. The original implementation 8 | > in caffe is [here][2]. 9 | 10 | 11 | ### Data 12 | > In this work, MNIST and MNIST_M datasets are used in experiments. MNIST dataset 13 | > can be downloaded with `torchvision.datasets`. MINIST_M dataset can be downloa- 14 | > ded at [Yaroslav Ganin's homepage][3]. Then you can extract the file to your data dire- 15 | > ctory and run the `preprocess.py` to make the directory able to be used with 16 | > `torchvision.datasets.ImageFolder`: 17 | ``` 18 | python preprocess.py 19 | ``` 20 | > If you could not download MNIST_M dataset from [Yaroslav Ganin's homepage][3], you cou- 21 | > ld download it from [MEGA Cloud][4]. Once you download it, then you just need to unzip 22 | > the file to `/data` and the `preprocess.py` should not be used. 23 | 24 | ### Experiments 25 | > You can run `main.py` to implements the MNSIT experiments. This work's results 26 | > are as follows: 27 | 28 | |Method | Target Acc(this work)| 29 | |:----------:|:----------------:| 30 | |Source Only| 0.5189| 31 | |DAN | 0.5829| 32 | 33 | 34 | 35 | [1]:https://arxiv.org/pdf/1502.02791.pdf 36 | [2]:https://github.com/thuml/DAN 37 | [3]:http://yaroslav.ganin.net/ 38 | [4]:https://mega.nz/#!FuIQhYKJ!IVxcmZK1ZH2MZA-7gKkgR4FExPuJl7-m89eDRPVKhF4 39 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | /MNIST 2 | /MNIST_M 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | import time 7 | from sklearn.manifold import TSNE 8 | import matplotlib.pyplot as plt 9 | import pylab 10 | from tensorboardX import SummaryWriter 11 | import torchvision.utils as vutils 12 | import utils 13 | import models 14 | import params 15 | import train, test 16 | 17 | src_train_dataloader = utils.get_train_loader('MNIST') 18 | src_test_dataloader =utils.get_test_loader('MNIST') 19 | tgt_train_dataloader = utils.get_train_loader('MNIST_M') 20 | tgt_test_dataloader = utils.get_test_loader('MNIST_M') 21 | 22 | common_net = models.Extractor() 23 | src_net = models.Classifier() 24 | tgt_net = models.Classifier() 25 | 26 | src_dataiter = iter(src_train_dataloader) 27 | tgt_dataiter = iter(tgt_train_dataloader) 28 | src_imgs, src_labels = next(src_dataiter) 29 | tgt_imgs, tgt_labels = next(tgt_dataiter) 30 | 31 | src_imgs_show = src_imgs[:4] 32 | tgt_imgs_show = tgt_imgs[:4] 33 | 34 | utils.imshow(vutils.make_grid(src_imgs_show)) 35 | utils.imshow(vutils.make_grid(tgt_imgs_show)) 36 | 37 | train_hist = {} 38 | train_hist['Total_loss'] = [] 39 | train_hist['Class_loss'] = [] 40 | train_hist['MMD_loss'] = [] 41 | 42 | test_hist = {} 43 | test_hist['Source Accuracy'] = [] 44 | test_hist['Target Accuracy'] = [] 45 | 46 | if params.use_gpu: 47 | common_net.cuda() 48 | src_net.cuda() 49 | tgt_net.cuda() 50 | 51 | src_features = common_net(Variable(src_imgs.expand(src_imgs.shape[0], 3, 28, 28).cuda())) 52 | tgt_features = common_net(Variable(tgt_imgs.expand(tgt_imgs.shape[0], 3, 28, 28).cuda())) 53 | src_features = src_features.cpu().data.numpy() 54 | tgt_features = tgt_features.cpu().data.numpy() 55 | src_features = TSNE(n_components= 2).fit_transform(src_features) 56 | tgt_features = TSNE(n_components= 2).fit_transform(tgt_features) 57 | 58 | plt.scatter(src_features[:, 0], src_features[:, 1], color = 'r') 59 | plt.scatter(tgt_features[:, 0], tgt_features[:, 1], color = 'b') 60 | plt.title('Non-adapted') 61 | pylab.show() 62 | 63 | optimizer = optim.SGD([{'params': common_net.parameters()}, 64 | {'params': src_net.parameters()}, 65 | {'params': tgt_net.parameters()}], lr= params.lr, momentum= params.momentum) 66 | 67 | criterion = nn.CrossEntropyLoss() 68 | 69 | for epoch in range(params.epochs): 70 | t0 = time.time() 71 | print('Epoch: {}'.format(epoch)) 72 | train.train(common_net, src_net, tgt_net, optimizer, criterion, 73 | epoch, src_train_dataloader, tgt_train_dataloader, train_hist) 74 | t1 = time.time() - t0 75 | print('Time: {:.4f}s'.format(t1)) 76 | test.test(common_net, src_net, src_test_dataloader, tgt_test_dataloader, epoch, test_hist) 77 | 78 | src_features = common_net(Variable(src_imgs.expand(src_imgs.shape[0], 3, 28, 28).cuda())) 79 | tgt_features = common_net(Variable(tgt_imgs.expand(tgt_imgs.shape[0], 3, 28, 28).cuda())) 80 | src_features = src_features.cpu().data.numpy() 81 | tgt_features = tgt_features.cpu().data.numpy() 82 | src_features = TSNE(n_components= 2).fit_transform(src_features) 83 | tgt_features = TSNE(n_components= 2).fit_transform(tgt_features) 84 | 85 | 86 | utils.visulize_loss(train_hist) 87 | utils.visualize_accuracy(test_hist) 88 | plt.scatter(src_features[:, 0], src_features[:, 1], color = 'r') 89 | plt.scatter(tgt_features[:, 0], tgt_features[:, 1], color = 'b') 90 | plt.title('Adapted') 91 | pylab.show() 92 | 93 | 94 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Extractor(nn.Module): 7 | 8 | def __init__(self): 9 | super(Extractor, self).__init__() 10 | self.conv1 = nn.Conv2d(3, 64, kernel_size= 5) 11 | self.bn1 = nn.BatchNorm2d(64) 12 | self.conv2 = nn.Conv2d(64, 50, kernel_size= 5) 13 | self.bn2 = nn.BatchNorm2d(50) 14 | self.conv2_drop = nn.Dropout2d() 15 | self.fc1 = nn.Linear(50 * 4 * 4, 100) 16 | self.bn3 = nn.BatchNorm1d(100) 17 | self.fc2 = nn.Linear(100, 100) 18 | self.bn4 = nn.BatchNorm1d(100) 19 | 20 | def forward(self, input): 21 | x = F.max_pool2d(F.relu((self.bn1(self.conv1(input)))), 2) 22 | x = F.max_pool2d(F.relu((self.conv2_drop(self.bn2(self.conv2(x))))), 2) 23 | x = x.view(-1, 50 * 4 * 4) 24 | x = self.fc1(x) 25 | x = self.bn3(x) 26 | x = self.fc2(x) 27 | x = self.bn4(x) 28 | 29 | return x 30 | 31 | class Classifier(nn.Module): 32 | 33 | def __init__(self): 34 | super(Classifier, self).__init__() 35 | self.fc3 = nn.Linear(100, 10) 36 | 37 | def forward(self, input): 38 | logits = self.fc3(input) 39 | 40 | return logits 41 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | batch_size = 512 2 | use_gpu = True 3 | data_root = './data' 4 | dataset_mean = (0.5, 0.5, 0.5) 5 | dataset_std = (0.5, 0.5, 0.5) 6 | mnist_path = data_root + '/MNIST' 7 | mnistm_path = data_root + '/MNIST_M' 8 | epochs = 1 9 | plot_iter = 10 10 | lr = 0.01 11 | momentum = 0.9 12 | 13 | theta1 = 0.5 14 | theta2 = 0.5 -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | 3 | data_dir = './data/MNIST_M' 4 | train_labels = './data/MNIST_M/mnist_m_train_labels.txt' 5 | test_labels = './data/MNIST_M/mnist_m_test_labels.txt' 6 | train_images = './data/MNIST_M/mnist_m_train' 7 | test_images = './data/MNIST_M/mnist_m_test' 8 | 9 | def mkdirs(path): 10 | train_dir = path + '/' + 'train' 11 | test_dir = path + '/' + 'test' 12 | if not os.path.exists(train_dir): 13 | os.mkdir(train_dir) 14 | if not os.path.exists(test_dir): 15 | os.mkdir(test_dir) 16 | for i in range(0, 10): 17 | if not os.path.exists(train_dir + '/' + str(i)): 18 | os.mkdir(train_dir + '/' + str(i)) 19 | if not os.path.exists(test_dir + '/' + str(i)): 20 | os.mkdir(test_dir + '/' + str(i)) 21 | 22 | def process(labels_path, images_path, data_dir): 23 | with open(labels_path) as f: 24 | for line in f.readlines(): 25 | img = images_path + '/' + line.split()[0] 26 | dir = data_dir + '/' + line.split()[1] 27 | shutil.move(img, dir) 28 | 29 | mkdirs(data_dir) 30 | process(train_labels, train_images, data_dir + '/train') 31 | process(test_labels, test_images, data_dir + '/test') 32 | os.remove(train_images) 33 | os.remove(test_images) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | 5 | import params 6 | import utils 7 | 8 | def test(common_net, src_net, source_dataloader, target_dataloader, epoch, test_hist): 9 | 10 | common_net.eval() 11 | src_net.eval() 12 | 13 | source_correct = 0 14 | target_correct = 0 15 | 16 | for batch_idx, sdata in enumerate(source_dataloader): 17 | input1, label1 = sdata 18 | if params.use_gpu: 19 | input1, label1 = Variable(input1.cuda()), Variable(label1.cuda()) 20 | else: 21 | input1, label1 = Variable(input1), Variable(label1) 22 | 23 | input1 = input1.expand(input1.shape[0], 3, 28, 28) 24 | output1 = src_net(common_net(input1)) 25 | pred1 = output1.data.max(1, keepdim = True)[1] 26 | source_correct += pred1.eq(label1.data.view_as(pred1)).cpu().sum() 27 | 28 | for batch_idx, tdata in enumerate(target_dataloader): 29 | input2, label2 = tdata 30 | if params.use_gpu: 31 | input2, label2 = Variable(input2.cuda()), Variable(label2.cuda()) 32 | else: 33 | input2, label2 = Variable(input2), Variable(label2) 34 | 35 | output2 = src_net(common_net(input2)) 36 | pred2 = output2.data.max(1, keepdim=True)[1] 37 | target_correct += pred2.eq(label2.data.view_as(pred2)).cpu().sum() 38 | 39 | source_accuracy = 100. * source_correct / len(source_dataloader.dataset) 40 | target_accuracy = 100. * target_correct / len(target_dataloader.dataset) 41 | 42 | print('\nSource Accuracy: {}/{} ({:.4f}%)\nTarget Accuracy: {}/{} ({:.4f}%)\n'.format( 43 | source_correct, len(source_dataloader.dataset), source_accuracy, 44 | target_correct, len(target_dataloader.dataset), target_accuracy, 45 | )) 46 | test_hist['Source Accuracy'].append(source_accuracy) 47 | test_hist['Target Accuracy'].append(target_accuracy) 48 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from torch.autograd import Variable 5 | 6 | import utils 7 | import params 8 | 9 | def train(common_net, src_net, tgt_net, optimizer, criterion, epoch, 10 | source_dataloader, target_dataloader, train_hist): 11 | 12 | common_net.train() 13 | src_net.train() 14 | tgt_net.train() 15 | 16 | start_steps = epoch * len(source_dataloader) 17 | total_steps = params.epochs * len(source_dataloader) 18 | 19 | source_iter = iter(source_dataloader) 20 | target_iter = iter(target_dataloader) 21 | 22 | for batch_idx in range(min(len(source_dataloader), len(target_dataloader))): 23 | # get data 24 | sdata = next(source_iter) 25 | tdata = next(target_iter) 26 | 27 | # prepare the data 28 | input1, label1 = sdata 29 | input2, label2 = tdata 30 | if params.use_gpu: 31 | input1, label1 = Variable(input1.cuda()), Variable(label1.cuda()) 32 | input2, label2 = Variable(input2.cuda()), Variable(label2.cuda()) 33 | else: 34 | input1, label1 = Variable(input1), Variable(label1) 35 | input2, label2 = Variable(input2), Variable(label2) 36 | 37 | optimizer.zero_grad() 38 | 39 | input1 = input1.expand(input1.shape[0], 3, 28, 28) 40 | input = torch.cat((input1, input2), 0) 41 | common_feature = common_net(input) 42 | 43 | src_feature, tgt_feature = torch.split(common_feature, int(params.batch_size)) 44 | 45 | src_output = src_net(src_feature) 46 | tgt_output = tgt_net(tgt_feature) 47 | 48 | class_loss = criterion(src_output, label1) 49 | 50 | mmd_loss = utils.mmd_loss(src_feature, tgt_feature) * params.theta1 + \ 51 | utils.mmd_loss(src_output, tgt_output) * params.theta2 52 | 53 | loss = class_loss + mmd_loss 54 | loss.backward() 55 | optimizer.step() 56 | step = epoch * len(target_dataloader) + batch_idx 57 | 58 | 59 | if (batch_idx + 1) % params.plot_iter == 0: 60 | print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}\tClass Loss: {:.6f}\tMMD Loss: {:.6f}'.format( 61 | batch_idx * len(input2), len(target_dataloader.dataset), 62 | 100. * batch_idx / len(target_dataloader), loss.data[0], class_loss.data[0], 63 | mmd_loss.data[0] 64 | )) 65 | train_hist['Total_loss'].append(loss.cpu().data[0]) 66 | train_hist['Class_loss'].append(class_loss.cpu().data[0]) 67 | train_hist['MMD_loss'].append(mmd_loss.cpu().data[0]) 68 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | from functools import partial 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pylab 9 | import params 10 | 11 | def visulize_loss(train_hist): 12 | x = range(len(train_hist['Total_loss'])) 13 | x = [i * params.plot_iter for i in x] 14 | 15 | total_loss = train_hist['Total_loss'] 16 | class_loss = train_hist['Class_loss'] 17 | mmd_loss = train_hist['MMD_loss'] 18 | 19 | plt.plot(x, total_loss, label = 'total loss') 20 | plt.plot(x, class_loss, label = 'class loss') 21 | plt.plot(x, mmd_loss, label = 'mmd loss') 22 | 23 | plt.xlabel('Step') 24 | plt.ylabel('Loss') 25 | 26 | plt.grid(True) 27 | pylab.show() 28 | 29 | def visualize_accuracy(test_hist): 30 | x = range(len(test_hist['Source Accuracy'])) 31 | 32 | source_accuracy = test_hist['Source Accuracy'] 33 | target_accuracy = test_hist['Target Accuracy'] 34 | 35 | plt.plot(x, source_accuracy, label = 'source accuracy') 36 | plt.plot(x, target_accuracy, label = 'target accuracy') 37 | 38 | plt.xlabel('Epoch') 39 | plt.ylabel('Accuracy') 40 | 41 | plt.grid(True) 42 | pylab.show() 43 | 44 | def imshow(img): 45 | img = img / 2 + 0.5 46 | npimg = img.numpy() 47 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 48 | pylab.show() 49 | 50 | def pairwise_distance(x, y): 51 | 52 | if not len(x.shape) == len(y.shape) == 2: 53 | raise ValueError('Both inputs should be matrices.') 54 | 55 | if x.shape[1] != y.shape[1]: 56 | raise ValueError('The number of features should be the same.') 57 | 58 | x = x.view(x.shape[0], x.shape[1], 1) 59 | y = torch.transpose(y, 0, 1) 60 | output = torch.sum((x - y) ** 2, 1) 61 | output = torch.transpose(output, 0, 1) 62 | 63 | return output 64 | 65 | def gaussian_kernel_matrix(x, y, sigmas): 66 | 67 | sigmas = sigmas.view(sigmas.shape[0], 1) 68 | beta = 1. / (2. * sigmas) 69 | dist = pairwise_distance(x, y).contiguous() 70 | dist_ = dist.view(1, -1) 71 | s = torch.matmul(beta, dist_) 72 | 73 | return torch.sum(torch.exp(-s), 0).view_as(dist) 74 | 75 | def maximum_mean_discrepancy(x, y, kernel= gaussian_kernel_matrix): 76 | 77 | cost = torch.mean(kernel(x, x)) 78 | cost += torch.mean(kernel(y, y)) 79 | cost -= 2 * torch.mean(kernel(x, y)) 80 | 81 | return cost 82 | 83 | def mmd_loss(source_features, target_features): 84 | 85 | sigmas = [ 86 | 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 87 | 1e3, 1e4, 1e5, 1e6 88 | ] 89 | if params.use_gpu: 90 | gaussian_kernel = partial( 91 | gaussian_kernel_matrix, sigmas = Variable(torch.cuda.FloatTensor(sigmas)) 92 | ) 93 | else: 94 | gaussian_kernel = partial( 95 | gaussian_kernel_matrix, sigmas = Variable(torch.FloatTensor(sigmas)) 96 | ) 97 | loss_value = maximum_mean_discrepancy(source_features, target_features, kernel= gaussian_kernel) 98 | loss_value = loss_value 99 | 100 | return loss_value 101 | 102 | def get_train_loader(dataset): 103 | """ 104 | Get train dataloader of source domain or target domain 105 | :return: dataloader 106 | """ 107 | if dataset == 'MNIST': 108 | transform = transforms.Compose([ 109 | transforms.ToTensor(), 110 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 111 | ]) 112 | 113 | data = datasets.MNIST(root= params.mnist_path, train= True, transform= transform, 114 | download= True) 115 | 116 | dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True, drop_last= True) 117 | elif dataset == 'MNIST_M': 118 | transform = transforms.Compose([ 119 | transforms.RandomCrop((28)), 120 | transforms.ToTensor(), 121 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 122 | ]) 123 | 124 | data = datasets.ImageFolder(root= params.mnistm_path + '/train', transform= transform) 125 | 126 | dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True, drop_last= True) 127 | else: 128 | raise Exception('There is no dataset named {}'.format(str(dataset))) 129 | 130 | return dataloader 131 | 132 | def get_test_loader(dataset): 133 | """ 134 | Get test dataloader of source domain or target domain 135 | :return: dataloader 136 | """ 137 | if dataset == 'MNIST': 138 | transform = transforms.Compose([ 139 | transforms.ToTensor(), 140 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 141 | ]) 142 | 143 | data = datasets.MNIST(root= params.mnist_path, train= False, transform= transform, 144 | download= True) 145 | 146 | dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True) 147 | elif dataset == 'MNIST_M': 148 | transform = transforms.Compose([ 149 | transforms.CenterCrop((28)), 150 | transforms.ToTensor(), 151 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 152 | ]) 153 | 154 | data = datasets.ImageFolder(root= params.mnistm_path + '/test', transform= transform) 155 | 156 | dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) 157 | else: 158 | raise Exception('There is no dataset named {}'.format(str(dataset))) 159 | 160 | return dataloader --------------------------------------------------------------------------------