├── .gitignore ├── LICENSE ├── README.md ├── adda.py ├── config.py ├── data.py ├── models.py ├── revgrad.py ├── task.png ├── test_model.py ├── train_source.py ├── trained_models └── .gitkeep ├── utils.py └── wdgrl.py /.gitignore: -------------------------------------------------------------------------------- 1 | trained_models/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Joris 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Adversarial Domain Adaptation 2 | A collection of implementations of adversarial unsupervised domain adaptation algorithms. 3 | 4 | ## Domain adaptation 5 | ![](task.png) 6 | The goal of domain adaptation is to transfer the knowledge of a model to a different but related data distribution. 7 | The model is trained on a *source* dataset and applied to a *target* dataset (usually unlabeled). 8 | In this case, the model is trained on regular MNIST images, but we want to get good performance on MNIST with random color (without any labels). 9 | 10 | In adversarial domain adaptation, this problem is usually solved by training an auxiliary model called the domain discriminator. The goal of this model is to classify examples as coming from the source or target distribution. The original classifier will then try to maximize the loss of the domain discriminator, comparable to the GAN training procedure. 11 | 12 | ## Implemented papers 13 | **Paper**: Unsupervised Domain Adaptation by Backpropagation, Ganin & Lemptsky (2014) 14 | **Link**: [https://arxiv.org/abs/1409.7495](https://arxiv.org/abs/1409.7495) 15 | **Description**: Negates the gradient of the discriminator for the feature extractor to train both networks simultaneously. 16 | **Implementation**: [revgrad.py](https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/revgrad.py) 17 | 18 | --- 19 | 20 | **Paper**: Adversarial Discriminative Domain Adaptation, Tzeng et al. (2017) 21 | **Link**: [https://arxiv.org/abs/1702.05464](https://arxiv.org/abs/1702.05464) 22 | **Description**: Adapts the weights of a classifier pretrained on source data to produce similar features on the target data. 23 | **Implementation**: [adda.py](https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/adda.py) 24 | 25 | --- 26 | 27 | **Paper**: Wasserstein Distance Guided Representation Learning, Shen et al. (2017) 28 | **Link**: [https://arxiv.org/abs/1707.01217](https://arxiv.org/abs/1707.01217) 29 | **Description**: Uses a domain critic to minimize the Wasserstein Distance (with Gradient Penalty) between domains. 30 | **Implementation**: [wdgrl.py](https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/wdgrl.py) 31 | 32 | 33 | ## Results 34 | 35 | Method | Accuracy on MNIST-M | Parameters 36 | ------------|---------------------|----------- 37 | Source only | 0.33 | 38 | RevGrad | 0.74 | default 39 | ADDA | 0.76 | default 40 | WDGRL | 0.78 | `--k-clf 10 --wd-clf 0.1` 41 | 42 | ## Instructions 43 | 1. Download the [BSDS500 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500) and extract it somewhere. Point the `DATA_DIR` variable in `config.py` to this location. 44 | 2. In a Python 3.6 environment, run: 45 | ``` 46 | $ conda install pytorch torchvision numpy -c pytorch 47 | $ pip install tqdm opencv-python 48 | ``` 49 | 3. Train a model on the source dataset with 50 | ``` 51 | $ python train_source.py 52 | ``` 53 | 4. Choose an algorithm and pass it the pretrained network, for example: 54 | ``` 55 | $ python adda.py trained_models/source.pt 56 | ``` 57 | -------------------------------------------------------------------------------- /adda.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements ADDA: 3 | Adversarial Discriminative Domain Adaptation, Tzeng et al. (2017) 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import MNIST 11 | from torchvision.transforms import Compose, ToTensor 12 | from tqdm import tqdm, trange 13 | 14 | import config 15 | from data import MNISTM 16 | from models import Net 17 | from utils import loop_iterable, set_requires_grad, GrayscaleToRgb 18 | 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | def main(args): 24 | source_model = Net().to(device) 25 | source_model.load_state_dict(torch.load(args.MODEL_FILE)) 26 | source_model.eval() 27 | set_requires_grad(source_model, requires_grad=False) 28 | 29 | clf = source_model 30 | source_model = source_model.feature_extractor 31 | 32 | target_model = Net().to(device) 33 | target_model.load_state_dict(torch.load(args.MODEL_FILE)) 34 | target_model = target_model.feature_extractor 35 | 36 | discriminator = nn.Sequential( 37 | nn.Linear(320, 50), 38 | nn.ReLU(), 39 | nn.Linear(50, 20), 40 | nn.ReLU(), 41 | nn.Linear(20, 1) 42 | ).to(device) 43 | 44 | half_batch = args.batch_size // 2 45 | source_dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True, 46 | transform=Compose([GrayscaleToRgb(), ToTensor()])) 47 | source_loader = DataLoader(source_dataset, batch_size=half_batch, 48 | shuffle=True, num_workers=1, pin_memory=True) 49 | 50 | target_dataset = MNISTM(train=False) 51 | target_loader = DataLoader(target_dataset, batch_size=half_batch, 52 | shuffle=True, num_workers=1, pin_memory=True) 53 | 54 | discriminator_optim = torch.optim.Adam(discriminator.parameters()) 55 | target_optim = torch.optim.Adam(target_model.parameters()) 56 | criterion = nn.BCEWithLogitsLoss() 57 | 58 | for epoch in range(1, args.epochs+1): 59 | batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) 60 | 61 | total_loss = 0 62 | total_accuracy = 0 63 | for _ in trange(args.iterations, leave=False): 64 | # Train discriminator 65 | set_requires_grad(target_model, requires_grad=False) 66 | set_requires_grad(discriminator, requires_grad=True) 67 | for _ in range(args.k_disc): 68 | (source_x, _), (target_x, _) = next(batch_iterator) 69 | source_x, target_x = source_x.to(device), target_x.to(device) 70 | 71 | source_features = source_model(source_x).view(source_x.shape[0], -1) 72 | target_features = target_model(target_x).view(target_x.shape[0], -1) 73 | 74 | discriminator_x = torch.cat([source_features, target_features]) 75 | discriminator_y = torch.cat([torch.ones(source_x.shape[0], device=device), 76 | torch.zeros(target_x.shape[0], device=device)]) 77 | 78 | preds = discriminator(discriminator_x).squeeze() 79 | loss = criterion(preds, discriminator_y) 80 | 81 | discriminator_optim.zero_grad() 82 | loss.backward() 83 | discriminator_optim.step() 84 | 85 | total_loss += loss.item() 86 | total_accuracy += ((preds > 0).long() == discriminator_y.long()).float().mean().item() 87 | 88 | # Train classifier 89 | set_requires_grad(target_model, requires_grad=True) 90 | set_requires_grad(discriminator, requires_grad=False) 91 | for _ in range(args.k_clf): 92 | _, (target_x, _) = next(batch_iterator) 93 | target_x = target_x.to(device) 94 | target_features = target_model(target_x).view(target_x.shape[0], -1) 95 | 96 | # flipped labels 97 | discriminator_y = torch.ones(target_x.shape[0], device=device) 98 | 99 | preds = discriminator(target_features).squeeze() 100 | loss = criterion(preds, discriminator_y) 101 | 102 | target_optim.zero_grad() 103 | loss.backward() 104 | target_optim.step() 105 | 106 | mean_loss = total_loss / (args.iterations*k_disc) 107 | mean_accuracy = total_accuracy / (args.iterations*k_disc) 108 | tqdm.write(f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, ' 109 | f'discriminator_accuracy={mean_accuracy:.4f}') 110 | 111 | # Create the full target model and save it 112 | clf.feature_extractor = target_model 113 | torch.save(clf.state_dict(), 'trained_models/adda.pt') 114 | 115 | 116 | if __name__ == '__main__': 117 | arg_parser = argparse.ArgumentParser(description='Domain adaptation using ADDA') 118 | arg_parser.add_argument('MODEL_FILE', help='A model in trained_models') 119 | arg_parser.add_argument('--batch-size', type=int, default=64) 120 | arg_parser.add_argument('--iterations', type=int, default=500) 121 | arg_parser.add_argument('--epochs', type=int, default=5) 122 | arg_parser.add_argument('--k-disc', type=int, default=1) 123 | arg_parser.add_argument('--k-clf', type=int, default=10) 124 | args = arg_parser.parse_args() 125 | main(args) 126 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | DATA_DIR = Path('/home/joris/data') 4 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torchvision import datasets, transforms 6 | 7 | import config 8 | 9 | 10 | class BSDS500(Dataset): 11 | 12 | def __init__(self): 13 | image_folder = config.DATA_DIR / 'BSR/BSDS500/data/images' 14 | self.image_files = list(map(str, image_folder.glob('*/*.jpg'))) 15 | 16 | def __getitem__(self, i): 17 | image = cv2.imread(self.image_files[i], cv2.IMREAD_COLOR) 18 | tensor = torch.from_numpy(image.transpose(2, 0, 1)) 19 | return tensor 20 | 21 | def __len__(self): 22 | return len(self.image_files) 23 | 24 | 25 | class MNISTM(Dataset): 26 | 27 | def __init__(self, train=True): 28 | super(MNISTM, self).__init__() 29 | self.mnist = datasets.MNIST(config.DATA_DIR / 'mnist', train=train, 30 | download=True) 31 | self.bsds = BSDS500() 32 | # Fix RNG so the same images are used for blending 33 | self.rng = np.random.RandomState(42) 34 | 35 | def __getitem__(self, i): 36 | digit, label = self.mnist[i] 37 | digit = transforms.ToTensor()(digit) 38 | bsds_image = self._random_bsds_image() 39 | patch = self._random_patch(bsds_image) 40 | patch = patch.float() / 255 41 | blend = torch.abs(patch - digit) 42 | return blend, label 43 | 44 | def _random_patch(self, image, size=(28, 28)): 45 | _, im_height, im_width = image.shape 46 | x = self.rng.randint(0, im_width-size[1]) 47 | y = self.rng.randint(0, im_height-size[0]) 48 | return image[:, y:y+size[0], x:x+size[1]] 49 | 50 | def _random_bsds_image(self): 51 | i = self.rng.choice(len(self.bsds)) 52 | return self.bsds[i] 53 | 54 | def __len__(self): 55 | return len(self.mnist) 56 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Net(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.feature_extractor = nn.Sequential( 8 | nn.Conv2d(3, 10, kernel_size=5), 9 | nn.MaxPool2d(2), 10 | nn.ReLU(), 11 | nn.Conv2d(10, 20, kernel_size=5), 12 | nn.MaxPool2d(2), 13 | nn.Dropout2d(), 14 | ) 15 | 16 | self.classifier = nn.Sequential( 17 | nn.Linear(320, 50), 18 | nn.ReLU(), 19 | nn.Dropout(), 20 | nn.Linear(50, 10), 21 | ) 22 | 23 | def forward(self, x): 24 | features = self.feature_extractor(x) 25 | features = features.view(x.shape[0], -1) 26 | logits = self.classifier(features) 27 | return logits 28 | -------------------------------------------------------------------------------- /revgrad.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements RevGrad: 3 | Unsupervised Domain Adaptation by Backpropagation, Ganin & Lemptsky (2014) 4 | Domain-adversarial training of neural networks, Ganin et al. (2016) 5 | """ 6 | import argparse 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from torchvision.datasets import MNIST 13 | from torchvision.transforms import Compose, ToTensor 14 | from tqdm import tqdm 15 | 16 | import config 17 | from data import MNISTM 18 | from models import Net 19 | from utils import GrayscaleToRgb, GradientReversal 20 | 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | def main(args): 26 | model = Net().to(device) 27 | model.load_state_dict(torch.load(args.MODEL_FILE)) 28 | feature_extractor = model.feature_extractor 29 | clf = model.classifier 30 | 31 | discriminator = nn.Sequential( 32 | GradientReversal(), 33 | nn.Linear(320, 50), 34 | nn.ReLU(), 35 | nn.Linear(50, 20), 36 | nn.ReLU(), 37 | nn.Linear(20, 1) 38 | ).to(device) 39 | 40 | half_batch = args.batch_size // 2 41 | source_dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True, 42 | transform=Compose([GrayscaleToRgb(), ToTensor()])) 43 | source_loader = DataLoader(source_dataset, batch_size=half_batch, 44 | shuffle=True, num_workers=1, pin_memory=True) 45 | 46 | target_dataset = MNISTM(train=False) 47 | target_loader = DataLoader(target_dataset, batch_size=half_batch, 48 | shuffle=True, num_workers=1, pin_memory=True) 49 | 50 | optim = torch.optim.Adam(list(discriminator.parameters()) + list(model.parameters())) 51 | 52 | for epoch in range(1, args.epochs+1): 53 | batches = zip(source_loader, target_loader) 54 | n_batches = min(len(source_loader), len(target_loader)) 55 | 56 | total_domain_loss = total_label_accuracy = 0 57 | for (source_x, source_labels), (target_x, _) in tqdm(batches, leave=False, total=n_batches): 58 | x = torch.cat([source_x, target_x]) 59 | x = x.to(device) 60 | domain_y = torch.cat([torch.ones(source_x.shape[0]), 61 | torch.zeros(target_x.shape[0])]) 62 | domain_y = domain_y.to(device) 63 | label_y = source_labels.to(device) 64 | 65 | features = feature_extractor(x).view(x.shape[0], -1) 66 | domain_preds = discriminator(features).squeeze() 67 | label_preds = clf(features[:source_x.shape[0]]) 68 | 69 | domain_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_y) 70 | label_loss = F.cross_entropy(label_preds, label_y) 71 | loss = domain_loss + label_loss 72 | 73 | optim.zero_grad() 74 | loss.backward() 75 | optim.step() 76 | 77 | total_domain_loss += domain_loss.item() 78 | total_label_accuracy += (label_preds.max(1)[1] == label_y).float().mean().item() 79 | 80 | mean_loss = total_domain_loss / n_batches 81 | mean_accuracy = total_label_accuracy / n_batches 82 | tqdm.write(f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, ' 83 | f'source_accuracy={mean_accuracy:.4f}') 84 | 85 | torch.save(model.state_dict(), 'trained_models/revgrad.pt') 86 | 87 | 88 | if __name__ == '__main__': 89 | arg_parser = argparse.ArgumentParser(description='Domain adaptation using RevGrad') 90 | arg_parser.add_argument('MODEL_FILE', help='A model in trained_models') 91 | arg_parser.add_argument('--batch-size', type=int, default=64) 92 | arg_parser.add_argument('--epochs', type=int, default=15) 93 | args = arg_parser.parse_args() 94 | main(args) 95 | -------------------------------------------------------------------------------- /task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jvanvugt/pytorch-domain-adaptation/be63aadc18821d6b19c75df51f264ff08370a765/task.png -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision.datasets import MNIST 7 | from torchvision.transforms import Compose, ToTensor 8 | from tqdm import tqdm 9 | 10 | from data import MNISTM 11 | from models import Net 12 | 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def main(args): 18 | dataset = MNISTM(train=False) 19 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 20 | drop_last=False, num_workers=1, pin_memory=True) 21 | 22 | model = Net().to(device) 23 | model.load_state_dict(torch.load(args.MODEL_FILE)) 24 | model.eval() 25 | 26 | total_accuracy = 0 27 | with torch.no_grad(): 28 | for x, y_true in tqdm(dataloader, leave=False): 29 | x, y_true = x.to(device), y_true.to(device) 30 | y_pred = model(x) 31 | total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item() 32 | 33 | mean_accuracy = total_accuracy / len(dataloader) 34 | print(f'Accuracy on target data: {mean_accuracy:.4f}') 35 | 36 | 37 | if __name__ == '__main__': 38 | arg_parser = argparse.ArgumentParser(description='Test a model on MNIST-M') 39 | arg_parser.add_argument('MODEL_FILE', help='A model in trained_models') 40 | arg_parser.add_argument('--batch-size', type=int, default=256) 41 | args = arg_parser.parse_args() 42 | main(args) 43 | -------------------------------------------------------------------------------- /train_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | from torchvision.datasets import MNIST 8 | from torchvision.transforms import Compose, ToTensor 9 | from tqdm import tqdm 10 | 11 | import config 12 | from models import Net 13 | from utils import GrayscaleToRgb 14 | 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def create_dataloaders(batch_size): 20 | dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True, 21 | transform=Compose([GrayscaleToRgb(), ToTensor()])) 22 | shuffled_indices = np.random.permutation(len(dataset)) 23 | train_idx = shuffled_indices[:int(0.8*len(dataset))] 24 | val_idx = shuffled_indices[int(0.8*len(dataset)):] 25 | 26 | train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, 27 | sampler=SubsetRandomSampler(train_idx), 28 | num_workers=1, pin_memory=True) 29 | val_loader = DataLoader(dataset, batch_size=batch_size, drop_last=False, 30 | sampler=SubsetRandomSampler(val_idx), 31 | num_workers=1, pin_memory=True) 32 | return train_loader, val_loader 33 | 34 | 35 | def do_epoch(model, dataloader, criterion, optim=None): 36 | total_loss = 0 37 | total_accuracy = 0 38 | for x, y_true in tqdm(dataloader, leave=False): 39 | x, y_true = x.to(device), y_true.to(device) 40 | y_pred = model(x) 41 | loss = criterion(y_pred, y_true) 42 | 43 | if optim is not None: 44 | optim.zero_grad() 45 | loss.backward() 46 | optim.step() 47 | 48 | total_loss += loss.item() 49 | total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item() 50 | mean_loss = total_loss / len(dataloader) 51 | mean_accuracy = total_accuracy / len(dataloader) 52 | 53 | return mean_loss, mean_accuracy 54 | 55 | 56 | def main(args): 57 | train_loader, val_loader = create_dataloaders(args.batch_size) 58 | 59 | model = Net().to(device) 60 | optim = torch.optim.Adam(model.parameters()) 61 | lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=1, verbose=True) 62 | criterion = torch.nn.CrossEntropyLoss() 63 | 64 | best_accuracy = 0 65 | for epoch in range(1, args.epochs+1): 66 | model.train() 67 | train_loss, train_accuracy = do_epoch(model, train_loader, criterion, optim=optim) 68 | 69 | model.eval() 70 | with torch.no_grad(): 71 | val_loss, val_accuracy = do_epoch(model, val_loader, criterion, optim=None) 72 | 73 | tqdm.write(f'EPOCH {epoch:03d}: train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f} ' 74 | f'val_loss={val_loss:.4f}, val_accuracy={val_accuracy:.4f}') 75 | 76 | if val_accuracy > best_accuracy: 77 | print('Saving model...') 78 | best_accuracy = val_accuracy 79 | torch.save(model.state_dict(), 'trained_models/source.pt') 80 | 81 | lr_schedule.step(val_loss) 82 | 83 | 84 | if __name__ == '__main__': 85 | arg_parser = argparse.ArgumentParser(description='Train a network on MNIST') 86 | arg_parser.add_argument('--batch-size', type=int, default=64) 87 | arg_parser.add_argument('--epochs', type=int, default=30) 88 | args = arg_parser.parse_args() 89 | main(args) 90 | -------------------------------------------------------------------------------- /trained_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jvanvugt/pytorch-domain-adaptation/be63aadc18821d6b19c75df51f264ff08370a765/trained_models/.gitkeep -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Function 6 | 7 | 8 | def set_requires_grad(model, requires_grad=True): 9 | for param in model.parameters(): 10 | param.requires_grad = requires_grad 11 | 12 | 13 | def loop_iterable(iterable): 14 | while True: 15 | yield from iterable 16 | 17 | 18 | class GrayscaleToRgb: 19 | """Convert a grayscale image to rgb""" 20 | def __call__(self, image): 21 | image = np.array(image) 22 | image = np.dstack([image, image, image]) 23 | return Image.fromarray(image) 24 | 25 | 26 | class GradientReversalFunction(Function): 27 | """ 28 | Gradient Reversal Layer from: 29 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 30 | 31 | Forward pass is the identity function. In the backward pass, 32 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 33 | """ 34 | 35 | @staticmethod 36 | def forward(ctx, x, lambda_): 37 | ctx.lambda_ = lambda_ 38 | return x.clone() 39 | 40 | @staticmethod 41 | def backward(ctx, grads): 42 | lambda_ = ctx.lambda_ 43 | lambda_ = grads.new_tensor(lambda_) 44 | dx = -lambda_ * grads 45 | return dx, None 46 | 47 | 48 | class GradientReversal(torch.nn.Module): 49 | def __init__(self, lambda_=1): 50 | super(GradientReversal, self).__init__() 51 | self.lambda_ = lambda_ 52 | 53 | def forward(self, x): 54 | return GradientReversalFunction.apply(x, self.lambda_) 55 | -------------------------------------------------------------------------------- /wdgrl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements WDGRL: 3 | Wasserstein Distance Guided Representation Learning, Shen et al. (2017) 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from torch import nn 9 | from torch.autograd import grad 10 | from torch.utils.data import DataLoader 11 | from torchvision.datasets import MNIST 12 | from torchvision.transforms import Compose, ToTensor 13 | from tqdm import tqdm, trange 14 | 15 | import config 16 | from data import MNISTM 17 | from models import Net 18 | from utils import loop_iterable, set_requires_grad, GrayscaleToRgb 19 | 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | def gradient_penalty(critic, h_s, h_t): 25 | # based on: https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L116 26 | alpha = torch.rand(h_s.size(0), 1).to(device) 27 | differences = h_t - h_s 28 | interpolates = h_s + (alpha * differences) 29 | interpolates = torch.stack([interpolates, h_s, h_t]).requires_grad_() 30 | 31 | preds = critic(interpolates) 32 | gradients = grad(preds, interpolates, 33 | grad_outputs=torch.ones_like(preds), 34 | retain_graph=True, create_graph=True)[0] 35 | gradient_norm = gradients.norm(2, dim=1) 36 | gradient_penalty = ((gradient_norm - 1)**2).mean() 37 | return gradient_penalty 38 | 39 | 40 | def main(args): 41 | clf_model = Net().to(device) 42 | clf_model.load_state_dict(torch.load(args.MODEL_FILE)) 43 | 44 | feature_extractor = clf_model.feature_extractor 45 | discriminator = clf_model.classifier 46 | 47 | critic = nn.Sequential( 48 | nn.Linear(320, 50), 49 | nn.ReLU(), 50 | nn.Linear(50, 20), 51 | nn.ReLU(), 52 | nn.Linear(20, 1) 53 | ).to(device) 54 | 55 | half_batch = args.batch_size // 2 56 | source_dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True, 57 | transform=Compose([GrayscaleToRgb(), ToTensor()])) 58 | source_loader = DataLoader(source_dataset, batch_size=half_batch, drop_last=True, 59 | shuffle=True, num_workers=0, pin_memory=True) 60 | 61 | target_dataset = MNISTM(train=False) 62 | target_loader = DataLoader(target_dataset, batch_size=half_batch, drop_last=True, 63 | shuffle=True, num_workers=0, pin_memory=True) 64 | 65 | critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) 66 | clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) 67 | clf_criterion = nn.CrossEntropyLoss() 68 | 69 | for epoch in range(1, args.epochs+1): 70 | batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) 71 | 72 | total_loss = 0 73 | total_accuracy = 0 74 | for _ in trange(args.iterations, leave=False): 75 | (source_x, source_y), (target_x, _) = next(batch_iterator) 76 | # Train critic 77 | set_requires_grad(feature_extractor, requires_grad=False) 78 | set_requires_grad(critic, requires_grad=True) 79 | 80 | source_x, target_x = source_x.to(device), target_x.to(device) 81 | source_y = source_y.to(device) 82 | 83 | with torch.no_grad(): 84 | h_s = feature_extractor(source_x).data.view(source_x.shape[0], -1) 85 | h_t = feature_extractor(target_x).data.view(target_x.shape[0], -1) 86 | for _ in range(args.k_critic): 87 | gp = gradient_penalty(critic, h_s, h_t) 88 | 89 | critic_s = critic(h_s) 90 | critic_t = critic(h_t) 91 | wasserstein_distance = critic_s.mean() - critic_t.mean() 92 | 93 | critic_cost = -wasserstein_distance + args.gamma*gp 94 | 95 | critic_optim.zero_grad() 96 | critic_cost.backward() 97 | critic_optim.step() 98 | 99 | total_loss += critic_cost.item() 100 | 101 | # Train classifier 102 | set_requires_grad(feature_extractor, requires_grad=True) 103 | set_requires_grad(critic, requires_grad=False) 104 | for _ in range(args.k_clf): 105 | source_features = feature_extractor(source_x).view(source_x.shape[0], -1) 106 | target_features = feature_extractor(target_x).view(target_x.shape[0], -1) 107 | 108 | source_preds = discriminator(source_features) 109 | clf_loss = clf_criterion(source_preds, source_y) 110 | wasserstein_distance = critic(source_features).mean() - critic(target_features).mean() 111 | 112 | loss = clf_loss + args.wd_clf * wasserstein_distance 113 | clf_optim.zero_grad() 114 | loss.backward() 115 | clf_optim.step() 116 | 117 | mean_loss = total_loss / (args.iterations * args.k_critic) 118 | tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}') 119 | torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt') 120 | 121 | 122 | if __name__ == '__main__': 123 | arg_parser = argparse.ArgumentParser(description='Domain adaptation using WDGRL') 124 | arg_parser.add_argument('MODEL_FILE', help='A model in trained_models') 125 | arg_parser.add_argument('--batch-size', type=int, default=64) 126 | arg_parser.add_argument('--iterations', type=int, default=500) 127 | arg_parser.add_argument('--epochs', type=int, default=5) 128 | arg_parser.add_argument('--k-critic', type=int, default=5) 129 | arg_parser.add_argument('--k-clf', type=int, default=1) 130 | arg_parser.add_argument('--gamma', type=float, default=10) 131 | arg_parser.add_argument('--wd-clf', type=float, default=1) 132 | args = arg_parser.parse_args() 133 | main(args) 134 | --------------------------------------------------------------------------------