├── README.md ├── data ├── mnistsample │ └── mnist.pkl.gz └── uspssample │ └── usps_28x28.pkl ├── dataset_mnist.py ├── dataset_usps.py ├── main.py ├── mnist2usps.yaml ├── net_config.py ├── requirements.txt ├── train_src.py └── usps2mnist.yaml /README.md: -------------------------------------------------------------------------------- 1 | # adda-pytorch 2 | Implementation of "Adversarial Discriminative Domain Adaptation"(https://arxiv.org/abs/1702.05464) on pytorch 3 | 4 | 5 | ### dataset 6 | mnist -> usps
7 | usps -> mnist 8 |
9 | 10 | ### command 11 | ##### $ python train_src.py --config "A->B config-file" 12 | #training a model on source domain data
13 | "A->B config-file" should be usps2mnist.yaml or mnist2usps.yaml
14 | val reports the accuracy on target data 15 |

16 | 17 | ##### $ python main.py --config "A->B config-file" 18 | #adapt the model trained on source data to fit target data with ADDA method
19 | "A->B config-file" should be usps2mnist.yaml or mnist2usps.yaml 20 |

21 | 22 | 23 | ### performance on testing set 24 | | Method | mnist -> usps | usps -> mnist | 25 | | ------------- |:-------------:|:-------------:| 26 | | source only | 84% (1560/1860)| 78% (7820/10000)| 27 | | adapted | 92% (1709/1860) |91% (9074/10000)| 28 | -------------------------------------------------------------------------------- /data/mnistsample/mnist.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitahhhh/adda-pytorch/b305c7e86070a9bae78f4fb4cf2ebfbf78b742f9/data/mnistsample/mnist.pkl.gz -------------------------------------------------------------------------------- /data/uspssample/usps_28x28.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitahhhh/adda-pytorch/b305c7e86070a9bae78f4fb4cf2ebfbf78b742f9/data/uspssample/usps_28x28.pkl -------------------------------------------------------------------------------- /dataset_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | import os 6 | import numpy as np 7 | import cPickle 8 | import gzip 9 | import torch.utils.data as data 10 | import urllib 11 | 12 | 13 | class MNISTSAMPLE(data.Dataset): 14 | 15 | def __init__(self, root, num_training_samples, train=True, transform=None, seed=None): 16 | self.url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz' 17 | self.filename = 'mnist.pkl.gz' 18 | self.train = train 19 | self.root = root 20 | self.num_training_samples = num_training_samples 21 | self.transform = transform 22 | self.download() 23 | self.test_set_size = 0 24 | self.train_data, self.train_labels = self.load_samples() 25 | if seed is not None: 26 | np.random.seed(seed) 27 | if self.train: 28 | total_num_samples = self.train_labels.shape[0] 29 | indices = np.arange(total_num_samples) 30 | np.random.shuffle(indices) 31 | self.train_data = self.train_data[indices[0:self.num_training_samples], ::] 32 | self.train_labels = self.train_labels[indices[0:self.num_training_samples]] 33 | self.train_data *= 255.0 34 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 35 | 36 | def __getitem__(self, index): 37 | img, label = self.train_data[index, ::], self.train_labels[index] 38 | if self.transform is not None: 39 | img = self.transform(img) 40 | label = torch.LongTensor([np.int64(label)]) 41 | return img, label 42 | 43 | def __len__(self): 44 | if self.train: 45 | return self.num_training_samples 46 | else: 47 | return self.test_set_size 48 | 49 | def download(self): 50 | filename = os.path.join(self.root, self.filename) 51 | dirname = os.path.dirname(filename) 52 | if not os.path.isdir(dirname): 53 | os.mkdir(dirname) 54 | if os.path.isfile(filename): 55 | return 56 | print("Download %s to %s" % (self.url, filename)) 57 | urllib.urlretrieve(self.url, filename) 58 | print("[DONE]") 59 | return 60 | 61 | def load_samples(self): 62 | filename = os.path.join(self.root, self.filename) 63 | f = gzip.open(filename, 'rb') 64 | train_set, valid_set, test_set = cPickle.load(f) 65 | f.close() 66 | if self.train: 67 | images = np.concatenate((train_set[0], valid_set[0]), axis=0) 68 | labels = np.concatenate((train_set[1], valid_set[1]), axis=0) 69 | else: 70 | images = test_set[0] 71 | labels = test_set[1] 72 | self.test_set_size = labels.shape[0] 73 | images = images.reshape((images.shape[0], 1, 28, 28)) 74 | return images, labels 75 | 76 | -------------------------------------------------------------------------------- /dataset_usps.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import cv2 4 | import os 5 | import numpy as np 6 | import cPickle 7 | import gzip 8 | import torch.utils.data as data 9 | import torch 10 | import urllib 11 | 12 | 13 | class USPSSAMPLE(data.Dataset): 14 | # Num of Train = 7438, Num ot Test 1860 15 | def __init__(self, root, num_training_samples, train=True, transform=None, seed=None): 16 | self.filename = 'usps_28x28.pkl' 17 | self.train = train 18 | self.root = root 19 | self.num_training_samples = num_training_samples 20 | self.transform = transform 21 | self.test_set_size = 0 22 | # self.download() 23 | self.train_data, self.train_labels = self.load_samples() 24 | if seed is not None: 25 | np.random.seed(seed) 26 | if self.train: 27 | total_num_samples = self.train_labels.shape[0] 28 | indices = np.arange(total_num_samples) 29 | np.random.shuffle(indices) 30 | self.train_data = self.train_data[indices[0:self.num_training_samples], ::] 31 | self.train_labels = self.train_labels[indices[0:self.num_training_samples]] 32 | self.train_data *= 255.0 33 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 34 | 35 | def __getitem__(self, index): 36 | img, label = self.train_data[index, ::], self.train_labels[index] 37 | if self.transform is not None: 38 | img = self.transform(img) 39 | label = torch.LongTensor([np.int64(label)]) 40 | return img, label 41 | 42 | def __len__(self): 43 | if self.train: 44 | return self.num_training_samples 45 | else: 46 | return self.test_set_size 47 | 48 | def download(self): 49 | filename = os.path.join(self.root, self.filename) 50 | dirname = os.path.dirname(filename) 51 | if not os.path.isdir(dirname): 52 | os.mkdir(dirname) 53 | if os.path.isfile(filename): 54 | return 55 | print("Download %s to %s" % (self.url, filename)) 56 | urllib.urlretrieve(self.url, filename) 57 | print("[DONE]") 58 | return 59 | 60 | def load_samples(self): 61 | filename = os.path.join(self.root, self.filename) 62 | f = gzip.open(filename, 'rb') 63 | data_set = cPickle.load(f) 64 | f.close() 65 | if self.train: 66 | images = data_set[0][0] 67 | labels = data_set[0][1] 68 | else: 69 | images = data_set[1][0] 70 | labels = data_set[1][1] 71 | self.test_set_size = labels.shape[0] 72 | return images, labels 73 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import sys 10 | import itertools 11 | import logging 12 | from dataset_mnist import * 13 | from dataset_usps import * 14 | from net_config import * 15 | from optparse import OptionParser 16 | 17 | # Training settings 18 | parser = OptionParser() 19 | parser.add_option('--config', 20 | type=str, 21 | help="net configuration", 22 | default="usps2mnist.yaml") 23 | 24 | (opts, args) = parser.parse_args(sys.argv) 25 | config = NetConfig(opts.config) 26 | kwargs = {'num_workers': 1, 'pin_memory': True} if config.use_cuda else {} 27 | torch.manual_seed(config.seed) 28 | if torch.cuda.is_available() == False: 29 | config.use_cuda = False 30 | print("invalid cuda access") 31 | if config.use_cuda: 32 | torch.cuda.manual_seed(config.seed) 33 | 34 | 35 | def read(argv,config): 36 | print(config) 37 | if os.path.exists(config.log): 38 | os.remove(config.log) 39 | base_folder_name = os.path.dirname(config.log) 40 | if not os.path.isdir(base_folder_name): 41 | os.mkdir(base_folder_name) 42 | logging.basicConfig(filename=config.log, level=logging.INFO, mode='w') 43 | console = logging.StreamHandler() 44 | console.setLevel(logging.INFO) 45 | logging.getLogger('').addHandler(console) 46 | logging.info("Let the journey begin!") 47 | logging.info(config) 48 | exec("train_dataset_a = %s(root=config.train_data_a_path, \ 49 | num_training_samples=config.train_data_a_size, \ 50 | train=config.train_data_a_use_train_data, \ 51 | transform=transforms.ToTensor(), \ 52 | seed=config.train_data_a_seed)" % config.train_data_a) 53 | train_loader_a = torch.utils.data.DataLoader(dataset=train_dataset_a, batch_size=config.batch_size, shuffle=True) 54 | 55 | exec("train_dataset_b = %s(root=config.train_data_b_path, \ 56 | num_training_samples=config.train_data_b_size, \ 57 | train=config.train_data_b_use_train_data, \ 58 | transform=transforms.ToTensor(), \ 59 | seed=config.train_data_b_seed)" % config.train_data_b) 60 | train_loader_b = torch.utils.data.DataLoader(dataset=train_dataset_b, batch_size=config.batch_size, shuffle=True) 61 | 62 | exec("test_dataset_b = %s(root=config.test_data_b_path, \ 63 | num_training_samples=config.test_data_b_size, \ 64 | train=config.test_data_b_use_train_data, \ 65 | transform=transforms.ToTensor(), \ 66 | seed=config.test_data_b_seed)" % config.test_data_b) 67 | test_loader_b = torch.utils.data.DataLoader(dataset=test_dataset_b, batch_size=config.test_batch_size, shuffle=True) 68 | 69 | return train_loader_a, train_loader_b, test_loader_b, 70 | 71 | train_loader_a, train_loader_b, test_loader_b = read(sys.argv,config) 72 | 73 | class Net(nn.Module): 74 | def __init__(self): 75 | super(Net, self).__init__() 76 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 77 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 78 | self.conv2_drop = nn.Dropout2d() 79 | self.fc1 = nn.Linear(320, 50) 80 | self.fc2 = nn.Linear(50, 10) 81 | 82 | def forward(self, x): 83 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 84 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 85 | x_f = self.fc1(x.view(-1, 320)) 86 | x = F.dropout(F.relu(x_f), training=self.training) 87 | x = self.fc2(x) 88 | return x_f, F.log_softmax(x) 89 | 90 | class Discrimer(nn.Module): 91 | def __init__(self): 92 | super(Discrimer, self).__init__() 93 | self.fc1 = nn.Linear(50, 512) 94 | self.fc2 = nn.Linear(512, 2) 95 | 96 | def forward(self, x): 97 | x = F.relu(self.fc1(x)) 98 | x = self.fc2(x) 99 | return F.log_softmax(x) 100 | 101 | 102 | model = Net() 103 | model_src = Net() 104 | critic = Discrimer() 105 | 106 | if config.use_cuda: 107 | model.cuda() 108 | model_src.cuda() 109 | critic.cuda() 110 | 111 | optimizer_d = optim.Adam(critic.parameters(), lr=config.lr) 112 | optimizer_g = optim.Adam(model.parameters(), lr=config.lr) 113 | print("load model...") 114 | PATH = config.pretrained_path #'pytorch_model_usps2mnist' 115 | model.load_state_dict(torch.load(PATH)) #model for adapt 116 | model_src.load_state_dict(torch.load(PATH)) 117 | 118 | 119 | def train(epoch): 120 | model.train() 121 | 122 | for batch_idx, ((data_src, target_src), (data, target)) in enumerate(itertools.izip(train_loader_a, train_loader_b)): 123 | if config.use_cuda: 124 | data_src, target_src = data_src.cuda(), target_src.cuda() 125 | data, target = data.cuda(), target.cuda() 126 | data_src, target_src = Variable(data_src), Variable(target_src) 127 | data, target = Variable(data), Variable(target) 128 | 129 | feat_src, output_src = model_src(data_src) 130 | feat, output = model(data) 131 | all_d_feat = torch.cat((feat_src,feat),0) 132 | all_d_score = critic(all_d_feat) 133 | all_d_label = torch.cat((Variable(torch.ones(all_d_score.size()[0]/2).long().cuda()), 134 | Variable(torch.zeros(all_d_score.size()[0]/2).long().cuda())),0) 135 | #D loss 136 | domain_loss = F.nll_loss(all_d_score, all_d_label) 137 | ###domain accuracy### 138 | predict = torch.squeeze(all_d_score.max(1)[1]) 139 | d_accu = (predict == all_d_label).float().mean() 140 | 141 | critic.zero_grad() 142 | model.zero_grad() 143 | domain_loss.backward(retain_variables=True) 144 | optimizer_d.step() 145 | 146 | #G loss 147 | gen_loss = F.nll_loss(all_d_score[all_d_score.size()[0]/2:,...], 148 | Variable(torch.ones(all_d_score.size()[0]/2).long().cuda())) 149 | 150 | model.zero_grad() 151 | critic.zero_grad() 152 | gen_loss.backward() 153 | optimizer_g.step() 154 | 155 | if batch_idx % config.log_interval == 0: 156 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tG Loss: {:.6f}\tD Loss: {:.6f}\tD accu: {:.3f}'.format( 157 | epoch, batch_idx * len(data), len(train_loader_a.dataset), 158 | 100. * batch_idx / len(train_loader_a), gen_loss.data[0],domain_loss.data[0],d_accu.data[0])) 159 | 160 | def test(epoch): 161 | model.eval() 162 | test_loss = 0 163 | correct = 0 164 | for data, target in test_loader_b: 165 | if config.use_cuda: 166 | data, target = data.cuda(), target.cuda() 167 | data, target = Variable(data, volatile=True), Variable(target) 168 | feat, output = model(data) 169 | target = torch.squeeze(target) 170 | test_loss += F.nll_loss(output, target).data[0] 171 | pred = output.data.max(1)[1] # get the index of the max log-probability 172 | correct += pred.eq(target.data).cpu().sum() 173 | 174 | test_loss = test_loss 175 | test_loss /= len(test_loader_b) # loss function already averages over batch size 176 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 177 | test_loss, correct, len(test_loader_b.dataset), 178 | 100. * correct / len(test_loader_b.dataset))) 179 | 180 | 181 | for epoch in range(1, config.epochs + 1): 182 | train(epoch) 183 | test(epoch) 184 | 185 | 186 | -------------------------------------------------------------------------------- /mnist2usps.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | batch_size: 64 3 | test_batch_size: 100 4 | lr: 0.0001 5 | use_cuda: True 6 | seed : 1 7 | epochs: 100 8 | log_interval: 20 9 | log: outputs/mnist2usps.log 10 | pretrained_path: pytorch_model_mnist2usps 11 | train_data_a: MNISTSAMPLE 12 | train_data_a_path: data/mnistsample 13 | train_data_a_size: 60000 14 | train_data_a_seed: 0 15 | train_data_a_use_train_data: True 16 | train_data_b: USPSSAMPLE 17 | train_data_b_path: data/uspssample 18 | train_data_b_size: 7438 19 | train_data_b_seed: 0 20 | train_data_b_use_train_data: True 21 | test_data_b: USPSSAMPLE 22 | test_data_b_path: data/uspssample 23 | test_data_b_size: 1860 24 | test_data_b_seed: 0 25 | test_data_b_use_train_data: False 26 | -------------------------------------------------------------------------------- /net_config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | class NetConfig(object): 5 | def __init__(self, config): 6 | stream = open(config,'r') 7 | docs = yaml.load_all(stream) 8 | for doc in docs: 9 | for k, v in doc.items(): 10 | if k == "train": 11 | for k1, v1 in v.items(): 12 | cmd = "self." + k1 + "=" + repr(v1) 13 | print(cmd) 14 | exec(cmd) 15 | stream.close() 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | -------------------------------------------------------------------------------- /train_src.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import sys 10 | import itertools 11 | import logging 12 | from dataset_mnist import * 13 | from dataset_usps import * 14 | from net_config import * 15 | from optparse import OptionParser 16 | import pdb 17 | 18 | # Training settings 19 | parser = OptionParser() 20 | parser.add_option('--config', 21 | type=str, 22 | help="net configuration", 23 | default="usps2mnist.yaml") 24 | (opts, args) = parser.parse_args(sys.argv) 25 | config = NetConfig(opts.config) 26 | kwargs = {'num_workers': 1, 'pin_memory': True} if config.use_cuda else {} 27 | torch.manual_seed(config.seed) 28 | if torch.cuda.is_available() == False: 29 | config.use_cuda = False 30 | print("invalid cuda access") 31 | if config.use_cuda: 32 | torch.cuda.manual_seed(config.seed) 33 | 34 | 35 | def read(argv,config): 36 | print(config) 37 | if os.path.exists(config.log): 38 | os.remove(config.log) 39 | base_folder_name = os.path.dirname(config.log) 40 | if not os.path.isdir(base_folder_name): 41 | os.mkdir(base_folder_name) 42 | logging.basicConfig(filename=config.log, level=logging.INFO, mode='w') 43 | console = logging.StreamHandler() 44 | console.setLevel(logging.INFO) 45 | logging.getLogger('').addHandler(console) 46 | logging.info("Let the journey begin!") 47 | logging.info(config) 48 | exec("train_dataset_a = %s(root=config.train_data_a_path, \ 49 | num_training_samples=config.train_data_a_size, \ 50 | train=config.train_data_a_use_train_data, \ 51 | transform=transforms.ToTensor(), \ 52 | seed=config.train_data_a_seed)" % config.train_data_a) 53 | train_loader_a = torch.utils.data.DataLoader(dataset=train_dataset_a, batch_size=config.batch_size, shuffle=True) 54 | 55 | exec("train_dataset_b = %s(root=config.train_data_b_path, \ 56 | num_training_samples=config.train_data_b_size, \ 57 | train=config.train_data_b_use_train_data, \ 58 | transform=transforms.ToTensor(), \ 59 | seed=config.train_data_b_seed)" % config.train_data_b) 60 | train_loader_b = torch.utils.data.DataLoader(dataset=train_dataset_b, batch_size=config.batch_size, shuffle=True) 61 | 62 | exec("test_dataset_b = %s(root=config.test_data_b_path, \ 63 | num_training_samples=config.test_data_b_size, \ 64 | train=config.test_data_b_use_train_data, \ 65 | transform=transforms.ToTensor(), \ 66 | seed=config.test_data_b_seed)" % config.test_data_b) 67 | test_loader_b = torch.utils.data.DataLoader(dataset=test_dataset_b, batch_size=config.test_batch_size, shuffle=True) 68 | 69 | return train_loader_a, train_loader_b, test_loader_b 70 | pdb.set_trace() 71 | 72 | train_loader_a, train_loader_b, test_loader_b = read(sys.argv,config) 73 | 74 | class Net(nn.Module): 75 | def __init__(self): 76 | super(Net, self).__init__() 77 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 78 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 79 | self.conv2_drop = nn.Dropout2d() 80 | self.fc1 = nn.Linear(320, 50) 81 | self.fc2 = nn.Linear(50, 10) 82 | 83 | def forward(self, x): 84 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 85 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 86 | x = x.view(-1, 320) 87 | x_f = F.relu(self.fc1(x)) 88 | x = F.dropout(x_f, training=self.training) 89 | x = self.fc2(x) 90 | return x_f, F.log_softmax(x) 91 | 92 | model = Net() 93 | if config.use_cuda: 94 | model.cuda() 95 | 96 | optimizer = optim.Adam(model.parameters(), lr=0.01) 97 | def train(epoch): 98 | model.train() 99 | for batch_idx, (data, target) in enumerate(train_loader_a): 100 | if config.use_cuda: 101 | data, target = data.cuda(), target.cuda() 102 | data, target = Variable(data), Variable(target) 103 | optimizer.zero_grad() 104 | feat, output = model(data) 105 | target = torch.squeeze(target) 106 | loss = F.nll_loss(output, target) 107 | loss.backward() 108 | optimizer.step() 109 | if batch_idx % config.log_interval == 0: 110 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 111 | epoch, batch_idx * len(data), len(train_loader_a.dataset), 112 | 100. * batch_idx / len(train_loader_a), loss.data[0])) 113 | 114 | def test(epoch): 115 | model.eval() 116 | test_loss = 0 117 | correct = 0 118 | for data, target in test_loader_b: 119 | if config.use_cuda: 120 | data, target = data.cuda(), target.cuda() 121 | data, target = Variable(data, volatile=True), Variable(target) 122 | feat, output = model(data) 123 | target = torch.squeeze(target) 124 | test_loss += F.nll_loss(output, target).data[0] 125 | pred = output.data.max(1)[1] # get the index of the max log-probability 126 | correct += pred.eq(target.data).cpu().sum() 127 | 128 | test_loss = test_loss 129 | test_loss /= len(test_loader_b) # loss function already averages over batch size 130 | print('\nTest on target valid set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 131 | test_loss, correct, len(test_loader_b.dataset), 132 | 100. * correct / len(test_loader_b.dataset))) 133 | 134 | 135 | for epoch in range(1, config.epochs + 1): 136 | train(epoch) 137 | test(epoch) 138 | 139 | PATH = 'pytorch_model_usps2mnist' 140 | torch.save(model.state_dict(), PATH) 141 | 142 | -------------------------------------------------------------------------------- /usps2mnist.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | batch_size: 64 3 | test_batch_size: 100 4 | lr: 0.001 5 | use_cuda: True 6 | seed : 1 7 | epochs: 70 8 | log_interval: 20 9 | log: outputs/usps2mnist.log 10 | pretrained_path: pytorch_model_usps2mnist 11 | train_data_a: USPSSAMPLE 12 | train_data_a_path: data/uspssample 13 | train_data_a_size: 7438 14 | train_data_a_seed: 0 15 | train_data_a_use_train_data: True 16 | train_data_b: MNISTSAMPLE 17 | train_data_b_path: data/mnistsample 18 | train_data_b_size: 60000 19 | train_data_b_seed: 0 20 | train_data_b_use_train_data: True 21 | test_data_b: MNISTSAMPLE 22 | test_data_b_path: data/mnistsample 23 | test_data_b_size: 10000 24 | test_data_b_seed: 0 25 | test_data_b_use_train_data: False 26 | --------------------------------------------------------------------------------