├── README.md ├── data_loader.py ├── dataset ├── mnist │ └── __init__.py └── mnist_m │ └── __init__.py ├── extra ├── mnist.jpg ├── mnist_m.jpg ├── mnist_m_rec_image_all.png ├── mnist_m_rec_image_private.png ├── mnist_m_rec_image_share.png ├── mnist_rec_image_all.png ├── mnist_rec_image_private.png ├── mnist_rec_image_share.png └── model.jpg ├── functions.py ├── model └── __init__.py ├── model_compat.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## This is a pytorch implementation of the model [Domain Separation Networks](https://arxiv.org/abs/1608.06019) 2 | 3 | ## Environment 4 | - Pytorch 0.4.0 5 | - Python 2.7 6 | 7 | ## Network Structure 8 | 9 | ![model](./extra/model.jpg) 10 | 11 | ## Usage 12 | 13 | `python train.py` 14 | 15 | **Note that this model is very sensitive to the loss weight, our implementation cannot perform as perfect as the 16 | original paper, so be careful when you tune parameters for other datasets. Moreover, this model may not be suitable 17 | for real nature image, cause the private and shared feature of nature image are more complicated, so that *difference 18 | loss* cannot adapt well** 19 | 20 | ## Result 21 | 22 | We only conduct the experiments from mnist to mnist_m, the target accuracy of our implementation is about 77% (original 23 | paper ~83%), and some results are shown as follows, from left to right: recovery from shared+private, shared and private 24 | features. 25 | 26 | ![mnist](./extra/mnist.jpg) 27 | 28 | ![mnist_m](./extra/mnist_m.jpg) 29 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | 5 | 6 | class GetLoader(data.Dataset): 7 | def __init__(self, data_root, data_list, transform=None): 8 | self.root = data_root 9 | self.transform = transform 10 | 11 | f = open(data_list, 'r') 12 | data_list = f.readlines() 13 | f.close() 14 | 15 | self.n_data = len(data_list) 16 | 17 | self.img_paths = [] 18 | self.img_labels = [] 19 | 20 | for data in data_list: 21 | self.img_paths.append(data[:-3]) 22 | self.img_labels.append(data[-2]) 23 | 24 | def __getitem__(self, item): 25 | img_paths, labels = self.img_paths[item], self.img_labels[item] 26 | imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') 27 | 28 | if self.transform is not None: 29 | imgs = self.transform(imgs) 30 | labels = int(labels) 31 | 32 | return imgs, labels 33 | 34 | def __len__(self): 35 | return self.n_data 36 | -------------------------------------------------------------------------------- /dataset/mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/dataset/mnist/__init__.py -------------------------------------------------------------------------------- /dataset/mnist_m/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/dataset/mnist_m/__init__.py -------------------------------------------------------------------------------- /extra/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist.jpg -------------------------------------------------------------------------------- /extra/mnist_m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_m.jpg -------------------------------------------------------------------------------- /extra/mnist_m_rec_image_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_m_rec_image_all.png -------------------------------------------------------------------------------- /extra/mnist_m_rec_image_private.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_m_rec_image_private.png -------------------------------------------------------------------------------- /extra/mnist_m_rec_image_share.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_m_rec_image_share.png -------------------------------------------------------------------------------- /extra/mnist_rec_image_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_rec_image_all.png -------------------------------------------------------------------------------- /extra/mnist_rec_image_private.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_rec_image_private.png -------------------------------------------------------------------------------- /extra/mnist_rec_image_share.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/mnist_rec_image_share.png -------------------------------------------------------------------------------- /extra/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/extra/model.jpg -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class ReverseLayerF(Function): 7 | 8 | @staticmethod 9 | def forward(ctx, x, p): 10 | ctx.p = p 11 | 12 | return x.view_as(x) 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | output = grad_output.neg() * ctx.p 17 | 18 | return output, None 19 | 20 | 21 | class MSE(nn.Module): 22 | def __init__(self): 23 | super(MSE, self).__init__() 24 | 25 | def forward(self, pred, real): 26 | diffs = torch.add(real, -pred) 27 | n = torch.numel(diffs.data) 28 | mse = torch.sum(diffs.pow(2)) / n 29 | 30 | return mse 31 | 32 | 33 | class SIMSE(nn.Module): 34 | 35 | def __init__(self): 36 | super(SIMSE, self).__init__() 37 | 38 | def forward(self, pred, real): 39 | diffs = torch.add(real, - pred) 40 | n = torch.numel(diffs.data) 41 | simse = torch.sum(diffs).pow(2) / (n ** 2) 42 | 43 | return simse 44 | 45 | 46 | class DiffLoss(nn.Module): 47 | 48 | def __init__(self): 49 | super(DiffLoss, self).__init__() 50 | 51 | def forward(self, input1, input2): 52 | 53 | batch_size = input1.size(0) 54 | input1 = input1.view(batch_size, -1) 55 | input2 = input2.view(batch_size, -1) 56 | 57 | input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach() 58 | input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6) 59 | 60 | input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach() 61 | input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6) 62 | 63 | diff_loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2)) 64 | 65 | return diff_loss 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DSN/63dcbc0440c7fb5f25633b6d1d35e5f9cb8772fd/model/__init__.py -------------------------------------------------------------------------------- /model_compat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from functions import ReverseLayerF 3 | 4 | 5 | class DSN(nn.Module): 6 | def __init__(self, code_size=100, n_class=10): 7 | super(DSN, self).__init__() 8 | self.code_size = code_size 9 | 10 | ########################################## 11 | # private source encoder 12 | ########################################## 13 | 14 | self.source_encoder_conv = nn.Sequential() 15 | self.source_encoder_conv.add_module('conv_pse1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, 16 | padding=2)) 17 | self.source_encoder_conv.add_module('ac_pse1', nn.ReLU(True)) 18 | self.source_encoder_conv.add_module('pool_pse1', nn.MaxPool2d(kernel_size=2, stride=2)) 19 | 20 | self.source_encoder_conv.add_module('conv_pse2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, 21 | padding=2)) 22 | self.source_encoder_conv.add_module('ac_pse2', nn.ReLU(True)) 23 | self.source_encoder_conv.add_module('pool_pse2', nn.MaxPool2d(kernel_size=2, stride=2)) 24 | 25 | self.source_encoder_fc = nn.Sequential() 26 | self.source_encoder_fc.add_module('fc_pse3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size)) 27 | self.source_encoder_fc.add_module('ac_pse3', nn.ReLU(True)) 28 | 29 | ######################################### 30 | # private target encoder 31 | ######################################### 32 | 33 | self.target_encoder_conv = nn.Sequential() 34 | self.target_encoder_conv.add_module('conv_pte1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, 35 | padding=2)) 36 | self.target_encoder_conv.add_module('ac_pte1', nn.ReLU(True)) 37 | self.target_encoder_conv.add_module('pool_pte1', nn.MaxPool2d(kernel_size=2, stride=2)) 38 | 39 | self.target_encoder_conv.add_module('conv_pte2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, 40 | padding=2)) 41 | self.target_encoder_conv.add_module('ac_pte2', nn.ReLU(True)) 42 | self.target_encoder_conv.add_module('pool_pte2', nn.MaxPool2d(kernel_size=2, stride=2)) 43 | 44 | self.target_encoder_fc = nn.Sequential() 45 | self.target_encoder_fc.add_module('fc_pte3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size)) 46 | self.target_encoder_fc.add_module('ac_pte3', nn.ReLU(True)) 47 | 48 | ################################ 49 | # shared encoder (dann_mnist) 50 | ################################ 51 | 52 | self.shared_encoder_conv = nn.Sequential() 53 | self.shared_encoder_conv.add_module('conv_se1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, 54 | padding=2)) 55 | self.shared_encoder_conv.add_module('ac_se1', nn.ReLU(True)) 56 | self.shared_encoder_conv.add_module('pool_se1', nn.MaxPool2d(kernel_size=2, stride=2)) 57 | 58 | self.shared_encoder_conv.add_module('conv_se2', nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5, 59 | padding=2)) 60 | self.shared_encoder_conv.add_module('ac_se2', nn.ReLU(True)) 61 | self.shared_encoder_conv.add_module('pool_se2', nn.MaxPool2d(kernel_size=2, stride=2)) 62 | 63 | self.shared_encoder_fc = nn.Sequential() 64 | self.shared_encoder_fc.add_module('fc_se3', nn.Linear(in_features=7 * 7 * 48, out_features=code_size)) 65 | self.shared_encoder_fc.add_module('ac_se3', nn.ReLU(True)) 66 | 67 | # classify 10 numbers 68 | self.shared_encoder_pred_class = nn.Sequential() 69 | self.shared_encoder_pred_class.add_module('fc_se4', nn.Linear(in_features=code_size, out_features=100)) 70 | self.shared_encoder_pred_class.add_module('relu_se4', nn.ReLU(True)) 71 | self.shared_encoder_pred_class.add_module('fc_se5', nn.Linear(in_features=100, out_features=n_class)) 72 | 73 | self.shared_encoder_pred_domain = nn.Sequential() 74 | self.shared_encoder_pred_domain.add_module('fc_se6', nn.Linear(in_features=100, out_features=100)) 75 | self.shared_encoder_pred_domain.add_module('relu_se6', nn.ReLU(True)) 76 | 77 | # classify two domain 78 | self.shared_encoder_pred_domain.add_module('fc_se7', nn.Linear(in_features=100, out_features=2)) 79 | 80 | ###################################### 81 | # shared decoder (small decoder) 82 | ###################################### 83 | 84 | self.shared_decoder_fc = nn.Sequential() 85 | self.shared_decoder_fc.add_module('fc_sd1', nn.Linear(in_features=code_size, out_features=588)) 86 | self.shared_decoder_fc.add_module('relu_sd1', nn.ReLU(True)) 87 | 88 | self.shared_decoder_conv = nn.Sequential() 89 | self.shared_decoder_conv.add_module('conv_sd2', nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, 90 | padding=2)) 91 | self.shared_decoder_conv.add_module('relu_sd2', nn.ReLU()) 92 | 93 | self.shared_decoder_conv.add_module('conv_sd3', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, 94 | padding=2)) 95 | self.shared_decoder_conv.add_module('relu_sd3', nn.ReLU()) 96 | 97 | self.shared_decoder_conv.add_module('us_sd4', nn.Upsample(scale_factor=2)) 98 | 99 | self.shared_decoder_conv.add_module('conv_sd5', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, 100 | padding=1)) 101 | self.shared_decoder_conv.add_module('relu_sd5', nn.ReLU(True)) 102 | 103 | self.shared_decoder_conv.add_module('conv_sd6', nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, 104 | padding=1)) 105 | 106 | def forward(self, input_data, mode, rec_scheme, p=0.0): 107 | 108 | result = [] 109 | 110 | if mode == 'source': 111 | 112 | # source private encoder 113 | private_feat = self.source_encoder_conv(input_data) 114 | private_feat = private_feat.view(-1, 64 * 7 * 7) 115 | private_code = self.source_encoder_fc(private_feat) 116 | 117 | elif mode == 'target': 118 | 119 | # target private encoder 120 | private_feat = self.target_encoder_conv(input_data) 121 | private_feat = private_feat.view(-1, 64 * 7 * 7) 122 | private_code = self.target_encoder_fc(private_feat) 123 | 124 | result.append(private_code) 125 | 126 | # shared encoder 127 | shared_feat = self.shared_encoder_conv(input_data) 128 | shared_feat = shared_feat.view(-1, 48 * 7 * 7) 129 | shared_code = self.shared_encoder_fc(shared_feat) 130 | result.append(shared_code) 131 | 132 | reversed_shared_code = ReverseLayerF.apply(shared_code, p) 133 | domain_label = self.shared_encoder_pred_domain(reversed_shared_code) 134 | result.append(domain_label) 135 | 136 | if mode == 'source': 137 | class_label = self.shared_encoder_pred_class(shared_code) 138 | result.append(class_label) 139 | 140 | # shared decoder 141 | 142 | if rec_scheme == 'share': 143 | union_code = shared_code 144 | elif rec_scheme == 'all': 145 | union_code = private_code + shared_code 146 | elif rec_scheme == 'private': 147 | union_code = private_code 148 | 149 | rec_vec = self.shared_decoder_fc(union_code) 150 | rec_vec = rec_vec.view(-1, 3, 14, 14) 151 | 152 | rec_code = self.shared_decoder_conv(rec_vec) 153 | result.append(rec_code) 154 | 155 | return result 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.backends.cudnn as cudnn 3 | import torch.utils.data 4 | from torch.autograd import Variable 5 | from torchvision import transforms 6 | from data_loader import GetLoader 7 | from torchvision import datasets 8 | from model_compat import DSN 9 | import torchvision.utils as vutils 10 | 11 | 12 | def test(epoch, name): 13 | 14 | ################### 15 | # params # 16 | ################### 17 | cuda = True 18 | cudnn.benchmark = True 19 | batch_size = 64 20 | image_size = 28 21 | 22 | ################### 23 | # load data # 24 | ################### 25 | 26 | img_transform = transforms.Compose([ 27 | transforms.Resize(image_size), 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 30 | ]) 31 | 32 | model_root = 'model' 33 | if name == 'mnist': 34 | mode = 'source' 35 | image_root = os.path.join('dataset', 'mnist') 36 | dataset = datasets.MNIST( 37 | root=image_root, 38 | train=False, 39 | transform=img_transform 40 | ) 41 | 42 | dataloader = torch.utils.data.DataLoader( 43 | dataset=dataset, 44 | batch_size=batch_size, 45 | shuffle=False, 46 | num_workers=8 47 | ) 48 | 49 | elif name == 'mnist_m': 50 | mode = 'target' 51 | image_root = os.path.join('dataset', 'mnist_m', 'mnist_m_test') 52 | test_list = os.path.join('dataset', 'mnist_m', 'mnist_m_test_labels.txt') 53 | 54 | dataset = GetLoader( 55 | data_root=image_root, 56 | data_list=test_list, 57 | transform=img_transform 58 | ) 59 | 60 | dataloader = torch.utils.data.DataLoader( 61 | dataset=dataset, 62 | batch_size=batch_size, 63 | shuffle=False, 64 | num_workers=8 65 | ) 66 | 67 | else: 68 | print 'error dataset name' 69 | 70 | #################### 71 | # load model # 72 | #################### 73 | 74 | my_net = DSN() 75 | checkpoint = torch.load(os.path.join(model_root, 'dsn_mnist_mnistm_epoch_' + str(epoch) + '.pth')) 76 | my_net.load_state_dict(checkpoint) 77 | my_net.eval() 78 | 79 | if cuda: 80 | my_net = my_net.cuda() 81 | 82 | #################### 83 | # transform image # 84 | #################### 85 | 86 | 87 | def tr_image(img): 88 | 89 | img_new = (img + 1) / 2 90 | 91 | return img_new 92 | 93 | 94 | len_dataloader = len(dataloader) 95 | data_iter = iter(dataloader) 96 | 97 | i = 0 98 | n_total = 0 99 | n_correct = 0 100 | 101 | while i < len_dataloader: 102 | 103 | data_input = data_iter.next() 104 | img, label = data_input 105 | 106 | batch_size = len(label) 107 | 108 | input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) 109 | class_label = torch.LongTensor(batch_size) 110 | 111 | if cuda: 112 | img = img.cuda() 113 | label = label.cuda() 114 | input_img = input_img.cuda() 115 | class_label = class_label.cuda() 116 | 117 | input_img.resize_as_(input_img).copy_(img) 118 | class_label.resize_as_(label).copy_(label) 119 | inputv_img = Variable(input_img) 120 | classv_label = Variable(class_label) 121 | 122 | result = my_net(input_data=inputv_img, mode='source', rec_scheme='share') 123 | pred = result[3].data.max(1, keepdim=True)[1] 124 | 125 | result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all') 126 | rec_img_all = tr_image(result[-1].data) 127 | 128 | result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share') 129 | rec_img_share = tr_image(result[-1].data) 130 | 131 | result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private') 132 | rec_img_private = tr_image(result[-1].data) 133 | 134 | if i == len_dataloader - 2: 135 | vutils.save_image(rec_img_all, name + '_rec_image_all.png', nrow=8) 136 | vutils.save_image(rec_img_share, name + '_rec_image_share.png', nrow=8) 137 | vutils.save_image(rec_img_private, name + '_rec_image_private.png', nrow=8) 138 | 139 | n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum() 140 | n_total += batch_size 141 | 142 | i += 1 143 | 144 | accu = n_correct * 1.0 / n_total 145 | 146 | print 'epoch: %d, accuracy of the %s dataset: %f' % (epoch, name, accu) 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim as optim 5 | import torch.utils.data 6 | import numpy as np 7 | from torch.autograd import Variable 8 | from torchvision import datasets 9 | from torchvision import transforms 10 | from model_compat import DSN 11 | from data_loader import GetLoader 12 | from functions import SIMSE, DiffLoss, MSE 13 | from test import test 14 | 15 | ###################### 16 | # params # 17 | ###################### 18 | 19 | source_image_root = os.path.join('.', 'dataset', 'mnist') 20 | target_image_root = os.path.join('.', 'dataset', 'mnist_m') 21 | model_root = 'model' 22 | cuda = True 23 | cudnn.benchmark = True 24 | lr = 1e-2 25 | batch_size = 32 26 | image_size = 28 27 | n_epoch = 100 28 | step_decay_weight = 0.95 29 | lr_decay_step = 20000 30 | active_domain_loss_step = 10000 31 | weight_decay = 1e-6 32 | alpha_weight = 0.01 33 | beta_weight = 0.075 34 | gamma_weight = 0.25 35 | momentum = 0.9 36 | 37 | manual_seed = random.randint(1, 10000) 38 | random.seed(manual_seed) 39 | torch.manual_seed(manual_seed) 40 | 41 | ####################### 42 | # load data # 43 | ####################### 44 | 45 | img_transform = transforms.Compose([ 46 | transforms.Resize(image_size), 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 49 | ]) 50 | 51 | dataset_source = datasets.MNIST( 52 | root=source_image_root, 53 | train=True, 54 | transform=img_transform 55 | ) 56 | 57 | dataloader_source = torch.utils.data.DataLoader( 58 | dataset=dataset_source, 59 | batch_size=batch_size, 60 | shuffle=True, 61 | num_workers=8 62 | ) 63 | 64 | train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt') 65 | 66 | dataset_target = GetLoader( 67 | data_root=os.path.join(target_image_root, 'mnist_m_train'), 68 | data_list=train_list, 69 | transform=img_transform 70 | ) 71 | 72 | dataloader_target = torch.utils.data.DataLoader( 73 | dataset=dataset_target, 74 | batch_size=batch_size, 75 | shuffle=True, 76 | num_workers=8 77 | ) 78 | 79 | ##################### 80 | # load model # 81 | ##################### 82 | 83 | my_net = DSN() 84 | 85 | ##################### 86 | # setup optimizer # 87 | ##################### 88 | 89 | 90 | def exp_lr_scheduler(optimizer, step, init_lr=lr, lr_decay_step=lr_decay_step, step_decay_weight=step_decay_weight): 91 | 92 | # Decay learning rate by a factor of step_decay_weight every lr_decay_step 93 | current_lr = init_lr * (step_decay_weight ** (step / lr_decay_step)) 94 | 95 | if step % lr_decay_step == 0: 96 | print 'learning rate is set to %f' % current_lr 97 | 98 | for param_group in optimizer.param_groups: 99 | param_group['lr'] = current_lr 100 | 101 | return optimizer 102 | 103 | 104 | optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) 105 | 106 | loss_classification = torch.nn.CrossEntropyLoss() 107 | loss_recon1 = MSE() 108 | loss_recon2 = SIMSE() 109 | loss_diff = DiffLoss() 110 | loss_similarity = torch.nn.CrossEntropyLoss() 111 | 112 | if cuda: 113 | my_net = my_net.cuda() 114 | loss_classification = loss_classification.cuda() 115 | loss_recon1 = loss_recon1.cuda() 116 | loss_recon2 = loss_recon2.cuda() 117 | loss_diff = loss_diff.cuda() 118 | loss_similarity = loss_similarity.cuda() 119 | 120 | for p in my_net.parameters(): 121 | p.requires_grad = True 122 | 123 | ############################# 124 | # training network # 125 | ############################# 126 | 127 | 128 | len_dataloader = min(len(dataloader_source), len(dataloader_target)) 129 | dann_epoch = np.floor(active_domain_loss_step / len_dataloader * 1.0) 130 | 131 | current_step = 0 132 | for epoch in xrange(n_epoch): 133 | 134 | data_source_iter = iter(dataloader_source) 135 | data_target_iter = iter(dataloader_target) 136 | 137 | i = 0 138 | 139 | while i < len_dataloader: 140 | 141 | ################################### 142 | # target data training # 143 | ################################### 144 | 145 | data_target = data_target_iter.next() 146 | t_img, t_label = data_target 147 | 148 | my_net.zero_grad() 149 | loss = 0 150 | batch_size = len(t_label) 151 | 152 | input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) 153 | class_label = torch.LongTensor(batch_size) 154 | domain_label = torch.ones(batch_size) 155 | domain_label = domain_label.long() 156 | 157 | if cuda: 158 | t_img = t_img.cuda() 159 | t_label = t_label.cuda() 160 | input_img = input_img.cuda() 161 | class_label = class_label.cuda() 162 | domain_label = domain_label.cuda() 163 | 164 | input_img.resize_as_(t_img).copy_(t_img) 165 | class_label.resize_as_(t_label).copy_(t_label) 166 | target_inputv_img = Variable(input_img) 167 | target_classv_label = Variable(class_label) 168 | target_domainv_label = Variable(domain_label) 169 | 170 | if current_step > active_domain_loss_step: 171 | p = float(i + (epoch - dann_epoch) * len_dataloader / (n_epoch - dann_epoch) / len_dataloader) 172 | p = 2. / (1. + np.exp(-10 * p)) - 1 173 | 174 | # activate domain loss 175 | result = my_net(input_data=target_inputv_img, mode='target', rec_scheme='all', p=p) 176 | target_privte_code, target_share_code, target_domain_label, target_rec_code = result 177 | target_dann = gamma_weight * loss_similarity(target_domain_label, target_domainv_label) 178 | loss += target_dann 179 | else: 180 | target_dann = Variable(torch.zeros(1).float().cuda()) 181 | result = my_net(input_data=target_inputv_img, mode='target', rec_scheme='all') 182 | target_privte_code, target_share_code, _, target_rec_code = result 183 | 184 | target_diff= beta_weight * loss_diff(target_privte_code, target_share_code) 185 | loss += target_diff 186 | target_mse = alpha_weight * loss_recon1(target_rec_code, target_inputv_img) 187 | loss += target_mse 188 | target_simse = alpha_weight * loss_recon2(target_rec_code, target_inputv_img) 189 | loss += target_simse 190 | 191 | loss.backward() 192 | optimizer.step() 193 | 194 | ################################### 195 | # source data training # 196 | ################################### 197 | 198 | data_source = data_source_iter.next() 199 | s_img, s_label = data_source 200 | 201 | my_net.zero_grad() 202 | batch_size = len(s_label) 203 | 204 | input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) 205 | class_label = torch.LongTensor(batch_size) 206 | domain_label = torch.zeros(batch_size) 207 | domain_label = domain_label.long() 208 | 209 | loss = 0 210 | 211 | if cuda: 212 | s_img = s_img.cuda() 213 | s_label = s_label.cuda() 214 | input_img = input_img.cuda() 215 | class_label = class_label.cuda() 216 | domain_label = domain_label.cuda() 217 | 218 | input_img.resize_as_(input_img).copy_(s_img) 219 | class_label.resize_as_(s_label).copy_(s_label) 220 | source_inputv_img = Variable(input_img) 221 | source_classv_label = Variable(class_label) 222 | source_domainv_label = Variable(domain_label) 223 | 224 | if current_step > active_domain_loss_step: 225 | 226 | # activate domain loss 227 | 228 | result = my_net(input_data=source_inputv_img, mode='source', rec_scheme='all', p=p) 229 | source_privte_code, source_share_code, source_domain_label, source_class_label, source_rec_code = result 230 | source_dann = gamma_weight * loss_similarity(source_domain_label, source_domainv_label) 231 | loss += source_dann 232 | else: 233 | source_dann = Variable(torch.zeros(1).float().cuda()) 234 | result = my_net(input_data=source_inputv_img, mode='source', rec_scheme='all') 235 | source_privte_code, source_share_code, _, source_class_label, source_rec_code = result 236 | 237 | source_classification = loss_classification(source_class_label, source_classv_label) 238 | loss += source_classification 239 | 240 | source_diff = beta_weight * loss_diff(source_privte_code, source_share_code) 241 | loss += source_diff 242 | source_mse = alpha_weight * loss_recon1(source_rec_code, source_inputv_img) 243 | loss += source_mse 244 | source_simse = alpha_weight * loss_recon2(source_rec_code, source_inputv_img) 245 | loss += source_simse 246 | 247 | loss.backward() 248 | optimizer = exp_lr_scheduler(optimizer=optimizer, step=current_step) 249 | optimizer.step() 250 | 251 | i += 1 252 | current_step += 1 253 | print 'source_classification: %f, source_dann: %f, source_diff: %f, ' \ 254 | 'source_mse: %f, source_simse: %f, target_dann: %f, target_diff: %f, ' \ 255 | 'target_mse: %f, target_simse: %f' \ 256 | % (source_classification.data.cpu().numpy(), source_dann.data.cpu().numpy(), source_diff.data.cpu().numpy(), 257 | source_mse.data.cpu().numpy(), source_simse.data.cpu().numpy(), target_dann.data.cpu().numpy(), 258 | target_diff.data.cpu().numpy(),target_mse.data.cpu().numpy(), target_simse.data.cpu().numpy()) 259 | 260 | # print 'step: %d, loss: %f' % (current_step, loss.cpu().data.numpy()) 261 | torch.save(my_net.state_dict(), model_root + '/dsn_mnist_mnistm_epoch_' + str(epoch) + '.pth') 262 | test(epoch=epoch, name='mnist') 263 | test(epoch=epoch, name='mnist_m') 264 | 265 | print 'done' 266 | 267 | 268 | 269 | 270 | 271 | --------------------------------------------------------------------------------