├── .DS_Store ├── README.md ├── create_mnistm.py ├── main.py ├── mnist.py ├── mnistm.py ├── model.py ├── params.py ├── test.py ├── train.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NaJaeMin92/pytorch-DANN/a70cf59fce69957751bce1e7c1271382e2dee934/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DANN-PyTorch :fire: 2 | PyTorch implementation of DANN (Domain-Adversarial Training of Neural Networks) 3 | 4 | > **[Unsupervised Domain Adaptation by Backpropagation](http://sites.skoltech.ru/compvision/projects/grl/files/paper.pdf)**
5 | > Yaroslav Ganin, Victor Lempitsky
6 | > *In PMLR-2015* 7 | 8 | > **[Domain-Adversarial Training of Neural Networks](http://jmlr.org/papers/volume17/15-239/15-239.pdf)**
9 | > Yaroslav Ganin et al.
10 | > *In JMLR-2016* 11 | 12 | 13 | ## Getting started 14 | 15 | ### Installation 16 | Install library versions that are compatible with your environment. 17 | ```bash 18 | git clone https://github.com/NaJaeMin92/pytorch-DANN.git 19 | cd pytorch-DANN 20 | conda create -n dann python=3.7 21 | conda activate dann 22 | pip install -r requirements.txt 23 | 24 | ``` 25 | 26 | ### Recommended configuration 27 | 28 | ``` 29 | python=3.7 30 | pytorch=1.12.1 31 | matplotlib=3.2.2 32 | sklearn=1.0.2 33 | ``` 34 | 35 | ### Usages 36 | Running the code below will execute both `source-only` and `DANN` training and testing: 37 | ``` 38 | python main.py 39 | # You can adjust training settings in 'params.py', including batch size and the number of training epochs. 40 | ``` 41 | 42 | ### t-SNE (t-distributed Stochastic Neighbor Embedding) 43 | Our code includes the functionality to visualize `t-SNE`, both before and after the process of domain adaptation using `sklearn.manifold`. 44 | 45 | ## Experimental results 46 | `MNIST -> MNIST-M` 47 | 48 | | Method | Test #1 | Test #2 | Test #3 | Test #4 | Test #5 | Avg. | 49 | | :-------------------------: | :-------: | :-------: | :-------: | :-------: | :---------: | :---------: | 50 | | Source Accuracy | 89 | 98 | 98 | 90 | 98 | **61.2** | 51 | | Target Accuracy | 47 | 56 | 54 | 46 | 53 | **51.2** | 52 | 53 | DANN 54 | | Method | Test #1 | Test #2 | Test #3 | Test #4 | Test #5 | Avg. | 55 | | :-------------------------: | :-------: | :-------: | :-------: | :-------: | :---------: | :---------: | 56 | | Source Accuracy | 96 | 96 | 97 | 97 | 96 | **96.4** | 57 | | Target Accuracy | 83 | 78 | 80 | 80 | 78 | **79.8** | 58 | | Domain Accuracy | 60 | 60 | 61 | 64 | 61 | **61.2** | 59 | -------------------------------------------------------------------------------- /create_mnistm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tarfile 6 | import os 7 | import pickle as pkl 8 | import numpy as np 9 | import skimage 10 | import skimage.io 11 | import skimage.transform 12 | from tensorflow.examples.tutorials.mnist import input_data 13 | 14 | mnist = input_data.read_data_sets('MNIST_data') 15 | 16 | BST_PATH = 'MNIST_data/BSR_bsds500.tgz' 17 | 18 | rand = np.random.RandomState(42) 19 | 20 | f = tarfile.open(BST_PATH) 21 | train_files = [] 22 | for name in f.getnames(): 23 | if name.startswith('BSR/BSDS500/data/images/train/'): 24 | train_files.append(name) 25 | 26 | print('Loading BSR training images') 27 | background_data = [] 28 | for name in train_files: 29 | try: 30 | fp = f.extractfile(name) 31 | bg_img = skimage.io.imread(fp) 32 | background_data.append(bg_img) 33 | except: 34 | continue 35 | 36 | 37 | def compose_image(digit, background): 38 | """Difference-blend a digit and a random patch from a background image.""" 39 | w, h, _ = background.shape 40 | dw, dh, _ = digit.shape 41 | x = np.random.randint(0, w - dw) 42 | y = np.random.randint(0, h - dh) 43 | 44 | bg = background[x:x + dw, y:y + dh] 45 | return np.abs(bg - digit).astype(np.uint8) 46 | 47 | 48 | def mnist_to_img(x): 49 | """Binarize MNIST digit and convert to RGB.""" 50 | x = (x > 0).astype(np.float32) 51 | d = x.reshape([28, 28, 1]) * 255 52 | return np.concatenate([d, d, d], 2) 53 | 54 | 55 | def create_mnistm(X): 56 | """ 57 | Give an array of MNIST digits, blend random background patches to 58 | build the MNIST-M data as described in 59 | http://jmlr.org/papers/volume17/15-239/15-239.pdf 60 | """ 61 | X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8) 62 | for i in range(X.shape[0]): 63 | 64 | if i % 1000 == 0: 65 | print('Processing example', i) 66 | 67 | bg_img = rand.choice(background_data) 68 | 69 | d = mnist_to_img(X[i]) 70 | d = compose_image(d, bg_img) 71 | X_[i] = d 72 | 73 | return X_ 74 | 75 | 76 | print('Building train set...') 77 | train = create_mnistm(mnist.train.images) 78 | print('Building test set...') 79 | test = create_mnistm(mnist.test.images) 80 | print('Building validation set...') 81 | valid = create_mnistm(mnist.validation.images) 82 | 83 | # Save data as pickle 84 | with open('MNIST_data/mnistm_data.pkl', 'wb') as f: 85 | pkl.dump({'train': train, 'test': test, 'valid': valid}, f, pkl.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import train 3 | import mnist 4 | import mnistm 5 | import model 6 | 7 | 8 | def main(): 9 | source_train_loader = mnist.mnist_train_loader 10 | target_train_loader = mnistm.mnistm_train_loader 11 | 12 | if torch.cuda.is_available(): 13 | encoder = model.Extractor().cuda() 14 | classifier = model.Classifier().cuda() 15 | discriminator = model.Discriminator().cuda() 16 | 17 | train.source_only(encoder, classifier, source_train_loader, target_train_loader) 18 | train.dann(encoder, classifier, discriminator, source_train_loader, target_train_loader) 19 | else: 20 | print("No GPUs available.") 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as datasets 2 | from torch.utils.data import SubsetRandomSampler, DataLoader 3 | from torchvision import transforms 4 | import torch 5 | import params 6 | 7 | transform = transforms.Compose([transforms.ToTensor(), 8 | transforms.Normalize((0.1307,), (0.3081,)) 9 | ]) 10 | 11 | mnist_train_dataset = datasets.MNIST(root='../data/MNIST', train=True, download=True, 12 | transform=transform) 13 | mnist_valid_dataset = datasets.MNIST(root='../data/MNIST', train=True, download=True, 14 | transform=transform) 15 | mnist_test_dataset = datasets.MNIST(root='../data/MNIST', train=False, transform=transform) 16 | 17 | indices = list(range(len(mnist_train_dataset))) 18 | validation_size = 5000 19 | train_idx, valid_idx = indices[validation_size:], indices[:validation_size] 20 | train_sampler = SubsetRandomSampler(train_idx) 21 | valid_sampler = SubsetRandomSampler(valid_idx) 22 | 23 | mnist_train_loader = DataLoader( 24 | mnist_train_dataset, 25 | batch_size=params.batch_size, 26 | sampler=train_sampler, 27 | num_workers=params.num_workers 28 | ) 29 | 30 | mnist_valid_loader = DataLoader( 31 | mnist_valid_dataset, 32 | batch_size=params.batch_size, 33 | sampler=train_sampler, 34 | num_workers=params.num_workers 35 | ) 36 | 37 | mnist_test_loader = DataLoader( 38 | mnist_test_dataset, 39 | batch_size=params.batch_size, 40 | num_workers=params.num_workers 41 | ) 42 | 43 | 44 | def one_hot_embedding(labels, num_classes=10): 45 | """Embedding labels to one-hot form. 46 | 47 | Args: 48 | labels: (LongTensor) class labels, sized [N,]. 49 | num_classes: (int) number of classes. 50 | 51 | Returns: 52 | (tensor) encoded labels, sized [N, #classes]. 53 | """ 54 | y = torch.eye(num_classes) 55 | return y[labels] 56 | -------------------------------------------------------------------------------- /mnistm.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as datasets 2 | from torch.utils.data import SubsetRandomSampler, DataLoader 3 | from torchvision import transforms 4 | import torch.utils.data as data 5 | import torch 6 | import os 7 | import errno 8 | from PIL import Image 9 | import params 10 | 11 | 12 | # MNIST-M 13 | class MNISTM(data.Dataset): 14 | """`MNIST-M Dataset.""" 15 | 16 | url = "https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz" 17 | 18 | raw_folder = 'raw' 19 | processed_folder = 'processed' 20 | training_file = 'mnist_m_train.pt' 21 | test_file = 'mnist_m_test.pt' 22 | 23 | def __init__(self, 24 | root, mnist_root="data", 25 | train=True, 26 | transform=None, target_transform=None, 27 | download=False): 28 | """Init MNIST-M dataset.""" 29 | super(MNISTM, self).__init__() 30 | self.root = os.path.expanduser(root) 31 | self.mnist_root = os.path.expanduser(mnist_root) 32 | self.transform = transform 33 | self.target_transform = target_transform 34 | self.train = train # training set or test set 35 | 36 | if download: 37 | self.download() 38 | 39 | if not self._check_exists(): 40 | raise RuntimeError('Dataset not found.' + 41 | ' You can use download=True to download it') 42 | 43 | if self.train: 44 | self.train_data, self.train_labels = \ 45 | torch.load(os.path.join(self.root, 46 | self.processed_folder, 47 | self.training_file)) 48 | else: 49 | self.test_data, self.test_labels = \ 50 | torch.load(os.path.join(self.root, 51 | self.processed_folder, 52 | self.test_file)) 53 | 54 | def __getitem__(self, index): 55 | """Get images and target for data loader. 56 | 57 | Args: 58 | index (int): Index 59 | 60 | Returns: 61 | tuple: (image, target) where target is index of the target class. 62 | """ 63 | if self.train: 64 | img, target = self.train_data[index], self.train_labels[index] 65 | else: 66 | img, target = self.test_data[index], self.test_labels[index] 67 | 68 | # doing this so that it is consistent with all other datasets 69 | # to return a PIL Image 70 | # print(type(img)) 71 | img = Image.fromarray(img.squeeze().numpy(), mode='RGB') 72 | 73 | if self.transform is not None: 74 | img = self.transform(img) 75 | 76 | if self.target_transform is not None: 77 | target = self.target_transform(target) 78 | 79 | return img, target 80 | 81 | def __len__(self): 82 | """Return size of dataset.""" 83 | if self.train: 84 | return len(self.train_data) 85 | else: 86 | return len(self.test_data) 87 | 88 | def _check_exists(self): 89 | return os.path.exists(os.path.join(self.root, 90 | self.processed_folder, 91 | self.training_file)) and \ 92 | os.path.exists(os.path.join(self.root, 93 | self.processed_folder, 94 | self.test_file)) 95 | 96 | def download(self): 97 | """Download the MNIST data.""" 98 | # import essential packages 99 | from six.moves import urllib 100 | import gzip 101 | import pickle 102 | from torchvision import datasets 103 | 104 | # check if dataset already exists 105 | if self._check_exists(): 106 | return 107 | 108 | # make data dirs 109 | try: 110 | os.makedirs(os.path.join(self.root, self.raw_folder)) 111 | os.makedirs(os.path.join(self.root, self.processed_folder)) 112 | except OSError as e: 113 | if e.errno == errno.EEXIST: 114 | pass 115 | else: 116 | raise 117 | 118 | # download pkl files 119 | print('Downloading ' + self.url) 120 | filename = self.url.rpartition('/')[2] 121 | file_path = os.path.join(self.root, self.raw_folder, filename) 122 | if not os.path.exists(file_path.replace('.gz', '')): 123 | data = urllib.request.urlopen(self.url) 124 | with open(file_path, 'wb') as f: 125 | f.write(data.read()) 126 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 127 | gzip.GzipFile(file_path) as zip_f: 128 | out_f.write(zip_f.read()) 129 | os.unlink(file_path) 130 | 131 | # process and save as torch files 132 | print('Processing...') 133 | 134 | # load MNIST-M images from pkl file 135 | with open(file_path.replace('.gz', ''), "rb") as f: 136 | mnist_m_data = pickle.load(f, encoding='bytes') 137 | mnist_m_train_data = torch.ByteTensor(mnist_m_data[b'train']) 138 | mnist_m_test_data = torch.ByteTensor(mnist_m_data[b'test']) 139 | 140 | # get MNIST labels 141 | mnist_train_labels = datasets.MNIST(root=self.mnist_root, 142 | train=True, 143 | download=True).train_labels 144 | mnist_test_labels = datasets.MNIST(root=self.mnist_root, 145 | train=False, 146 | download=True).test_labels 147 | 148 | # save MNIST-M dataset 149 | training_set = (mnist_m_train_data, mnist_train_labels) 150 | test_set = (mnist_m_test_data, mnist_test_labels) 151 | with open(os.path.join(self.root, 152 | self.processed_folder, 153 | self.training_file), 'wb') as f: 154 | torch.save(training_set, f) 155 | with open(os.path.join(self.root, 156 | self.processed_folder, 157 | self.test_file), 'wb') as f: 158 | torch.save(test_set, f) 159 | 160 | print('MNISTM Done!') 161 | 162 | 163 | transform = transforms.Compose([transforms.ToTensor(), 164 | transforms.Normalize((0.29730626, 0.29918741, 0.27534935), 165 | (0.32780124, 0.32292358, 0.32056796)) 166 | ]) 167 | 168 | mnistm_train_dataset = MNISTM(root='../data/MNIST-M', train=True, download=True, 169 | transform=transform) 170 | mnistm_valid_dataset = MNISTM(root='../data/MNIST-M', train=True, download=True, 171 | transform=transform) 172 | mnistm_test_dataset = MNISTM(root='../data/MNIST-M', train=False, transform=transform) 173 | 174 | indices = list(range(len(mnistm_train_dataset))) 175 | validation_size = 5000 176 | train_idx, valid_idx = indices[validation_size:], indices[:validation_size] 177 | train_sampler = SubsetRandomSampler(train_idx) 178 | valid_sampler = SubsetRandomSampler(valid_idx) 179 | 180 | mnistm_train_loader = DataLoader( 181 | mnistm_train_dataset, 182 | batch_size=params.batch_size, 183 | sampler=train_sampler, 184 | num_workers=params.num_workers 185 | ) 186 | 187 | mnistm_valid_loader = DataLoader( 188 | mnistm_valid_dataset, 189 | batch_size=params.batch_size, 190 | sampler=train_sampler, 191 | num_workers=params.num_workers 192 | ) 193 | 194 | mnistm_test_loader = DataLoader( 195 | mnistm_test_dataset, 196 | batch_size=params.batch_size, 197 | num_workers=params.num_workers 198 | ) 199 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from utils import ReverseLayerF 4 | 5 | 6 | class Extractor(nn.Module): 7 | def __init__(self): 8 | super(Extractor, self).__init__() 9 | self.extractor = nn.Sequential( 10 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2), 11 | nn.ReLU(), 12 | nn.MaxPool2d(kernel_size=2), 13 | 14 | nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5, padding=2), 15 | nn.ReLU(), 16 | nn.MaxPool2d(kernel_size=2) 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.extractor(x) 21 | x = x.view(-1, 3 * 28 * 28) 22 | return x 23 | 24 | 25 | class Classifier(nn.Module): 26 | def __init__(self): 27 | super(Classifier, self).__init__() 28 | self.classifier = nn.Sequential( 29 | nn.Linear(in_features=3 * 28 * 28, out_features=100), 30 | nn.ReLU(), 31 | nn.Linear(in_features=100, out_features=100), 32 | nn.ReLU(), 33 | nn.Linear(in_features=100, out_features=10) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.classifier(x) 38 | return x 39 | 40 | 41 | class Discriminator(nn.Module): 42 | def __init__(self): 43 | super(Discriminator, self).__init__() 44 | self.discriminator = nn.Sequential( 45 | nn.Linear(in_features=3 * 28 * 28, out_features=100), 46 | nn.ReLU(), 47 | nn.Linear(in_features=100, out_features=2) 48 | ) 49 | 50 | def forward(self, input_feature, alpha): 51 | reversed_input = ReverseLayerF.apply(input_feature, alpha) 52 | x = self.discriminator(reversed_input) 53 | return x 54 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | batch_size = 32 2 | epochs = 100 3 | num_workers = 4 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from model import Discriminator 5 | from utils import set_model_mode 6 | 7 | 8 | def tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode): 9 | encoder.cuda() 10 | classifier.cuda() 11 | set_model_mode('eval', [encoder, classifier]) 12 | 13 | if training_mode == 'DANN': 14 | discriminator.cuda() 15 | set_model_mode('eval', [discriminator]) 16 | domain_correct = 0 17 | 18 | source_correct = 0 19 | target_correct = 0 20 | 21 | for batch_idx, (source_data, target_data) in enumerate(zip(source_test_loader, target_test_loader)): 22 | p = float(batch_idx) / len(source_test_loader) 23 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 24 | 25 | # Process source and target data 26 | source_image, source_label = process_data(source_data, expand_channels=True) 27 | target_image, target_label = process_data(target_data) 28 | 29 | # Compute source and target predictions 30 | source_pred = compute_output(encoder, classifier, source_image, alpha=None) 31 | target_pred = compute_output(encoder, classifier, target_image, alpha=None) 32 | 33 | # Update correct counts 34 | source_correct += source_pred.eq(source_label.data.view_as(source_pred)).sum().item() 35 | target_correct += target_pred.eq(target_label.data.view_as(target_pred)).sum().item() 36 | 37 | if training_mode == 'DANN': 38 | # Process combined images for domain classification 39 | combined_image = torch.cat((source_image, target_image), 0) 40 | domain_labels = torch.cat((torch.zeros(source_label.size(0), dtype=torch.long), 41 | torch.ones(target_label.size(0), dtype=torch.long)), 0).cuda() 42 | 43 | # Compute domain predictions 44 | domain_pred = compute_output(encoder, discriminator, combined_image, alpha=alpha) 45 | domain_correct += domain_pred.eq(domain_labels.data.view_as(domain_pred)).sum().item() 46 | 47 | source_dataset_len = len(source_test_loader.dataset) 48 | target_dataset_len = len(target_test_loader.dataset) 49 | 50 | accuracies = { 51 | "Source": { 52 | "correct": source_correct, 53 | "total": source_dataset_len, 54 | "accuracy": calculate_accuracy(source_correct, source_dataset_len) 55 | }, 56 | "Target": { 57 | "correct": target_correct, 58 | "total": target_dataset_len, 59 | "accuracy": calculate_accuracy(target_correct, target_dataset_len) 60 | } 61 | } 62 | 63 | if training_mode == 'DANN': 64 | accuracies["Domain"] = { 65 | "correct": domain_correct, 66 | "total": source_dataset_len + target_dataset_len, 67 | "accuracy": calculate_accuracy(domain_correct, source_dataset_len + target_dataset_len) 68 | } 69 | 70 | print_accuracy(training_mode, accuracies) 71 | 72 | 73 | def process_data(data, expand_channels=False): 74 | images, labels = data 75 | images, labels = images.cuda(), labels.cuda() 76 | if expand_channels: 77 | images = images.repeat(1, 3, 1, 1) # Repeat channels to convert to 3-channel images 78 | return images, labels 79 | 80 | 81 | def compute_output(encoder, classifier, images, alpha=None): 82 | features = encoder(images) 83 | if isinstance(classifier, Discriminator): 84 | outputs = classifier(features, alpha) # Domain classifier 85 | else: 86 | outputs = classifier(features) # Category classifier 87 | preds = outputs.data.max(1, keepdim=True)[1] 88 | return preds 89 | 90 | 91 | def calculate_accuracy(correct, total): 92 | return 100. * correct / total 93 | 94 | 95 | def print_accuracy(training_mode, accuracies): 96 | print(f"Test Results on {training_mode}:") 97 | for key, value in accuracies.items(): 98 | print(f"{key} Accuracy: {value['correct']}/{value['total']} ({value['accuracy']:.2f}%)") 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import utils 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import test 7 | import mnist 8 | import mnistm 9 | from utils import save_model 10 | from utils import visualize 11 | from utils import set_model_mode 12 | import params 13 | 14 | # Source : 0, Target :1 15 | source_test_loader = mnist.mnist_test_loader 16 | target_test_loader = mnistm.mnistm_test_loader 17 | 18 | 19 | def source_only(encoder, classifier, source_train_loader, target_train_loader): 20 | print("Training with only the source dataset") 21 | 22 | classifier_criterion = nn.CrossEntropyLoss().cuda() 23 | optimizer = optim.SGD( 24 | list(encoder.parameters()) + 25 | list(classifier.parameters()), 26 | lr=0.01, momentum=0.9) 27 | 28 | for epoch in range(params.epochs): 29 | print(f"Epoch: {epoch}") 30 | set_model_mode('train', [encoder, classifier]) 31 | 32 | start_steps = epoch * len(source_train_loader) 33 | total_steps = params.epochs * len(target_train_loader) 34 | 35 | for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)): 36 | source_image, source_label = source_data 37 | p = float(batch_idx + start_steps) / total_steps 38 | 39 | source_image = torch.cat((source_image, source_image, source_image), 1) # MNIST convert to 3 channel 40 | source_image, source_label = source_image.cuda(), source_label.cuda() # 32 41 | 42 | optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p) 43 | optimizer.zero_grad() 44 | 45 | source_feature = encoder(source_image) 46 | 47 | # Classification loss 48 | class_pred = classifier(source_feature) 49 | class_loss = classifier_criterion(class_pred, source_label) 50 | 51 | class_loss.backward() 52 | optimizer.step() 53 | if (batch_idx + 1) % 100 == 0: 54 | total_processed = batch_idx * len(source_image) 55 | total_dataset = len(source_train_loader.dataset) 56 | percentage_completed = 100. * batch_idx / len(source_train_loader) 57 | print(f'[{total_processed}/{total_dataset} ({percentage_completed:.0f}%)]\tClassification Loss: {class_loss.item():.4f}') 58 | 59 | test.tester(encoder, classifier, None, source_test_loader, target_test_loader, training_mode='Source_only') 60 | 61 | save_model(encoder, classifier, None, 'Source-only') 62 | visualize(encoder, 'Source-only') 63 | 64 | 65 | def dann(encoder, classifier, discriminator, source_train_loader, target_train_loader): 66 | print("Training with the DANN adaptation method") 67 | 68 | classifier_criterion = nn.CrossEntropyLoss().cuda() 69 | discriminator_criterion = nn.CrossEntropyLoss().cuda() 70 | 71 | optimizer = optim.SGD( 72 | list(encoder.parameters()) + 73 | list(classifier.parameters()) + 74 | list(discriminator.parameters()), 75 | lr=0.01, 76 | momentum=0.9) 77 | 78 | for epoch in range(params.epochs): 79 | print(f"Epoch: {epoch}") 80 | set_model_mode('train', [encoder, classifier, discriminator]) 81 | 82 | start_steps = epoch * len(source_train_loader) 83 | total_steps = params.epochs * len(target_train_loader) 84 | 85 | for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)): 86 | 87 | source_image, source_label = source_data 88 | target_image, target_label = target_data 89 | 90 | p = float(batch_idx + start_steps) / total_steps 91 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 92 | 93 | source_image = torch.cat((source_image, source_image, source_image), 1) 94 | 95 | source_image, source_label = source_image.cuda(), source_label.cuda() 96 | target_image, target_label = target_image.cuda(), target_label.cuda() 97 | combined_image = torch.cat((source_image, target_image), 0) 98 | 99 | optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p) 100 | optimizer.zero_grad() 101 | 102 | combined_feature = encoder(combined_image) 103 | source_feature = encoder(source_image) 104 | 105 | # 1.Classification loss 106 | class_pred = classifier(source_feature) 107 | class_loss = classifier_criterion(class_pred, source_label) 108 | 109 | # 2. Domain loss 110 | domain_pred = discriminator(combined_feature, alpha) 111 | 112 | domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor) 113 | domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor) 114 | domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda() 115 | domain_loss = discriminator_criterion(domain_pred, domain_combined_label) 116 | 117 | total_loss = class_loss + domain_loss 118 | total_loss.backward() 119 | optimizer.step() 120 | 121 | if (batch_idx + 1) % 100 == 0: 122 | print('[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tClassification Loss: {:.4f}\tDomain Loss: {:.4f}'.format( 123 | batch_idx * len(target_image), len(target_train_loader.dataset), 100. * batch_idx / len(target_train_loader), total_loss.item(), class_loss.item(), domain_loss.item())) 124 | 125 | test.tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode='DANN') 126 | 127 | save_model(encoder, classifier, discriminator, 'DANN') 128 | visualize(encoder, 'DANN') 129 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from torch.autograd import Function 4 | from sklearn.manifold import TSNE 5 | import torch 6 | import mnist 7 | import mnistm 8 | import itertools 9 | import os 10 | 11 | 12 | class ReverseLayerF(Function): 13 | 14 | @staticmethod 15 | def forward(ctx, x, alpha): 16 | ctx.alpha = alpha 17 | 18 | return x.view_as(x) 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | output = grad_output.neg() * ctx.alpha 23 | 24 | return output, None 25 | 26 | 27 | def optimizer_scheduler(optimizer, p): 28 | """ 29 | Adjust the learning rate of optimizer 30 | :param optimizer: optimizer for updating parameters 31 | :param p: a variable for adjusting learning rate 32 | :return: optimizer 33 | """ 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75 36 | 37 | return optimizer 38 | 39 | 40 | def one_hot_embedding(labels, num_classes=10): 41 | """Embedding labels to one-hot form. 42 | 43 | Args: 44 | labels: (LongTensor) class labels, sized [N,]. 45 | num_classes: (int) number of classes. 46 | 47 | Returns: 48 | (tensor) encoded labels, sized [N, #classes]. 49 | """ 50 | y = torch.eye(num_classes) 51 | return y[labels] 52 | 53 | 54 | def save_model(encoder, classifier, discriminator, training_mode): 55 | print('Saving models ...') 56 | save_folder = 'trained_models' 57 | if not os.path.exists(save_folder): 58 | os.makedirs(save_folder) 59 | 60 | torch.save(encoder.state_dict(), 'trained_models/encoder_' + str(training_mode) + '.pt') 61 | torch.save(classifier.state_dict(), 'trained_models/classifier_' + str(training_mode) + '.pt') 62 | 63 | if training_mode == 'dann': 64 | torch.save(discriminator.state_dict(), 'trained_models/discriminator_' + str(training_mode) + '.pt') 65 | 66 | print('The model has been successfully saved!') 67 | 68 | 69 | def plot_embedding(X, y, d, training_mode): 70 | x_min, x_max = np.min(X, 0), np.max(X, 0) 71 | X = (X - x_min) / (x_max - x_min) 72 | y = list(itertools.chain.from_iterable(y)) 73 | y = np.asarray(y) 74 | 75 | plt.figure(figsize=(10, 10)) 76 | for i in range(len(d)): # X.shape[0] : 1024 77 | # plot colored number 78 | if d[i] == 0: 79 | colors = (0.0, 0.0, 1.0, 1.0) 80 | else: 81 | colors = (1.0, 0.0, 0.0, 1.0) 82 | plt.text(X[i, 0], X[i, 1], str(y[i]), 83 | color=colors, 84 | fontdict={'weight': 'bold', 'size': 9}) 85 | 86 | plt.xticks([]), plt.yticks([]) 87 | 88 | save_folder = 'saved_plot' 89 | if not os.path.exists(save_folder): 90 | os.makedirs(save_folder) 91 | 92 | fig_name = 'saved_plot/' + str(training_mode) + '.png' 93 | plt.savefig(fig_name) 94 | print('{} has been successfully saved!'.format(fig_name)) 95 | 96 | 97 | def visualize(encoder, training_mode): 98 | # Draw 512 samples in test_data 99 | source_test_loader = mnist.mnist_test_loader 100 | target_test_loader = mnistm.mnistm_test_loader 101 | 102 | # Get source_test samples 103 | source_label_list = [] 104 | source_img_list = [] 105 | for i, test_data in enumerate(source_test_loader): 106 | if i >= 16: # to get only 512 samples 107 | break 108 | img, label = test_data 109 | label = label.numpy() 110 | img = img.cuda() 111 | img = torch.cat((img, img, img), 1) # MNIST channel 1 -> 3 112 | source_label_list.append(label) 113 | source_img_list.append(img) 114 | 115 | source_img_list = torch.stack(source_img_list) 116 | source_img_list = source_img_list.view(-1, 3, 28, 28) 117 | 118 | # Get target_test samples 119 | target_label_list = [] 120 | target_img_list = [] 121 | for i, test_data in enumerate(target_test_loader): 122 | if i >= 16: 123 | break 124 | img, label = test_data 125 | label = label.numpy() 126 | img = img.cuda() 127 | target_label_list.append(label) 128 | target_img_list.append(img) 129 | 130 | target_img_list = torch.stack(target_img_list) 131 | target_img_list = target_img_list.view(-1, 3, 28, 28) 132 | 133 | # Stack source_list + target_list 134 | combined_label_list = source_label_list 135 | combined_label_list.extend(target_label_list) 136 | combined_img_list = torch.cat((source_img_list, target_img_list), 0) 137 | 138 | source_domain_list = torch.zeros(512).type(torch.LongTensor) 139 | target_domain_list = torch.ones(512).type(torch.LongTensor) 140 | combined_domain_list = torch.cat((source_domain_list, target_domain_list), 0).cuda() 141 | 142 | print("Extracting features to draw t-SNE plot...") 143 | combined_feature = encoder(combined_img_list) # combined_feature : 1024,2352 144 | 145 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000) 146 | dann_tsne = tsne.fit_transform(combined_feature.detach().cpu().numpy()) 147 | 148 | print('Drawing t-SNE plot ...') 149 | plot_embedding(dann_tsne, combined_label_list, combined_domain_list, training_mode) 150 | 151 | 152 | def visualize_input(): 153 | source_test_loader = mnist.mnist_test_loader 154 | target_test_loader = mnistm.mnistm_test_loader 155 | 156 | # Get source_test samples 157 | source_label_list = [] 158 | source_img_list = [] 159 | for i, test_data in enumerate(source_test_loader): 160 | if i >= 16: # to get only 512 samples 161 | break 162 | img, label = test_data 163 | label = label.numpy() 164 | img = img.cuda() 165 | img = torch.cat((img, img, img), 1) # MNIST channel 1 -> 3 166 | source_label_list.append(label) 167 | source_img_list.append(img) 168 | 169 | source_img_list = torch.stack(source_img_list) 170 | source_img_list = source_img_list.view(-1, 3, 28, 28) 171 | 172 | # Get target_test samples 173 | target_label_list = [] 174 | target_img_list = [] 175 | for i, test_data in enumerate(target_test_loader): 176 | if i >= 16: 177 | break 178 | img, label = test_data 179 | label = label.numpy() 180 | img = img.cuda() 181 | target_label_list.append(label) 182 | target_img_list.append(img) 183 | 184 | target_img_list = torch.stack(target_img_list) 185 | target_img_list = target_img_list.view(-1, 3, 28, 28) 186 | 187 | # Stack source_list + target_list 188 | combined_label_list = source_label_list 189 | combined_label_list.extend(target_label_list) 190 | combined_img_list = torch.cat((source_img_list, target_img_list), 0) 191 | 192 | source_domain_list = torch.zeros(512).type(torch.LongTensor) 193 | target_domain_list = torch.ones(512).type(torch.LongTensor) 194 | combined_domain_list = torch.cat((source_domain_list, target_domain_list), 0).cuda() 195 | 196 | print("Extracting features to draw t-SNE plot...") 197 | combined_feature = combined_img_list # combined_feature : 1024,3,28,28 198 | combined_feature = combined_feature.view(1024, -1) # flatten 199 | # print(type(combined_feature), combined_feature.shape) 200 | 201 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000) 202 | dann_tsne = tsne.fit_transform(combined_feature.detach().cpu().numpy()) 203 | print('Drawing t-SNE plot ...') 204 | plot_embedding(dann_tsne, combined_label_list, combined_domain_list, 'input') 205 | 206 | 207 | def set_model_mode(mode='train', models=None): 208 | for model in models: 209 | if mode == 'train': 210 | model.train() 211 | else: 212 | model.eval() 213 | --------------------------------------------------------------------------------