├── .gitignore ├── README.md ├── config.yaml ├── core ├── __init__.py ├── test.py └── train.py ├── datasets ├── __init__.py ├── mnist.py └── mnistm.py ├── main.py ├── models ├── __init__.py ├── functions.py └── model.py ├── pytorch_DANN.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Datasets 132 | data/ 133 | 134 | # Models 135 | save/ 136 | 137 | # Visualization 138 | imgs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DANN 2 | 3 | The PyTorch implementation of DANN (Domain-Adversarial Training of Neural Networks). 4 | 5 | - [Unsupervised Domain Adaptation by Backpropagation](https://arxiv.org/abs/1409.7495) 6 | - [Domain-Adversarial Training of Neural Networks](https://arxiv.org/abs/1505.07818) 7 | 8 | ![dann](https://user-images.githubusercontent.com/97284065/175561529-d2e836b6-deba-42bb-8b5f-ab8f3491c248.png) 9 | 10 | 11 | ## Environment 12 | 13 | ``` 14 | python 3.7 15 | pytorch 1.11.0 16 | torhvision 0.12.0 17 | ``` 18 | 19 | 20 | ## Usage 21 | 22 | If the models need to train, follow the below instruction. 23 | 24 | > Train the models on source-only 25 | > ``` 26 | > python main.py --source 'mnist' --target 'mnistm' --mode 'source-only' --train 27 | > ``` 28 | > Train the models on DANN 29 | > ``` 30 | > python main.py --source 'mnist' --target 'mnistm' --mode 'dann' --train 31 | > ``` 32 | 33 | 34 | If the models only test, follow the below instruction. 35 | 36 | > Test the models on source-only 37 | > ``` 38 | > python main.py --source 'mnist' --target 'mnistm' --mode 'source-only' --extractor 'weights_filename' --classifier 'weights_filename' 39 | > ``` 40 | > Test the models on DANN 41 | > ``` 42 | > python main.py --source 'mnist' --target 'mnistm' --mode 'dann' --extractor 'weights_filename' --classifier 'weights_filename' 43 | > ``` 44 | 45 | 46 | ## Experiments 47 | 48 | `MNIST → MNIST-M` 49 | | | Paper | This repo | 50 | | :---------: | :----: | :-------: | 51 | | Source-Only | 0.5225 | 0.6195 | 52 | | DANN | 0.7666 | 0.8050 | 53 | 54 | The result of experiments is the average of 5 experiments below. 55 | 56 | 57 | ### Details 58 | 59 | `MNIST → MNIST-M` 60 | | | Test 1 | Test 2 | Test 3 | Test 4 | Test 5 | 61 | | :---------: |--------| ------ | ------ | ------ | ------ | 62 | | Source-Only | 0.6160 | 0.6251 | 0.6162 | 0.6193 | 0.6208 | 63 | | DANN | 0.8205 | 0.7816 | 0.8035 | 0.8281 | 0.7911 | 64 | 65 | 66 | ## Visualizations 67 | ![visualizations](https://user-images.githubusercontent.com/97284065/175574285-ef19218e-6922-434f-bd06-4913390af4f7.png) 68 | 69 | 70 | ## Reference 71 | - [https://github.com/fungtion/DANN](https://github.com/fungtion/DANN) 72 | - [https://github.com/NaJaeMin92/pytorch_DANN](https://github.com/NaJaeMin92/pytorch_DANN) -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # Parameters for datasets 2 | root: './data' 3 | img_size: 28 4 | batch_size: 32 5 | 6 | # Parameters for model architecture 7 | extractor: 8 | in_channels: 3 9 | 10 | classifier: 11 | in_features: 768 12 | out_features: 10 13 | 14 | discriminator: 15 | in_features: 768 16 | out_features: 2 17 | 18 | # Parameters for training 19 | epochs: 100 20 | save: './save' 21 | momentum: 0.9 22 | gamma: 10 23 | lr: 24 | initial_lr: 0.01 25 | alpha: 10 26 | beta: 0.75 27 | 28 | # Parameters for visualization 29 | visual_root: './imgs' -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .test import test 2 | from .train import source_only, dann 3 | 4 | 5 | __all__ = ['source_only', 'dann', 'test'] -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def test(extractor, classifier, data_loader): 5 | """ Estimate the model performance """ 6 | 7 | # Load the models to GPU 8 | extractor = extractor.cuda() 9 | classifier = classifier.cuda() 10 | 11 | # Set the model to evaluation mode 12 | extractor.eval() 13 | classifier.eval() 14 | 15 | num_data = 0 16 | total_acc = 0.0 17 | 18 | # Test 19 | with torch.no_grad(): 20 | for images, labels in data_loader: 21 | # Load images and labels to GPU 22 | images = images.cuda() 23 | labels = labels.cuda() 24 | 25 | # Predict the labels 26 | preds = classifier(extractor(images)) 27 | 28 | # Update the total accuracy 29 | num_data += len(images) 30 | total_acc += (preds.max(1)[1] == labels).sum().item() 31 | 32 | total_acc = total_acc / num_data 33 | 34 | print('Test Accuracy: {:.4f}%'.format(total_acc * 100)) 35 | -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | from utils import optimizer_scheduler, save_model 8 | 9 | 10 | def source_only(extractor, classifier, source_loader): 11 | """ Train the models using only source dataset """ 12 | 13 | # Get the parameters for training 14 | config = yaml.load(open('config.yaml')) 15 | 16 | # Load the models to GPU 17 | extractor = extractor.cuda() 18 | classifier = classifier.cuda() 19 | 20 | # Set up criterion and optimizer 21 | criterion = nn.CrossEntropyLoss().cuda() 22 | optimizer = optim.SGD(params=list(extractor.parameters()) + list(classifier.parameters()), 23 | lr=config['lr']['initial_lr'], 24 | momentum=config['momentum']) 25 | 26 | # Training 27 | print('\nSource-Only Training...\n') 28 | 29 | for epoch in range(config['epochs']): 30 | # Set the model to train mode 31 | extractor.train() 32 | classifier.train() 33 | 34 | num_data = 0 35 | total_acc = 0.0 36 | total_loss = 0.0 37 | 38 | for idx, (images, labels) in enumerate(source_loader): 39 | # Update the learning rate 40 | p = (idx + epoch * len(source_loader)) / config['epochs'] / len(source_loader) 41 | optimizer = optimizer_scheduler(optimizer, p) 42 | 43 | # Load images and labels to GPU 44 | images = images.cuda() 45 | labels = labels.cuda() 46 | 47 | # Predict labels and compute loss 48 | preds = classifier(extractor(images)) 49 | loss = criterion(preds, labels) 50 | 51 | # Optimize the models 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | # Update total loss and total accuracy 57 | num_data += len(images) 58 | total_acc += (preds.max(1)[1] == labels).sum().item() 59 | total_loss += loss.item() 60 | 61 | total_acc = total_acc / num_data 62 | total_loss = total_loss / len(source_loader) 63 | 64 | # Print log information 65 | print('Epoch [{:4}/{:4}] Loss: {:8.4f}, Accuracy: {:.4f}%'.format( 66 | epoch+1, config['epochs'], total_loss, total_acc * 100 67 | )) 68 | 69 | # Save the model parameters 70 | if (epoch + 1) % 10 == 0: 71 | save_model(extractor, 'source_extractor_{}.pt'.format(epoch+1)) 72 | save_model(classifier, 'source_classifier_{}.pt'.format(epoch+1)) 73 | 74 | return extractor, classifier 75 | 76 | 77 | def dann(extractor, classifier, discriminator, source_loader, target_loader): 78 | """ Train the models of DANN """ 79 | 80 | # Get the parameters for training 81 | config = yaml.load(open('config.yaml')) 82 | 83 | # Load the models to GPU 84 | extractor = extractor.cuda() 85 | classifier = classifier.cuda() 86 | discriminator = discriminator.cuda() 87 | 88 | # Set up criterion and optimizer 89 | cls_criterion = nn.CrossEntropyLoss().cuda() 90 | dis_criterion = nn.CrossEntropyLoss().cuda() 91 | 92 | optimizer = optim.SGD(params=list(extractor.parameters()) + 93 | list(classifier.parameters()) + 94 | list(discriminator.parameters()), 95 | lr=config['lr']['initial_lr'], 96 | momentum=config['momentum']) 97 | 98 | # Training 99 | print('\nDANN Training...\n') 100 | 101 | for epoch in range(config['epochs']): 102 | # Set the model to train mode 103 | extractor.train() 104 | classifier.train() 105 | discriminator.train() 106 | 107 | num_data = 0 108 | total_acc = 0.0 109 | total_loss = 0.0 110 | total_cls_loss = 0.0 111 | total_dis_loss = 0.0 112 | len_loader = min(len(source_loader), len(target_loader)) 113 | 114 | for idx, (src_data, tgt_data) in enumerate(zip(source_loader, target_loader)): 115 | src_images, src_labels = src_data 116 | tgt_images, _ = tgt_data 117 | 118 | # Compute the alpha value and update the learning rate 119 | p = (idx + epoch * len_loader) / config['epochs'] / len_loader 120 | alpha = 2. / (1. + np.exp(-config['gamma'] * p)) - 1 121 | optimizer = optimizer_scheduler(optimizer, p) 122 | 123 | # Load images and labels to GPU 124 | src_images, src_labels = src_images.cuda(), src_labels.cuda() 125 | tgt_images = tgt_images.cuda() 126 | 127 | # Predict class labels and compute classification loss 128 | cls_preds = classifier(extractor(src_images)) 129 | cls_loss = cls_criterion(cls_preds, src_labels) 130 | 131 | # Update total classification loss and total classification accuracy 132 | num_data += len(src_images) 133 | total_acc += (cls_preds.max(1)[1] == src_labels).sum().item() 134 | total_cls_loss += cls_loss.item() 135 | 136 | 137 | # Make the domain labels 138 | domain_source_labels = torch.zeros(src_images.shape[0]).type(torch.LongTensor) 139 | domain_target_labels = torch.ones(tgt_images.shape[0]).type(torch.LongTensor) 140 | domain_labels = torch.cat([domain_source_labels, domain_target_labels], 0).cuda() 141 | combined_images = torch.cat([src_images, tgt_images], 0) 142 | 143 | # Predict domain labels and compute discrimination loss 144 | dis_preds = discriminator(extractor(combined_images), alpha) 145 | dis_loss = dis_criterion(dis_preds, domain_labels) 146 | 147 | # Update total discrimination loss and total loss 148 | loss = cls_loss + dis_loss 149 | total_dis_loss += dis_loss.item() 150 | total_loss += loss.item() 151 | 152 | # Optimize the models 153 | optimizer.zero_grad() 154 | loss.backward() 155 | optimizer.step() 156 | 157 | total_acc = total_acc / num_data 158 | total_loss = total_loss / len_loader 159 | total_cls_loss = total_cls_loss / len_loader 160 | total_dis_loss = total_dis_loss / len_loader 161 | 162 | # Print log information 163 | print('Epoch [{:4}/{:4}] Loss: {:8.4f}, Class Loss: {:8.4f}, Domain Loss: {:8.4f}, Accuracy: {:.4f}%'.format( 164 | epoch+1, config['epochs'], total_loss, total_cls_loss, total_dis_loss, total_acc * 100 165 | )) 166 | 167 | # Save the model parameters 168 | if (epoch + 1) % 10 == 0: 169 | save_model(extractor, 'dann_extractor_{}.pt'.format(epoch+1)) 170 | save_model(classifier, 'dann_classifier_{}.pt'.format(epoch+1)) 171 | 172 | return extractor, classifier 173 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import get_mnist 2 | from .mnistm import get_mnistm 3 | 4 | 5 | __all__ = ['get_mnist', 'get_mnistm'] -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torch.utils.data import DataLoader 3 | from torchvision import datasets, transforms 4 | 5 | 6 | def get_mnist(train=True): 7 | """ Get the MNIST data loader """ 8 | 9 | # Get the parameters for creating data loader 10 | config = yaml.load(open('config.yaml')) 11 | 12 | # Image pre-processing 13 | transform = transforms.Compose([transforms.Resize(config['img_size']), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.1307,), (0.3081,)), 16 | transforms.Lambda(lambda x: x.repeat(3, 1, 1))]) 17 | 18 | # MNIST dataset 19 | mnist = datasets.MNIST(root=config['root'], 20 | train=train, 21 | download=True, 22 | transform=transform) 23 | 24 | # MNIST data loader 25 | mnist_loader = DataLoader(dataset=mnist, 26 | batch_size=config['batch_size'], 27 | shuffle=True) 28 | 29 | return mnist_loader 30 | -------------------------------------------------------------------------------- /datasets/mnistm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import pickle 4 | import requests 5 | 6 | import yaml 7 | import torch 8 | from PIL import Image 9 | from torch.utils import data 10 | from torchvision import datasets, transforms 11 | 12 | 13 | class MNISTM(data.Dataset): 14 | """ The MNIST-M dataset class """ 15 | 16 | url = 'https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz' 17 | 18 | def __init__(self, root, train=True, download=False, transform=None): 19 | 20 | super(MNISTM, self).__init__() 21 | self.root = root 22 | self.data_dir = 'MNISTM' 23 | self.raw_dir = 'raw' 24 | self.train = train 25 | self.transform = transform 26 | 27 | if download: 28 | self.download() 29 | 30 | def __getitem__(self, index): 31 | """ Get images and target for data loader """ 32 | 33 | if self.train: 34 | image, target = self.train_images[index], self.train_labels[index] 35 | else: 36 | image, target = self.test_images[index], self.test_labels[index] 37 | 38 | image = Image.fromarray(image.squeeze().numpy(), mode='RGB') 39 | 40 | # Pre-processing 41 | if self.transform is not None: 42 | image = self.transform(image) 43 | 44 | return image, target 45 | 46 | def __len__(self): 47 | """ Return size of dataset """ 48 | 49 | if self.train: 50 | return len(self.train_images) 51 | else: 52 | return len(self.test_images) 53 | 54 | def download(self): 55 | """ Download the MNIST-M data """ 56 | 57 | # Make data directory 58 | os.makedirs(os.path.join(self.root, self.data_dir, self.raw_dir), exist_ok=True) 59 | 60 | # Download the pkl file 61 | filename = self.url.split('/')[-1] 62 | filepath = os.path.join(self.root, self.data_dir, self.raw_dir, filename) 63 | 64 | if not os.path.exists(filepath): 65 | print('Downloading {}'.format(self.url)) 66 | 67 | response = requests.get(self.url) 68 | open(filepath, 'wb').write(response.content) 69 | 70 | # Extract pkl file from gz file 71 | with open(filepath.replace('.gz', ''), 'wb') as f: 72 | f.write(gzip.open(filepath, 'rb').read()) 73 | 74 | # Load MNIST-M images from pkl file 75 | with open(filepath.replace('.gz', ''), 'rb') as f: 76 | mnistm_data = pickle.load(f, encoding='bytes') 77 | 78 | self.train_images = torch.ByteTensor(mnistm_data[b'train']) 79 | self.test_images = torch.ByteTensor(mnistm_data[b'test']) 80 | 81 | # Get MNIST-M labels from MNIST dataset 82 | self.train_labels = datasets.MNIST(root=self.root, 83 | train=True, 84 | download=True).targets 85 | self.test_labels = datasets.MNIST(root=self.root, 86 | train=False, 87 | download=True).targets 88 | 89 | 90 | def get_mnistm(train=True): 91 | """ Get the MNIST-M data loader """ 92 | 93 | # Get the parameters for creating data loader 94 | config = yaml.load(open('config.yaml')) 95 | 96 | # Image pre-processing 97 | transform = transforms.Compose([transforms.Resize(config['img_size']), 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 100 | 101 | # MNIST-M dataset 102 | mnistm = MNISTM(root=config['root'], 103 | train=train, 104 | download=True, 105 | transform=transform) 106 | 107 | # MNIST-M data loader 108 | mnistm_loader = data.DataLoader(dataset=mnistm, 109 | batch_size=config['batch_size'], 110 | shuffle=True) 111 | 112 | return mnistm_loader 113 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import warnings 5 | 6 | from utils import load_model, visualize 7 | from core import source_only, dann, test 8 | from datasets import get_mnist, get_mnistm 9 | from models import Extractor, Classifier, Discriminator 10 | 11 | 12 | MODE_MAP = {'source-only': 'Source-Only', 'dann': 'DANN'} 13 | DATASETS_MAP = {'mnist': 'get_mnist', 'mnistm': 'get_mnistm'} 14 | 15 | 16 | # Ignore warnings 17 | warnings.filterwarnings(action='ignore') 18 | 19 | 20 | def get_args(): 21 | """ Get the arguments for training and test """ 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--source', type=str, default='mnist', choices=DATASETS_MAP.keys(), help='Source datasets') 26 | parser.add_argument('--target', type=str, default='mnistm', choices=DATASETS_MAP.keys(), help='Target datasets') 27 | parser.add_argument('--mode', type=str, default='dann', choices=MODE_MAP.keys(), help='Training mode') 28 | parser.add_argument('--train', action='store_true', help='Train the models') 29 | parser.add_argument('--extractor', type=str, default=None, help='Extractor\'s weights file') 30 | parser.add_argument('--classifier', type=str, default=None, help='Classifier\'s weights file') 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def main(args): 37 | """ The main function """ 38 | 39 | # Get the parameters 40 | config = yaml.load(open('config.yaml')) 41 | 42 | # Get the datasets 43 | train_source_loader = eval(DATASETS_MAP[args.source])(train=True) 44 | train_target_loader = eval(DATASETS_MAP[args.target])(train=True) 45 | test_source_loader = eval(DATASETS_MAP[args.source])(train=False) 46 | test_target_loader = eval(DATASETS_MAP[args.target])(train=False) 47 | 48 | # Get the models 49 | extractor = Extractor(**config['extractor']) 50 | classifier = Classifier(**config['classifier']) 51 | discriminator = Discriminator(**config['discriminator']) 52 | 53 | # Training 54 | if args.train: 55 | if args.mode == 'source-only': 56 | extractor, classifier = source_only(extractor, classifier, train_source_loader) 57 | else: 58 | extractor, classifier = dann(extractor, classifier, discriminator, train_source_loader, train_target_loader) 59 | 60 | # Load the models 61 | else: 62 | assert args.extractor != None, 'If train is False, you have to input the weights file.' 63 | assert args.classifier != None, 'If train is False, you have to input the weights file.' 64 | 65 | ext_filepath = os.path.join(config['save'], args.extractor) 66 | cls_filepath = os.path.join(config['save'], args.classifier) 67 | 68 | assert os.path.exists(ext_filepath), 'There is no {}'.format(ext_filepath) 69 | assert os.path.exists(cls_filepath), 'There is no {}'.format(cls_filepath) 70 | 71 | extractor = load_model(extractor, args.extractor) 72 | classifier = load_model(classifier, args.classifier) 73 | 74 | # Test 75 | print('\nTest Result with Source Datasets on {}\n'.format(MODE_MAP[args.mode])) 76 | test(extractor, classifier, test_source_loader) 77 | 78 | print('\nTest Result with Target Datasets on {}\n'.format(MODE_MAP[args.mode])) 79 | test(extractor, classifier, test_target_loader) 80 | 81 | # Visualization 82 | print('\nVisualizing...\n') 83 | visualize(extractor, test_source_loader, test_target_loader, MODE_MAP[args.mode] + '.png') 84 | 85 | print('Done!') 86 | 87 | 88 | if __name__ == '__main__': 89 | args = get_args() 90 | main(args) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .functions import ReverseLayerF 3 | 4 | 5 | __all__ = ['Extractor', 'Classifier', 'Discriminator'] -------------------------------------------------------------------------------- /models/functions.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | 4 | class ReverseLayerF(Function): 5 | """ The gradient reverse layer class """ 6 | 7 | @staticmethod 8 | def forward(ctx, x, alpha): 9 | """ The method for forward propagation """ 10 | 11 | ctx.alpha = alpha 12 | return x 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | """ The method for backpropagation """ 17 | 18 | output = grad_output.neg() * ctx.alpha 19 | return output, None 20 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .functions import ReverseLayerF 4 | 5 | 6 | class Extractor(nn.Module): 7 | """ The neural network class for extracting feature maps """ 8 | 9 | def __init__(self, in_channels): 10 | 11 | super(Extractor, self).__init__() 12 | self.in_channels = in_channels 13 | 14 | self.extractor = nn.Sequential( 15 | nn.Conv2d(in_channels=self.in_channels, out_channels=32, kernel_size=5), 16 | nn.ReLU(), 17 | nn.MaxPool2d(kernel_size=2, stride=2), 18 | 19 | nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5), 20 | nn.ReLU(), 21 | nn.MaxPool2d(kernel_size=2, stride=2) 22 | ) 23 | 24 | def forward(self, x): 25 | """ The method for forward propagation """ 26 | 27 | x = self.extractor(x) 28 | x = x.view(x.shape[0], -1) 29 | return x 30 | 31 | 32 | class Classifier(nn.Module): 33 | """ The neural network class for classifying labels """ 34 | 35 | def __init__(self, in_features, out_features=10): 36 | 37 | super(Classifier, self).__init__() 38 | self.in_features = in_features 39 | self.out_features = out_features 40 | 41 | self.classifier = nn.Sequential( 42 | nn.Linear(in_features=self.in_features, out_features=100), 43 | nn.ReLU(), 44 | 45 | nn.Linear(in_features=100, out_features=100), 46 | nn.ReLU(), 47 | 48 | nn.Linear(in_features=100, out_features=self.out_features) 49 | ) 50 | 51 | def forward(self, x): 52 | """ The method for forward propagation """ 53 | 54 | x = self.classifier(x) 55 | return x 56 | 57 | 58 | class Discriminator(nn.Module): 59 | """ The neural network class for discriminating domain label """ 60 | 61 | def __init__(self, in_features, out_features=2): 62 | 63 | super(Discriminator, self).__init__() 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | 67 | self.discriminator = nn.Sequential( 68 | nn.Linear(in_features=self.in_features, out_features=100), 69 | nn.ReLU(), 70 | 71 | nn.Linear(in_features=100, out_features=self.out_features) 72 | ) 73 | 74 | def forward(self, x, alpha): 75 | """ The method for forward propagation """ 76 | 77 | x = ReverseLayerF.apply(x, alpha) 78 | x = self.discriminator(x) 79 | return x 80 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from sklearn.manifold import TSNE 7 | 8 | 9 | def optimizer_scheduler(optimizer, p): 10 | """ Adjust the learning rate of optimizer """ 11 | 12 | # Get the parameters for adjusting the learning rate 13 | config = yaml.load(open('config.yaml'))['lr'] 14 | initial_lr = config['initial_lr'] 15 | alpha = config['alpha'] 16 | beta = config['beta'] 17 | 18 | for param_group in optimizer.param_groups: 19 | param_group['lr'] = initial_lr / (1. + alpha * p) ** beta 20 | 21 | return optimizer 22 | 23 | 24 | def save_model(model, filename): 25 | """ Save the model parameters """ 26 | 27 | # Get the directory name 28 | root = yaml.load(open('config.yaml'))['save'] 29 | 30 | # Make the directory for saving model parameters 31 | if not os.path.exists(root): 32 | os.makedirs(root) 33 | 34 | # Save the model parameters 35 | torch.save(model.state_dict(), os.path.join(root, filename)) 36 | 37 | 38 | def load_model(model, filename): 39 | """ Load the model parameters """ 40 | 41 | # Get the directory name 42 | root = yaml.load(open('config.yaml'))['save'] 43 | filepath = os.path.join(root, filename) 44 | 45 | assert os.path.exists(filepath), 'There is no {}.'.format(filepath) 46 | 47 | # Load the model parameters 48 | model.load_state_dict(torch.load(filepath)) 49 | 50 | return model 51 | 52 | 53 | def _plot_graph(features, labels, domain, filename): 54 | """ Plot the t-SNE graph """ 55 | 56 | # Make the visualization directory 57 | root = yaml.load(open('config.yaml'))['visual_root'] 58 | 59 | if not os.path.exists(root): 60 | os.mkdir(root) 61 | 62 | # Rescale the feature range 63 | feat_max, feat_min = np.max(features, 0), np.min(features, 0) 64 | features = (features - feat_min) / (feat_max - feat_min) 65 | 66 | # Plotting 67 | color = {0: 'r', 1: 'b'} 68 | 69 | plt.figure(figsize=(10, 10)) 70 | plt.title(filename.split('.')[0], fontsize=20) 71 | 72 | for i in range(features.shape[0]): 73 | plt.text(features[i][0], features[i][1], 74 | str(labels[i]), 75 | color=color[domain[i]], 76 | fontdict={'weight': 'bold', 'size': 9}) 77 | 78 | plt.xticks([]) 79 | plt.yticks([]) 80 | plt.xlim(-0.05, 1.05) 81 | plt.ylim(-0.05, 1.05) 82 | plt.tight_layout() 83 | plt.savefig(os.path.join(root, filename)) 84 | 85 | 86 | def visualize(extractor, source_loader, target_loader, filename): 87 | """ Visualize the data distribution using t-SNE """ 88 | 89 | images = [] 90 | labels = [] 91 | domain = [] 92 | 93 | # Get some samples from the data loader 94 | for idx, (src_data, tgt_data) in enumerate(zip(source_loader, target_loader)): 95 | if idx >= 15: 96 | break 97 | 98 | images.extend(src_data[0].tolist()) 99 | images.extend(tgt_data[0].tolist()) 100 | 101 | labels.extend(src_data[1].tolist()) 102 | labels.extend(tgt_data[1].tolist()) 103 | 104 | domain.extend([0] * src_data[0].shape[0]) 105 | domain.extend([1] * tgt_data[0].shape[0]) 106 | 107 | # Load a model and images to GPU 108 | extractor = extractor.cuda() 109 | images = torch.tensor(images).cuda() 110 | 111 | # Extract the feature maps 112 | features = extractor(images) 113 | 114 | # Reduce the feature dimensions 115 | tsne = TSNE(n_components=2, perplexity=30., n_iter=3000, init='pca') 116 | features = tsne.fit_transform(features.detach().cpu().numpy()) 117 | 118 | # Plotting 119 | _plot_graph(features, labels, domain, filename) 120 | --------------------------------------------------------------------------------