├── data ├── __init__.py ├── .gitignore └── SynDig.py ├── models ├── __init__.py └── models.py ├── train ├── __init__.py ├── params.py ├── test.py └── train.py ├── util ├── __init__.py ├── preprocess.py └── utils.py ├── requirements.txt ├── LICENSE ├── README.md └── main.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | /MNIST 3 | /MNIST_M 4 | /SVHN 5 | /SynthDigits 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy>=1.1.0 2 | numpy>=1.14.3 3 | matplotlib>=2.1.0 4 | torch>=0.4.1 5 | Pillow>=5.1.0 6 | scikit_learn>=0.19.2 7 | torchvision>=0.2.1 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 CuthbertCai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /train/params.py: -------------------------------------------------------------------------------- 1 | from models import models 2 | 3 | # utility params 4 | fig_mode = None 5 | embed_plot_epoch=10 6 | 7 | # model params 8 | use_gpu = True 9 | dataset_mean = (0.5, 0.5, 0.5) 10 | dataset_std = (0.5, 0.5, 0.5) 11 | 12 | batch_size = 512 13 | epochs = 1000 14 | gamma = 10 15 | theta = 1 16 | 17 | # path params 18 | data_root = './data' 19 | 20 | mnist_path = data_root + '/MNIST' 21 | mnistm_path = data_root + '/MNIST_M' 22 | svhn_path = data_root + '/SVHN' 23 | syndig_path = data_root + '/SynthDigits' 24 | 25 | save_dir = './experiment' 26 | 27 | 28 | # specific dataset params 29 | extractor_dict = {'MNIST_MNIST_M': models.Extractor(), 30 | 'SVHN_MNIST': models.SVHN_Extractor(), 31 | 'SynDig_SVHN': models.SVHN_Extractor()} 32 | 33 | class_dict = {'MNIST_MNIST_M': models.Class_classifier(), 34 | 'SVHN_MNIST': models.SVHN_Class_classifier(), 35 | 'SynDig_SVHN': models.SVHN_Class_classifier()} 36 | 37 | domain_dict = {'MNIST_MNIST_M': models.Domain_classifier(), 38 | 'SVHN_MNIST': models.SVHN_Domain_classifier(), 39 | 'SynDig_SVHN': models.SVHN_Domain_classifier()} 40 | -------------------------------------------------------------------------------- /util/preprocess.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | 3 | data_dir = './data/MNIST_M' 4 | train_labels = './data/MNIST_M/mnist_m_train_labels.txt' 5 | test_labels = './data/MNIST_M/mnist_m_test_labels.txt' 6 | train_images = './data/MNIST_M/mnist_m_train' 7 | test_images = './data/MNIST_M/mnist_m_test' 8 | 9 | def mkdirs(path): 10 | train_dir = path + '/' + 'train' 11 | test_dir = path + '/' + 'test' 12 | if not os.path.exists(train_dir): 13 | os.mkdir(train_dir) 14 | if not os.path.exists(test_dir): 15 | os.mkdir(test_dir) 16 | for i in range(0, 10): 17 | if not os.path.exists(train_dir + '/' + str(i)): 18 | os.mkdir(train_dir + '/' + str(i)) 19 | if not os.path.exists(test_dir + '/' + str(i)): 20 | os.mkdir(test_dir + '/' + str(i)) 21 | 22 | def process(labels_path, images_path, data_dir): 23 | with open(labels_path) as f: 24 | for line in f.readlines(): 25 | img = images_path + '/' + line.split()[0] 26 | dir = data_dir + '/' + line.split()[1] 27 | shutil.move(img, dir) 28 | 29 | mkdirs(data_dir) 30 | process(train_labels, train_images, data_dir + '/train') 31 | process(test_labels, test_images, data_dir + '/test') 32 | os.remove(train_images) 33 | os.remove(test_images) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytoch_DANN 2 | > This is a implementation of [Domain-Adversarial Training of Neural Networks][1] 3 | > with pytorch. This paper introduced a simple and effective method for accompli- 4 | > shing domian adaptation with SGD with a GRL(Gradient Reveral Layer). According 5 | > to this paper, domain classifier is used to decrease the H-divergence between 6 | > source domain distribution and target domain distribution. For the tensorflow 7 | > version, you can see [tf-dann][2]. 8 | 9 | ### requirements 10 | > python3.6.2 11 | > `pip install -r requirements.txt` 12 | 13 | ### Data 14 | > In this work, MNIST and MNIST_M datasets are used in experiments. MNIST dataset 15 | > can be downloaded with `torchvision.datasets`. MINIST_M dataset can be downloa- 16 | > ded at [Yaroslav Ganin's homepage][3]. Then you can extract the file to your data dire- 17 | > ctory and run the `preprocess.py` to make the directory able to be used with 18 | > `torchvision.datasets.ImageFolder`: 19 | ``` 20 | python preprocess.py 21 | ``` 22 | 23 | ### Experiments 24 | > You can run `main.py` to implements the MNSIT experiments for the paper with the 25 | > similar model and same paramenters.The paper's results and this work's results a- 26 | > re as follows: 27 | 28 | |Method | Target Acc(paper) | Target Acc(this work)| 29 | |:----------:|:-----------------:|:---------------------:| 30 | |Source Only| 0.5225 | 0.5189| 31 | |DANN | 0.7666 | 0.7600|`````` 32 | 33 | > Experiment on SVHN->MNIST is added in this project, but some bugs are not fixed. 34 | > The accuracies of source and target domains are not good at the same time. 35 | 36 | > Experiment on SynDig->SVHN is added. 37 | 38 | 39 | 40 | [1]:https://arxiv.org/pdf/1505.07818.pdf 41 | [2]:https://github.com/pumpikano/tf-dann 42 | [3]:http://yaroslav.ganin.net/ 43 | -------------------------------------------------------------------------------- /data/SynDig.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import zipfile 5 | import os 6 | import urllib 7 | import os.path 8 | import numpy as np 9 | 10 | 11 | class SynDig(data.Dataset): 12 | 13 | url = 'https://doc-08-a8-docs.googleusercontent.com/docs/securesc/4gco78h4r5v7n2eq50hcumr89oar2vtn/' \ 14 | 'j64h6ekj56csgpnthf6revr2h1sogunh/1513591200000/02005382952228186512/07954859324473388693/' \ 15 | '0B9Z4d7lAwbnTSVR1dEFSRUFxOUU?e=download&nonce=i254fkf8136em&user=' \ 16 | '07954859324473388693&hash=cbcagg6svrku8ot6c9e27m3saorf50m1' 17 | zipname = 'SynDigits.zip' 18 | split_list = { 19 | 'train': ["synth_train_32x32.mat"], 20 | 'train_small': ["synth_train_32x32_small.mat"], 21 | 'test': ["synth_test_32x32.mat",""], 22 | 'test_small': ["synth_test_32x32_small.mat"] 23 | } 24 | 25 | def __init__(self, root, split= 'train', transform= None, 26 | target_transform= None, download= False): 27 | self.root = root 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | self.split = split 31 | 32 | if self.split not in self.split_list: 33 | raise ValueError('Wrong split entered! Please use split="train" ' 34 | 'or split="train_small or split="test" or split="test_small"') 35 | self.filename = self.split_list[self.split][0] 36 | 37 | if download: 38 | self.download() 39 | 40 | if not self._check_exists(): 41 | raise RuntimeError('Dataset not found.' + 42 | 'You can use download=True to download it.') 43 | 44 | 45 | import scipy.io as sio 46 | 47 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 48 | self.data = loaded_mat['X'] 49 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 50 | 51 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 52 | 53 | def __getitem__(self, index): 54 | 55 | img, target = self.data[index], self.labels[index] 56 | 57 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 58 | 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | 65 | return img, target 66 | 67 | def __len__(self): 68 | return len(self.data) 69 | 70 | def _check_exists(self): 71 | return os.path.exists(os.path.join(self.root, self.filename)) 72 | 73 | 74 | def download(self): 75 | """Download dataset.""" 76 | filename = os.path.join(self.root, self.zipname) 77 | dirname = os.path.dirname(filename) 78 | if not os.path.isdir(dirname): 79 | os.makedirs(dirname) 80 | if os.path.isfile(filename): 81 | return 82 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 83 | urllib.request.urlretrieve(self.url, filename) 84 | file = zipfile.ZipFile(filename) 85 | file.extractall() 86 | file.close() 87 | print("[DONE]") 88 | return -------------------------------------------------------------------------------- /train/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the model with target domain 3 | """ 4 | import torch 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | from train import params 9 | 10 | 11 | def test(feature_extractor, class_classifier, domain_classifier, source_dataloader, target_dataloader): 12 | """ 13 | Test the performance of the model 14 | :param feature_extractor: network used to extract feature from target samples 15 | :param class_classifier: network used to predict labels 16 | :param domain_classifier: network used to predict domain 17 | :param source_dataloader: test dataloader of source domain 18 | :param target_dataloader: test dataloader of target domain 19 | :return: None 20 | """ 21 | # setup the network 22 | feature_extractor.eval() 23 | class_classifier.eval() 24 | domain_classifier.eval() 25 | source_correct = 0.0 26 | target_correct = 0.0 27 | domain_correct = 0.0 28 | tgt_correct = 0.0 29 | src_correct = 0.0 30 | 31 | for batch_idx, sdata in enumerate(source_dataloader): 32 | # setup hyperparameters 33 | p = float(batch_idx) / len(source_dataloader) 34 | constant = 2. / (1. + np.exp(-10 * p)) - 1. 35 | 36 | input1, label1 = sdata 37 | if params.use_gpu: 38 | input1, label1 = Variable(input1.cuda()), Variable(label1.cuda()) 39 | src_labels = Variable(torch.zeros((input1.size()[0])).type(torch.LongTensor).cuda()) 40 | else: 41 | input1, label1 = Variable(input1), Variable(label1) 42 | src_labels = Variable(torch.zeros((input1.size()[0])).type(torch.LongTensor)) 43 | 44 | output1 = class_classifier(feature_extractor(input1)) 45 | pred1 = output1.data.max(1, keepdim = True)[1] 46 | source_correct += pred1.eq(label1.data.view_as(pred1)).cpu().sum() 47 | 48 | src_preds = domain_classifier(feature_extractor(input1), constant) 49 | src_preds = src_preds.data.max(1, keepdim= True)[1] 50 | src_correct += src_preds.eq(src_labels.data.view_as(src_preds)).cpu().sum() 51 | 52 | for batch_idx, tdata in enumerate(target_dataloader): 53 | # setup hyperparameters 54 | p = float(batch_idx) / len(source_dataloader) 55 | constant = 2. / (1. + np.exp(-10 * p)) - 1 56 | 57 | input2, label2 = tdata 58 | if params.use_gpu: 59 | input2, label2 = Variable(input2.cuda()), Variable(label2.cuda()) 60 | tgt_labels = Variable(torch.ones((input2.size()[0])).type(torch.LongTensor).cuda()) 61 | else: 62 | input2, label2 = Variable(input2), Variable(label2) 63 | tgt_labels = Variable(torch.ones((input2.size()[0])).type(torch.LongTensor)) 64 | 65 | output2 = class_classifier(feature_extractor(input2)) 66 | pred2 = output2.data.max(1, keepdim=True)[1] 67 | target_correct += pred2.eq(label2.data.view_as(pred2)).cpu().sum() 68 | 69 | tgt_preds = domain_classifier(feature_extractor(input2), constant) 70 | tgt_preds = tgt_preds.data.max(1, keepdim=True)[1] 71 | tgt_correct += tgt_preds.eq(tgt_labels.data.view_as(tgt_preds)).cpu().sum() 72 | 73 | domain_correct = tgt_correct + src_correct 74 | 75 | print('\nSource Accuracy: {}/{} ({:.4f}%)\nTarget Accuracy: {}/{} ({:.4f}%)\n' 76 | 'Domain Accuracy: {}/{} ({:.4f}%)\n'. 77 | format( 78 | source_correct, len(source_dataloader.dataset), 100. * float(source_correct) / len(source_dataloader.dataset), 79 | target_correct, len(target_dataloader.dataset), 100. * float(target_correct) / len(target_dataloader.dataset), 80 | domain_correct, len(source_dataloader.dataset) + len(target_dataloader.dataset), 81 | 100. * float(domain_correct) / (len(source_dataloader.dataset) + len(target_dataloader.dataset)) 82 | )) -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | 7 | class GradReverse(torch.autograd.Function): 8 | """ 9 | Extension of grad reverse layer 10 | """ 11 | @staticmethod 12 | def forward(ctx, x, constant): 13 | ctx.constant = constant 14 | return x.view_as(x) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | grad_output = grad_output.neg() * ctx.constant 19 | return grad_output, None 20 | 21 | def grad_reverse(x, constant): 22 | return GradReverse.apply(x, constant) 23 | 24 | class Extractor(nn.Module): 25 | 26 | def __init__(self): 27 | super(Extractor, self).__init__() 28 | self.conv1 = nn.Conv2d(3, 32, kernel_size=5) 29 | self.conv2 = nn.Conv2d(32, 48, kernel_size=5) 30 | # self.conv1 = nn.Conv2d(3, 64, kernel_size= 5) 31 | # self.bn1 = nn.BatchNorm2d(64) 32 | # self.conv2 = nn.Conv2d(64, 50, kernel_size= 5) 33 | # self.bn2 = nn.BatchNorm2d(50) 34 | self.conv2_drop = nn.Dropout2d() 35 | 36 | def forward(self, input): 37 | input = input.expand(input.data.shape[0], 3, 28, 28) 38 | # x = F.relu(F.max_pool2d(self.bn1(self.conv1(input)), 2)) 39 | # x = F.relu(F.max_pool2d(self.conv2_drop(self.bn2(self.conv2(x))), 2)) 40 | # x = x.view(-1, 50 * 4 * 4) 41 | x = F.relu(F.max_pool2d(self.conv1(input), 2)) 42 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 43 | x = x.view(-1, 48 * 4 * 4) 44 | 45 | return x 46 | 47 | class Class_classifier(nn.Module): 48 | 49 | def __init__(self): 50 | super(Class_classifier, self).__init__() 51 | # self.fc1 = nn.Linear(50 * 4 * 4, 100) 52 | # self.bn1 = nn.BatchNorm1d(100) 53 | # self.fc2 = nn.Linear(100, 100) 54 | # self.bn2 = nn.BatchNorm1d(100) 55 | # self.fc3 = nn.Linear(100, 10) 56 | self.fc1 = nn.Linear(48 * 4 * 4, 100) 57 | self.fc2 = nn.Linear(100, 100) 58 | self.fc3 = nn.Linear(100, 10) 59 | 60 | def forward(self, input): 61 | # logits = F.relu(self.bn1(self.fc1(input))) 62 | # logits = self.fc2(F.dropout(logits)) 63 | # logits = F.relu(self.bn2(logits)) 64 | # logits = self.fc3(logits) 65 | logits = F.relu(self.fc1(input)) 66 | logits = self.fc2(F.dropout(logits)) 67 | logits = F.relu(logits) 68 | logits = self.fc3(logits) 69 | 70 | return F.log_softmax(logits, 1) 71 | 72 | class Domain_classifier(nn.Module): 73 | 74 | def __init__(self): 75 | super(Domain_classifier, self).__init__() 76 | # self.fc1 = nn.Linear(50 * 4 * 4, 100) 77 | # self.bn1 = nn.BatchNorm1d(100) 78 | # self.fc2 = nn.Linear(100, 2) 79 | self.fc1 = nn.Linear(48 * 4 * 4, 100) 80 | self.fc2 = nn.Linear(100, 2) 81 | 82 | def forward(self, input, constant): 83 | input = GradReverse.grad_reverse(input, constant) 84 | # logits = F.relu(self.bn1(self.fc1(input))) 85 | # logits = F.log_softmax(self.fc2(logits), 1) 86 | logits = F.relu(self.fc1(input)) 87 | logits = F.log_softmax(self.fc2(logits), 1) 88 | 89 | return logits 90 | 91 | 92 | 93 | class SVHN_Extractor(nn.Module): 94 | 95 | def __init__(self): 96 | super(SVHN_Extractor, self).__init__() 97 | self.conv1 = nn.Conv2d(3, 64, kernel_size= 5) 98 | self.bn1 = nn.BatchNorm2d(64) 99 | self.conv2 = nn.Conv2d(64, 64, kernel_size= 5) 100 | self.bn2 = nn.BatchNorm2d(64) 101 | self.conv3 = nn.Conv2d(64, 128, kernel_size= 5, padding= 2) 102 | self.bn3 = nn.BatchNorm2d(128) 103 | self.conv3_drop = nn.Dropout2d() 104 | self.init_params() 105 | 106 | def init_params(self): 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | init.kaiming_normal_(m.weight, mode= 'fan_out') 111 | if m.bias is not None: 112 | init.constant_(m.bias, 0) 113 | if isinstance(m, nn.BatchNorm2d): 114 | init.constant_(m.weight, 1) 115 | init.constant_(m.bias, 0) 116 | 117 | def forward(self, input): 118 | input = input.expand(input.data.shape[0], 3, 28, 28) 119 | x = F.relu(self.bn1(self.conv1(input))) 120 | x = F.max_pool2d(x, 3, 2) 121 | x = F.relu(self.bn2(self.conv2(x))) 122 | x = F.max_pool2d(x, 3, 2) 123 | x = F.relu(self.bn3(self.conv3(x))) 124 | x = self.conv3_drop(x) 125 | 126 | return x.view(-1, 128 * 3 * 3) 127 | 128 | class SVHN_Class_classifier(nn.Module): 129 | 130 | def __init__(self): 131 | super(SVHN_Class_classifier, self).__init__() 132 | self.fc1 = nn.Linear(128 * 3 * 3, 3072) 133 | self.bn1 = nn.BatchNorm1d(3072) 134 | self.fc2 = nn.Linear(3072, 2048) 135 | self.bn2 = nn.BatchNorm1d(2048) 136 | self.fc3 = nn.Linear(2048, 10) 137 | 138 | def forward(self, input): 139 | logits = F.relu(self.bn1(self.fc1(input))) 140 | logits = F.dropout(logits) 141 | logits = F.relu(self.bn2(self.fc2(logits))) 142 | logits = self.fc3(logits) 143 | 144 | return F.log_softmax(logits, 1) 145 | 146 | class SVHN_Domain_classifier(nn.Module): 147 | 148 | def __init__(self): 149 | super(SVHN_Domain_classifier, self).__init__() 150 | self.fc1 = nn.Linear(128 * 3 * 3, 1024) 151 | self.bn1 = nn.BatchNorm1d(1024) 152 | self.fc2 = nn.Linear(1024, 1024) 153 | self.bn2 = nn.BatchNorm1d(1024) 154 | self.fc3 = nn.Linear(1024, 2) 155 | 156 | def forward(self, input, constant): 157 | input = GradReverse.grad_reverse(input, constant) 158 | logits = F.relu(self.bn1(self.fc1(input))) 159 | logits = F.dropout(logits) 160 | logits = F.relu(self.bn2(self.fc2(logits))) 161 | logits = F.dropout(logits) 162 | logits = self.fc3(logits) 163 | 164 | return F.log_softmax(logits, 1) 165 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | 5 | from train import params 6 | from util import utils 7 | 8 | import torch.optim as optim 9 | 10 | 11 | def train(training_mode, feature_extractor, class_classifier, domain_classifier, class_criterion, domain_criterion, 12 | source_dataloader, target_dataloader, optimizer, epoch): 13 | """ 14 | Execute target domain adaptation 15 | :param training_mode: 16 | :param feature_extractor: 17 | :param class_classifier: 18 | :param domain_classifier: 19 | :param class_criterion: 20 | :param domain_criterion: 21 | :param source_dataloader: 22 | :param target_dataloader: 23 | :param optimizer: 24 | :return: 25 | """ 26 | 27 | # setup models 28 | feature_extractor.train() 29 | class_classifier.train() 30 | domain_classifier.train() 31 | 32 | # steps 33 | start_steps = epoch * len(source_dataloader) 34 | total_steps = params.epochs * len(source_dataloader) 35 | 36 | for batch_idx, (sdata, tdata) in enumerate(zip(source_dataloader, target_dataloader)): 37 | 38 | if training_mode == 'dann': 39 | # setup hyperparameters 40 | p = float(batch_idx + start_steps) / total_steps 41 | constant = 2. / (1. + np.exp(-params.gamma * p)) - 1 42 | 43 | # prepare the data 44 | input1, label1 = sdata 45 | input2, label2 = tdata 46 | size = min((input1.shape[0], input2.shape[0])) 47 | input1, label1 = input1[0:size, :, :, :], label1[0:size] 48 | input2, label2 = input2[0:size, :, :, :], label2[0:size] 49 | if params.use_gpu: 50 | input1, label1 = Variable(input1.cuda()), Variable(label1.cuda()) 51 | input2, label2 = Variable(input2.cuda()), Variable(label2.cuda()) 52 | else: 53 | input1, label1 = Variable(input1), Variable(label1) 54 | input2, label2 = Variable(input2), Variable(label2) 55 | 56 | # setup optimizer 57 | optimizer = utils.optimizer_scheduler(optimizer, p) 58 | optimizer.zero_grad() 59 | 60 | # prepare domain labels 61 | if params.use_gpu: 62 | source_labels = Variable(torch.zeros((input1.size()[0])).type(torch.LongTensor).cuda()) 63 | target_labels = Variable(torch.ones((input2.size()[0])).type(torch.LongTensor).cuda()) 64 | else: 65 | source_labels = Variable(torch.zeros((input1.size()[0])).type(torch.LongTensor)) 66 | target_labels = Variable(torch.ones((input2.size()[0])).type(torch.LongTensor)) 67 | 68 | # compute the output of source domain and target domain 69 | src_feature = feature_extractor(input1) 70 | tgt_feature = feature_extractor(input2) 71 | 72 | # compute the class loss of src_feature 73 | class_preds = class_classifier(src_feature) 74 | class_loss = class_criterion(class_preds, label1) 75 | 76 | # compute the domain loss of src_feature and target_feature 77 | tgt_preds = domain_classifier(tgt_feature, constant) 78 | src_preds = domain_classifier(src_feature, constant) 79 | tgt_loss = domain_criterion(tgt_preds, target_labels) 80 | src_loss = domain_criterion(src_preds, source_labels) 81 | domain_loss = tgt_loss + src_loss 82 | 83 | loss = class_loss + params.theta * domain_loss 84 | loss.backward() 85 | optimizer.step() 86 | 87 | # print loss 88 | if (batch_idx + 1) % 10 == 0: 89 | print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}\tClass Loss: {:.6f}\tDomain Loss: {:.6f}'.format( 90 | batch_idx * len(input2), len(target_dataloader.dataset), 91 | 100. * batch_idx / len(target_dataloader), loss.item(), class_loss.item(), 92 | domain_loss.item() 93 | )) 94 | 95 | 96 | elif training_mode == 'source': 97 | # prepare the data 98 | input1, label1 = sdata 99 | size = input1.shape[0] 100 | input1, label1 = input1[0:size, :, :, :], label1[0:size] 101 | 102 | if params.use_gpu: 103 | input1, label1 = Variable(input1.cuda()), Variable(label1.cuda()) 104 | else: 105 | input1, label1 = Variable(input1), Variable(label1) 106 | 107 | # setup optimizer 108 | optimizer = optim.SGD(list(feature_extractor.parameters())+list(class_classifier.parameters()), lr=0.01, momentum=0.9) 109 | optimizer.zero_grad() 110 | 111 | # compute the output of source domain and target domain 112 | src_feature = feature_extractor(input1) 113 | 114 | # compute the class loss of src_feature 115 | class_preds = class_classifier(src_feature) 116 | class_loss = class_criterion(class_preds, label1) 117 | 118 | class_loss.backward() 119 | optimizer.step() 120 | 121 | # print loss 122 | if (batch_idx + 1) % 10 == 0: 123 | print('[{}/{} ({:.0f}%)]\tClass Loss: {:.6f}'.format( 124 | batch_idx * len(input1), len(source_dataloader.dataset), 125 | 100. * batch_idx / len(source_dataloader), class_loss.item() 126 | )) 127 | 128 | elif training_mode == 'target': 129 | # prepare the data 130 | input2, label2 = tdata 131 | size = input2.shape[0] 132 | input2, label2 = input2[0:size, :, :, :], label2[0:size] 133 | if params.use_gpu: 134 | input2, label2 = Variable(input2.cuda()), Variable(label2.cuda()) 135 | else: 136 | input2, label2 = Variable(input2), Variable(label2) 137 | 138 | # setup optimizer 139 | optimizer = optim.SGD(list(feature_extractor.parameters()) + list(class_classifier.parameters()), lr=0.01, 140 | momentum=0.9) 141 | optimizer.zero_grad() 142 | 143 | # compute the output of source domain and target domain 144 | tgt_feature = feature_extractor(input2) 145 | 146 | # compute the class loss of src_feature 147 | class_preds = class_classifier(tgt_feature) 148 | class_loss = class_criterion(class_preds, label2) 149 | 150 | class_loss.backward() 151 | optimizer.step() 152 | 153 | # print loss 154 | if (batch_idx + 1) % 10 == 0: 155 | print('[{}/{} ({:.0f}%)]\tClass Loss: {:.6f}'.format( 156 | batch_idx * len(input2), len(target_dataloader.dataset), 157 | 100. * batch_idx / len(target_dataloader), class_loss.item() 158 | )) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main script for models 3 | """ 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | import numpy as np 11 | 12 | from models import models 13 | from train import test, train, params 14 | from util import utils 15 | from sklearn.manifold import TSNE 16 | 17 | import argparse, sys, os 18 | 19 | import torch 20 | from torch.autograd import Variable 21 | 22 | import time 23 | 24 | 25 | 26 | def visualizePerformance(feature_extractor, class_classifier, domain_classifier, src_test_dataloader, 27 | tgt_test_dataloader, num_of_samples=None, imgName=None): 28 | """ 29 | Evaluate the performance of dann and source only by visualization. 30 | 31 | :param feature_extractor: network used to extract feature from target samples 32 | :param class_classifier: network used to predict labels 33 | :param domain_classifier: network used to predict domain 34 | :param source_dataloader: test dataloader of source domain 35 | :param target_dataloader: test dataloader of target domain 36 | :param num_of_samples: the number of samples (from train and test respectively) for t-sne 37 | :param imgName: the name of saving image 38 | 39 | :return: 40 | """ 41 | 42 | # Setup the network 43 | feature_extractor.eval() 44 | class_classifier.eval() 45 | domain_classifier.eval() 46 | 47 | # Randomly select samples from source domain and target domain. 48 | if num_of_samples is None: 49 | num_of_samples = params.batch_size 50 | else: 51 | assert len(src_test_dataloader) * num_of_samples, \ 52 | 'The number of samples can not bigger than dataset.' # NOT PRECISELY COMPUTATION 53 | 54 | # Collect source data. 55 | s_images, s_labels, s_tags = [], [], [] 56 | for batch in src_test_dataloader: 57 | images, labels = batch 58 | 59 | if params.use_gpu: 60 | s_images.append(images.cuda()) 61 | else: 62 | s_images.append(images) 63 | s_labels.append(labels) 64 | 65 | s_tags.append(torch.zeros((labels.size()[0])).type(torch.LongTensor)) 66 | 67 | if len(s_images * params.batch_size) > num_of_samples: 68 | break 69 | 70 | s_images, s_labels, s_tags = torch.cat(s_images)[:num_of_samples], \ 71 | torch.cat(s_labels)[:num_of_samples], torch.cat(s_tags)[:num_of_samples] 72 | 73 | 74 | # Collect test data. 75 | t_images, t_labels, t_tags = [], [], [] 76 | for batch in tgt_test_dataloader: 77 | images, labels = batch 78 | 79 | if params.use_gpu: 80 | t_images.append(images.cuda()) 81 | else: 82 | t_images.append(images) 83 | t_labels.append(labels) 84 | 85 | t_tags.append(torch.ones((labels.size()[0])).type(torch.LongTensor)) 86 | 87 | if len(t_images * params.batch_size) > num_of_samples: 88 | break 89 | 90 | t_images, t_labels, t_tags = torch.cat(t_images)[:num_of_samples], \ 91 | torch.cat(t_labels)[:num_of_samples], torch.cat(t_tags)[:num_of_samples] 92 | 93 | # Compute the embedding of target domain. 94 | embedding1 = feature_extractor(s_images) 95 | embedding2 = feature_extractor(t_images) 96 | 97 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000) 98 | 99 | if params.use_gpu: 100 | dann_tsne = tsne.fit_transform(np.concatenate((embedding1.cpu().detach().numpy(), 101 | embedding2.cpu().detach().numpy()))) 102 | else: 103 | dann_tsne = tsne.fit_transform(np.concatenate((embedding1.detach().numpy(), 104 | embedding2.detach().numpy()))) 105 | 106 | 107 | utils.plot_embedding(dann_tsne, np.concatenate((s_labels, t_labels)), 108 | np.concatenate((s_tags, t_tags)), 'Domain Adaptation', imgName) 109 | 110 | 111 | 112 | 113 | def main(args): 114 | 115 | # Set global parameters. 116 | params.fig_mode = args.fig_mode 117 | params.epochs = args.max_epoch 118 | params.training_mode = args.training_mode 119 | params.source_domain = args.source_domain 120 | params.target_domain = args.target_domain 121 | if params.embed_plot_epoch is None: 122 | params.embed_plot_epoch = args.embed_plot_epoch 123 | params.lr = args.lr 124 | 125 | 126 | if args.save_dir is not None: 127 | params.save_dir = args.save_dir 128 | else: 129 | print('Figures will be saved in ./experiment folder.') 130 | 131 | # prepare the source data and target data 132 | 133 | src_train_dataloader = utils.get_train_loader(params.source_domain) 134 | src_test_dataloader = utils.get_test_loader(params.source_domain) 135 | tgt_train_dataloader = utils.get_train_loader(params.target_domain) 136 | tgt_test_dataloader = utils.get_test_loader(params.target_domain) 137 | 138 | if params.fig_mode is not None: 139 | print('Images from training on source domain:') 140 | 141 | utils.displayImages(src_train_dataloader, imgName='source') 142 | 143 | print('Images from test on target domain:') 144 | utils.displayImages(tgt_test_dataloader, imgName='target') 145 | 146 | # init models 147 | model_index = params.source_domain + '_' + params.target_domain 148 | feature_extractor = params.extractor_dict[model_index] 149 | class_classifier = params.class_dict[model_index] 150 | domain_classifier = params.domain_dict[model_index] 151 | 152 | if params.use_gpu: 153 | feature_extractor.cuda() 154 | class_classifier.cuda() 155 | domain_classifier.cuda() 156 | 157 | # init criterions 158 | class_criterion = nn.NLLLoss() 159 | domain_criterion = nn.NLLLoss() 160 | 161 | # init optimizer 162 | optimizer = optim.SGD([{'params': feature_extractor.parameters()}, 163 | {'params': class_classifier.parameters()}, 164 | {'params': domain_classifier.parameters()}], lr= params.lr, momentum= 0.9) 165 | 166 | for epoch in range(params.epochs): 167 | print('Epoch: {}'.format(epoch)) 168 | train.train(args.training_mode, feature_extractor, class_classifier, domain_classifier, class_criterion, domain_criterion, 169 | src_train_dataloader, tgt_train_dataloader, optimizer, epoch) 170 | test.test(feature_extractor, class_classifier, domain_classifier, src_test_dataloader, tgt_test_dataloader) 171 | 172 | 173 | # Plot embeddings periodically. 174 | if epoch % params.embed_plot_epoch == 0 and params.fig_mode is not None: 175 | visualizePerformance(feature_extractor, class_classifier, domain_classifier, src_test_dataloader, 176 | tgt_test_dataloader, imgName='embedding_' + str(epoch)) 177 | 178 | 179 | 180 | def parse_arguments(argv): 181 | """Command line parse.""" 182 | parser = argparse.ArgumentParser() 183 | 184 | parser.add_argument('--source_domain', type= str, default= 'MNIST', help= 'Choose source domain.') 185 | 186 | parser.add_argument('--target_domain', type= str, default= 'MNIST_M', help = 'Choose target domain.') 187 | 188 | parser.add_argument('--fig_mode', type=str, default=None, help='Plot experiment figures.') 189 | 190 | parser.add_argument('--save_dir', type=str, default=None, help='Path to save plotted images.') 191 | 192 | parser.add_argument('--training_mode', type=str, default='dann', help='Choose a mode to train the model.') 193 | 194 | parser.add_argument('--max_epoch', type=int, default=100, help='The max number of epochs.') 195 | 196 | parser.add_argument('--embed_plot_epoch', type= int, default=100, help= 'Epoch number of plotting embeddings.') 197 | 198 | parser.add_argument('--lr', type= float, default= 0.01, help= 'Learning rate.') 199 | 200 | return parser.parse_args() 201 | 202 | 203 | 204 | if __name__ == '__main__': 205 | main(parse_arguments(sys.argv[1:])) 206 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | 6 | from train import params 7 | from sklearn.manifold import TSNE 8 | 9 | 10 | import matplotlib.pyplot as plt 11 | plt.switch_backend('agg') 12 | 13 | import numpy as np 14 | import os, time 15 | from data import SynDig 16 | 17 | 18 | def get_train_loader(dataset): 19 | """ 20 | Get train dataloader of source domain or target domain 21 | :return: dataloader 22 | """ 23 | if dataset == 'MNIST': 24 | transform = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 27 | ]) 28 | 29 | data = datasets.MNIST(root= params.mnist_path, train= True, transform= transform, 30 | download= True) 31 | 32 | dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True) 33 | 34 | 35 | elif dataset == 'MNIST_M': 36 | transform = transforms.Compose([ 37 | transforms.RandomCrop((28)), 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 40 | ]) 41 | 42 | data = datasets.ImageFolder(root=params.mnistm_path + '/train', transform= transform) 43 | 44 | dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) 45 | 46 | elif dataset == 'SVHN': 47 | transform = transforms.Compose([ 48 | transforms.RandomCrop((28)), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std) 51 | ]) 52 | 53 | data1 = datasets.SVHN(root=params.svhn_path, split='train', transform=transform, download=True) 54 | data2 = datasets.SVHN(root= params.svhn_path, split= 'extra', transform = transform, download= True) 55 | 56 | data = torch.utils.data.ConcatDataset((data1, data2)) 57 | 58 | dataloader = DataLoader(dataset=data, batch_size=params.batch_size, shuffle=True) 59 | elif dataset == 'SynDig': 60 | transform = transforms.Compose([ 61 | transforms.RandomCrop((28)), 62 | transforms.ToTensor(), 63 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 64 | ]) 65 | 66 | data = SynDig.SynDig(root= params.syndig_path, split= 'train', transform= transform, download= False) 67 | 68 | dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) 69 | 70 | 71 | else: 72 | raise Exception('There is no dataset named {}'.format(str(dataset))) 73 | 74 | return dataloader 75 | 76 | 77 | 78 | def get_test_loader(dataset): 79 | """ 80 | Get test dataloader of source domain or target domain 81 | :return: dataloader 82 | """ 83 | if dataset == 'MNIST': 84 | transform = transforms.Compose([ 85 | transforms.ToTensor(), 86 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 87 | ]) 88 | 89 | data = datasets.MNIST(root= params.mnist_path, train= False, transform= transform, 90 | download= True) 91 | 92 | dataloader = DataLoader(dataset= data, batch_size= 1, shuffle= False) 93 | elif dataset == 'MNIST_M': 94 | transform = transforms.Compose([ 95 | # transforms.RandomCrop((28)), 96 | transforms.CenterCrop((28)), 97 | transforms.ToTensor(), 98 | transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) 99 | ]) 100 | 101 | data = datasets.ImageFolder(root=params.mnistm_path + '/test', transform= transform) 102 | 103 | dataloader = DataLoader(dataset = data, batch_size= 1, shuffle= False) 104 | elif dataset == 'SVHN': 105 | transform = transforms.Compose([ 106 | transforms.CenterCrop((28)), 107 | transforms.ToTensor(), 108 | transforms.Normalize(mean= params.dataset_mean, std = params.dataset_std) 109 | ]) 110 | 111 | data = datasets.SVHN(root= params.svhn_path, split= 'test', transform = transform, download= True) 112 | 113 | dataloader = DataLoader(dataset = data, batch_size= 1, shuffle= False) 114 | elif dataset == 'SynDig': 115 | transform = transforms.Compose([ 116 | transforms.CenterCrop((28)), 117 | transforms.ToTensor(), 118 | transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std) 119 | ]) 120 | 121 | data = SynDig.SynDig(root= params.syndig_path, split= 'test', transform= transform, download= False) 122 | 123 | dataloader = DataLoader(dataset= data, batch_size= 1, shuffle= False) 124 | else: 125 | raise Exception('There is no dataset named {}'.format(str(dataset))) 126 | 127 | return dataloader 128 | 129 | 130 | 131 | def optimizer_scheduler(optimizer, p): 132 | """ 133 | Adjust the learning rate of optimizer 134 | :param optimizer: optimizer for updating parameters 135 | :param p: a variable for adjusting learning rate 136 | :return: optimizer 137 | """ 138 | for param_group in optimizer.param_groups: 139 | param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75 140 | 141 | return optimizer 142 | 143 | 144 | 145 | def displayImages(dataloader, length=8, imgName=None): 146 | """ 147 | Randomly sample some images and display 148 | :param dataloader: maybe trainloader or testloader 149 | :param length: number of images to be displayed 150 | :param imgName: the name of saving image 151 | :return: 152 | """ 153 | if params.fig_mode is None: 154 | return 155 | 156 | # randomly sample some images. 157 | dataiter = iter(dataloader) 158 | images, labels = dataiter.next() 159 | 160 | # process images so they can be displayed. 161 | images = images[:length] 162 | 163 | images = torchvision.utils.make_grid(images).numpy() 164 | images = images/2 + 0.5 165 | images = np.transpose(images, (1, 2, 0)) 166 | 167 | 168 | if params.fig_mode == 'display': 169 | 170 | plt.imshow(images) 171 | plt.show() 172 | 173 | if params.fig_mode == 'save': 174 | # Check if folder exist, otherwise need to create it. 175 | folder = os.path.abspath(params.save_dir) 176 | 177 | if not os.path.exists(folder): 178 | os.makedirs(folder) 179 | 180 | if imgName is None: 181 | imgName = 'displayImages' + str(int(time.time())) 182 | 183 | 184 | # Check extension in case. 185 | if not (imgName.endswith('.jpg') or imgName.endswith('.png') or imgName.endswith('.jpeg')): 186 | imgName = os.path.join(folder, imgName + '.jpg') 187 | 188 | plt.imsave(imgName, images) 189 | plt.close() 190 | 191 | # print labels 192 | print(' '.join('%5s' % labels[j].item() for j in range(length))) 193 | 194 | 195 | 196 | 197 | def plot_embedding(X, y, d, title=None, imgName=None): 198 | """ 199 | Plot an embedding X with the class label y colored by the domain d. 200 | 201 | :param X: embedding 202 | :param y: label 203 | :param d: domain 204 | :param title: title on the figure 205 | :param imgName: the name of saving image 206 | 207 | :return: 208 | """ 209 | if params.fig_mode is None: 210 | return 211 | 212 | # normalization 213 | x_min, x_max = np.min(X, 0), np.max(X, 0) 214 | X = (X - x_min) / (x_max - x_min) 215 | 216 | # Plot colors numbers 217 | plt.figure(figsize=(10,10)) 218 | ax = plt.subplot(111) 219 | 220 | for i in range(X.shape[0]): 221 | # plot colored number 222 | plt.text(X[i, 0], X[i, 1], str(y[i]), 223 | color=plt.cm.bwr(d[i]/1.), 224 | fontdict={'weight': 'bold', 'size': 9}) 225 | 226 | plt.xticks([]), plt.yticks([]) 227 | 228 | # If title is not given, we assign training_mode to the title. 229 | if title is not None: 230 | plt.title(title) 231 | else: 232 | plt.title(params.training_mode) 233 | 234 | if params.fig_mode == 'display': 235 | # Directly display if no folder provided. 236 | plt.show() 237 | 238 | if params.fig_mode == 'save': 239 | # Check if folder exist, otherwise need to create it. 240 | folder = os.path.abspath(params.save_dir) 241 | 242 | if not os.path.exists(folder): 243 | os.makedirs(folder) 244 | 245 | if imgName is None: 246 | imgName = 'plot_embedding' + str(int(time.time())) 247 | 248 | # Check extension in case. 249 | if not (imgName.endswith('.jpg') or imgName.endswith('.png') or imgName.endswith('.jpeg')): 250 | imgName = os.path.join(folder, imgName + '.jpg') 251 | 252 | print('Saving ' + imgName + ' ...') 253 | plt.savefig(imgName) 254 | plt.close() --------------------------------------------------------------------------------