├── DM-ADA ├── README.md ├── main.py ├── models.py ├── trainer_mixed.py └── utils.py ├── Mixup_RevGrad ├── README.md ├── RevGrad.py ├── RevGrad_mixup.py ├── data_loader.py └── models.py ├── README.md └── docs ├── model.png └── visda_results.png /DM-ADA/README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Domain Adaptation with Domain Mixup 2 | 3 | This is the implementation of proposed DM-ADA method. 4 | 5 | ## Requirements 6 | 7 | * Python 2.7 8 | * PyTorch 0.4.0 / 0.4.1 9 | 10 | ## Prerequisites 11 | 12 | Download MNIST, SVHN and USPS datasets, and prepare the datasets with following structure: 13 | ``` 14 | /Dataset_Root 15 | └── mnist 16 | ├── trainset 17 | │ ├── subfolders for 0 ~ 9 18 | ├── testset 19 | ├── svhn 20 | ├── usps 21 | ``` 22 | 23 | ## Training 24 | 25 | * Train the Source-only baseline (validation on target domain is conducted for each epoch): 26 | ``` 27 | python main.py --dataroot --method sourceonly --source_dataset --target_dataset 28 | ``` 29 | 30 | * Train the DM-ADA model (validation on target domain is conducted for each epoch): 31 | ``` 32 | python main.py --dataroot --method DM-ADA --source_dataset --target_dataset 33 | ``` 34 | -------------------------------------------------------------------------------- /DM-ADA/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | import argparse 5 | import os, sys 6 | import random 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | import trainer_mixed 14 | 15 | def main(): 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataroot', required=True, help='path to source dataset') 19 | parser.add_argument('--checkpoint', type=str, default=None, help='pretrained model') 20 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 21 | parser.add_argument('--batchSize', type=int, default=100, help='input batch size') 22 | parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network') 23 | parser.add_argument('--nz', type=int, default=512, help='size of the latent z vector') 24 | parser.add_argument('--ngf', type=int, default=64, help='Number of filters to use in the generator network') 25 | parser.add_argument('--ndf', type=int, default=64, help='Number of filters to use in the discriminator network') 26 | parser.add_argument('--nepochs', type=int, default=50, help='number of epochs to train for') 27 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate, default=0.0005') 28 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use, -1 for CPU training') 29 | parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints') 30 | parser.add_argument('--method', default='DM-ADA', help='Method to train| DM-ADA, sourceonly') 31 | parser.add_argument('--manualSeed', type=int, default = 400, help='manual seed') 32 | parser.add_argument('--KL_weight', type=float, default = 1.0, help='weight for KL divergence') 33 | parser.add_argument('--adv_weight', type=float, default = 0.1, help='weight for adv loss') 34 | parser.add_argument('--lrd', type=float, default=0.0001, help='learning rate decay, default=0.0001') 35 | parser.add_argument('--gamma', type=float, default = 0.3, help='multiplicative factor for target adv. loss') 36 | parser.add_argument('--delta', type=float, default = 0.3, help='multiplicative factor for mix adv. loss') 37 | parser.add_argument('--source_dataset', default='svhn', help='name of the source dataset') 38 | parser.add_argument('--target_dataset', default='mnist', help='name of the target dataset') 39 | parser.add_argument('--alpha', type=float, default = 2.0, help='the hyperparameter for beta distribution') 40 | parser.add_argument('--clip_thr', type = float, default = 0.1, help='the threshold of mixup ratio clipping') 41 | 42 | opt = parser.parse_args() 43 | print(opt) 44 | 45 | # Creating log directory 46 | try: 47 | os.makedirs(opt.outf) 48 | except OSError: 49 | pass 50 | try: 51 | os.makedirs(os.path.join(opt.outf, 'source_generation')) 52 | except OSError: 53 | pass 54 | try: 55 | os.makedirs(os.path.join(opt.outf, 'target_generation')) 56 | except OSError: 57 | pass 58 | try: 59 | os.makedirs(os.path.join(opt.outf, 'models')) 60 | except OSError: 61 | pass 62 | try: 63 | os.makedirs(os.path.join(opt.outf, 'mix_images')) 64 | except OSError: 65 | pass 66 | try: 67 | os.makedirs(os.path.join(opt.outf, 'mix_generation')) 68 | except OSError: 69 | pass 70 | 71 | # Setting random seed 72 | if opt.manualSeed is None: 73 | opt.manualSeed = random.randint(1, 10000) 74 | print("Random Seed: ", opt.manualSeed) 75 | random.seed(opt.manualSeed) 76 | np.random.seed(opt.manualSeed) 77 | torch.manual_seed(opt.manualSeed) 78 | if opt.gpu>=0: 79 | torch.cuda.manual_seed_all(opt.manualSeed) 80 | 81 | # GPU/CPU flags 82 | cudnn.benchmark = True 83 | if torch.cuda.is_available() and opt.gpu == -1: 84 | print("WARNING: You have a CUDA device, so you should probably run with --gpu [gpu id]") 85 | if opt.gpu>=0: 86 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu) 87 | 88 | # Creating data loaders 89 | mean = np.array([0.44, 0.44, 0.44]) 90 | std = np.array([0.19, 0.19, 0.19]) 91 | 92 | # define the directory for train and validation 93 | source_train_root = os.path.join(opt.dataroot, opt.source_dataset, 'trainset') 94 | source_val_root = os.path.join(opt.dataroot, opt.source_dataset, 'testset') 95 | target_train_root = os.path.join(opt.dataroot, opt.target_dataset, 'trainset') 96 | target_val_root = os.path.join(opt.dataroot, opt.target_dataset, 'testset') 97 | 98 | # define the preprocess operation 99 | resize_shape = (opt.imageSize, opt.imageSize) 100 | transform_source = transforms.Compose( 101 | [transforms.Resize(resize_shape), transforms.ToTensor(), transforms.Normalize(mean, std)]) 102 | transform_target = transforms.Compose( 103 | [transforms.Resize(resize_shape), transforms.ToTensor(), transforms.Normalize(mean, std)]) 104 | 105 | # define dataloaders 106 | source_train = dset.ImageFolder(root=source_train_root, transform=transform_source) 107 | source_val = dset.ImageFolder(root=source_val_root, transform=transform_source) 108 | target_train = dset.ImageFolder(root=target_train_root, transform=transform_target) 109 | target_val = dset.ImageFolder(root=target_val_root, transform=transform_target) 110 | 111 | source_trainloader = torch.utils.data.DataLoader(source_train, batch_size=opt.batchSize, shuffle=True, 112 | num_workers=opt.workers, drop_last=True) 113 | source_valloader = torch.utils.data.DataLoader(source_val, batch_size=opt.batchSize, shuffle=False, 114 | num_workers=opt.workers, drop_last=False) 115 | target_trainloader = torch.utils.data.DataLoader(target_train, batch_size=opt.batchSize, shuffle=True, 116 | num_workers=opt.workers, drop_last=True) 117 | target_valloader = torch.utils.data.DataLoader(target_val, batch_size=opt.batchSize, shuffle=False, 118 | num_workers=opt.workers, drop_last=False) 119 | 120 | nclasses = len(source_train.classes) 121 | 122 | # Training 123 | if opt.method == 'DM-ADA': 124 | DM_ADA_trainer = trainer_mixed.DM_ADA(opt, nclasses, mean, std, source_trainloader, 125 | source_valloader, target_trainloader, target_valloader) 126 | DM_ADA_trainer.train() 127 | elif opt.method == 'sourceonly': 128 | sourceonly_trainer = trainer_mixed.Sourceonly(opt, nclasses, source_trainloader, target_valloader) 129 | sourceonly_trainer.train() 130 | else: 131 | raise ValueError('method argument should be DM-ADA or sourceonly') 132 | 133 | if __name__ == '__main__': 134 | main() 135 | 136 | -------------------------------------------------------------------------------- /DM-ADA/models.py: -------------------------------------------------------------------------------- 1 | import torch, utils 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | """ 7 | Generator network 8 | """ 9 | class _netG(nn.Module): 10 | def __init__(self, opt, nclasses): 11 | super(_netG, self).__init__() 12 | 13 | self.ndim = opt.ndf*4 14 | self.ngf = opt.ngf 15 | self.nz = opt.nz 16 | self.gpu = opt.gpu 17 | self.nclasses = nclasses 18 | 19 | self.main = nn.Sequential( 20 | nn.ConvTranspose2d(self.ndim + self.nz + nclasses + 1, self.ngf*8, 2, 1, 0, bias=False), 21 | nn.BatchNorm2d(self.ngf*8), 22 | nn.ReLU(True), 23 | 24 | nn.ConvTranspose2d(self.ngf*8, self.ngf*4, 4, 2, 1, bias=False), 25 | nn.BatchNorm2d(self.ngf*4), 26 | nn.ReLU(True), 27 | 28 | nn.ConvTranspose2d(self.ngf*4, self.ngf*2, 4, 2, 1, bias=False), 29 | nn.BatchNorm2d(self.ngf*2), 30 | nn.ReLU(True), 31 | 32 | nn.ConvTranspose2d(self.ngf*2, self.ngf, 4, 2, 1, bias=False), 33 | nn.BatchNorm2d(self.ngf), 34 | nn.ReLU(True), 35 | 36 | nn.ConvTranspose2d(self.ngf, 3, 4, 2, 1, bias=False), 37 | nn.Tanh() 38 | ) 39 | 40 | def forward(self, input): 41 | batchSize = input.size()[0] 42 | input = input.view(-1, self.ndim+self.nclasses+1, 1, 1) 43 | noise = torch.FloatTensor(batchSize, self.nz, 1, 1).normal_(0, 1) 44 | if self.gpu>=0: 45 | noise = noise.cuda() 46 | noisev = Variable(noise) 47 | output = self.main(torch.cat((input, noisev),1)) 48 | 49 | return output 50 | 51 | """ 52 | Discriminator network 53 | """ 54 | class _netD(nn.Module): 55 | def __init__(self, opt, nclasses): 56 | super(_netD, self).__init__() 57 | 58 | self.ndf = opt.ndf 59 | self.feature = nn.Sequential( 60 | nn.Conv2d(3, self.ndf, 3, 1, 1), 61 | nn.BatchNorm2d(self.ndf), 62 | nn.LeakyReLU(0.2, inplace=True), 63 | nn.MaxPool2d(2,2), 64 | 65 | nn.Conv2d(self.ndf, self.ndf*2, 3, 1, 1), 66 | nn.BatchNorm2d(self.ndf*2), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | nn.MaxPool2d(2,2), 69 | 70 | nn.Conv2d(self.ndf*2, self.ndf*4, 3, 1, 1), 71 | nn.BatchNorm2d(self.ndf*4), 72 | nn.LeakyReLU(0.2, inplace=True), 73 | nn.MaxPool2d(2,2), 74 | 75 | nn.Conv2d(self.ndf*4, self.ndf*2, 3, 1, 1), 76 | nn.BatchNorm2d(self.ndf*2), 77 | nn.LeakyReLU(0.2, inplace=True), 78 | nn.MaxPool2d(4,4) 79 | ) 80 | 81 | self.classifier_c = nn.Sequential(nn.Linear(self.ndf*2, nclasses)) 82 | self.classifier_s = nn.Sequential( 83 | nn.Linear(self.ndf, 1), 84 | nn.Sigmoid()) 85 | self.classifier_t = nn.Sequential(nn.Linear(self.ndf*2, self.ndf)) 86 | 87 | def forward(self, input): 88 | output = self.feature(input) 89 | output_c = self.classifier_c(output.view(-1, self.ndf * 2)) 90 | output_t = self.classifier_t(output.view(-1, self.ndf * 2)) 91 | output_s = self.classifier_s(output_t) 92 | output_s = output_s.view(-1) 93 | return output_s, output_c, output_t 94 | 95 | """ 96 | Feature extraction network 97 | """ 98 | class _netF(nn.Module): 99 | def __init__(self, opt): 100 | super(_netF, self).__init__() 101 | 102 | 103 | self.ndf = opt.ndf 104 | self.nz = opt.nz 105 | self.gpu = opt.gpu 106 | 107 | self.feature = nn.Sequential( 108 | nn.Conv2d(3, self.ndf, 5, 1, 0), 109 | nn.ReLU(inplace=True), 110 | nn.MaxPool2d(2, 2), 111 | 112 | nn.Conv2d(self.ndf, self.ndf, 5, 1, 0), 113 | nn.ReLU(inplace=True), 114 | nn.MaxPool2d(2, 2), 115 | 116 | nn.Conv2d(self.ndf, self.ndf*2, 5, 1,0), 117 | nn.ReLU(inplace=True) 118 | ) 119 | 120 | self.mean = nn.Sequential(nn.Linear(self.ndf*2, self.ndf*2)) 121 | self.std = nn.Sequential(nn.Linear(self.ndf*2, self.ndf*2)) 122 | 123 | def forward(self, input): 124 | batchSize = input.size()[0] 125 | output = self.feature(input) 126 | 127 | mean_vector = self.mean(output.view(-1, 2*self.ndf)) 128 | std_vector = self.std(output.view(-1, 2*self.ndf)) 129 | if self.gpu>=0: 130 | mean_vector = mean_vector.cuda() 131 | std_vector = std_vector.cuda() 132 | 133 | return output.view(-1, 2*self.ndf), mean_vector, std_vector 134 | 135 | """ 136 | Classifier network 137 | """ 138 | class _netC(nn.Module): 139 | def __init__(self, opt, nclasses): 140 | super(_netC, self).__init__() 141 | self.ndf = opt.ndf 142 | self.main = nn.Sequential( 143 | nn.Linear(4*self.ndf, 2*self.ndf), 144 | nn.ReLU(inplace=True), 145 | nn.Linear(2*self.ndf, nclasses), 146 | ) 147 | self.soft = nn.Sequential(nn.Sigmoid()) 148 | 149 | def forward(self, input): 150 | output_logit = self.main(input) 151 | output = self.soft(output_logit) 152 | return output_logit, output 153 | 154 | 155 | """ 156 | Feature extraction network for Sourceonly 157 | """ 158 | class netF(nn.Module): 159 | def __init__(self, opt): 160 | super(netF, self).__init__() 161 | 162 | self.ndf = opt.ndf 163 | self.feature = nn.Sequential( 164 | nn.Conv2d(3, self.ndf, 5, 1, 0), 165 | nn.ReLU(inplace=True), 166 | nn.MaxPool2d(2, 2), 167 | 168 | nn.Conv2d(self.ndf, self.ndf, 5, 1, 0), 169 | nn.ReLU(inplace=True), 170 | nn.MaxPool2d(2, 2), 171 | 172 | nn.Conv2d(self.ndf, self.ndf * 2, 5, 1, 0), 173 | nn.ReLU(inplace=True) 174 | ) 175 | 176 | def forward(self, input): 177 | output = self.feature(input) 178 | return output.view(-1, 2 * self.ndf) 179 | 180 | 181 | """ 182 | Classifier network for Sourceonly 183 | """ 184 | class netC(nn.Module): 185 | def __init__(self, opt, nclasses): 186 | super(netC, self).__init__() 187 | self.ndf = opt.ndf 188 | self.main = nn.Sequential( 189 | nn.Linear(2 * self.ndf, 2 * self.ndf), 190 | nn.ReLU(inplace=True), 191 | nn.Linear(2 * self.ndf, nclasses), 192 | ) 193 | 194 | def forward(self, input): 195 | output = self.main(input) 196 | return output -------------------------------------------------------------------------------- /DM-ADA/trainer_mixed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable, Function 5 | import torch.optim as optim 6 | import torchvision.utils as vutils 7 | import itertools, datetime 8 | import numpy as np 9 | import models 10 | import utils 11 | import random 12 | import os, sys 13 | 14 | 15 | class DM_ADA(object): 16 | 17 | def __init__(self, opt, nclasses, mean, std, source_trainloader, source_valloader, target_trainloader, 18 | target_valloader): 19 | 20 | self.source_trainloader = source_trainloader 21 | self.source_valloader = source_valloader 22 | self.target_trainloader = target_trainloader 23 | self.target_valloader = target_valloader 24 | self.opt = opt 25 | self.mean = mean 26 | self.std = std 27 | self.best_val = 0 28 | 29 | # Defining networks and optimizers 30 | self.nclasses = nclasses 31 | self.netG = models._netG(opt, nclasses) 32 | self.netD = models._netD(opt, nclasses) 33 | self.netF = models._netF(opt) 34 | self.netC = models._netC(opt, nclasses) 35 | 36 | # Weight initialization 37 | self.netG.apply(utils.weights_init) 38 | self.netD.apply(utils.weights_init) 39 | self.netF.apply(utils.weights_init) 40 | self.netC.apply(utils.weights_init) 41 | 42 | # Defining loss criterions 43 | self.criterion_c = nn.CrossEntropyLoss() 44 | self.criterion_s = nn.BCELoss() 45 | 46 | if opt.gpu >= 0: 47 | self.netD.cuda() 48 | self.netG.cuda() 49 | self.netF.cuda() 50 | self.netC.cuda() 51 | self.criterion_c.cuda() 52 | self.criterion_s.cuda() 53 | 54 | # Defining optimizers 55 | self.optimizerD = optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(0.8, 0.999)) 56 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(0.8, 0.999)) 57 | self.optimizerF = optim.Adam(self.netF.parameters(), lr=opt.lr, betas=(0.8, 0.999)) 58 | self.optimizerC = optim.Adam(self.netC.parameters(), lr=opt.lr, betas=(0.8, 0.999)) 59 | 60 | # Other variables 61 | self.real_label_val = 1 62 | self.fake_label_val = 0 63 | 64 | """ 65 | Validation function 66 | """ 67 | 68 | def validate(self, epoch): 69 | 70 | self.netF.eval() 71 | self.netC.eval() 72 | total = 0 73 | correct = 0 74 | 75 | # Testing the model 76 | for i, datas in enumerate(self.target_valloader): 77 | inputs, labels = datas 78 | inputv, labelv = Variable(inputs.cuda()), Variable(labels.cuda()) 79 | 80 | embedding, mean, std = self.netF(inputv) 81 | mean_std = torch.cat((mean, std), 1) 82 | outC_logit, _ = self.netC(mean_std) 83 | _, predicted = torch.max(outC_logit.data, 1) 84 | total += labels.size(0) 85 | correct += ((predicted == labels.cuda()).sum()) 86 | 87 | val_acc = 100 * float(correct) / total 88 | 89 | # Saving checkpoints 90 | if val_acc > self.best_val: 91 | self.best_val = val_acc 92 | torch.save(self.netF.state_dict(), '%s/models/model_best_netF.pth' % (self.opt.outf)) 93 | torch.save(self.netC.state_dict(), '%s/models/model_best_netC.pth' % (self.opt.outf)) 94 | torch.save(self.netD.state_dict(), '%s/models/model_best_netD.pth' % (self.opt.outf)) 95 | torch.save(self.netG.state_dict(), '%s/models/model_best_netG.pth' % (self.opt.outf)) 96 | 97 | # Print the validation information 98 | print('%s| Epoch: %d, Correct/Total: %d / %d, Val Accuracy: %f, Best Accuracy: %f %%\n' \ 99 | % (datetime.datetime.now(), epoch, correct, total, val_acc, self.best_val)) 100 | 101 | """ 102 | Train function 103 | """ 104 | 105 | def train(self): 106 | 107 | curr_iter = 0 108 | 109 | reallabel = torch.FloatTensor(self.opt.batchSize).fill_(self.real_label_val) 110 | fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(self.fake_label_val) 111 | if self.opt.gpu >= 0: 112 | reallabel = reallabel.cuda() 113 | fakelabel = fakelabel.cuda() 114 | reallabelv = Variable(reallabel) 115 | fakelabelv = Variable(fakelabel) 116 | 117 | for epoch in range(self.opt.nepochs): 118 | 119 | self.netG.train() 120 | self.netF.train() 121 | self.netC.train() 122 | self.netD.train() 123 | 124 | for i, (datas, datat) in enumerate(itertools.izip(self.source_trainloader, self.target_trainloader)): 125 | 126 | ########################### 127 | # Forming input variables 128 | ########################### 129 | 130 | src_inputs, src_labels = datas 131 | tgt_inputs, __ = datat 132 | src_inputs_unnorm = (((src_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2 133 | tgt_inputs_unnorm = (((tgt_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2 134 | 135 | # Creating one hot vector 136 | labels_onehot = np.zeros((self.opt.batchSize, self.nclasses + 1), dtype=np.float32) 137 | for num in range(self.opt.batchSize): 138 | labels_onehot[num, src_labels[num]] = 1 139 | src_labels_onehot = torch.from_numpy(labels_onehot) 140 | 141 | labels_onehot = np.zeros((self.opt.batchSize, self.nclasses + 1), dtype=np.float32) 142 | for num in range(self.opt.batchSize): 143 | labels_onehot[num, self.nclasses] = 1 144 | tgt_labels_onehot = torch.from_numpy(labels_onehot) 145 | 146 | # feed variables to gpu 147 | if self.opt.gpu >= 0: 148 | src_inputs, src_labels = src_inputs.cuda(), src_labels.cuda() 149 | src_inputs_unnorm = src_inputs_unnorm.cuda() 150 | tgt_inputs_unnorm = tgt_inputs_unnorm.cuda() 151 | tgt_inputs = tgt_inputs.cuda() 152 | src_labels_onehot = src_labels_onehot.cuda() 153 | tgt_labels_onehot = tgt_labels_onehot.cuda() 154 | 155 | # Wrapping in variable 156 | src_inputsv, src_labelsv = Variable(src_inputs), Variable(src_labels) 157 | src_inputs_unnormv = Variable(src_inputs_unnorm) 158 | tgt_inputsv = Variable(tgt_inputs) 159 | tgt_inputs_unnormv = Variable(tgt_inputs_unnorm) 160 | src_labels_onehotv = Variable(src_labels_onehot) 161 | tgt_labels_onehotv = Variable(tgt_labels_onehot) 162 | 163 | ########################### 164 | # Updates 165 | ########################### 166 | 167 | # Mix source and target domain images 168 | mix_ratio = np.random.beta(self.opt.alpha, self.opt.alpha) 169 | mix_ratio = round(mix_ratio, 2) 170 | # clip the mixup_ratio 171 | if (mix_ratio >= 0.5 and mix_ratio < (0.5 + self.opt.clip_thr)): 172 | mix_ratio = 0.5 + self.opt.clip_thr 173 | if (mix_ratio > (0.5 - self.opt.clip_thr) and mix_ratio < 0.5): 174 | mix_ratio = 0.5 - self.opt.clip_thr 175 | 176 | # Define labels for mixed images 177 | mix_label = torch.FloatTensor(self.opt.batchSize).fill_(mix_ratio) 178 | if self.opt.gpu >= 0: 179 | mix_label = mix_label.cuda() 180 | mix_labelv = Variable(mix_label) 181 | 182 | mix_samples = mix_ratio * src_inputs_unnormv + (1 - mix_ratio) * tgt_inputs_unnormv 183 | 184 | # Define the label for mixed input 185 | labels_onehot = np.zeros((self.opt.batchSize, self.nclasses + 1), dtype=np.float32) 186 | for num in range(self.opt.batchSize): 187 | labels_onehot[num, src_labels[num]] = mix_ratio 188 | labels_onehot[num, self.nclasses] = 1.0 - mix_ratio 189 | mix_labels_onehot = torch.from_numpy(labels_onehot) 190 | 191 | if self.opt.gpu >= 0: 192 | mix_labels_onehot = mix_labels_onehot.cuda() 193 | mix_labels_onehotv = Variable(mix_labels_onehot) 194 | 195 | # Generating images for both domains (add mixed images) 196 | 197 | src_emb, src_mn, src_sd = self.netF(src_inputsv) 198 | tgt_emb, tgt_mn, tgt_sd = self.netF(tgt_inputsv) 199 | 200 | # Generate mean and std for mixed samples 201 | mix_mn = src_mn * mix_ratio + tgt_mn * (1.0 - mix_ratio) 202 | mix_sd = src_sd * mix_ratio + tgt_sd * (1.0 - mix_ratio) 203 | 204 | src_mn_sd = torch.cat((src_mn, src_sd), 1) 205 | outC_src_logit, outC_src = self.netC(src_mn_sd) 206 | 207 | src_emb_cat = torch.cat((src_mn, src_sd, src_labels_onehotv), 1) 208 | src_gen = self.netG(src_emb_cat) 209 | 210 | tgt_emb_cat = torch.cat((tgt_mn, tgt_sd, tgt_labels_onehotv), 1) 211 | tgt_gen = self.netG(tgt_emb_cat) 212 | 213 | mix_emb_cat = torch.cat((mix_mn, mix_sd, mix_labels_onehotv), 1) 214 | mix_gen = self.netG(mix_emb_cat) 215 | 216 | # Updating D network 217 | 218 | self.netD.zero_grad() 219 | 220 | src_realoutputD_s, src_realoutputD_c, src_realoutputD_t = self.netD(src_inputs_unnormv) 221 | errD_src_real_s = self.criterion_s(src_realoutputD_s, reallabelv) 222 | errD_src_real_c = self.criterion_c(src_realoutputD_c, src_labelsv) 223 | 224 | src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen) 225 | errD_src_fake_s = self.criterion_s(src_fakeoutputD_s, fakelabelv) 226 | 227 | tgt_realoutputD_s, tgt_realoutputD_c, tgt_realoutputD_t = self.netD(tgt_inputs_unnormv) 228 | tgt_fakeoutputD_s, tgt_fakeoutputD_c, _ = self.netD(tgt_gen) 229 | errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s, fakelabelv) 230 | 231 | mix_s, _, mix_t = self.netD(mix_samples) 232 | if (mix_ratio > 0.5): 233 | tmp_margin = 2 * mix_ratio - 1. 234 | errD_mix_t = F.triplet_margin_loss(mix_t, src_realoutputD_t, tgt_realoutputD_t, margin=tmp_margin) 235 | else: 236 | tmp_margin = 1. - 2 * mix_ratio 237 | errD_mix_t = F.triplet_margin_loss(mix_t, tgt_realoutputD_t, src_realoutputD_t, margin=tmp_margin) 238 | errD_mix_s = self.criterion_s(mix_s, mix_labelv) 239 | errD_mix = errD_mix_s + errD_mix_t 240 | 241 | mix_gen_s, _, _ = self.netD(mix_gen) 242 | errD_mix_gen = self.criterion_s(mix_gen_s, fakelabelv) 243 | 244 | errD = errD_src_real_c + errD_src_real_s + errD_src_fake_s + errD_tgt_fake_s + errD_mix + errD_mix_gen 245 | errD.backward(retain_graph=True) 246 | self.optimizerD.step() 247 | 248 | # Updating G network 249 | 250 | self.netG.zero_grad() 251 | 252 | src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen) 253 | errG_src_c = self.criterion_c(src_fakeoutputD_c, src_labelsv) 254 | errG_src_s = self.criterion_s(src_fakeoutputD_s, reallabelv) 255 | 256 | mix_gen_s, _, _ = self.netD(mix_gen) 257 | errG_mix_gen_s = self.criterion_s(mix_gen_s, reallabelv) 258 | 259 | errG = errG_src_c + errG_src_s + errG_mix_gen_s 260 | errG.backward(retain_graph=True) 261 | self.optimizerG.step() 262 | 263 | # Updating C network 264 | 265 | self.netC.zero_grad() 266 | errC = self.criterion_c(outC_src_logit, src_labelsv) 267 | errC.backward(retain_graph=True) 268 | self.optimizerC.step() 269 | 270 | # Updating F network 271 | 272 | self.netF.zero_grad() 273 | err_KL_src = torch.mean(0.5 * torch.sum(torch.exp(src_sd) + src_mn ** 2 - 1. - src_sd, 1)) 274 | err_KL_tgt = torch.mean(0.5 * torch.sum(torch.exp(tgt_sd) + tgt_mn ** 2 - 1. - tgt_sd, 1)) 275 | err_KL = (err_KL_src + err_KL_tgt) * (self.opt.KL_weight) 276 | 277 | errF_fromC = self.criterion_c(outC_src_logit, src_labelsv) 278 | 279 | src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen) 280 | errF_src_fromD = self.criterion_c(src_fakeoutputD_c, src_labelsv) * (self.opt.adv_weight) 281 | 282 | tgt_fakeoutputD_s, tgt_fakeoutputD_c, _ = self.netD(tgt_gen) 283 | errF_tgt_fromD = self.criterion_s(tgt_fakeoutputD_s, reallabelv) * ( 284 | self.opt.adv_weight * self.opt.gamma) 285 | 286 | mix_gen_s, _, _ = self.netD(mix_gen) 287 | errF_mix_fromD = self.criterion_s(mix_gen_s, reallabelv) * (self.opt.adv_weight * self.opt.delta) 288 | 289 | errF = err_KL + errF_fromC + errF_src_fromD + errF_tgt_fromD + errF_mix_fromD 290 | errF.backward() 291 | self.optimizerF.step() 292 | 293 | curr_iter += 1 294 | 295 | # print training information 296 | if ((i + 1) % 50 == 0): 297 | text_format = 'epoch: {}, iteration: {}, errD: {}, errG: {}, ' \ 298 | + 'errC: {}, errF: {}' 299 | train_text = text_format.format(epoch + 1, i + 1, \ 300 | errD.item(), errG.item(), errC.item(), errF.item()) 301 | print(train_text) 302 | 303 | # Visualization 304 | if i == 1: 305 | vutils.save_image((src_gen.data / 2) + 0.5, 306 | '%s/source_generation/source_gen_%d.png' % (self.opt.outf, epoch)) 307 | vutils.save_image((tgt_gen.data / 2) + 0.5, 308 | '%s/target_generation/target_gen_%d.png' % (self.opt.outf, epoch)) 309 | vutils.save_image((mix_gen.data / 2) + 0.5, 310 | '%s/mix_generation/mix_gen_%d.png' % (self.opt.outf, epoch)) 311 | vutils.save_image((mix_samples.data / 2) + 0.5, 312 | '%s/mix_images/mix_samples_%d.png' % (self.opt.outf, epoch)) 313 | 314 | # Learning rate scheduling 315 | if self.opt.lrd: 316 | self.optimizerD = utils.exp_lr_scheduler(self.optimizerD, epoch, self.opt.lr, self.opt.lrd, 317 | curr_iter) 318 | self.optimizerF = utils.exp_lr_scheduler(self.optimizerF, epoch, self.opt.lr, self.opt.lrd, 319 | curr_iter) 320 | self.optimizerC = utils.exp_lr_scheduler(self.optimizerC, epoch, self.opt.lr, self.opt.lrd, 321 | curr_iter) 322 | 323 | # Validate every epoch 324 | self.validate(epoch + 1) 325 | 326 | 327 | class Sourceonly(object): 328 | 329 | def __init__(self, opt, nclasses, source_trainloader, target_valloader): 330 | 331 | self.source_trainloader = source_trainloader 332 | self.target_valloader = target_valloader 333 | self.opt = opt 334 | self.best_val = 0 335 | 336 | # Defining networks and optimizers 337 | self.nclasses = nclasses 338 | self.netF = models.netF(opt) 339 | self.netC = models.netC(opt, nclasses) 340 | 341 | # Weight initialization 342 | self.netF.apply(utils.weights_init) 343 | self.netC.apply(utils.weights_init) 344 | 345 | # Defining loss criterions 346 | self.criterion = nn.CrossEntropyLoss() 347 | 348 | if opt.gpu >= 0: 349 | self.netF.cuda() 350 | self.netC.cuda() 351 | self.criterion.cuda() 352 | 353 | # Defining optimizers 354 | self.optimizerF = optim.Adam(self.netF.parameters(), lr=opt.lr, betas=(0.8, 0.999)) 355 | self.optimizerC = optim.Adam(self.netC.parameters(), lr=opt.lr, betas=(0.8, 0.999)) 356 | 357 | """ 358 | Validation function 359 | """ 360 | 361 | def validate(self, epoch): 362 | 363 | self.netF.eval() 364 | self.netC.eval() 365 | total = 0 366 | correct = 0 367 | 368 | # Testing the model 369 | for i, datas in enumerate(self.target_valloader): 370 | inputs, labels = datas 371 | inputv, labelv = Variable(inputs.cuda()), Variable(labels.cuda()) 372 | 373 | outC = self.netC(self.netF(inputv)) 374 | _, predicted = torch.max(outC.data, 1) 375 | total += labels.size(0) 376 | correct += ((predicted == labels.cuda()).sum()) 377 | 378 | val_acc = 100 * float(correct) / total 379 | 380 | # Saving checkpoints 381 | if val_acc > self.best_val: 382 | self.best_val = val_acc 383 | torch.save(self.netF.state_dict(), '%s/models/model_best_netF_sourceonly.pth' % (self.opt.outf)) 384 | torch.save(self.netC.state_dict(), '%s/models/model_best_netC_sourceonly.pth' % (self.opt.outf)) 385 | 386 | print('%s| Epoch: %d, Val Accuracy: %f, Best Accuracy: %f %%\n' % ( 387 | datetime.datetime.now(), epoch, val_acc, self.best_val)) 388 | 389 | """ 390 | Train function 391 | """ 392 | 393 | def train(self): 394 | 395 | curr_iter = 0 396 | for epoch in range(self.opt.nepochs): 397 | 398 | self.netF.train() 399 | self.netC.train() 400 | 401 | for i, datas in enumerate(self.source_trainloader): 402 | 403 | ########################### 404 | # Forming input variables 405 | ########################### 406 | 407 | src_inputs, src_labels = datas 408 | if self.opt.gpu >= 0: 409 | src_inputs, src_labels = src_inputs.cuda(), src_labels.cuda() 410 | src_inputsv, src_labelsv = Variable(src_inputs), Variable(src_labels) 411 | 412 | ########################### 413 | # Updates 414 | ########################### 415 | 416 | self.netC.zero_grad() 417 | self.netF.zero_grad() 418 | outC = self.netC(self.netF(src_inputsv)) 419 | loss = self.criterion(outC, src_labelsv) 420 | loss.backward() 421 | self.optimizerC.step() 422 | self.optimizerF.step() 423 | 424 | curr_iter += 1 425 | 426 | # print training information 427 | if ((i + 1) % 50 == 0): 428 | text_format = 'epoch: {}, iteration: {}, errC: {}' 429 | train_text = text_format.format(epoch + 1, i + 1, loss.item()) 430 | print(train_text) 431 | 432 | # Learning rate scheduling 433 | if self.opt.lrd: 434 | self.optimizerF = utils.exp_lr_scheduler(self.optimizerF, epoch, self.opt.lr, self.opt.lrd, 435 | curr_iter) 436 | self.optimizerC = utils.exp_lr_scheduler(self.optimizerC, epoch, self.opt.lr, self.opt.lrd, 437 | curr_iter) 438 | 439 | # Validate every epoch 440 | self.validate(epoch) 441 | -------------------------------------------------------------------------------- /DM-ADA/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | 4 | def weights_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | m.weight.data.normal_(0.0, 0.02) 8 | elif classname.find('BatchNorm') != -1: 9 | m.weight.data.normal_(1.0, 0.02) 10 | m.bias.data.fill_(0) 11 | elif classname.find('Linear') != -1: 12 | size = m.weight.size() 13 | m.weight.data.normal_(0.0, 0.1) 14 | m.bias.data.fill_(0) 15 | 16 | def weights_init_xavier(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('Conv') != -1: 19 | init.xavier_normal(m.weight.data, gain=0.02) 20 | elif classname.find('Linear') != -1: 21 | init.xavier_normal(m.weight.data, gain=0.02) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.normal(m.weight.data, 1.0, 0.02) 24 | init.constant(m.bias.data, 0.0) 25 | 26 | def lr_scheduler(optimizer, lr): 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | return optimizer 30 | 31 | def exp_lr_scheduler(optimizer, epoch, init_lr, lrd, nevals): 32 | """Implements torch learning reate decay with SGD""" 33 | lr = init_lr / (1 + nevals*lrd) 34 | 35 | for param_group in optimizer.param_groups: 36 | param_group['lr'] = lr 37 | 38 | return optimizer 39 | -------------------------------------------------------------------------------- /Mixup_RevGrad/README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Domain Adaptation with Domain Mixup 2 | 3 | This is the implementation of Domain Mixup strategy on a classifical adversarial domain adaptation method, [RevGrad](https://arxiv.org/abs/1409.7495v2). 4 | 5 | ## Requirements 6 | 7 | * Python 2.7 8 | * PyTorch 0.4.0 / 0.4.1 9 | 10 | ## Prerequisites 11 | 12 | Download MNIST, SVHN and USPS datasets, and prepare the datasets with following structure: 13 | ``` 14 | /Dataset_Root 15 | └── mnist 16 | ├── trainset 17 | │ ├── subfolders for 0 ~ 9 18 | ├── testset 19 | ├── svhn 20 | ├── usps 21 | ``` 22 | 23 | ## Training 24 | 25 | * Train the original RevGrad model (validation on target domain is conducted for each epoch): 26 | ``` 27 | python RevGrad.py --root_path --source --target 28 | ``` 29 | 30 | * Train the RevGrad model with Domain Mixup strategy (validation on target domain is conducted for each epoch): 31 | ``` 32 | python RevGrad_mixup.py --root_path --source --target 33 | ``` 34 | -------------------------------------------------------------------------------- /Mixup_RevGrad/RevGrad.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os, sys 3 | import math 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from torch.utils import model_zoo 12 | 13 | import data_loader 14 | import models 15 | 16 | ################################### 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--root_path', required = True, help = 'root to the data') 20 | parser.add_argument('--source', type = str, default = 'usps', help = 'the source domain') 21 | parser.add_argument('--target', type = str, default = 'mnist', help = 'the target domain') 22 | parser.add_argument('--model_dir', type = str, default = './models/', help = 'the path to save models') 23 | parser.add_argument('--batch_size', type = int, default = 100, help = 'the size of mini-batch') 24 | parser.add_argument('--epochs', type = int, default = 100, help = 'the number of epochs') 25 | parser.add_argument('--lr', type = float, default = 0.001, help = 'the initial learning rate') 26 | parser.add_argument('--momentum', type = float, default = 0.9, help = 'the momentum of gradient') 27 | parser.add_argument('--l2_decay', type = float, default = 5e-4, help = 'the l2 decay used in training') 28 | parser.add_argument('--seed', type = int, default = 100, help = 'the manual seed') 29 | parser.add_argument('--log_interval', type = int, default = 50, help = 'the interval of print') 30 | parser.add_argument('--gpu_id', type = str, default = '0', help = 'the gpu device id') 31 | 32 | opt = parser.parse_args() 33 | print (opt) 34 | 35 | ################################### 36 | 37 | # Training settings 38 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 39 | 40 | cuda = torch.cuda.is_available() 41 | if cuda: 42 | torch.cuda.manual_seed(opt.seed) 43 | 44 | # Dataloader 45 | 46 | kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} 47 | 48 | source_loader = data_loader.load_training(opt.root_path, opt.source, opt.batch_size, kwargs) 49 | target_train_loader = data_loader.load_training(opt.root_path, opt.target, opt.batch_size, kwargs) 50 | target_test_loader = data_loader.load_testing(opt.root_path, opt.target, opt.batch_size, kwargs) 51 | 52 | len_source_dataset = len(source_loader.dataset) 53 | len_target_dataset = len(target_test_loader.dataset) 54 | len_source_loader = len(source_loader) 55 | len_target_loader = len(target_train_loader) 56 | nclasses = len(source_loader.dataset.classes) 57 | 58 | ################################### 59 | 60 | # For every epoch training 61 | def train(epoch, model): 62 | 63 | optimizer = torch.optim.Adam(model.parameters(), lr = opt.lr) 64 | loss_class = torch.nn.CrossEntropyLoss() 65 | loss_domain = torch.nn.CrossEntropyLoss() 66 | 67 | data_source_iter = iter(source_loader) 68 | data_target_iter = iter(target_train_loader) 69 | dlabel_src = Variable(torch.ones(opt.batch_size).long().cuda()) 70 | dlabel_tgt = Variable(torch.zeros(opt.batch_size).long().cuda()) 71 | 72 | i = 1 73 | while i <= len_source_loader: 74 | model.train() 75 | 76 | # the parameter for reversing gradients 77 | p = float(i + epoch * len_source_loader) / opt.epochs / len_source_loader 78 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 79 | 80 | # for the source domain batch 81 | source_data, source_label = data_source_iter.next() 82 | if cuda: 83 | source_data, source_label = source_data.cuda(), source_label.cuda() 84 | source_data, source_label = Variable(source_data), Variable(source_label) 85 | 86 | _, clabel_src, dlabel_pred_src = model(source_data, alpha = alpha) 87 | label_loss = loss_class(clabel_src, source_label) 88 | domain_loss_src = loss_domain(dlabel_pred_src, dlabel_src) 89 | 90 | # for the target domain batch 91 | target_data, target_label = data_target_iter.next() 92 | if i % len_target_loader == 0: 93 | data_target_iter = iter(target_train_loader) 94 | if cuda: 95 | target_data, target_label = target_data.cuda(), target_label.cuda() 96 | target_data = Variable(target_data) 97 | 98 | _, clabel_tgt, dlabel_pred_tgt = model(target_data, alpha = alpha) 99 | domain_loss_tgt = loss_domain(dlabel_pred_tgt, dlabel_tgt) 100 | 101 | domain_loss_total = domain_loss_src + domain_loss_tgt 102 | loss_total = label_loss + domain_loss_total 103 | 104 | optimizer.zero_grad() 105 | # label_loss.backward() 106 | loss_total.backward() 107 | optimizer.step() 108 | 109 | if i % opt.log_interval == 0: 110 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tlabel_Loss: {:.6f}\tdomain_Loss: {:.6f}'.format( 111 | epoch, i * len(source_data), len_source_dataset, 112 | 100. * i / len_source_loader, label_loss.item(), domain_loss_total.item())) 113 | i = i + 1 114 | 115 | 116 | # For every epoch evaluation 117 | def test(model): 118 | model.eval() 119 | test_loss = 0 120 | correct = 0 121 | 122 | for data, target in target_test_loader: 123 | if cuda: 124 | data, target = data.cuda(), target.cuda() 125 | data, target = Variable(data), Variable(target) 126 | 127 | _, s_output, t_output = model(data, alpha = 0) 128 | test_loss += F.nll_loss(F.log_softmax(s_output, dim = 1), target, size_average=False).item() 129 | pred = s_output.max(1)[1] 130 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 131 | 132 | test_loss /= len_target_dataset 133 | 134 | print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 135 | opt.target, test_loss, correct, len_target_dataset, 136 | 100. * correct / len_target_dataset)) 137 | 138 | return correct 139 | 140 | 141 | if __name__ == '__main__': 142 | 143 | model = models.RevGrad(num_classes = nclasses) 144 | print (model) 145 | 146 | max_correct = 0 147 | if cuda: 148 | model.cuda() 149 | 150 | # start training 151 | for epoch in range(1, opt.epochs + 1): 152 | train(epoch, model) 153 | # test for every epoch 154 | t_correct = test(model) 155 | if t_correct > max_correct: 156 | max_correct = t_correct 157 | if not os.path.exists(opt.model_dir): 158 | os.mkdir(opt.model_dir) 159 | torch.save(model.state_dict(), os.path.join(opt.model_dir, 'best_model.pkl')) 160 | 161 | print('source: {} to target: {} max correct: {} max accuracy{: .2f}%\n'.format( 162 | opt.source, opt.target, max_correct, 100. * max_correct / len_target_dataset )) -------------------------------------------------------------------------------- /Mixup_RevGrad/RevGrad_mixup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os, sys 3 | import math 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from torch.utils import model_zoo 12 | 13 | import data_loader 14 | import models 15 | from models import ReverseLayerF 16 | 17 | ################################### 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--root_path', required = True, help = 'root to the data') 21 | parser.add_argument('--source', type = str, default = 'usps', help = 'the source domain') 22 | parser.add_argument('--target', type = str, default = 'mnist', help = 'the target domain') 23 | parser.add_argument('--model_dir', type = str, default = './models/', help = 'the path to save models') 24 | parser.add_argument('--batch_size', type = int, default = 100, help = 'the size of mini-batch') 25 | parser.add_argument('--epochs', type = int, default = 100, help = 'the number of epochs') 26 | parser.add_argument('--lr', type = float, default = 0.001, help = 'the initial learning rate') 27 | parser.add_argument('--momentum', type = float, default = 0.9, help = 'the momentum of gradient') 28 | parser.add_argument('--l2_decay', type = float, default = 5e-4, help = 'the l2 decay used in training') 29 | parser.add_argument("--alpha", type = float, default = 2.0, help = 'the parameter of beta distribution') 30 | parser.add_argument("--clip_thr", type = float, default = 0.3, help = 'the threshold for mixup ratio clipping') 31 | parser.add_argument("--mix_weight", type=float, default = 1.0, help = 'the weight of mixup loss') 32 | parser.add_argument('--seed', type = int, default = 100, help = 'the manual seed') 33 | parser.add_argument('--log_interval', type = int, default = 50, help = 'the interval of print') 34 | parser.add_argument('--gpu_id', type = str, default = '0', help = 'the gpu device id') 35 | 36 | opt = parser.parse_args() 37 | print (opt) 38 | 39 | ################################### 40 | 41 | # Training settings 42 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 43 | 44 | cuda = torch.cuda.is_available() 45 | if cuda: 46 | torch.cuda.manual_seed(opt.seed) 47 | 48 | # Dataloader 49 | 50 | kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} 51 | 52 | source_loader = data_loader.load_training(opt.root_path, opt.source, opt.batch_size, kwargs) 53 | target_train_loader = data_loader.load_training(opt.root_path, opt.target, opt.batch_size, kwargs) 54 | target_test_loader = data_loader.load_testing(opt.root_path, opt.target, opt.batch_size, kwargs) 55 | 56 | len_source_dataset = len(source_loader.dataset) 57 | len_target_dataset = len(target_test_loader.dataset) 58 | len_source_loader = len(source_loader) 59 | len_target_loader = len(target_train_loader) 60 | nclasses = len(source_loader.dataset.classes) 61 | 62 | ################################### 63 | 64 | # For every epoch training 65 | def train(epoch, model): 66 | 67 | optimizer = torch.optim.Adam(model.parameters(), lr = opt.lr) 68 | loss_class = torch.nn.CrossEntropyLoss() 69 | loss_domain = torch.nn.CrossEntropyLoss() 70 | 71 | data_source_iter = iter(source_loader) 72 | data_target_iter = iter(target_train_loader) 73 | dlabel_src = Variable(torch.ones(opt.batch_size).long().cuda()) 74 | dlabel_tgt = Variable(torch.zeros(opt.batch_size).long().cuda()) 75 | 76 | i = 1 77 | while i <= len_source_loader: 78 | model.train() 79 | 80 | # the parameter for reversing gradients 81 | p = float(i + epoch * len_source_loader) / opt.epochs / len_source_loader 82 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 83 | 84 | # for the source domain batch 85 | source_data, source_label = data_source_iter.next() 86 | if cuda: 87 | source_data, source_label = source_data.cuda(), source_label.cuda() 88 | source_data, source_label = Variable(source_data), Variable(source_label) 89 | 90 | emb_src, clabel_src, dlabel_pred_src = model(source_data, alpha = alpha) 91 | label_loss = loss_class(clabel_src, source_label) 92 | domain_loss_src = loss_domain(dlabel_pred_src, dlabel_src) 93 | 94 | # for the target domain batch 95 | target_data, target_label = data_target_iter.next() 96 | if i % len_target_loader == 0: 97 | data_target_iter = iter(target_train_loader) 98 | if cuda: 99 | target_data, target_label = target_data.cuda(), target_label.cuda() 100 | target_data = Variable(target_data) 101 | 102 | emb_tgt, clabel_tgt, dlabel_pred_tgt = model(target_data, alpha = alpha) 103 | domain_loss_tgt = loss_domain(dlabel_pred_tgt, dlabel_tgt) 104 | 105 | # feature-level mixup 106 | mix_ratio = np.random.beta(opt.alpha, opt.alpha) 107 | mix_ratio = round(mix_ratio, 2) 108 | # clip the mixup ratio 109 | if (mix_ratio >= 0.5 and mix_ratio < (0.5 + opt.clip_thr)): 110 | mix_ratio = 0.5 + opt.clip_thr 111 | if (mix_ratio > (0.5 - opt.clip_thr) and mix_ratio < 0.5): 112 | mix_ratio = 0.5 - opt.clip_thr 113 | 114 | dlabel_mix = Variable((torch.ones(opt.batch_size) * mix_ratio).long().cuda()) 115 | emb_mix = mix_ratio * emb_src + (1 - mix_ratio) * emb_tgt 116 | reverse_emb_mix = ReverseLayerF.apply(emb_mix, alpha) 117 | dlabel_pred_mix = model.domain_classifier(reverse_emb_mix) 118 | domain_loss_mix = loss_domain(dlabel_pred_mix, dlabel_mix) 119 | 120 | domain_loss_total = domain_loss_src + domain_loss_tgt + domain_loss_mix * opt.mix_weight 121 | loss_total = label_loss + domain_loss_total 122 | 123 | optimizer.zero_grad() 124 | loss_total.backward() 125 | optimizer.step() 126 | 127 | if i % opt.log_interval == 0: 128 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tlabel_Loss: {:.6f}\tdomain_Loss: {:.6f}'.format( 129 | epoch, i * len(source_data), len_source_dataset, 130 | 100. * i / len_source_loader, label_loss.item(), domain_loss_total.item())) 131 | i = i + 1 132 | 133 | 134 | # For every epoch evaluation 135 | def test(model): 136 | model.eval() 137 | test_loss = 0 138 | correct = 0 139 | 140 | for data, target in target_test_loader: 141 | if cuda: 142 | data, target = data.cuda(), target.cuda() 143 | data, target = Variable(data), Variable(target) 144 | 145 | _, s_output, t_output = model(data, alpha = 0) 146 | test_loss += F.nll_loss(F.log_softmax(s_output, dim = 1), target, size_average=False).item() 147 | pred = s_output.max(1)[1] 148 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 149 | 150 | test_loss /= len_target_dataset 151 | 152 | print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 153 | opt.target, test_loss, correct, len_target_dataset, 154 | 100. * correct / len_target_dataset)) 155 | 156 | return correct 157 | 158 | 159 | if __name__ == '__main__': 160 | 161 | model = models.RevGrad(num_classes = nclasses) 162 | print (model) 163 | 164 | max_correct = 0 165 | if cuda: 166 | model.cuda() 167 | 168 | # start training 169 | for epoch in range(1, opt.epochs + 1): 170 | train(epoch, model) 171 | # test for every epoch 172 | t_correct = test(model) 173 | if t_correct > max_correct: 174 | max_correct = t_correct 175 | if not os.path.exists(opt.model_dir): 176 | os.mkdir(opt.model_dir) 177 | torch.save(model.state_dict(), os.path.join(opt.model_dir, 'best_model.pkl')) 178 | 179 | print('source: {} to target: {} max correct: {} max accuracy{: .2f}%\n'.format( 180 | opt.source, opt.target, max_correct, 100. * max_correct / len_target_dataset )) -------------------------------------------------------------------------------- /Mixup_RevGrad/data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | import torch 3 | import numpy as np 4 | import os 5 | 6 | mean = np.array([0.5, 0.5, 0.5]) 7 | std = np.array([0.5, 0.5, 0.5]) 8 | 9 | def load_training(root_path, dir, batch_size, kwargs): 10 | transform = transforms.Compose( 11 | [transforms.Resize([28, 28]), 12 | transforms.ToTensor(), 13 | transforms.Normalize(mean, std)]) 14 | data = datasets.ImageFolder(root=os.path.join(root_path, dir, 'trainset'), transform=transform) 15 | train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 16 | 17 | return train_loader 18 | 19 | def load_testing(root_path, dir, batch_size, kwargs): 20 | transform = transforms.Compose( 21 | [transforms.Resize([28, 28]), 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean, std)]) 24 | data = datasets.ImageFolder(root=os.path.join(root_path, dir, 'testset'), transform=transform) 25 | test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, **kwargs) 26 | 27 | return test_loader -------------------------------------------------------------------------------- /Mixup_RevGrad/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd import Variable 5 | import torchvision.models as models 6 | import math 7 | 8 | class ReverseLayerF(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, x, alpha): 12 | ctx.alpha = alpha 13 | 14 | return x.view_as(x) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | output = grad_output.neg() * ctx.alpha 19 | 20 | return output, None 21 | 22 | class RevGrad(nn.Module): 23 | 24 | def __init__(self, num_classes = 10): 25 | super(RevGrad, self).__init__() 26 | self.nclasses = num_classes 27 | 28 | self.feature = nn.Sequential() 29 | self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5)) 30 | self.feature.add_module('f_bn1', nn.BatchNorm2d(64)) 31 | self.feature.add_module('f_pool1', nn.MaxPool2d(2)) 32 | self.feature.add_module('f_relu1', nn.ReLU(True)) 33 | self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5)) 34 | self.feature.add_module('f_bn2', nn.BatchNorm2d(50)) 35 | self.feature.add_module('f_drop1', nn.Dropout2d()) 36 | self.feature.add_module('f_pool2', nn.MaxPool2d(2)) 37 | self.feature.add_module('f_relu2', nn.ReLU(True)) 38 | 39 | self.class_classifier = nn.Sequential() 40 | self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100)) 41 | self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100)) 42 | self.class_classifier.add_module('c_relu1', nn.ReLU(True)) 43 | self.class_classifier.add_module('c_drop1', nn.Dropout2d()) 44 | self.class_classifier.add_module('c_fc2', nn.Linear(100, 100)) 45 | self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(100)) 46 | self.class_classifier.add_module('c_relu2', nn.ReLU(True)) 47 | self.class_classifier.add_module('c_fc3', nn.Linear(100, self.nclasses)) 48 | 49 | self.domain_classifier = nn.Sequential() 50 | self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100)) 51 | self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100)) 52 | self.domain_classifier.add_module('d_relu1', nn.ReLU(True)) 53 | self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2)) 54 | 55 | def forward(self, input_data, alpha): 56 | input_data = input_data.expand(len(input_data), 3, 28, 28) 57 | feature = self.feature(input_data) 58 | feature = feature.view(-1, 50 * 4 * 4) 59 | reverse_feature = ReverseLayerF.apply(feature, alpha) 60 | class_output = self.class_classifier(feature) 61 | domain_output = self.domain_classifier(reverse_feature) 62 | 63 | return feature, class_output, domain_output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Domain Adaptation with Domain Mixup 2 | 3 |

4 | 5 |

6 | 7 |
8 | 9 | This is the implementation of Adversarial Domain Adaptation with Domain Mixup in PyTorch. This work is accepted as Oral presentation at AAAI 2020. 10 | 11 | #### Adversarial Domain Adaptation with Domain Mixup: [[Paper (arxiv)]](https://arxiv.org/abs/1912.01805). 12 |
13 | 14 | ## Getting Started 15 | 16 | * We combine Domain Mixup strategy with a classical adversarial domain adaptation method, [RevGrad](https://arxiv.org/abs/1409.7495v2), to showcase its effectiveness on boosting feature alignment. Details are presented in the [Mixup_RevGrad](https://github.com/ChrisAllenMing/Mixup_for_UDA/tree/master/Mixup_RevGrad) folder. 17 | * The proposed DM-ADA approach utilizes a VAE-GAN based framework and performs Domain Mixup on both pixel and feature level. Details are presented in the [DM-ADA](https://github.com/ChrisAllenMing/Mixup_for_UDA/tree/master/DM-ADA) folder. Some typical generations from source, target and mixup features are as follows (VisDA-2017 dataset is employed). 18 | 19 |

20 | 21 |

22 | 23 | ## Citation 24 | 25 | If this work helps your research, please cite the following paper. 26 | ``` 27 | @inproceedings{xu2020adversarial, 28 | author = {Minghao Xu and Jian Zhang and Bingbing Ni and Teng Li and Chengjie Wang and Qi Tian and Wenjun Zhang}, 29 | title = {Adversarial Domain Adaptation with Domain Mixup}, 30 | booktitle = {The Thirty-Fourth AAAI Conference on Artificial Intelligence}, 31 | pages = {6502--6509}, 32 | publisher = {AAAI Press}, 33 | year = {2020} 34 | } 35 | ``` -------------------------------------------------------------------------------- /docs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisAllenMing/Mixup_for_UDA/be4b733e606a0449224f54adbf9f3bbcee4aadc3/docs/model.png -------------------------------------------------------------------------------- /docs/visda_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisAllenMing/Mixup_for_UDA/be4b733e606a0449224f54adbf9f3bbcee4aadc3/docs/visda_results.png --------------------------------------------------------------------------------