├── models ├── __init__.py ├── torchvision_variants.py ├── torch_future.py ├── large_models.py └── model.py ├── train ├── __init__.py ├── optim.py ├── old_main.py └── utils.py ├── caffenet ├── __init__.py └── caffenet_pytorch.py ├── dataset ├── __init__.py └── data_loader.py ├── README.md ├── generalization_mnistM.sh ├── generalization_svhn.sh ├── compute_dataset_statistics.py ├── convert_images.py ├── .gitignore ├── test.py ├── logger.py └── new_main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /caffenet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is a pytorch implementation of the paper *[Hallucinating Agnostic Images to Generalize Across Domains](https://arxiv.org/abs/1808.01102)* 2 | -------------------------------------------------------------------------------- /generalization_mnistM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python new_main.py --epoch 200 --source mnist svhn synth --target mnist_m --data_aug_mode simple --source_limit 20000 --use_deco --generalization --classifier multi --DANN_weight 2.0 3 | -------------------------------------------------------------------------------- /generalization_svhn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python new_main.py --epoch 800 --source mnist_m mnist synth --target svhn --data_aug_mode simple-tuned --source_limit 20000 --target_limit 20000 --use_deco --generalization --classifier multi $1 3 | -------------------------------------------------------------------------------- /compute_dataset_statistics.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | from dataset import data_loader 7 | from dataset.data_loader import get_dataloader 8 | 9 | 10 | def get_args(): 11 | args = ArgumentParser() 12 | args.add_argument("dataset_name", choices=data_loader.dataset_list) 13 | args.add_argument("--data_aug", default="simple-no-norm") 14 | return args.parse_args() 15 | 16 | 17 | def get_stats(input_loader): 18 | n_batches = len(input_loader) 19 | std = np.zeros((n_batches, 3)) 20 | mean = np.zeros((n_batches, 3)) 21 | 22 | for i, (data, _) in enumerate(tqdm(input_loader)): 23 | std[i] = data.numpy().std(axis=(0, 2, 3)) 24 | mean[i] = data.numpy().mean(axis=(0, 2, 3)) 25 | print("STD:", std.mean(axis=0)) 26 | print("MEAN:", mean.mean(axis=0)) 27 | 28 | 29 | if __name__ == "__main__": 30 | args = get_args() 31 | input_loader = get_dataloader(args.dataset_name, 1000, 28, args.data_aug, None) 32 | get_stats(input_loader) 33 | -------------------------------------------------------------------------------- /convert_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from os import path 4 | 5 | import torch 6 | from torchvision.utils import save_image 7 | 8 | from dataset.data_loader import get_images_for_conversion 9 | 10 | 11 | def get_args(): 12 | args = ArgumentParser() 13 | args.add_argument("model_path") 14 | args.add_argument("input_path") 15 | args.add_argument("output_path") 16 | return args.parse_args() 17 | 18 | 19 | def convert_dataset(model, input_loader, output_folder, input_prefix): 20 | for i, (img, im_path) in enumerate(input_loader): 21 | out = torch.tanh(model.deco(img.unsqueeze(0).cuda())).squeeze().data 22 | outpath = path.join(output_folder, im_path[input_prefix:]) 23 | folder = path.dirname(outpath) 24 | if not path.exists(folder): 25 | os.makedirs(folder) 26 | save_image(out, outpath) 27 | if i % 100 == 0: 28 | print("%d/%d" % (i, len(input_loader))) 29 | 30 | 31 | if __name__ == "__main__": 32 | args = get_args() 33 | input_folder = args.input_path 34 | l = len(input_folder) 35 | output_folder = args.output_path 36 | model_path = args.model_path 37 | model = torch.load(model_path) 38 | input_loader = get_images_for_conversion(input_folder, image_size=28) 39 | with torch.no_grad(): 40 | convert_dataset(model, input_loader, output_folder, l) 41 | -------------------------------------------------------------------------------- /caffenet/caffenet_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.torch_future import LocalResponseNorm, Flatten 5 | 6 | def load_caffenet(): 7 | model = nn.Sequential( # Sequential, 8 | nn.Conv2d(3, 96, (11, 11), (4, 4)), #0 9 | nn.ReLU(), 10 | nn.MaxPool2d((3, 3), (2, 2), (0, 0), ceil_mode=True), 11 | LocalResponseNorm(5, 0.0001, 0.75), 12 | nn.Conv2d(96, 256, (5, 5), (1, 1), (2, 2), 1, 2), #4 13 | nn.ReLU(), 14 | nn.MaxPool2d((3, 3), (2, 2), (0, 0), ceil_mode=True), 15 | LocalResponseNorm(5, 0.0001, 0.75), 16 | nn.Conv2d(256, 384, (3, 3), (1, 1), (1, 1)), #8 17 | nn.ReLU(), 18 | nn.Conv2d(384, 384, (3, 3), (1, 1), (1, 1), 1, 2), #10 19 | nn.ReLU(), 20 | nn.Conv2d(384, 256, (3, 3), (1, 1), (1, 1), 1, 2), #12 21 | nn.ReLU(), 22 | nn.MaxPool2d((3, 3), (2, 2), (0, 0), ceil_mode=True), 23 | Flatten(), 24 | nn.Linear(9216, 4096), # Linear, 16 25 | nn.ReLU(), 26 | nn.Dropout(0.5), 27 | nn.Linear(4096, 4096), # Linear, 19 28 | nn.ReLU(), 29 | nn.Dropout(0.5), 30 | ) 31 | pretrained = torch.load('caffenet/caffenet_pytorch.pth') 32 | pretrained["16.weight"] = pretrained["16.1.weight"] 33 | pretrained["16.bias"] = pretrained["16.1.bias"] 34 | pretrained["19.weight"] = pretrained["19.1.weight"] 35 | pretrained["19.bias"] = pretrained["19.1.bias"] 36 | del pretrained["16.1.weight"], pretrained["16.1.bias"], pretrained["19.1.weight"], pretrained["19.1.bias"], \ 37 | pretrained["22.1.weight"], pretrained["22.1.bias"] 38 | model.load_state_dict(pretrained) 39 | return model -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | *.pkl 4 | 5 | dataset/* 6 | *.pth 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.backends.cudnn as cudnn 3 | import torch.utils.data 4 | from torchvision import transforms 5 | from dataset.data_loader import GetLoader, get_dataset, dataset_list 6 | from torchvision import datasets 7 | 8 | cache = {} 9 | 10 | 11 | def get_dataloader(dataset_name, image_size, limit, batch_size, tune_stats=False): 12 | dataloader = cache.get(dataset_name, None) 13 | if tune_stats: 14 | mode = "test-tuned" 15 | else: 16 | mode = "test" 17 | if dataloader is None: 18 | dataloader = torch.utils.data.DataLoader( 19 | dataset=get_dataset(dataset_name, image_size, mode=mode, limit=limit), 20 | batch_size=batch_size, 21 | shuffle=False, 22 | num_workers=4, 23 | pin_memory=True 24 | ) 25 | cache[dataset_name] = dataloader 26 | return dataloader 27 | 28 | 29 | def test(dataset_name, epoch, model, image_size, domain, batch_size=1024, limit=None, tune_stats=False): 30 | assert dataset_name in dataset_list 31 | model.eval() 32 | cuda = True 33 | cudnn.benchmark = True 34 | lambda_val = 0 35 | 36 | n_total = 0.0 37 | n_correct = 0.0 38 | 39 | model.train(False) 40 | dataloader = get_dataloader(dataset_name, image_size, limit, batch_size, tune_stats=tune_stats) 41 | for i, (t_img, t_label) in enumerate(dataloader): 42 | batch_size = len(t_label) 43 | if cuda: 44 | t_img = t_img.cuda() 45 | t_label = t_label.cuda() 46 | with torch.no_grad(): 47 | class_output, _, _ = model(input_data=t_img, lambda_val=lambda_val, domain=domain) 48 | pred = class_output.data.max(1, keepdim=True)[1] 49 | n_correct += pred.eq(t_label.view_as(pred)).cpu().sum().item() 50 | n_total += batch_size 51 | 52 | accu = n_correct / n_total 53 | 54 | print('epoch: %d, accuracy of the %s dataset (%d batches): %f' % (epoch, dataset_name, len(dataloader), accu)) 55 | return accu 56 | -------------------------------------------------------------------------------- /models/torchvision_variants.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.utils import model_zoo 3 | from torchvision.models.alexnet import model_urls 4 | 5 | from models import model as models 6 | 7 | 8 | class SmallAlexNet(nn.Module): 9 | "This alexnet is transformed to receive 57x57 input images instead of the usual 227" 10 | def __init__(self, num_classes=1000): 11 | super(SmallAlexNet, self).__init__() 12 | self.features = nn.Sequential( 13 | nn.Conv2d(3, 64, kernel_size=11, stride=2, padding=3), 14 | nn.ReLU(inplace=True), 15 | models.PassData(), # this is here to simplify the pretrained weight loading 16 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=3, stride=2), 19 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=3, stride=2), 26 | ) 27 | self.classifier = nn.Sequential( 28 | nn.Dropout(), 29 | nn.Linear(256 * 6 * 6, 4096), 30 | nn.ReLU(inplace=True), 31 | nn.Dropout(), 32 | nn.Linear(4096, 4096), 33 | nn.ReLU(inplace=True), 34 | nn.Linear(4096, num_classes), 35 | ) 36 | 37 | def forward(self, x): 38 | x = self.features(x) 39 | x = x.view(x.size(0), 256 * 6 * 6) 40 | x = self.classifier(x) 41 | return x 42 | 43 | 44 | def small_alexnet(pretrained=False, **kwargs): 45 | r"""AlexNet model architecture from the 46 | `"One weird trick..." `_ paper. 47 | 48 | Args: 49 | pretrained (bool): If True, returns a model pre-trained on ImageNet 50 | """ 51 | model = SmallAlexNet(**kwargs) 52 | if pretrained: 53 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 54 | return model -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | -------------------------------------------------------------------------------- /train/optim.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | from torch import optim 5 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler 6 | 7 | GAMMA = 0.1 8 | 9 | base_step_down_ratio = 0.4 10 | 11 | NESTEROV = True 12 | WEIGHT_DECAY = 0.0005 13 | MOMENTUM = 0.9 14 | 15 | 16 | class Optimizers(Enum): 17 | adam = "adam" 18 | sgd = "sgd" 19 | rmsprop = "rms_prop" 20 | 21 | 22 | optimizer_list = [v.value for v in Optimizers] 23 | 24 | 25 | def get_optimizer_and_scheduler(optim_name, net, max_epochs, lr, keep_pretrained_fixed): 26 | if keep_pretrained_fixed: 27 | params = net.get_trainable_params 28 | else: 29 | params = net.parameters 30 | print("Number of trainable group of params %d:" % sum(1 for x in params())) 31 | if optim_name == Optimizers.adam.value: 32 | optimizer = optim.Adam(params(), lr=lr) 33 | step_down_ratio = 0.8 34 | elif optim_name == Optimizers.rmsprop.value: 35 | optimizer = optim.RMSprop(params(), lr=lr) 36 | step_down_ratio = 0.8 37 | elif optim_name == Optimizers.sgd.value: 38 | optimizer = optim.SGD(params(), lr=lr, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=NESTEROV) 39 | step_down_ratio = base_step_down_ratio 40 | scheduler = get_scheduler(optimizer, max_epochs, step_down_ratio) 41 | # scheduler = InvertedLR(optimizer, (10000.0/max_epochs)*0.0003, 0.75, 1) # TODO: add option to enable it on demand 42 | return optimizer, scheduler 43 | 44 | 45 | def get_scheduler(optimizer, max_epochs, step_down=base_step_down_ratio): 46 | steps = [max_epochs * k for k in np.arange(step_down, 1.0, step_down)] 47 | return MultiStepLR(optimizer, milestones=steps, gamma=GAMMA) 48 | 49 | 50 | class InvertedLR(_LRScheduler): 51 | """Set the learning rate of each parameter group to the initial lr decayed 52 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 53 | 54 | Args: 55 | optimizer (Optimizer): Wrapped optimizer. 56 | gamma (float): Multiplicative factor of learning rate decay. 57 | last_epoch (int): The index of last epoch. Default: -1. 58 | """ 59 | 60 | def __init__(self, optimizer, gamma, power, init_lr, last_epoch=-1): 61 | self.gamma = gamma 62 | self.power = power 63 | self.init_lr = init_lr 64 | self.iter = -1 65 | super(InvertedLR, self).__init__(optimizer, last_epoch) 66 | 67 | def step_iter(self): 68 | self.iter += 1 69 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 70 | param_group['lr'] = lr 71 | 72 | def get_lr(self): 73 | return [base_lr * (self.init_lr * (1 + self.gamma * self.iter) ** (-self.power)) 74 | for base_lr in self.base_lrs] 75 | 76 | def step(self, epoch=None): 77 | pass -------------------------------------------------------------------------------- /models/torch_future.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | from torch import nn as nn 4 | 5 | 6 | def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1): 7 | r"""Applies local response normalization over an input signal composed of 8 | several input planes, where channels occupy the second dimension. 9 | Applies normalization across channels. 10 | 11 | See :class:`~torch.nn.LocalResponseNorm` for details. 12 | """ 13 | dim = input.dim() 14 | if dim < 3: 15 | raise ValueError('Expected 3D or higher dimensionality \ 16 | input (got {} dimensions)'.format(dim)) 17 | div = input.mul(input).unsqueeze(1) 18 | if dim == 3: 19 | div = F.pad(div, (0, 0, size // 2, (size - 1) // 2)) 20 | div = F.avg_pool2d(div, (size, 1), stride=1).squeeze(1) 21 | else: 22 | sizes = input.size() 23 | div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) 24 | div = F.pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) 25 | div = F.avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) 26 | div = div.view(sizes) 27 | div = div.mul(alpha).add(k).pow(beta) 28 | 29 | return input / div 30 | 31 | 32 | class LocalResponseNorm(nn.Module): 33 | r"""Applies local response normalization over an input signal composed 34 | of several input planes, where channels occupy the second dimension. 35 | Applies normalization across channels. 36 | 37 | .. math:: 38 | b_{c} = a_{c}\left(k + \frac{\alpha}{n} 39 | \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta} 40 | 41 | Args: 42 | size: amount of neighbouring channels used for normalization 43 | alpha: multiplicative factor. Default: 0.0001 44 | beta: exponent. Default: 0.75 45 | k: additive factor. Default: 1 46 | 47 | Shape: 48 | - Input: :math:`(N, C, ...)` 49 | - Output: :math:`(N, C, ...)` (same shape as input) 50 | 51 | Examples: 52 | >>> lrn = nn.LocalResponseNorm(2) 53 | >>> signal_2d = torch.randn(32, 5, 24, 24) 54 | >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7) 55 | >>> output_2d = lrn(signal_2d) 56 | >>> output_4d = lrn(signal_4d) 57 | """ 58 | 59 | def __init__(self, size, alpha=1e-4, beta=0.75, k=1): 60 | super(LocalResponseNorm, self).__init__() 61 | self.size = size 62 | self.alpha = alpha 63 | self.beta = beta 64 | self.k = k 65 | 66 | def forward(self, input): 67 | return local_response_norm(input, self.size, self.alpha, self.beta, 68 | self.k) 69 | 70 | def __repr__(self): 71 | return self.__class__.__name__ + '(' + str(self.size) + ', alpha=' + str(self.alpha) + ', beta=' + str( 72 | self.beta) + ', k=' + str(self.k) + ')' 73 | 74 | 75 | class Flatten(nn.Module): 76 | def forward(self, x): 77 | return x.view(x.size(0), -1) -------------------------------------------------------------------------------- /new_main.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import time 4 | import torch.backends.cudnn as cudnn 5 | import torch.utils.data 6 | 7 | from dataset.data_loader import get_dataloader 8 | from logger import Logger 9 | from models.model import get_net 10 | from test import test 11 | from train.optim import get_optimizer_and_scheduler 12 | from train.utils import get_name, get_folder_name, ensure_dir, train_epoch, get_args, do_pretraining, simple_tuned 13 | 14 | args = get_args() 15 | print(args) 16 | manual_seed = random.randint(1, 1000) 17 | run_name = get_name(args, manual_seed) 18 | print("Working on " + run_name) 19 | log_folder = "logs/" 20 | if args.tmp_log: 21 | log_folder = "/tmp/" 22 | folder_name = get_folder_name(args.source, args.target) 23 | logger = Logger("{}/{}/{}".format(log_folder, folder_name, run_name)) 24 | 25 | model_root = 'models' 26 | 27 | cuda = True 28 | cudnn.benchmark = True 29 | lr = args.lr 30 | batch_size = args.batch_size 31 | image_size = args.image_size 32 | test_batch_size = 1000 33 | if image_size > 100: 34 | test_batch_size = 256 35 | n_epoch = args.epochs 36 | dann_weight = args.DANN_weight 37 | entropy_weight = args.entropy_loss_weight 38 | 39 | source_dataset_names = args.source 40 | target_dataset_name = args.target 41 | 42 | random.seed(manual_seed) 43 | torch.manual_seed(manual_seed) 44 | 45 | args.domain_classes = 1 + len(args.source) 46 | dataloader_source = get_dataloader(args.source, batch_size, image_size, args.data_aug_mode, args.source_limit) 47 | dataloader_target = get_dataloader(args.target, batch_size, image_size, args.data_aug_mode, args.target_limit) 48 | print("Len source %d, len target %d" % (len(dataloader_source), len(dataloader_target))) 49 | 50 | # load model 51 | my_net = get_net(args) 52 | 53 | # setup optimizer 54 | optimizer, scheduler = get_optimizer_and_scheduler(args.optimizer, my_net, args.epochs, args.lr, 55 | args.keep_pretrained_fixed) 56 | 57 | if cuda: 58 | my_net = my_net.cuda() 59 | 60 | if args.deco_pretrain > 0: 61 | do_pretraining(args.deco_pretrain, dataloader_source, dataloader_target, my_net, logger) 62 | start = time.time() 63 | # training 64 | if args.data_aug_mode == simple_tuned: 65 | tune_stats = True 66 | else: 67 | tune_stats = False 68 | for epoch in range(n_epoch): 69 | scheduler.step() 70 | logger.scalar_summary("aux/lr", scheduler.get_lr()[0], epoch) 71 | train_epoch(epoch, dataloader_source, dataloader_target, optimizer, my_net, logger, n_epoch, cuda, dann_weight, 72 | entropy_weight, scheduler, args.generalization) 73 | my_net.set_deco_mode("source") 74 | for d, source in enumerate(source_dataset_names): 75 | s_acc = test(source, epoch, my_net, image_size, d, test_batch_size, limit=args.source_limit, tune_stats=tune_stats) 76 | if len(source_dataset_names) == 1: 77 | source_name = "acc/source" 78 | else: 79 | source_name = "acc/source_%s" % source 80 | logger.scalar_summary(source_name, s_acc, epoch) 81 | my_net.set_deco_mode("target") 82 | t_acc = test(target_dataset_name, epoch, my_net, image_size, len(args.source), test_batch_size, limit=args.target_limit, tune_stats=tune_stats) 83 | logger.scalar_summary("acc/target", t_acc, epoch) 84 | 85 | save_path = '{}/{}/{}_{}.pth'.format(model_root, folder_name, run_name, epoch) 86 | print("Network saved to {}".format(save_path)) 87 | ensure_dir(save_path) 88 | torch.save(my_net, save_path) 89 | print('done, it took %g' % (time.time() - start)) 90 | -------------------------------------------------------------------------------- /models/large_models.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from torch import nn as nn 4 | from torchvision.models import resnet50, alexnet 5 | from torchvision.models.resnet import Bottleneck 6 | 7 | from caffenet.caffenet_pytorch import load_caffenet 8 | from models import torchvision_variants as tv 9 | from models.model import BasicDECO, BasicDANN 10 | from models.torch_future import Flatten 11 | 12 | 13 | class DECO(BasicDECO): 14 | def __init__(self, deco_args): 15 | super(DECO, self).__init__(deco_args) 16 | if self.deco_args.no_pool: 17 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=4, padding=2, 18 | bias=False) 19 | else: 20 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=2, padding=2, 21 | bias=False) 22 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 23 | self.bn1 = nn.BatchNorm2d(self.inplanes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.layer1 = self._make_layer(deco_args.block, self.inplanes, deco_args.n_layers) 26 | if self.deco_args.deconv: 27 | self.deconv = nn.ConvTranspose2d(self.inplanes * deco_args.block.expansion, deco_args.output_channels, 5, 28 | padding=2, stride=4) 29 | else: 30 | self.conv_out = nn.Conv2d(self.inplanes * deco_args.block.expansion, deco_args.output_channels, 1) 31 | self.init_weights() 32 | 33 | def forward(self, input_data): 34 | x = self.conv1(input_data) 35 | x = self.bn1(x) 36 | x = self.relu(x) 37 | if self.deco_args.no_pool is False: 38 | x = self.maxpool(x) 39 | 40 | x = self.layer1(x) 41 | if self.deco_args.deconv: 42 | x = self.deconv(x, output_size=input_data.shape) 43 | else: 44 | x = self.conv_out(x) 45 | x = nn.functional.upsample(x, scale_factor=4, mode='bilinear') 46 | 47 | return self.weighted_sum(input_data, x) 48 | 49 | 50 | class BigDecoDANN(BasicDANN): 51 | def get_trainable_params(self): 52 | return itertools.chain(self.domain_classifier.parameters(), self.class_classifier.parameters(), 53 | self.bottleneck.parameters()) 54 | 55 | 56 | class ResNet50(BigDecoDANN): 57 | def __init__(self, domain_classes, n_classes): 58 | super(ResNet50, self).__init__() 59 | resnet = resnet50(pretrained=True) 60 | self.features = nn.Sequential( 61 | resnet.conv1, 62 | resnet.bn1, 63 | resnet.relu, 64 | resnet.maxpool, 65 | resnet.layer1, 66 | resnet.layer2, 67 | resnet.layer3, 68 | resnet.layer4, 69 | resnet.avgpool 70 | ) 71 | self.class_classifier = nn.Linear(512 * Bottleneck.expansion, n_classes) 72 | self.domain_classifier = nn.Sequential( 73 | nn.Dropout(), 74 | nn.Linear(512 * Bottleneck.expansion, 1024), # pretrained.classifier[1] 75 | nn.ReLU(inplace=True), 76 | nn.Dropout(), 77 | nn.Linear(1024, 1024), # pretrained.classifier[4] 78 | nn.ReLU(inplace=True), 79 | nn.Linear(1024, domain_classes), 80 | ) 81 | 82 | def get_trainable_params(self): 83 | return itertools.chain(self.domain_classifier.parameters(), self.class_classifier.parameters()) 84 | 85 | 86 | class AlexNet(BigDecoDANN): 87 | def __init__(self, domain_classes, n_classes): 88 | super(AlexNet, self).__init__() 89 | pretrained = alexnet(pretrained=True) 90 | self.build_self(pretrained, domain_classes, n_classes) 91 | 92 | def build_self(self, pretrained, domain_classes, n_classes): 93 | self._convs = pretrained.features 94 | self.bottleneck = nn.Linear(4096, 256) # bottleneck 95 | self._classifier = nn.Sequential( 96 | Flatten(), 97 | nn.Dropout(), 98 | pretrained.classifier[1], # nn.Linear(256 * 6 * 6, 4096), # 99 | nn.ReLU(inplace=True), 100 | nn.Dropout(), 101 | pretrained.classifier[4], # nn.Linear(4096, 4096), # 102 | nn.ReLU(inplace=True), 103 | self.bottleneck, 104 | nn.ReLU(inplace=True) 105 | ) 106 | self.features = nn.Sequential(self._convs, self._classifier) 107 | self.class_classifier = nn.Sequential(nn.Dropout(), nn.Linear(256, n_classes)) 108 | self.domain_classifier = nn.Sequential( 109 | nn.Dropout(), 110 | nn.Linear(256, 1024), # pretrained.classifier[1] 111 | nn.ReLU(inplace=True), 112 | nn.Dropout(), 113 | nn.Linear(1024, 1024), # pretrained.classifier[4] 114 | nn.ReLU(inplace=True), 115 | nn.Linear(1024, domain_classes), 116 | ) 117 | 118 | 119 | class SmallAlexNet(AlexNet): 120 | def __init__(self, domain_classes, n_classes): 121 | super(AlexNet, self).__init__() 122 | pretrained = tv.small_alexnet(pretrained=True) 123 | self.build_self(pretrained, domain_classes, n_classes) 124 | 125 | 126 | class CaffeNet(BigDecoDANN): 127 | def __init__(self, domain_classes, n_classes): 128 | super(CaffeNet, self).__init__() 129 | pretrained = load_caffenet() 130 | self._convs = nn.Sequential(*list(pretrained)[:16]) 131 | self.bottleneck = nn.Linear(4096, 256) # bottleneck 132 | self._classifier = nn.Sequential(*list(pretrained)[16:22], 133 | self.bottleneck, 134 | nn.ReLU(inplace=True)) 135 | self.features = nn.Sequential(self._convs, self._classifier) 136 | self.class_classifier = nn.Linear(256, n_classes) 137 | self.domain_classifier = nn.Sequential( 138 | nn.Dropout(), 139 | nn.Linear(256, 1024), # pretrained.classifier[1] 140 | nn.ReLU(inplace=True), 141 | nn.Dropout(), 142 | nn.Linear(1024, 1024), # pretrained.classifier[4] 143 | nn.ReLU(inplace=True), 144 | nn.Linear(1024, domain_classes), 145 | ) 146 | 147 | 148 | class AlexNetNoBottleneck(BigDecoDANN): 149 | def __init__(self, domain_classes, n_classes): 150 | super(AlexNetNoBottleneck, self).__init__() 151 | pretrained = alexnet(pretrained=True) 152 | self._convs = pretrained.features 153 | self._classifier = nn.Sequential( 154 | Flatten(), 155 | nn.Dropout(), 156 | pretrained.classifier[1], # nn.Linear(256 * 6 * 6, 4096), # 157 | nn.ReLU(inplace=True), 158 | nn.Dropout(), 159 | pretrained.classifier[4], # nn.Linear(4096, 4096), # 160 | nn.ReLU(inplace=True), 161 | ) 162 | self.features = nn.Sequential(self._convs, self._classifier) 163 | self.class_classifier = nn.Linear(4096, n_classes) 164 | self.domain_classifier = nn.Sequential( 165 | nn.Dropout(), 166 | nn.Linear(4096, 1024), # pretrained.classifier[1] 167 | nn.ReLU(inplace=True), 168 | nn.Dropout(), 169 | nn.Linear(1024, 1024), # pretrained.classifier[4] 170 | nn.ReLU(inplace=True), 171 | nn.Linear(1024, domain_classes), 172 | ) 173 | 174 | def get_trainable_params(self): 175 | return itertools.chain(self.domain_classifier.parameters(), self.class_classifier.parameters()) -------------------------------------------------------------------------------- /train/old_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import os 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | from torch.autograd import Variable 8 | 9 | from dataset import data_loader 10 | from dataset.data_loader import get_dataset 11 | 12 | from logger import Logger 13 | from models.model import Combo, CNNModel 14 | import numpy as np 15 | from test import test 16 | import time 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--lr', type=float, default=1e-3) 22 | parser.add_argument('--image_size', type=int, default=28) 23 | parser.add_argument('--batch_size', default=128, type=int) 24 | parser.add_argument('--epochs', default=100, type=int) 25 | parser.add_argument('--DANN_weight', default=1.0, type=float) 26 | parser.add_argument('--use_deco', action="store_true", help="If true use deco architecture") 27 | parser.add_argument('--suffix', help="Will be added to end of name", default="") 28 | parser.add_argument('--source', default="mnist", choices=data_loader.dataset_list) 29 | parser.add_argument('--target', default="mnist_m", choices=data_loader.dataset_list) 30 | return parser.parse_args() 31 | 32 | 33 | def get_name(args): 34 | name = "lr:%g_batchSize:%d_epochs:%d_DannWeight:%g_imageSize:%d" % (args.lr, args.batch_size, args.epochs, 35 | args.DANN_weight, args.image_size) 36 | if args.use_deco: 37 | name += "_deco" 38 | return name + args.suffix + "_%d" % (time.time() % 100) 39 | 40 | 41 | def to_np(x): 42 | return x.data.cpu().numpy() 43 | 44 | 45 | def to_grid(x): 46 | channels = x.shape[1] 47 | y = x.swapaxes(1, 3).reshape(3, 28 * 3, 28, channels).swapaxes(1, 2).reshape(28 * 3, 28 * 3, channels).squeeze()[np.newaxis, ...] 48 | print(y.shape) 49 | return y 50 | 51 | 52 | args = get_args() 53 | run_name = get_name(args) 54 | logger = Logger("logs/{}_{}/{}".format(args.source, args.target, run_name)) 55 | 56 | model_root = 'models' 57 | 58 | cuda = True 59 | cudnn.benchmark = True 60 | lr = args.lr 61 | batch_size = args.batch_size 62 | image_size = args.image_size 63 | n_epoch = args.epochs 64 | dann_weight = args.DANN_weight 65 | source_dataset_name = args.source 66 | target_dataset_name = args.target 67 | 68 | manual_seed = random.randint(1, 10000) 69 | random.seed(manual_seed) 70 | torch.manual_seed(manual_seed) 71 | 72 | dataloader_source = torch.utils.data.DataLoader( 73 | dataset=get_dataset(args.source, image_size), 74 | batch_size=batch_size, 75 | shuffle=True, 76 | num_workers=4) 77 | 78 | dataloader_target = torch.utils.data.DataLoader( 79 | dataset=get_dataset(args.target, image_size), 80 | batch_size=batch_size, 81 | shuffle=True, 82 | num_workers=4) 83 | 84 | # load model 85 | 86 | if args.use_deco: 87 | my_net = Combo() 88 | else: 89 | my_net = CNNModel() 90 | 91 | # setup optimizer 92 | 93 | optimizer = optim.Adam(my_net.parameters(), lr=lr) 94 | 95 | loss_class = torch.nn.NLLLoss() 96 | loss_domain = torch.nn.NLLLoss() 97 | 98 | if cuda: 99 | my_net = my_net.cuda() 100 | loss_class = loss_class.cuda() 101 | loss_domain = loss_domain.cuda() 102 | 103 | for p in my_net.parameters(): 104 | p.requires_grad = True 105 | 106 | # training 107 | 108 | for epoch in range(n_epoch): 109 | my_net.train(True) 110 | len_dataloader = min(len(dataloader_source), len(dataloader_target)) 111 | data_source_iter = iter(dataloader_source) 112 | data_target_iter = iter(dataloader_target) 113 | 114 | i = 0 115 | while i < len_dataloader: 116 | 117 | p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader 118 | lambda_val = 2. / (1. + np.exp(-10 * p)) - 1 119 | 120 | # training model using source data 121 | data_source = data_source_iter.next() 122 | s_img, s_label = data_source 123 | 124 | my_net.zero_grad() 125 | batch_size = len(s_label) 126 | 127 | input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) 128 | class_label = torch.LongTensor(batch_size) 129 | domain_label = torch.zeros(batch_size) 130 | domain_label = domain_label.long() 131 | 132 | if cuda: 133 | s_img = s_img.cuda() 134 | s_label = s_label.cuda() 135 | input_img = input_img.cuda() 136 | class_label = class_label.cuda() 137 | domain_label = domain_label.cuda() 138 | 139 | input_img.resize_as_(s_img).copy_(s_img) 140 | class_label.resize_as_(s_label).copy_(s_label) 141 | inputv_img = Variable(input_img) 142 | classv_label = Variable(class_label) 143 | domainv_label = Variable(domain_label) 144 | 145 | class_output, domain_output = my_net(input_data=inputv_img, lambda_val=lambda_val) 146 | err_s_label = loss_class(class_output, classv_label) 147 | err_s_domain = loss_domain(domain_output, domainv_label) 148 | 149 | # training model using target data 150 | data_target = data_target_iter.next() 151 | t_img, _ = data_target 152 | 153 | batch_size = len(t_img) 154 | 155 | input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) 156 | domain_label = torch.ones(batch_size) 157 | domain_label = domain_label.long() 158 | 159 | if cuda: 160 | t_img = t_img.cuda() 161 | input_img = input_img.cuda() 162 | domain_label = domain_label.cuda() 163 | 164 | input_img.resize_as_(t_img).copy_(t_img) 165 | inputv_img = Variable(input_img) 166 | domainv_label = Variable(domain_label) 167 | 168 | _, domain_output = my_net(input_data=inputv_img, lambda_val=lambda_val) 169 | err_t_domain = loss_domain(domain_output, domainv_label) 170 | err = dann_weight * err_t_domain + dann_weight * err_s_domain + err_s_label 171 | err.backward() 172 | optimizer.step() 173 | 174 | if i is 0: 175 | if args.use_deco: 176 | source_images = my_net.deco(Variable(s_img[:9])) 177 | target_images = my_net.deco(Variable(t_img[:9])) 178 | else: 179 | source_images = Variable(s_img[:9]) 180 | target_images = Variable(t_img[:9]) 181 | logger.image_summary("images/source", to_grid(to_np(source_images)), i + epoch * len_dataloader) 182 | logger.image_summary("images/target", to_grid(to_np(target_images)), i + epoch * len_dataloader) 183 | 184 | i += 1 185 | 186 | if (i % 200) == 0: 187 | logger.scalar_summary("loss/source", err_s_label, i + epoch * len_dataloader) 188 | logger.scalar_summary("loss/domain", (err_s_domain + err_t_domain) / 2, i + epoch * len_dataloader) 189 | print('epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f' \ 190 | % (epoch, i, len_dataloader, err_s_label.cpu().data.numpy(), 191 | err_s_domain.cpu().data.numpy(), err_t_domain.cpu().data.numpy())) 192 | 193 | my_net.train(False) 194 | s_acc = test(source_dataset_name, epoch, my_net, image_size) 195 | t_acc = test(target_dataset_name, epoch, my_net, image_size) 196 | my_net.train(True) 197 | logger.scalar_summary("acc/source", s_acc, i + epoch * len_dataloader) 198 | logger.scalar_summary("acc/target", t_acc, i + epoch * len_dataloader) 199 | save_path = '{}/{}_{}/{}_{}.pth'.format(model_root, args.source, args.target, run_name, epoch) 200 | print("Network saved to {}".format(save_path)) 201 | torch.save(my_net, save_path) 202 | print('done') 203 | -------------------------------------------------------------------------------- /dataset/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import pickle 4 | 5 | import numpy as np 6 | import torch.utils 7 | import torch.utils.data as data 8 | from PIL import Image 9 | from torchvision import datasets 10 | from torchvision import transforms 11 | from torchvision.datasets import ImageFolder 12 | 13 | from train.utils import simple_tuned 14 | 15 | mnist = 'mnist' 16 | mnist_m = 'mnist_m' 17 | svhn = 'svhn' 18 | synth = 'synth' 19 | usps = 'usps' 20 | 21 | synth_signs = "synth_signs" 22 | gtsrb = "gtsrb" 23 | webcam = "webcam" 24 | amazon = "amazon" 25 | dslr = "dslr" 26 | eth80_p00 = "eth80-p-000" 27 | eth80_p22 = "eth80-p-022" 28 | eth80_p45 = "eth80-p-045" 29 | eth80_p68 = "eth80-p-068" 30 | eth80_p90 = "eth80-p-090" 31 | eth_list = [eth80_p00, eth80_p22, eth80_p45, eth80_p68, eth80_p90] 32 | 33 | mnist_image_root = os.path.join('dataset', 'mnist') 34 | mnist_m_image_root = os.path.join('dataset', 'mnist_m') 35 | synth_image_root = os.path.join('dataset', 'SynthDigits') 36 | usps_image_root = os.path.join('dataset', 'usps') 37 | gtsrb_image_root = os.path.join('dataset', gtsrb, "signs_") 38 | synth_signs_image_root = os.path.join('dataset', synth_signs, "synth_signs_") 39 | 40 | office_list = [amazon, webcam, dslr] 41 | dataset_list = [mnist, mnist_m, svhn, synth, usps, synth_signs, gtsrb] + office_list + eth_list 42 | 43 | dataset_std = {mnist: (0.30280363, 0.30280363, 0.30280363), 44 | mnist_m: (0.2384788, 0.22375608, 0.24496263), 45 | svhn: (0.1951134, 0.19804622, 0.19481073), 46 | synth: (0.29410212, 0.2939651, 0.29404707), 47 | usps: (0.25887518, 0.25887518, 0.25887518), 48 | synth_signs: (0.28315591, 0.2789663, 0.30152685), 49 | gtsrb: (0.26750497, 0.2494335, 0.25966593), 50 | eth80_p00: (0.21904688, 0.1887137, 0.18277248), 51 | eth80_p22: (0.23107395, 0.19354062, 0.16959815), 52 | eth80_p45: (0.2437287, 0.20425622, 0.17508801), 53 | eth80_p68: (0.24174337, 0.20555658, 0.17148162), 54 | eth80_p90: (0.21427298, 0.18139302, 0.173147) 55 | } 56 | 57 | dataset_mean = {mnist: (0.13909429, 0.13909429, 0.13909429), 58 | mnist_m: (0.45920207, 0.46326601, 0.41085603), 59 | svhn: (0.43744073, 0.4437959, 0.4733686), 60 | synth: (0.46332872, 0.46316052, 0.46327512), 61 | usps: (0.17025368, 0.17025368, 0.17025368), 62 | synth_signs: (0.43471373, 0.40261434, 0.43641199), 63 | gtsrb: (0.36089954, 0.31854348, 0.33791215), 64 | eth80_p00: (0.45713681, 0.4620426, 0.58741586), 65 | eth80_p22: (0.45154243, 0.45350623, 0.58116992), 66 | eth80_p45: (0.4453549, 0.44773841, 0.58268696), 67 | eth80_p68: (0.4284884, 0.43210022, 0.5712227), 68 | eth80_p90: (0.4330995, 0.44464724, 0.58874557) 69 | } 70 | 71 | 72 | def get_images_for_conversion(folder_path, image_size=228): 73 | img_transform = get_transform(image_size, "test", None) 74 | return ImageFolderWithPath(folder_path, transform=img_transform) 75 | 76 | 77 | def get_dataset(dataset_name, image_size, mode="train", limit=None): 78 | img_transform = get_transform(image_size, mode, dataset_name) 79 | if dataset_name == mnist: 80 | dataset = datasets.MNIST( 81 | root=mnist_image_root, 82 | train=True, 83 | transform=img_transform, download=True 84 | ) 85 | elif dataset_name == svhn: 86 | dataset = datasets.SVHN( 87 | root=os.path.join('dataset', 'svhn'), 88 | transform=img_transform, download=True 89 | ) 90 | elif dataset_name == mnist_m: 91 | train_list = os.path.join(mnist_m_image_root, 'mnist_m_train_labels.txt') 92 | dataset = GetLoader( 93 | data_root=os.path.join(mnist_m_image_root, 'mnist_m_train'), 94 | data_list=train_list, 95 | transform=img_transform 96 | ) 97 | elif dataset_name == synth: 98 | train_mat = os.path.join(synth_image_root, 'synth_train_32x32.mat') 99 | dataset = GetSynthDigits( 100 | data_root=synth_image_root, 101 | data_mat=train_mat, 102 | transform=img_transform 103 | ) 104 | elif dataset_name == usps: 105 | data_file = "usps_28x28.pkl" 106 | dataset = GetUSPS( 107 | data_root=usps_image_root, 108 | data_file=data_file, 109 | transform=img_transform 110 | ) 111 | elif dataset_name == gtsrb: 112 | dataset = GetNumpyDataset(gtsrb_image_root, mode, img_transform) 113 | elif dataset_name == synth_signs: 114 | dataset = GetNumpyDataset(synth_signs_image_root, mode, img_transform) 115 | elif dataset_name == amazon: 116 | dataset = datasets.ImageFolder('dataset/amazon', transform=img_transform) 117 | elif dataset_name == dslr: 118 | dataset = datasets.ImageFolder('dataset/dslr', transform=img_transform) 119 | elif dataset_name == webcam: 120 | dataset = datasets.ImageFolder('dataset/webcam', transform=img_transform) 121 | elif dataset_name in eth_list: 122 | dataset = datasets.ImageFolder('dataset/' + dataset_name, transform=img_transform) 123 | elif type(dataset_name) is list: 124 | return ConcatDataset([get_dataset(dset, image_size, mode, limit) for dset in dataset_name]) 125 | if limit: 126 | indices = index_cache.get((dataset_name, limit), None) 127 | if indices is None: 128 | indices = torch.randperm(len(dataset))[:limit] 129 | index_cache[(dataset_name, limit)] = indices 130 | dataset = Subset(dataset, indices) 131 | return RgbWrapper(dataset) 132 | 133 | 134 | def get_transform(image_size, mode, name): 135 | if isinstance(name, list): 136 | return None 137 | # TODO use dataset specific mean and std 138 | if mode == "train": 139 | img_transform = transforms.Compose([ 140 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.5, hue=0.5), 141 | transforms.RandomAffine(15, shear=15), 142 | transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)), 143 | transforms.ToTensor(), 144 | transforms.Normalize(mean=dataset_mean[name], std=dataset_std[name]) 145 | ]) 146 | elif mode == "office": 147 | img_transform = transforms.Compose([ 148 | transforms.RandomResizedCrop(image_size), 149 | transforms.RandomHorizontalFlip(), 150 | # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), 151 | transforms.ToTensor(), 152 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 153 | ]) 154 | elif mode == "simple": 155 | img_transform = transforms.Compose([ 156 | transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)), 157 | transforms.ToTensor(), 158 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 159 | ]) 160 | elif mode == simple_tuned: 161 | img_transform = transforms.Compose([ 162 | transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)), 163 | transforms.ToTensor(), 164 | transforms.Normalize(mean=dataset_mean[name], std=dataset_std[name]) 165 | ]) 166 | elif mode == "simple-no-norm": 167 | img_transform = transforms.Compose([ 168 | transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)), 169 | transforms.ToTensor() 170 | ]) 171 | elif mode == "test": 172 | if name in office_list: 173 | mean = [0.485, 0.456, 0.406] 174 | std = [0.229, 0.224, 0.225] 175 | else: 176 | mean = [0.5, 0.5, 0.5] 177 | std = [0.5, 0.5, 0.5] 178 | img_transform = transforms.Compose([ 179 | transforms.Resize(image_size), 180 | transforms.ToTensor(), 181 | transforms.Normalize(mean=mean, std=std) 182 | ]) 183 | elif mode == "test-tuned": 184 | img_transform = transforms.Compose([ 185 | transforms.Resize(image_size), 186 | transforms.ToTensor(), 187 | transforms.Normalize(mean=dataset_mean[name], std=dataset_std[name]) 188 | ]) 189 | return img_transform 190 | 191 | 192 | index_cache = {} 193 | 194 | class GetLoader(data.Dataset): 195 | def __init__(self, data_root, data_list, transform=None): 196 | self.root = data_root 197 | self.transform = transform 198 | 199 | f = open(data_list, 'r') 200 | data_list = f.readlines() 201 | f.close() 202 | 203 | self.n_data = len(data_list) 204 | 205 | self.img_paths = [] 206 | self.img_labels = [] 207 | 208 | for data in data_list: 209 | self.img_paths.append(data[:-3]) 210 | self.img_labels.append(data[-2]) 211 | 212 | def __getitem__(self, item): 213 | img_paths, labels = self.img_paths[item], self.img_labels[item] 214 | imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') 215 | 216 | if self.transform is not None: 217 | imgs = self.transform(imgs) 218 | labels = int(labels) 219 | 220 | return imgs, labels 221 | 222 | def __len__(self): 223 | return self.n_data 224 | 225 | 226 | class GetSynthDigits(data.Dataset): 227 | def __init__(self, data_root, data_mat, transform=None): 228 | self.root = data_root 229 | self.transform = transform 230 | 231 | import scipy.io as sio 232 | 233 | # reading(loading) mat file as array 234 | loaded_mat = sio.loadmat(data_mat) 235 | self.data = loaded_mat['X'] 236 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 237 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 238 | 239 | def __getitem__(self, item): 240 | img, labels = self.data[item], self.labels[item] 241 | # doing this so that it is consistent with all other datasets 242 | # to return a PIL Image 243 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 244 | if self.transform is not None: 245 | img = self.transform(img) 246 | labels = int(labels) 247 | 248 | return img, labels 249 | 250 | def __len__(self): 251 | return len(self.data) 252 | 253 | 254 | class GetNumpyDataset(data.Dataset): 255 | def __init__(self, data_root, mode, transform=None): 256 | self.root = data_root 257 | self.transform = transform 258 | 259 | test = np.load(self.root + "test_data.npy") 260 | l_test = np.load(self.root + "test_labels.npy") 261 | if mode in ["test", "test-tuned"]: 262 | self.data = test 263 | self.labels = l_test 264 | else: 265 | train = np.load(self.root + "train_data.npy") 266 | l_train = np.load(self.root + "train_labels.npy") 267 | self.data = np.vstack((train, test)) 268 | self.labels = np.hstack((l_train, l_test)) 269 | 270 | self.labels = self.labels.astype(np.int64).squeeze() 271 | 272 | def __getitem__(self, item): 273 | img, labels = self.data[item], self.labels[item] 274 | # doing this so that it is consistent with all other datasets 275 | # to return a PIL Image 276 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 277 | if self.transform is not None: 278 | img = self.transform(img) 279 | labels = int(labels) 280 | 281 | return img, labels 282 | 283 | def __len__(self): 284 | return len(self.data) 285 | 286 | 287 | class GetUSPS(data.Dataset): 288 | def __init__(self, data_root, data_file, transform=None): 289 | self.root = data_root 290 | self.filename = data_file 291 | # Num of Train = 7438, Num ot Test 1860 292 | self.transform = transform 293 | self.dataset_size = None 294 | self.data, self.labels = self.load_samples() 295 | 296 | total_num_samples = self.labels.shape[0] 297 | indices = np.arange(total_num_samples) 298 | np.random.shuffle(indices) 299 | self.data = self.data[indices[0:self.dataset_size], ::] 300 | self.labels = self.labels[indices[0:self.dataset_size]].astype(np.int64).squeeze() 301 | self.data *= 255.0 302 | self.data = np.repeat(self.data.astype("uint8"), 3, axis=1) 303 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 304 | 305 | def __getitem__(self, index): 306 | img, labels = self.data[index, ::], self.labels[index] 307 | img = Image.fromarray(img) 308 | if self.transform is not None: 309 | img = self.transform(img) 310 | labels = int(labels) 311 | # label = torch.FloatTensor([label.item()]) 312 | return img, labels 313 | 314 | def load_samples(self): 315 | """Load sample images from dataset.""" 316 | filename = os.path.join(self.root, self.filename) 317 | f = gzip.open(filename, "rb") 318 | data_set = pickle.load(f, encoding="bytes") 319 | f.close() 320 | images_train = data_set[0][0] 321 | images_test = data_set[1][0] 322 | images = np.concatenate((images_train, images_test), axis=0) 323 | labels_train = data_set[0][1] 324 | labels_test = data_set[1][1] 325 | labels = np.concatenate((labels_train, labels_test), axis=0) 326 | self.dataset_size = labels.shape[0] 327 | return images, labels 328 | 329 | def __len__(self): 330 | """Return size of dataset.""" 331 | return self.dataset_size 332 | 333 | 334 | def get_dataloader(dataset_name, batch_size, image_size, mode, limit): 335 | return torch.utils.data.DataLoader( 336 | dataset=get_dataset(dataset_name, image_size, mode, limit), 337 | batch_size=batch_size, 338 | shuffle=True, 339 | drop_last=True, 340 | num_workers=4, 341 | pin_memory=True) 342 | 343 | 344 | class RgbWrapper(torch.utils.data.Dataset): 345 | def __init__(self, dataset): 346 | self.dataset = dataset 347 | 348 | def __len__(self): 349 | return self.dataset.__len__() 350 | 351 | def __getitem__(self, i): 352 | data, label = self.dataset.__getitem__(i) 353 | return data.expand(3, data.shape[1], data.shape[2]), label 354 | 355 | 356 | class ImageFolderWithPath(ImageFolder): 357 | def __getitem__(self, index): 358 | path, _ = self.imgs[index] 359 | img = self.loader(path) 360 | if self.transform is not None: 361 | img = self.transform(img) 362 | return img, path 363 | 364 | 365 | class Subset(torch.utils.data.Dataset): 366 | def __init__(self, dataset, indices): 367 | self.dataset = dataset 368 | self.indices = indices 369 | 370 | def __getitem__(self, idx): 371 | return self.dataset[self.indices[idx]] 372 | 373 | def __len__(self): 374 | return len(self.indices) 375 | 376 | 377 | class ConcatDataset(torch.utils.data.Dataset): 378 | def __init__(self, datasets): 379 | self.datasets = datasets 380 | 381 | def __getitem__(self, i): 382 | return tuple(d[i % len(d)] for d in self.datasets) 383 | 384 | def __len__(self): 385 | return max(len(d) for d in self.datasets) 386 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import os 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | from torchvision.utils import make_grid 10 | 11 | from dataset import data_loader 12 | from models.model import entropy_loss, Combo, deco_types, classifier_list, deco_modes 13 | from train.optim import optimizer_list, Optimizers, get_optimizer_and_scheduler 14 | import itertools 15 | import random 16 | 17 | simple_tuned = "simple-tuned" 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | # optimizer 23 | parser.add_argument('--optimizer', choices=optimizer_list, default=Optimizers.adam.value) 24 | parser.add_argument('--lr', type=float, default=1e-3) 25 | parser.add_argument('--batch_size', default=128, type=int) 26 | parser.add_argument('--epochs', default=100, type=int) 27 | parser.add_argument('--keep_pretrained_fixed', action="store_true") 28 | # data 29 | parser.add_argument('--image_size', type=int, default=28) 30 | parser.add_argument('--data_aug_mode', default="train", choices=["train", "simple", simple_tuned, "office"]) 31 | parser.add_argument('--source', default=[data_loader.mnist], choices=data_loader.dataset_list, nargs='+') 32 | parser.add_argument('--target', default=data_loader.mnist_m, choices=data_loader.dataset_list) 33 | parser.add_argument('--n_classes', default=10, type=int) 34 | parser.add_argument('--target_limit', type=int, default=None, help="Number of max samples in target") 35 | parser.add_argument('--source_limit', type=int, default=None, help="Number of max samples in each source") 36 | # losses 37 | parser.add_argument('--DANN_weight', default=1.0, type=float) 38 | parser.add_argument('--entropy_loss_weight', default=0.0, type=float, help="Entropy loss on target, default is 0") 39 | # deco 40 | parser.add_argument('--use_deco', action="store_true", help="If true use deco architecture") 41 | parser.add_argument('--train_deco_weight', default=True, type=bool, help="Train the deco weight (True by default)") 42 | parser.add_argument('--train_image_weight', default=False, type=bool, 43 | help="Train the image weight (False by default)") 44 | parser.add_argument('--deco_no_residual', action="store_true", help="If set, no residual will be applied to DECO") 45 | parser.add_argument('--deco_blocks', default=4, type=int) 46 | parser.add_argument('--deco_kernels', default=64, type=int) 47 | parser.add_argument('--deco_block_type', default='basic', choices=deco_types.keys(), 48 | help="Which kind of deco block to use") 49 | parser.add_argument('--deco_output_channels', type=int, default=3, help="3 or 1") 50 | parser.add_argument('--deco_mode', default="shared", choices=deco_modes.keys()) 51 | parser.add_argument('--deco_tanh', action="store_true", help="If set, tanh will be applied to DECO output") 52 | parser.add_argument('--deco_pretrain', default=0, type=int, help="Number of epoch to pretrain DECO (default is 0)") 53 | parser.add_argument('--deco_no_pool', action="store_true") 54 | parser.add_argument('--deco_deconv', action="store_true") 55 | # misc 56 | parser.add_argument('--suffix', help="Will be added to end of name", default="") 57 | parser.add_argument('--classifier', default=None, choices=classifier_list.keys()) 58 | parser.add_argument('--tmp_log', action="store_true", help="If set, logger will save to /tmp instead") 59 | parser.add_argument('--generalization', action="store_true", 60 | help="If set, the target will not be used during training") 61 | args = parser.parse_args() 62 | args.source = sorted(args.source) 63 | return args 64 | 65 | 66 | def get_name(args, seed): 67 | name = "%s_lr:%g_BS:%d_eps:%d_IS:%d_DW:%g_DA%s" % (args.optimizer, args.lr, args.batch_size, args.epochs, 68 | args.image_size, args.DANN_weight, args.data_aug_mode) 69 | if args.source_limit: 70 | name += "_sL%d" % args.source_limit 71 | if args.target_limit: 72 | name += "_tL%d" % args.target_limit 73 | if args.keep_pretrained_fixed: 74 | name += "_freezeNet" 75 | if args.entropy_loss_weight > 0.0: 76 | name += "_entropy:%g" % args.entropy_loss_weight 77 | if args.use_deco: 78 | name += "_deco%d_%d_%s_%dc" % ( 79 | args.deco_blocks, args.deco_kernels, args.deco_block_type, args.deco_output_channels) 80 | if args.deco_pretrain > 0: 81 | name += "_pretrain%d" % args.deco_pretrain 82 | if args.deco_mode != "shared": 83 | name += "_" + args.deco_mode 84 | if args.deco_no_residual: 85 | name += "_no_res" 86 | if args.deco_tanh: 87 | name += "_tanh" 88 | elif args.train_deco_weight or args.train_image_weight: 89 | name += "_train%s%sWeight" % ( 90 | "Deco" if args.train_deco_weight else "", "Image" if args.train_image_weight else "") 91 | else: 92 | name += "_vanilla" 93 | if args.classifier: 94 | name += "_" + args.classifier 95 | if args.generalization: 96 | name += "_generalization" 97 | if args.suffix: 98 | name += "_" + args.suffix 99 | return name + "_%d" % (seed) 100 | 101 | 102 | def to_np(x): 103 | return x.data.cpu().numpy() 104 | 105 | 106 | def to_grid(x): 107 | y = make_grid(x, nrow=3, padding=1, normalize=False, range=None, scale_each=False, pad_value=0) 108 | # import ipdb; ipdb.set_trace() 109 | tmp = y.cpu().numpy().swapaxes(0,1).swapaxes(1,2) 110 | return tmp[np.newaxis, ...] 111 | # channels = x.shape[1] 112 | # s = x.shape[2] 113 | # y = x.swapaxes(1, 3).reshape(3, s * 3, s, channels).swapaxes(1, 2).reshape(s * 3, s * 3, channels).squeeze()[ 114 | # np.newaxis, ...] 115 | # return y 116 | 117 | 118 | def get_folder_name(source, target): 119 | return '-'.join(source) + "_" + target 120 | 121 | 122 | def ensure_dir(file_path): 123 | directory = os.path.dirname(file_path) 124 | if not os.path.exists(directory): 125 | os.makedirs(directory) 126 | 127 | 128 | def do_pretraining(num_epochs, dataloader_source, dataloader_target, model, logger): 129 | for name, deco in model.get_decos(): 130 | print("Pretraining " + name) 131 | pretrain_deco(num_epochs, dataloader_source, dataloader_target, deco, logger, name) 132 | 133 | 134 | def pretrain_deco(num_epochs, dataloader_source, dataloader_target, model, logger, mode): 135 | optimizer, scheduler = get_optimizer_and_scheduler(Optimizers.adam.value, model, num_epochs, 0.001, False) 136 | loss_f = nn.MSELoss().cuda() 137 | for epoch in range(num_epochs): 138 | model.train() 139 | if len(dataloader_source) > len(dataloader_target): 140 | source_loader = dataloader_source 141 | target_loader = itertools.cycle(dataloader_target) 142 | else: 143 | source_loader = itertools.cycle(dataloader_source) 144 | target_loader = dataloader_target 145 | 146 | for i, (source_batches, target_data) in enumerate(zip(source_loader, target_loader)): 147 | scheduler.step() 148 | optimizer.zero_grad() 149 | source_loss = 0.0 150 | for v, source_data in enumerate(source_batches): 151 | s_img, _ = source_data 152 | img_in = s_img.cuda() 153 | out = model(img_in) 154 | loss = loss_f(out, img_in) 155 | loss.backward() 156 | source_loss += loss.data.cpu().numpy() 157 | 158 | # pretrain target deco only if needed 159 | target_loss = 0.0 160 | target_image, _ = target_data 161 | img_in = target_image.cuda() 162 | out = model(img_in) 163 | loss = loss_f(out, img_in) 164 | loss.backward() 165 | target_loss = loss.data.cpu().numpy() 166 | optimizer.step() 167 | if i == 0: 168 | with torch.no_grad(): 169 | source_images = s_img[:9].cuda() 170 | target_images = target_image[:9].cuda() 171 | source_images = model(source_images) 172 | target_images = model(target_images) 173 | logger.image_summary("reconstruction/%s/source" % mode, to_grid(source_images), epoch) 174 | logger.image_summary("reconstruction/%s/target" % mode, to_grid(target_images), epoch) 175 | 176 | print("%d/%d - Reconstruction loss source: %g, target %g" % (epoch, num_epochs, source_loss, target_loss)) 177 | logger.scalar_summary("reconstruction/%s/source" % mode, source_loss, epoch) 178 | logger.scalar_summary("reconstruction/%s/target" % mode, target_loss, epoch) 179 | 180 | 181 | def softmax_list(source_target_similarity): 182 | total_sum = sum(source_target_similarity) 183 | if total_sum == 0: 184 | total_sum = 1 185 | return [v / total_sum for v in source_target_similarity] 186 | 187 | 188 | def train_epoch(epoch, dataloader_source, dataloader_target, optimizer, model, logger, n_epoch, cuda, 189 | dann_weight, entropy_weight, scheduler, generalize, weight_sources=False): 190 | model.train() 191 | len_dataloader = min(len(dataloader_source), len(dataloader_target)) 192 | data_sources_iter = iter(dataloader_source) 193 | data_target_iter = iter(dataloader_target) 194 | # import ipdb; ipdb.set_trace() 195 | batch_idx = 0 196 | domain_error = 0 197 | # TODO count epochs on source 198 | past_source_target_similarity = None 199 | weight_sources = True 200 | if generalize: 201 | weight_sources = False 202 | while batch_idx < len_dataloader: 203 | try: 204 | scheduler.step_iter() 205 | except AttributeError: 206 | pass 207 | absolute_iter_count = batch_idx + epoch * len_dataloader 208 | p = float(absolute_iter_count) / n_epoch / len_dataloader 209 | lambda_val = 2. / (1. + np.exp(-10 * p)) - 1 210 | if domain_error > 3.0: 211 | print("Shutting down DANN gradient to avoid collapse (iter %d)" % absolute_iter_count) 212 | lambda_val = 0.0 213 | data_sources_batch = data_sources_iter.next() 214 | # process source datasets (can be multiple) 215 | err_s_label = 0.0 216 | err_s_domain = 0.0 217 | num_source_domains = len(data_sources_batch) 218 | model.set_deco_mode("source") 219 | source_domain_losses = [] 220 | observed_domain_losses = [] 221 | source_target_similarity = [] 222 | for v, source_data in enumerate(data_sources_batch): 223 | s_img, s_label = source_data 224 | class_loss, domain_loss, observation_loss, target_similarity = compute_batch_loss(cuda, lambda_val, model, 225 | s_img, s_label, v, num_source_domains) 226 | if weight_sources and past_source_target_similarity is not None: 227 | class_loss = class_loss * (torch.tensor(len(data_sources_batch)).float() * torch.tensor(past_source_target_similarity[v])).cuda() 228 | 229 | loss = class_loss + dann_weight * domain_loss + observation_loss 230 | loss.backward() 231 | # used for logging only 232 | err_s_label += class_loss.item() 233 | source_domain_losses.append(domain_loss.data.cpu().numpy()) 234 | observed_domain_losses.append(observation_loss.data.cpu().numpy()) 235 | source_target_similarity.append(target_similarity) 236 | err_s_domain += domain_loss.data.cpu().numpy() 237 | past_source_target_similarity = softmax_list(source_target_similarity) 238 | 239 | err_s_label = err_s_label / num_source_domains 240 | err_s_domain = err_s_domain / num_source_domains 241 | 242 | entropy_target = 0 243 | err_t_domain = 0 244 | domain_error = err_s_domain 245 | if generalize is False: 246 | # training model using target data 247 | model.set_deco_mode("target") 248 | t_img, _ = data_target_iter.next() 249 | entropy_target, target_domain_loss, observation_loss, _ = compute_batch_loss(cuda, lambda_val, model, t_img, None, num_source_domains, 250 | num_source_domains) 251 | loss = entropy_weight * entropy_target * lambda_val + dann_weight * target_domain_loss + observation_loss 252 | loss.backward() 253 | err_t_domain = target_domain_loss.data.cpu().numpy() 254 | observed_domain_losses.append(observation_loss.data.cpu().numpy()) 255 | domain_error = (err_s_domain + err_t_domain) / 2 256 | 257 | # err = dann_weight * err_t_domain + dann_weight * err_s_domain + err_s_label + entropy_weight * entropy_target * lambda_val 258 | optimizer.step() 259 | optimizer.zero_grad() 260 | 261 | # logging stuff 262 | if (batch_idx % (len_dataloader / 2 + 1)) == 0: 263 | logger.scalar_summary("loss/source", err_s_label, absolute_iter_count) 264 | logger.scalar_summary("loss/domain", domain_error, absolute_iter_count) 265 | logger.scalar_summary("loss/observer_domain", sum(observed_domain_losses) / len(observed_domain_losses), absolute_iter_count) 266 | for k, val in enumerate(source_domain_losses): 267 | logger.scalar_summary("loss/domain_s%d" % k, val, absolute_iter_count) 268 | for k, val in enumerate(past_source_target_similarity): 269 | logger.scalar_summary("similarity/prob/%d" % k, val, absolute_iter_count) 270 | logger.scalar_summary("loss/entropy_target", entropy_target, absolute_iter_count) 271 | print('epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f' \ 272 | % (epoch, batch_idx, len_dataloader, err_s_label, err_s_domain, err_t_domain)) 273 | batch_idx += 1 274 | 275 | # at then end of one training epoch, save statistics and images 276 | if isinstance(model, Combo): 277 | with torch.no_grad(): 278 | # get target images 279 | target_images = random_items(iter(dataloader_target))[0][0] 280 | target_images = target_images[:9].cuda() 281 | model.set_deco_mode("target") 282 | logger.image_summary("original/target", to_grid(target_images), epoch) 283 | target_images = model.deco(target_images) 284 | logger.image_summary("images/target", to_grid(target_images), epoch) 285 | sources = random_items(iter(dataloader_source))[0] 286 | model.set_deco_mode("source") 287 | for n, (s_img, _) in enumerate(sources): 288 | source_images = s_img[:9].cuda() 289 | logger.image_summary("original/source_%d" % n, to_grid(source_images), epoch) 290 | source_images = model.deco(source_images) 291 | logger.image_summary("images/source_%d" % n, to_grid(source_images), epoch) 292 | for name, deco in model.get_decos(): 293 | logger.scalar_summary("aux/%s/deco_to_image_ratio" % name, deco.ratio.item(), epoch) 294 | logger.scalar_summary("aux/%s/deco_weight" % name, deco.deco_weight.item(), epoch) 295 | logger.scalar_summary("aux/%s/image_weight" % name, deco.image_weight.item(), epoch) 296 | 297 | logger.scalar_summary("aux/p", p, epoch) 298 | logger.scalar_summary("aux/lambda", lambda_val, epoch) 299 | 300 | 301 | # from https://code.activestate.com/recipes/426332-picking-random-items-from-an-iterator/ 302 | def random_items(iterable, k=1): 303 | result = [None] * k 304 | for i, item in enumerate(iterable): 305 | if i < k: 306 | result[i] = item 307 | else: 308 | j = int(random.random() * (i + 1)) 309 | if j < k: 310 | result[j] = item 311 | random.shuffle(result) 312 | return result 313 | 314 | 315 | def compute_batch_loss(cuda, lambda_val, model, img, label, _domain_label, target_label): 316 | eps = 1e-4 317 | domain_label = torch.ones(img.shape[0]).long() * _domain_label 318 | if cuda: 319 | img = img.cuda() 320 | if label is not None: label = label.cuda() 321 | domain_label = domain_label.cuda() 322 | class_output, domain_output, observer_output = model(input_data=img, lambda_val=lambda_val, domain=_domain_label) 323 | # compute losses 324 | if label is not None: 325 | class_loss = F.cross_entropy(class_output, label) 326 | else: 327 | class_loss = entropy_loss(class_output) 328 | 329 | domain_loss = F.cross_entropy(domain_output, domain_label) 330 | observer_loss = F.cross_entropy(observer_output, domain_label) 331 | if target_label < observer_output.shape[1]: 332 | target_similarity = (F.softmax(observer_output, 1)[:, target_label].mean()).data.cpu().numpy() 333 | else: # generalization 334 | target_similarity = 0.0 335 | return class_loss, domain_loss, observer_loss, target_similarity 336 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import itertools 4 | import torch 5 | import torch.nn as nn 6 | from torch import nn as nn 7 | from torch.autograd import Function 8 | from torch.nn import Parameter 9 | from torchvision.models.resnet import BasicBlock, Bottleneck, conv3x3 10 | import torch.nn.functional as func 11 | 12 | from models.torch_future import Flatten 13 | 14 | image_weight = 1.0 15 | deco_starting_weight = 0.0001 16 | 17 | 18 | class DecoArgs: 19 | def __init__(self, args): 20 | self.n_layers = args.deco_blocks 21 | self.train_deco_weight = args.train_deco_weight 22 | self.train_image_weight = args.train_image_weight 23 | self.deco_kernels = args.deco_kernels 24 | self.block = deco_types[args.deco_block_type] 25 | self.output_channels = args.deco_output_channels 26 | self.deco_weight = deco_starting_weight 27 | self.no_residual = args.deco_no_residual 28 | self.mode = args.deco_mode 29 | self.use_tanh = args.deco_tanh 30 | self.no_pool = args.deco_no_pool 31 | self.deconv = args.deco_deconv 32 | 33 | 34 | def get_classifier(name, domain_classes, n_classes, generalization): 35 | if name: 36 | return classifier_list[name](domain_classes, n_classes, generalization) 37 | return CNNModel(domain_classes, n_classes) 38 | 39 | 40 | def get_net(args): 41 | domain_classes = args.domain_classes 42 | if args.use_deco: 43 | deco_args = DecoArgs(args) 44 | my_net = deco_modes[deco_args.mode](deco_args, classifier=args.classifier, domain_classes=domain_classes, 45 | n_classes=args.n_classes, generalization=args.generalization) 46 | else: 47 | my_net = get_classifier(args.classifier, domain_classes=domain_classes, n_classes=args.n_classes, generalization=args.generalization) 48 | 49 | for p in my_net.parameters(): 50 | p.requires_grad = True 51 | print(my_net) 52 | return my_net 53 | 54 | 55 | def entropy_loss(x): 56 | return torch.sum(-func.softmax(x, 1) * func.log_softmax(x, 1), 1).mean() 57 | 58 | 59 | deco_types = {'basic': BasicBlock, 'bottleneck': Bottleneck} 60 | 61 | 62 | # Utility class for the combo network 63 | class PassData(nn.Module): 64 | def forward(self, input_data): 65 | return input_data 66 | 67 | 68 | class GradientKillerLayer(Function): 69 | @staticmethod 70 | def forward(ctx, x, **kwargs): 71 | return x.view_as(x) 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | return None, None 76 | 77 | 78 | class ReverseLayerF(Function): 79 | @staticmethod 80 | def forward(ctx, x, lambda_val): 81 | ctx.lambda_val = lambda_val 82 | 83 | return x.view_as(x) 84 | 85 | @staticmethod 86 | def backward(ctx, grad_output): 87 | output = grad_output.neg() * ctx.lambda_val 88 | 89 | return output, None 90 | 91 | 92 | class Combo(nn.Module): 93 | def __init__(self, deco_args, classifier, domain_classes=2, n_classes=10, generalization=False): 94 | super(Combo, self).__init__() 95 | self.net = get_classifier(classifier, domain_classes, n_classes, generalization) 96 | from models.large_models import SmallAlexNet, BigDecoDANN, DECO 97 | if isinstance(self.net, SmallAlexNet): 98 | self.deco_architecture = DECO_mini 99 | elif isinstance(self.net, BigDecoDANN): 100 | self.deco_architecture = DECO 101 | else: 102 | self.deco_architecture = IncrementalDECO 103 | 104 | def set_deco_mode(self, mode): 105 | self.deco = self.domain_transforms[mode] 106 | 107 | def forward(self, input_data, lambda_val, domain): 108 | input_data = self.deco(input_data) 109 | return self.net(input_data, lambda_val, domain) 110 | 111 | def get_trainable_params(self): 112 | return itertools.chain(self.get_deco_parameters(), self.net.get_trainable_params()) 113 | 114 | 115 | class NoDecoCombo(Combo): 116 | def __init__(self, deco_args, classifier, domain_classes=2, n_classes=10): 117 | super(NoDecoCombo, self).__init__(deco_args, classifier, domain_classes, n_classes) 118 | self.deco = PassData 119 | 120 | def set_deco_mode(self, mode): 121 | pass 122 | 123 | def get_trainable_params(self): 124 | return self.net.get_trainable_params() 125 | 126 | 127 | class SourceOnlyCombo(Combo): 128 | def __init__(self, deco_args, classifier, domain_classes=2, n_classes=10): 129 | super(SourceOnlyCombo, self).__init__(deco_args, classifier, domain_classes, n_classes) 130 | self.source = self.deco_architecture(deco_args) 131 | self.target = PassData() 132 | self.domain_transforms = {"source": self.source, 133 | "target": self.target} 134 | self.deco = self.domain_transforms["source"] 135 | 136 | def get_decos(self, mode=None): 137 | return [("source", self.source)] 138 | 139 | def get_deco_parameters(self): 140 | return self.source.parameters() 141 | 142 | 143 | class SharedCombo(Combo): 144 | def __init__(self, deco_args, classifier, domain_classes=2, n_classes=10, generalization=False): 145 | super(SharedCombo, self).__init__(deco_args, classifier, domain_classes, n_classes, generalization) 146 | self._deconet = self.deco_architecture(deco_args) 147 | self.domain_transforms = {"source": self._deconet, 148 | "target": self._deconet} 149 | self.deco = self.domain_transforms["source"] 150 | 151 | def get_decos(self, mode=None): 152 | return [("shared", self.deco)] 153 | 154 | def get_deco_parameters(self): 155 | return self.deco.parameters() 156 | 157 | 158 | class TargetOnlyCombo(Combo): 159 | def __init__(self, deco_args, classifier, domain_classes=2, n_classes=10): 160 | super(TargetOnlyCombo, self).__init__(deco_args, classifier, domain_classes, n_classes) 161 | self.source = PassData() 162 | self.target = self.deco_architecture(deco_args) 163 | self.domain_transforms = {"source": self.source, 164 | "target": self.target} 165 | self.deco = self.domain_transforms["source"] 166 | 167 | def get_decos(self, mode=None): 168 | return [("target", self.target)] 169 | 170 | def get_deco_parameters(self): 171 | return self.target.parameters() 172 | 173 | 174 | class BothCombo(Combo): 175 | def __init__(self, deco_args, classifier, domain_classes=2, n_classes=10): 176 | super(BothCombo, self).__init__(deco_args, classifier, domain_classes, n_classes) 177 | self.source = self.deco_architecture(deco_args) 178 | self.target = self.deco_architecture(deco_args) 179 | self.domain_transforms = {"source": self.source, 180 | "target": self.target} 181 | self.deco = self.domain_transforms["source"] 182 | 183 | def get_deco_parameters(self): 184 | return itertools.chain(self.source.parameters(), self.target.parameters()) 185 | 186 | def get_decos(self, mode=None): 187 | if mode: 188 | return self.domain_transforms[mode] 189 | return ("source", self.source), ("target", self.target) 190 | 191 | 192 | class IncrementalBlock(nn.Module): 193 | def __init__(self, inplanes, outplanes, stride=1, downsample=None): 194 | super(IncrementalBlock, self).__init__() 195 | interm_planes = outplanes - inplanes 196 | self.conv1 = conv3x3(inplanes, interm_planes, stride) 197 | self.bn1 = nn.BatchNorm2d(interm_planes) 198 | self.relu1 = nn.ReLU(inplace=True) 199 | self.relu2 = nn.ReLU(inplace=True) 200 | self.conv2 = conv3x3(interm_planes,interm_planes) 201 | self.bn2 = nn.BatchNorm2d(interm_planes) 202 | self.stride = stride 203 | 204 | def forward(self, x): 205 | residual = x 206 | out = self.conv1(x) 207 | out = self.bn1(out) 208 | out = self.relu1(out) 209 | out = self.conv2(out) 210 | out = self.bn2(out) 211 | out = self.relu2(out) 212 | return torch.cat((residual,out),1) # final channel dimension = outplanes 213 | 214 | class IncrementalDECO(nn.Module): 215 | def __init__(self, deco_args): 216 | super(IncrementalDECO, self).__init__() 217 | self.ratio = 1.0 218 | self.deco_weight = 1.0 219 | self.image_weight = 1.0 220 | self.final_kernels = deco_args.deco_kernels 221 | self.deco_args = deco_args 222 | self.block1 = IncrementalBlock(3,8) 223 | self.block2 = IncrementalBlock(8,16) 224 | self.block3 = IncrementalBlock(16,24) 225 | self.block4 = IncrementalBlock(24,32) 226 | self.block5 = IncrementalBlock(32,48) 227 | self.block6 = IncrementalBlock(48,64) 228 | self.block7 = IncrementalBlock(64,128) 229 | # self.block8 = conv3x3(128,64) 230 | # self.bn8 = nn.BatchNorm2d(64) 231 | # self.block9 = conv3x3(64,32) 232 | # self.bn9 = nn.BatchNorm2d(32) 233 | # self.block10 = conv3x3(32,16) 234 | # self.bn10 = nn.BatchNorm2d(16) 235 | # self.block11 = conv3x3(16,8) 236 | # self.bn11 = nn.BatchNorm2d(8) 237 | self.conv_out = nn.Conv2d(128, 3, 1) 238 | 239 | def init_weights(self): 240 | for m in self.modules(): 241 | if isinstance(m, nn.Conv2d): 242 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 243 | m.weight.data.normal_(0, math.sqrt(2. / n)) 244 | elif isinstance(m, nn.BatchNorm2d): 245 | m.weight.data.fill_(1) 246 | m.bias.data.zero_() 247 | 248 | def forward(self, input_data): 249 | x = self.block1(input_data) 250 | x = self.block2(x) 251 | x = self.block3(x) 252 | x = self.block4(x) 253 | x = self.block5(x) 254 | x = self.block6(x) 255 | x = self.block7(x) 256 | # x = self.block8(x) 257 | # x = self.bn8(x) 258 | # x = self.block9(F.relu(x, inplace=True)) 259 | # x = self.bn9(x) 260 | # x = self.block10(F.relu(x, inplace=True)) 261 | # x = self.bn10(x) 262 | # x = self.block11(F.relu(x, inplace=True)) 263 | # x = self.bn11(x) 264 | x = self.conv_out(x) 265 | self.ratio = input_data.norm() / x.norm() #USELESS 266 | self.deco_weight = self.ratio #USELESS 267 | self.image_weight = self.ratio #USELESS 268 | return x 269 | 270 | 271 | deco_modes = {"shared": SharedCombo, 272 | "separated": BothCombo, 273 | "source": SourceOnlyCombo, 274 | "target": TargetOnlyCombo} 275 | 276 | 277 | class BasicDECO(nn.Module): 278 | def __init__(self, deco_args): 279 | super(BasicDECO, self).__init__() 280 | self.inplanes = deco_args.deco_kernels 281 | self.ratio = 1.0 282 | self.deco_args = deco_args 283 | if self.deco_args.no_residual: 284 | deco_args.train_deco_weight = False 285 | deco_args.train_image_weight = False 286 | if self.deco_args.train_deco_weight: 287 | self.deco_weight = Parameter(torch.FloatTensor(1), requires_grad=True) 288 | else: 289 | self.deco_weight = torch.FloatTensor(1).cuda() 290 | if self.deco_args.train_image_weight: 291 | self.image_weight = Parameter(torch.FloatTensor(1), requires_grad=True) 292 | else: 293 | self.image_weight = torch.FloatTensor(1).cuda() 294 | self.deco_weight.data.fill_(deco_args.deco_weight) 295 | self.image_weight.data.fill_(image_weight) 296 | self.use_tanh = deco_args.use_tanh 297 | 298 | def init_weights(self): 299 | for m in self.modules(): 300 | if isinstance(m, nn.Conv2d): 301 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 302 | m.weight.data.normal_(0, math.sqrt(2. / n)) 303 | elif isinstance(m, nn.BatchNorm2d): 304 | m.weight.data.fill_(1) 305 | m.bias.data.zero_() 306 | 307 | def weighted_sum(self, input_data, x): 308 | if self.deco_args.no_residual: 309 | self.ratio = input_data.norm() / x.norm() 310 | if self.use_tanh: 311 | x = torch.tanh(x) 312 | return x 313 | x = self.deco_weight * x 314 | input_data = self.image_weight * input_data 315 | self.ratio = input_data.norm() / x.norm() 316 | if self.use_tanh: 317 | return torch.tanh(x + input_data) 318 | else: 319 | return x + input_data 320 | 321 | def _make_layer(self, block, planes, blocks, stride=1): 322 | downsample = None 323 | if stride != 1 or self.inplanes != planes * block.expansion: 324 | downsample = nn.Sequential( 325 | nn.Conv2d(self.inplanes, planes * block.expansion, 326 | kernel_size=1, stride=stride, bias=False), 327 | nn.BatchNorm2d(planes * block.expansion), 328 | ) 329 | 330 | layers = [block(self.inplanes, planes, stride, downsample)] 331 | self.inplanes = planes * block.expansion 332 | for i in range(1, blocks): 333 | layers.append(block(self.inplanes, planes)) 334 | 335 | return nn.Sequential(*layers) 336 | 337 | 338 | class DECO_mini(BasicDECO): 339 | def __init__(self, deco_args): 340 | super(DECO_mini, self).__init__(deco_args) 341 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=1, padding=2, 342 | bias=False) 343 | self.bn1 = nn.BatchNorm2d(self.inplanes) 344 | self.relu = nn.ReLU(inplace=True) 345 | self.layer1 = self._make_layer(deco_args.block, self.inplanes, deco_args.n_layers) 346 | self.conv_out = nn.Conv2d(self.inplanes, deco_args.output_channels, 1) 347 | self.init_weights() 348 | 349 | def forward(self, input_data): 350 | x = self.conv1(input_data) 351 | x = self.bn1(x) 352 | x = self.relu(x) 353 | 354 | x = self.layer1(x) 355 | x = self.conv_out(x) 356 | 357 | return self.weighted_sum(input_data, x) 358 | 359 | 360 | class Tiny_DECO(BasicDECO): 361 | def __init__(self, deco_args): 362 | super(Tiny_DECO, self).__init__(deco_args) 363 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=4, padding=2, bias=False) 364 | self.bn1 = nn.BatchNorm2d(self.inplanes) 365 | self.relu = nn.ReLU(inplace=True) 366 | self.layer1 = self._make_layer(deco_args.block, self.inplanes, deco_args.n_layers) 367 | self.conv_out = nn.Conv2d(self.inplanes * deco_args.block.expansion, deco_args.output_channels, 1) 368 | self.init_weights() 369 | 370 | def forward(self, input_data): 371 | x = self.conv1(input_data) 372 | x = self.bn1(x) 373 | x = self.relu(x) 374 | 375 | x = self.layer1(x) 376 | x = self.conv_out(x) 377 | 378 | return self.weighted_sum(input_data, x) 379 | 380 | 381 | class BasicDANN(nn.Module): 382 | def __init__(self): 383 | super(BasicDANN, self).__init__() 384 | self.features = None 385 | self.domain_classifier = None 386 | self.class_classifier = None 387 | self.observer = PassData() 388 | 389 | def forward(self, input_data, lambda_val, domain=None): 390 | feature = self.features(input_data) 391 | # print(feature.shape) 392 | feature = feature.view(input_data.shape[0], -1) 393 | reverse_feature = ReverseLayerF.apply(feature, lambda_val) 394 | class_output = self.class_classifier(feature) 395 | domain_output = self.domain_classifier(reverse_feature) 396 | observation = self.observer(GradientKillerLayer.apply(feature)) 397 | return class_output, domain_output, observation 398 | 399 | def get_trainable_params(self): 400 | return self.parameters() 401 | 402 | # TODO: after refactoring, remove this 403 | def set_deco_mode(self, mode): 404 | pass 405 | 406 | 407 | class MnistModel(BasicDANN): 408 | def __init__(self, domain_classes, n_classes, generalization=True): 409 | super(MnistModel, self).__init__() 410 | print("Using LeNet") 411 | self.features = nn.Sequential( 412 | nn.Conv2d(3, 32, 5), 413 | nn.ReLU(True), 414 | nn.MaxPool2d(2, 2), 415 | nn.Conv2d(32, 48, 5), 416 | nn.ReLU(True), 417 | nn.MaxPool2d(2, 2) 418 | ) 419 | self.domain_classifier = nn.Sequential( 420 | nn.Linear(48 * 4 * 4, 100), 421 | nn.ReLU(True), 422 | nn.Linear(100, domain_classes) 423 | ) 424 | self.class_classifier = nn.Sequential( 425 | nn.Linear(48 * 4 * 4, 100), 426 | nn.ReLU(True), 427 | nn.Linear(100, 100), 428 | nn.ReLU(True), 429 | nn.Linear(100, n_classes) 430 | ) 431 | 432 | 433 | class SVHNModel(BasicDANN): 434 | def __init__(self, domain_classes, n_classes, gen): 435 | super(SVHNModel, self).__init__() 436 | print("Using SVHN") 437 | self.features = nn.Sequential( 438 | nn.Conv2d(3, 64, 5, padding=2), 439 | nn.Dropout(0.1, True), 440 | nn.ReLU(True), 441 | nn.MaxPool2d(3, 2, padding=1), 442 | nn.Conv2d(64, 64, 5, padding=2), 443 | nn.Dropout(0.25, True), 444 | nn.ReLU(True), 445 | nn.MaxPool2d(3, 2, padding=1), 446 | nn.Conv2d(64, 128, 5, padding=2), 447 | nn.Dropout(0.25, True), 448 | nn.ReLU(True) 449 | ) 450 | self.domain_classifier = nn.Sequential( 451 | nn.Linear(128 * 8 * 8, 1024), 452 | nn.Dropout(0.5, True), 453 | nn.ReLU(True), 454 | nn.Linear(1024, 1024), 455 | nn.Dropout(0.5, True), 456 | nn.ReLU(True), 457 | nn.Linear(1024, domain_classes) 458 | ) 459 | self.class_classifier = nn.Sequential( 460 | nn.Linear(128 * 8 * 8, 3072), 461 | nn.Dropout(0.5, True), 462 | nn.ReLU(True), 463 | nn.Linear(3072, 2048), 464 | nn.Dropout(0.5, True), 465 | nn.ReLU(True), 466 | nn.Linear(2048, n_classes) 467 | ) 468 | 469 | 470 | class MultisourceModelWeighted(BasicDANN): 471 | def __init__(self, domain_classes, n_classes, generalization): 472 | super(MultisourceModelWeighted, self).__init__() 473 | self.domains = domain_classes 474 | if generalization: 475 | self.domains -= 1 476 | self.generalization = generalization 477 | self.n_classes = n_classes 478 | self.features = nn.Sequential( 479 | nn.Conv2d(3, 64, 3, padding=1), 480 | nn.ReLU(True), 481 | nn.MaxPool2d(2, 2), 482 | nn.Conv2d(64, 128, 3, padding=2), 483 | nn.ReLU(True), 484 | nn.MaxPool2d(2, 2), 485 | nn.Conv2d(128, 256, 3, padding=1), 486 | nn.ReLU(True), 487 | nn.MaxPool2d(2, 2) 488 | ) 489 | self.class_classifier = nn.Sequential( 490 | nn.Conv2d(256, 256, 3, padding=1), 491 | nn.ReLU(True), 492 | Flatten(), 493 | nn.Linear(256 * 4 * 4, 2048), 494 | nn.ReLU(True), 495 | nn.Linear(2048, 1024), 496 | nn.ReLU(True) 497 | ) 498 | self.per_domain_classifier = nn.ModuleList([nn.Linear(1024, n_classes) for k in range(domain_classes - 1)]) 499 | self.domain_classifier = nn.Sequential( 500 | Flatten(), 501 | nn.Linear(256 * 4 * 4, 2048), 502 | nn.ReLU(True), 503 | nn.Linear(2048, 2048), 504 | nn.ReLU(True), 505 | nn.Linear(2048, self.domains) 506 | ) 507 | self.observer = nn.Sequential( 508 | Flatten(), 509 | nn.Linear(256 * 4 * 4, 1024), 510 | nn.ReLU(True), 511 | nn.Linear(1024, 1024), 512 | nn.ReLU(True), 513 | nn.Linear(1024, self.domains) 514 | ) 515 | 516 | def forward(self, input_data, lambda_val, domain): 517 | feature = self.features(input_data) 518 | reverse_feature = ReverseLayerF.apply(feature, lambda_val) 519 | class_features = self.class_classifier(feature) 520 | domain_output = self.domain_classifier(reverse_feature) 521 | observation = self.observer(GradientKillerLayer.apply(feature)) 522 | if domain < len(self.per_domain_classifier): # one of the source domains 523 | class_output = self.per_domain_classifier[domain](class_features) 524 | else: # if target domain 525 | class_output = torch.zeros(input_data.shape[0], self.n_classes).cuda() 526 | 527 | if self.generalization: 528 | softmax_obs = nn.functional.softmax(observation, 1) 529 | else: 530 | softmax_obs = nn.functional.softmax(observation[:, :-1], 1) 531 | for k, predictor in enumerate(self.per_domain_classifier): 532 | class_output = class_output + nn.functional.softmax(predictor(class_features), 1) * softmax_obs[:, k].mean().detach() 533 | return class_output, domain_output, observation 534 | 535 | 536 | class MultisourceModel(BasicDANN): 537 | def __init__(self, domain_classes, n_classes, generalization): 538 | super(MultisourceModel, self).__init__() 539 | self.domains = domain_classes 540 | if generalization: 541 | self.domains -= 1 542 | self.features = nn.Sequential( 543 | nn.Conv2d(3, 64, 3, padding=1), 544 | nn.ReLU(True), 545 | nn.MaxPool2d(2, 2), 546 | nn.Conv2d(64, 128, 3, padding=2), 547 | nn.ReLU(True), 548 | nn.MaxPool2d(2, 2), 549 | nn.Conv2d(128, 256, 3, padding=1), 550 | nn.ReLU(True), 551 | nn.MaxPool2d(2, 2) 552 | ) 553 | self.class_classifier = nn.Sequential( 554 | nn.Conv2d(256, 256, 3, padding=1), 555 | nn.ReLU(True), 556 | Flatten(), 557 | nn.Linear(256 * 4 * 4, 2048), 558 | nn.ReLU(True), 559 | nn.Linear(2048, 1024), 560 | nn.ReLU(True), 561 | nn.Linear(1024, n_classes) 562 | ) 563 | self.domain_classifier = nn.Sequential( 564 | Flatten(), 565 | nn.Linear(256 * 4 * 4, 2048), 566 | nn.ReLU(True), 567 | nn.Linear(2048, 2048), 568 | nn.ReLU(True), 569 | nn.Linear(2048, self.domains) 570 | ) 571 | self.observer = nn.Sequential( 572 | nn.Conv2d(3, 64, 3, padding=1), 573 | nn.ReLU(True), 574 | nn.MaxPool2d(2, 2), 575 | nn.Conv2d(64, 64, 3, padding=2), 576 | nn.ReLU(True), 577 | nn.MaxPool2d(2, 2), 578 | nn.Conv2d(64, 128, 3, padding=1), 579 | nn.ReLU(True), 580 | nn.MaxPool2d(2, 2), 581 | Flatten(), 582 | nn.Linear(128 * 4 * 4, 512), 583 | nn.Dropout(), 584 | nn.ReLU(True), 585 | #nn.Linear(512, 512), 586 | #nn.Dropout(), 587 | #nn.ReLU(True), 588 | nn.Linear(512, self.domains) 589 | ) 590 | 591 | def forward(self, input_data, lambda_val, domain=None): 592 | feature = self.features(input_data) 593 | reverse_feature = ReverseLayerF.apply(feature, lambda_val) 594 | class_output = self.class_classifier(feature) 595 | domain_output = self.domain_classifier(reverse_feature) 596 | #observation = self.observer(GradientKillerLayer.apply(input_data)) 597 | observation = self.observer(ReverseLayerF.apply(input_data, lambda_val / 10.0)) 598 | return class_output, domain_output, observation 599 | 600 | 601 | class CNNModel(BasicDANN): 602 | def __init__(self, domain_classes, n_classes): 603 | super(CNNModel, self).__init__() 604 | self.feature = nn.Sequential() 605 | self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5)) 606 | self.feature.add_module('f_bn1', nn.BatchNorm2d(64)) 607 | self.feature.add_module('f_pool1', nn.MaxPool2d(2)) 608 | self.feature.add_module('f_relu1', nn.ReLU(True)) 609 | self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5)) 610 | self.feature.add_module('f_bn2', nn.BatchNorm2d(50)) 611 | self.feature.add_module('f_drop1', nn.Dropout2d()) 612 | self.feature.add_module('f_pool2', nn.MaxPool2d(2)) 613 | self.feature.add_module('f_relu2', nn.ReLU(True)) 614 | 615 | self.class_classifier = nn.Sequential() 616 | self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100)) 617 | self.class_classifier.add_module('c_bn1', nn.BatchNorm2d(100)) 618 | self.class_classifier.add_module('c_relu1', nn.ReLU(True)) 619 | self.class_classifier.add_module('c_drop1', nn.Dropout2d()) 620 | self.class_classifier.add_module('c_fc2', nn.Linear(100, 100)) 621 | self.class_classifier.add_module('c_bn2', nn.BatchNorm2d(100)) 622 | self.class_classifier.add_module('c_relu2', nn.ReLU(True)) 623 | self.class_classifier.add_module('c_fc3', nn.Linear(100, n_classes)) 624 | # self.class_classifier.add_module('c_softmax', nn.LogSoftmax(1)) 625 | 626 | self.domain_classifier = nn.Sequential() 627 | self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100)) 628 | self.domain_classifier.add_module('d_bn1', nn.BatchNorm2d(100)) 629 | self.domain_classifier.add_module('d_relu1', nn.ReLU(True)) 630 | self.domain_classifier.add_module('d_fc2', nn.Linear(100, domain_classes)) 631 | 632 | def forward(self, input_data, lambda_val): 633 | feature = self.feature(input_data) 634 | feature = feature.view(-1, 50 * 4 * 4) 635 | reverse_feature = ReverseLayerF.apply(feature, lambda_val) 636 | class_output = self.class_classifier(feature) 637 | domain_output = self.domain_classifier(reverse_feature) 638 | 639 | return class_output, domain_output 640 | 641 | 642 | from models.large_models import DECO, BigDecoDANN, ResNet50, AlexNet, SmallAlexNet, CaffeNet, AlexNetNoBottleneck 643 | 644 | classifier_list = {"roided_lenet": CNNModel, 645 | "mnist": MnistModel, 646 | "svhn": SVHNModel, 647 | "multi": MultisourceModel, 648 | "multi_weighted": MultisourceModelWeighted, 649 | "alexnet": AlexNet, 650 | "alexnet_no_bottleneck": AlexNetNoBottleneck, 651 | "caffenet": CaffeNet, 652 | "small_alexnet": SmallAlexNet, 653 | "resnet50": ResNet50} 654 | --------------------------------------------------------------------------------