├── .gitignore ├── LICENSE ├── README.md ├── core ├── __init__.py ├── adapt.py └── test.py ├── datasets ├── __init__.py ├── mnist.py └── usps.py ├── main.py ├── misc ├── params.py └── utils.py ├── models ├── __init__.py ├── classifier.py ├── discriminator.py └── generator.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # datasets and model snapshots 2 | data 3 | snapshots 4 | 5 | # trash 6 | .DS_Store 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yusu Pan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-ARDA (Work In Process) 2 | A PyTorch implementation for [Adversarial Representation Learning for Domain Adaptation](https://arxiv.org/abs/1707.01217) 3 | 4 | ## Result 5 | I can't get expected result in this code of current version, need reinforcement. 6 | 7 | 8 | | | MNIST (Source) | USPS (Target) | 9 | | :-----------------: | :------------: | :-----------: | 10 | | Generator + Encoder | 98.44167% | 88.70664% | 11 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .adapt import train 2 | from .test import test 3 | 4 | __all__ = (train, test) 5 | -------------------------------------------------------------------------------- /core/adapt.py: -------------------------------------------------------------------------------- 1 | """Execute domain adaption for ARDA.""" 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from misc import params 7 | from misc.utils import (calc_gradient_penalty, get_inf_iterator, get_optimizer, 8 | make_variable, save_model) 9 | 10 | 11 | def train(classifier, generator, critic, src_data_loader, tgt_data_loader): 12 | """Train generator, classifier and critic jointly.""" 13 | #################### 14 | # 1. setup network # 15 | #################### 16 | 17 | # set train state for Dropout and BN layers 18 | classifier.train() 19 | generator.train() 20 | critic.train() 21 | 22 | # set criterion for classifier and optimizers 23 | criterion = nn.CrossEntropyLoss() 24 | optimizer_c = get_optimizer(classifier, "Adam") 25 | optimizer_g = get_optimizer(generator, "Adam") 26 | optimizer_d = get_optimizer(critic, "Adam") 27 | 28 | # zip source and target data pair 29 | data_iter_src = get_inf_iterator(src_data_loader) 30 | data_iter_tgt = get_inf_iterator(tgt_data_loader) 31 | 32 | # counter 33 | g_step = 0 34 | 35 | # positive and negative labels 36 | pos_labels = make_variable(torch.FloatTensor([1])) 37 | neg_labels = make_variable(torch.FloatTensor([-1])) 38 | 39 | #################### 40 | # 2. train network # 41 | #################### 42 | 43 | for epoch in range(params.num_epochs): 44 | ########################### 45 | # 2.1 train discriminator # 46 | ########################### 47 | # requires to compute gradients for D 48 | for p in critic.parameters(): 49 | p.requires_grad = True 50 | 51 | # set steps for discriminator 52 | if g_step < 25 or g_step % 500 == 0: 53 | # this helps to start with the critic at optimum 54 | # even in the first iterations. 55 | critic_iters = 100 56 | else: 57 | critic_iters = params.d_steps 58 | 59 | # loop for optimizing discriminator 60 | for d_step in range(critic_iters): 61 | # convert images into torch.Variable 62 | images_src, labels_src = next(data_iter_src) 63 | images_tgt, _ = next(data_iter_tgt) 64 | images_src = make_variable(images_src) 65 | labels_src = make_variable(labels_src.squeeze_()) 66 | images_tgt = make_variable(images_tgt) 67 | if images_src.size(0) != params.batch_size or \ 68 | images_tgt.size(0) != params.batch_size: 69 | continue 70 | 71 | # zero gradients for optimizer 72 | optimizer_d.zero_grad() 73 | 74 | # compute source data loss for discriminator 75 | feat_src = generator(images_src) 76 | d_loss_src = critic(feat_src.detach()) 77 | d_loss_src = d_loss_src.mean() 78 | d_loss_src.backward(neg_labels) 79 | 80 | # compute target data loss for discriminator 81 | feat_tgt = generator(images_tgt) 82 | d_loss_tgt = critic(feat_tgt.detach()) 83 | d_loss_tgt = d_loss_tgt.mean() 84 | d_loss_tgt.backward(pos_labels) 85 | 86 | # compute gradient penalty 87 | gradient_penalty = calc_gradient_penalty( 88 | critic, feat_src.data, feat_tgt.data) 89 | gradient_penalty.backward() 90 | 91 | # optimize weights of discriminator 92 | d_loss = - d_loss_src + d_loss_tgt + gradient_penalty 93 | optimizer_d.step() 94 | 95 | ######################## 96 | # 2.2 train classifier # 97 | ######################## 98 | 99 | # zero gradients for optimizer 100 | optimizer_c.zero_grad() 101 | 102 | # compute loss for critic 103 | preds_c = classifier(generator(images_src).detach()) 104 | c_loss = criterion(preds_c, labels_src) 105 | 106 | # optimize source classifier 107 | c_loss.backward() 108 | optimizer_c.step() 109 | 110 | ####################### 111 | # 2.3 train generator # 112 | ####################### 113 | # avoid to compute gradients for D 114 | for p in critic.parameters(): 115 | p.requires_grad = False 116 | 117 | # zero grad for optimizer of generator 118 | optimizer_g.zero_grad() 119 | 120 | # compute source data classification loss for generator 121 | feat_src = generator(images_src) 122 | preds_c = classifier(feat_src) 123 | g_loss_cls = criterion(preds_c, labels_src) 124 | g_loss_cls.backward() 125 | 126 | # compute source data discriminattion loss for generator 127 | feat_src = generator(images_src) 128 | g_loss_src = critic(feat_src).mean() 129 | g_loss_src.backward(pos_labels) 130 | 131 | # compute target data discriminattion loss for generator 132 | feat_tgt = generator(images_tgt) 133 | g_loss_tgt = critic(feat_tgt).mean() 134 | g_loss_tgt.backward(neg_labels) 135 | 136 | # compute loss for generator 137 | g_loss = g_loss_src - g_loss_tgt + g_loss_cls 138 | 139 | # optimize weights of generator 140 | optimizer_g.step() 141 | g_step += 1 142 | 143 | ################## 144 | # 2.4 print info # 145 | ################## 146 | if ((epoch + 1) % params.log_step == 0): 147 | print("Epoch [{}/{}]:" 148 | "d_loss={:.5f} c_loss={:.5f} g_loss={:.5f} " 149 | "D(x)={:.5f} D(G(z))={:.5f} GP={:.5f}" 150 | .format(epoch + 1, 151 | params.num_epochs, 152 | d_loss.data[0], 153 | c_loss.data[0], 154 | g_loss.data[0], 155 | d_loss_src.data[0], 156 | d_loss_tgt.data[0], 157 | gradient_penalty.data[0] 158 | )) 159 | 160 | ############################# 161 | # 2.5 save model parameters # 162 | ############################# 163 | if ((epoch + 1) % params.save_step == 0): 164 | save_model(critic, "WGAN-GP_critic-{}.pt".format(epoch + 1)) 165 | save_model(classifier, 166 | "WGAN-GP_classifier-{}.pt".format(epoch + 1)) 167 | save_model(generator, "WGAN-GP_generator-{}.pt".format(epoch + 1)) 168 | 169 | return classifier, generator 170 | -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | """Test result of domain adaption for ARDA.""" 2 | 3 | import torch.nn as nn 4 | 5 | from misc.utils import make_variable 6 | 7 | 8 | def test(classifier, generator, data_loader, dataset="MNIST"): 9 | """Evaluate classifier on source or target domains.""" 10 | # set eval state for Dropout and BN layers 11 | generator.eval() 12 | classifier.eval() 13 | 14 | # init loss and accuracy 15 | loss = 0 16 | acc = 0 17 | 18 | # set loss function 19 | criterion = nn.CrossEntropyLoss() 20 | 21 | # evaluate network 22 | for (images, labels) in data_loader: 23 | images = make_variable(images, volatile=True) 24 | labels = make_variable(labels.squeeze_()) 25 | 26 | preds = classifier(generator(images)) 27 | loss += criterion(preds, labels).data[0] 28 | 29 | pred_cls = preds.data.max(1)[1] 30 | acc += pred_cls.eq(labels.data).cpu().sum() 31 | 32 | loss /= len(data_loader) 33 | acc /= len(data_loader.dataset) 34 | 35 | print("Avg Loss = {:.5f}, Avg Accuracy = {:2.5%}".format(loss, acc)) 36 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import get_mnist 2 | from .usps import get_usps 3 | 4 | __all__ = (get_usps, get_mnist) 5 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for MNIST.""" 2 | 3 | 4 | import torch 5 | from torchvision import datasets, transforms 6 | 7 | from misc import params 8 | 9 | 10 | def get_mnist(train): 11 | """Get MNIST dataset loader.""" 12 | # image pre-processing 13 | pre_process = transforms.Compose([transforms.ToTensor(), 14 | transforms.Normalize( 15 | mean=params.dataset_mean, 16 | std=params.dataset_std)]) 17 | 18 | # dataset and data loader 19 | mnist_dataset = datasets.MNIST(root=params.data_root, 20 | train=train, 21 | transform=pre_process, 22 | download=True) 23 | 24 | mnist_data_loader = torch.utils.data.DataLoader( 25 | dataset=mnist_dataset, 26 | batch_size=params.batch_size, 27 | shuffle=True) 28 | 29 | return mnist_data_loader 30 | -------------------------------------------------------------------------------- /datasets/usps.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for USPS. 2 | 3 | Modified from 4 | https://github.com/mingyuliutw/CoGAN_PyTorch/blob/master/src/dataset_usps.py 5 | """ 6 | 7 | import gzip 8 | import os 9 | import pickle 10 | import urllib 11 | 12 | import numpy as np 13 | import torch 14 | import torch.utils.data as data 15 | from torchvision import transforms 16 | 17 | from misc import params 18 | 19 | 20 | class USPS(data.Dataset): 21 | """USPS Dataset. 22 | 23 | Args: 24 | root (string): Root directory of dataset where dataset file exist. 25 | train (bool, optional): If True, resample from dataset randomly. 26 | download (bool, optional): If true, downloads the dataset 27 | from the internet and puts it in root directory. 28 | If dataset is already downloaded, it is not downloaded again. 29 | transform (callable, optional): A function/transform that takes in 30 | an PIL image and returns a transformed version. 31 | E.g, ``transforms.RandomCrop`` 32 | """ 33 | 34 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN_PyTorch/master/data/uspssample/usps_28x28.pkl" 35 | 36 | def __init__(self, root, train=True, transform=None, download=False): 37 | """Init USPS dataset.""" 38 | # init params 39 | self.root = os.path.expanduser(root) 40 | self.filename = "usps_28x28.pkl" 41 | self.train = train 42 | # Num of Train = 7438, Num ot Test 1860 43 | self.transform = transform 44 | self.dataset_size = None 45 | 46 | # download dataset. 47 | if download: 48 | self.download() 49 | if not self._check_exists(): 50 | raise RuntimeError("Dataset not found." + 51 | " You can use download=True to download it") 52 | 53 | self.train_data, self.train_labels = self.load_samples() 54 | if self.train: 55 | total_num_samples = self.train_labels.shape[0] 56 | indices = np.arange(total_num_samples) 57 | np.random.shuffle(indices) 58 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 59 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 60 | self.train_data *= 255.0 61 | self.train_data = self.train_data.transpose( 62 | (0, 2, 3, 1)) # convert to HWC 63 | 64 | def __getitem__(self, index): 65 | """Get images and target for data loader. 66 | 67 | Args: 68 | index (int): Index 69 | Returns: 70 | tuple: (image, target) where target is index of the target class. 71 | """ 72 | img, label = self.train_data[index, ::], self.train_labels[index] 73 | if self.transform is not None: 74 | img = self.transform(img) 75 | label = torch.LongTensor([np.int64(label).item()]) 76 | # label = torch.FloatTensor([label.item()]) 77 | return img, label 78 | 79 | def __len__(self): 80 | """Return size of dataset.""" 81 | return self.dataset_size 82 | 83 | def _check_exists(self): 84 | """Check if dataset is download and in right place.""" 85 | return os.path.exists(os.path.join(self.root, self.filename)) 86 | 87 | def download(self): 88 | """Download dataset.""" 89 | filename = os.path.join(self.root, self.filename) 90 | dirname = os.path.dirname(filename) 91 | if not os.path.isdir(dirname): 92 | os.makedirs(dirname) 93 | if os.path.isfile(filename): 94 | return 95 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 96 | urllib.request.urlretrieve(self.url, filename) 97 | print("[DONE]") 98 | return 99 | 100 | def load_samples(self): 101 | """Load sample images from dataset.""" 102 | filename = os.path.join(self.root, self.filename) 103 | f = gzip.open(filename, "rb") 104 | data_set = pickle.load(f, encoding="bytes") 105 | f.close() 106 | if self.train: 107 | images = data_set[0][0] 108 | labels = data_set[0][1] 109 | self.dataset_size = labels.shape[0] 110 | else: 111 | images = data_set[1][0] 112 | labels = data_set[1][1] 113 | self.dataset_size = labels.shape[0] 114 | return images, labels 115 | 116 | 117 | def get_usps(train): 118 | """Get USPS dataset loader.""" 119 | # image pre-processing 120 | pre_process = transforms.Compose([transforms.ToTensor(), 121 | transforms.Normalize( 122 | mean=params.dataset_mean, 123 | std=params.dataset_std)]) 124 | 125 | # dataset and data loader 126 | usps_dataset = USPS(root=params.data_root, 127 | train=train, 128 | transform=pre_process, 129 | download=True) 130 | 131 | usps_data_loader = torch.utils.data.DataLoader( 132 | dataset=usps_dataset, 133 | batch_size=params.batch_size, 134 | shuffle=True) 135 | 136 | return usps_data_loader 137 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Main script for ARDA.""" 2 | 3 | from core import test, train 4 | from misc import params 5 | from misc.utils import get_data_loader, init_model, init_random_seed 6 | from models import Classifier, Discriminator, Generator 7 | 8 | if __name__ == '__main__': 9 | # init random seed 10 | init_random_seed(params.manual_seed) 11 | 12 | # load dataset 13 | src_data_loader = get_data_loader(params.src_dataset) 14 | src_data_loader_test = get_data_loader(params.src_dataset, train=False) 15 | tgt_data_loader = get_data_loader(params.tgt_dataset) 16 | tgt_data_loader_test = get_data_loader(params.tgt_dataset, train=False) 17 | 18 | # init models 19 | classifier = init_model(net=Classifier(), 20 | restore=params.c_model_restore) 21 | generator = init_model(net=Generator(), 22 | restore=params.g_model_restore) 23 | critic = init_model(net=Discriminator(input_dims=params.d_input_dims, 24 | hidden_dims=params.d_hidden_dims, 25 | output_dims=params.d_output_dims), 26 | restore=params.d_model_restore) 27 | 28 | # train models 29 | print("=== Training models ===") 30 | print(">>> Classifier <<<") 31 | print(classifier) 32 | print(">>> Generator <<<") 33 | print(generator) 34 | print(">>> Critic <<<") 35 | print(critic) 36 | 37 | if not (params.eval_only and classifier.restored and 38 | generator.restored and critic.restored): 39 | classifier, generator = train( 40 | classifier, generator, critic, src_data_loader, tgt_data_loader) 41 | 42 | # evaluate models 43 | print("=== Evaluating models ===") 44 | print(">>> on source domain <<<") 45 | test(classifier, generator, src_data_loader, params.src_dataset) 46 | print(">>> on target domain <<<") 47 | test(classifier, generator, tgt_data_loader, params.tgt_dataset) 48 | -------------------------------------------------------------------------------- /misc/params.py: -------------------------------------------------------------------------------- 1 | """Parameters for ARDA.""" 2 | 3 | # params for dataset and data loader 4 | data_root = "data" 5 | dataset_mean_value = 0.5 6 | dataset_std_value = 0.5 7 | dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value) 8 | dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value) 9 | batch_size = 50 10 | image_size = 64 11 | src_dataset = "MNIST" 12 | tgt_dataset = "USPS" 13 | 14 | # params for critic 15 | d_input_dims = 500 16 | d_hidden_dims = 500 17 | d_output_dims = 1 18 | d_steps = 5 19 | d_model_restore = None 20 | 21 | # params for generator 22 | g_model_restore = None 23 | 24 | # params for classifier 25 | c_model_restore = None 26 | 27 | # params for training network 28 | num_gpu = 1 29 | num_epochs = 20000 30 | log_step = 100 31 | save_step = 5000 32 | manual_seed = None 33 | model_root = "snapshots" 34 | eval_only = False 35 | 36 | # params for optimizing models 37 | learning_rate = 1e-4 38 | beta1 = 0.5 39 | beta2 = 0.9 40 | 41 | # params for WGAN and WGAN-GP 42 | use_gradient_penalty = True # quickly switch WGAN and WGAN-GP 43 | penalty_lambda = 10 44 | 45 | # params for interaction of discriminative and transferable feature learning 46 | dc_lambda = 10 47 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | """Helpful functions for ARDA.""" 2 | 3 | import os 4 | import random 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | from torch.autograd import Variable, grad 10 | 11 | from datasets import get_mnist, get_usps 12 | from misc import params 13 | 14 | 15 | def make_variable(tensor, volatile=False): 16 | """Convert Tensor to Variable.""" 17 | if torch.cuda.is_available(): 18 | tensor = tensor.cuda() 19 | return Variable(tensor, volatile=volatile) 20 | 21 | 22 | def make_cuda(tensor): 23 | """Use CUDA if it's available.""" 24 | if torch.cuda.is_available(): 25 | tensor = tensor.cuda() 26 | return tensor 27 | 28 | 29 | def denormalize(x, std, mean): 30 | """Invert normalization, and then convert array into image.""" 31 | out = x * std + mean 32 | return out.clamp(0, 1) 33 | 34 | 35 | def calc_gradient_penalty(D, real_data, fake_data): 36 | """Calculatge gradient penalty for WGAN-GP.""" 37 | alpha = torch.rand(params.batch_size, 1) 38 | alpha = alpha.expand(real_data.size()) 39 | alpha = make_cuda(alpha) 40 | 41 | interpolates = make_variable(alpha * real_data + ((1 - alpha) * fake_data)) 42 | interpolates.requires_grad = True 43 | 44 | disc_interpolates = D(interpolates) 45 | 46 | gradients = grad(outputs=disc_interpolates, 47 | inputs=interpolates, 48 | grad_outputs=make_cuda( 49 | torch.ones(disc_interpolates.size())), 50 | create_graph=True, 51 | retain_graph=True, 52 | only_inputs=True)[0] 53 | 54 | gradient_penalty = params.penalty_lambda * \ 55 | ((gradients.norm(2, dim=1) - 1) ** 2).mean() 56 | 57 | return gradient_penalty 58 | 59 | 60 | def init_weights(layer): 61 | """Init weights for layers.""" 62 | layer_name = layer.__class__.__name__ 63 | if layer_name.find("Conv") != -1: 64 | layer.weight.data.normal_(0.0, 0.02) 65 | elif layer_name.find("BatchNorm") != -1: 66 | layer.weight.data.normal_(1.0, 0.02) 67 | layer.bias.data.fill_(0) 68 | 69 | 70 | def init_random_seed(manual_seed): 71 | """Init random seed.""" 72 | seed = None 73 | if manual_seed is None: 74 | seed = random.randint(1, 10000) 75 | else: 76 | seed = manual_seed 77 | print("use random seed: {}".format(seed)) 78 | random.seed(seed) 79 | torch.manual_seed(seed) 80 | if torch.cuda.is_available(): 81 | torch.cuda.manual_seed_all(seed) 82 | 83 | 84 | def init_model(net, restore): 85 | """Init models with cuda and weights.""" 86 | # init weights of model 87 | net.apply(init_weights) 88 | 89 | # restore model weights 90 | if restore is not None and os.path.exists(restore): 91 | net.load_state_dict(torch.load(restore)) 92 | net.restored = True 93 | print("Restore model from: {}".format(os.path.abspath(restore))) 94 | 95 | # check if cuda is available 96 | if torch.cuda.is_available(): 97 | cudnn.benchmark = True 98 | net.cuda() 99 | 100 | return net 101 | 102 | 103 | def save_model(net, filename): 104 | """Save trained model.""" 105 | if not os.path.exists(params.model_root): 106 | os.makedirs(params.model_root) 107 | torch.save(net.state_dict(), 108 | os.path.join(params.model_root, filename)) 109 | print("save pretrained model to: {}".format(os.path.join(params.model_root, 110 | filename))) 111 | 112 | 113 | def get_optimizer(net, name="Adam"): 114 | """Get optimizer by name.""" 115 | if name == "Adam": 116 | return optim.Adam(net.parameters(), 117 | lr=params.learning_rate, 118 | betas=(params.beta1, params.beta2)) 119 | 120 | 121 | def get_data_loader(name, train=True): 122 | """Get data loader by name.""" 123 | if name == "MNIST": 124 | return get_mnist(train) 125 | elif name == "USPS": 126 | return get_usps(train) 127 | 128 | 129 | def get_inf_iterator(data_loader): 130 | """Inf data iterator.""" 131 | while True: 132 | for images, labels in data_loader: 133 | yield (images, labels) 134 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import Classifier 2 | from .discriminator import Discriminator 3 | from .generator import Generator 4 | 5 | __all__ = (Classifier, Discriminator, Generator) 6 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | """Classifier for ARDA. 2 | 3 | guarantee learned domain-invariant representations are discriminative enough 4 | to accomplish the final classification task 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class Classifier(nn.Module): 12 | """LeNet classifier model for ARDA.""" 13 | 14 | def __init__(self): 15 | """Init LeNet encoder.""" 16 | super(Classifier, self).__init__() 17 | self.fc2 = nn.Linear(500, 10) 18 | 19 | def forward(self, feat): 20 | """Forward the LeNet classifier.""" 21 | out = F.dropout(F.relu(feat), training=self.training) 22 | out = self.fc2(out) 23 | return out 24 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | """Discriminator for ARDA. 2 | 3 | estimate the Wasserstein distance between the source and target 4 | representation distributions 5 | """ 6 | 7 | 8 | from torch import nn 9 | 10 | 11 | class Discriminator(nn.Module): 12 | """Discriminator model.""" 13 | 14 | def __init__(self, input_dims, hidden_dims, output_dims): 15 | """Init discriminator.""" 16 | super(Discriminator, self).__init__() 17 | 18 | self.restored = False 19 | 20 | self.layer = nn.Sequential( 21 | nn.Linear(input_dims, hidden_dims), 22 | nn.ReLU(), 23 | nn.Linear(hidden_dims, hidden_dims), 24 | nn.ReLU(), 25 | nn.Linear(hidden_dims, output_dims) 26 | ) 27 | 28 | def forward(self, input): 29 | """Forward the discriminator.""" 30 | out = self.layer(input) 31 | return out.view(-1) 32 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | """Generator for ARDA. 2 | 3 | learn the domain-invariant feature representations from inputs across domains. 4 | """ 5 | 6 | from torch import nn 7 | 8 | 9 | class Generator(nn.Module): 10 | """LeNet encoder model for ARDA.""" 11 | 12 | def __init__(self): 13 | """Init LeNet encoder.""" 14 | super(Generator, self).__init__() 15 | 16 | self.restored = False 17 | 18 | self.encoder = nn.Sequential( 19 | # 1st conv block 20 | # input [1 x 28 x 28] 21 | # output [64 x 12 x 12] 22 | nn.Conv2d(1, 64, 5, 1, 0, bias=False), 23 | nn.MaxPool2d(2), 24 | nn.ReLU(), 25 | # 2nd conv block 26 | # input [64 x 12 x 12] 27 | # output [50 x 4 x 4] 28 | nn.Conv2d(64, 50, 5, 1, 0, bias=False), 29 | nn.Dropout2d(), 30 | nn.MaxPool2d(2), 31 | nn.ReLU() 32 | ) 33 | self.fc1 = nn.Linear(50 * 4 * 4, 500) 34 | 35 | def forward(self, input): 36 | """Forward the LeNet.""" 37 | conv_out = self.encoder(input) 38 | feat = self.fc1(conv_out.view(-1, 50 * 4 * 4)) 39 | return feat 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | --------------------------------------------------------------------------------