├── DaNN.py ├── README.md ├── data_loader.py ├── djp_mmd.py ├── main_DaNN_DJP.py └── webcam_dslr_acc.jpg /DaNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class DaNN(nn.Module): 7 | def __init__(self, n_input=28 * 28, n_hidden=256, n_class=10): 8 | super(DaNN, self).__init__() 9 | # single layer feedforward neural network 10 | self.layer_input = nn.Linear(n_input, n_hidden) 11 | self.dropout = nn.Dropout(p=0.5) 12 | self.relu = nn.ReLU() 13 | self.layer_hidden = nn.Linear(n_hidden, n_class) 14 | 15 | # the sequence of network is defined by forward 16 | def forward(self, src, tar): 17 | x_src = self.layer_input(src) 18 | x_tar = self.layer_input(tar) 19 | x_src = self.dropout(x_src) 20 | x_tar = self.dropout(x_tar) 21 | x_src_mmd = self.relu(x_src) 22 | x_tar_mmd = self.relu(x_tar) 23 | y_src = self.layer_hidden(x_src_mmd) 24 | return y_src, x_src_mmd, x_tar_mmd 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Domain Adaptive Neural Networks with DJP-MMD 2 | 3 | This repository contains codes of the DJP-MMD metric proposed in IJCNN 2020. We later extended its to deep neural networks [Domain Adaptive Neural Networks (DaNN)](https://link.springer.com/chapter/10.1007/978-3-319-13560-1_76) by replacing the marginal MMD in DaNN. Considering this work has not been published, if you are interested in this method, please cite the original paper. 4 | 5 | ## Requirements 6 | 7 | - [PyTorch](https://pytorch.org/) (version >= 0.4.1) 8 | - [scikit-learn](https://scikit-learn.org/stable/) 9 | 10 | ## Experiments 11 | 12 | We perform the DJP-MMD in Domain Adaptive Neural Networks in [ Office-Caltech10](https://github.com/jindongwang/transferlearning/tree/master/data#office-caltech10) raw images, and this new metric shows better convergence speed and accuracy. 13 | 14 |
15 | 16 |
17 | 18 | ## Citation 19 | 20 | This code is corresponding to our [paper](https://ieeexplore.ieee.org/document/9207365) below: 21 | 22 | ``` 23 | @Inproceedings{wenz20djpmmd, 24 | title={Discriminative Joint Probability Maximum Mean Discrepancy ({DJP-MMD}) for Domain Adaptation}, 25 | author={Zhang, Wen and Wu, Dongrui}, 26 | booktitle={Proc. Int'l Joint Conf. on Neural Networks}, 27 | year={2020}, 28 | month=jul, 29 | pages={1--8}, 30 | address={Glasgow, UK} 31 | } 32 | ``` 33 | 34 | Please cite our paper if you like or use our work for your research, thanks! 35 | 36 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torchvision import datasets, transforms 5 | 6 | 7 | def load_train(root_dir, domain, batch_size): 8 | transform = transforms.Compose([ 9 | transforms.Grayscale(), 10 | transforms.Resize([28, 28]), 11 | transforms.ToTensor(), 12 | transforms.Normalize([0.5], [0.5]), 13 | ]) 14 | image_folder = datasets.ImageFolder(root=root_dir + domain, transform=transform) 15 | data_loader = torch.utils.data.DataLoader(dataset=image_folder, batch_size=batch_size, 16 | shuffle=True, num_workers=2, drop_last=True) 17 | return data_loader 18 | 19 | 20 | def load_test(root_dir, domain, batch_size): 21 | transform = transforms.Compose([ 22 | transforms.Grayscale(), 23 | transforms.Resize([28, 28]), 24 | transforms.ToTensor(), 25 | transforms.Normalize([0.5], [0.5]), 26 | ] 27 | ) 28 | image_folder = datasets.ImageFolder( 29 | root=root_dir + domain, 30 | transform=transform 31 | ) 32 | data_loader = torch.utils.data.DataLoader(dataset=image_folder, batch_size=batch_size, 33 | shuffle=False, num_workers=2) 34 | return data_loader 35 | -------------------------------------------------------------------------------- /djp_mmd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/7/4 19:18 3 | # @Author : wenzhang 4 | # @File : djp_mmd.py 5 | 6 | 7 | import torch as tr 8 | 9 | 10 | def _primal_kernel(Xs, Xt): 11 | Z = tr.cat((Xs.T, Xt.T), 1) # Xs / Xt: batch_size * k 12 | return Z 13 | 14 | 15 | def _linear_kernel(Xs, Xt): 16 | Z = tr.cat((Xs, Xt), 0) # Xs / Xt: batch_size * k 17 | K = tr.mm(Z, Z.T) 18 | return K 19 | 20 | 21 | def _rbf_kernel(Xs, Xt, sigma): 22 | Z = tr.cat((Xs, Xt), 0) 23 | ZZT = tr.mm(Z, Z.T) 24 | diag_ZZT = tr.diag(ZZT).unsqueeze(1) 25 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 26 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.T 27 | K = tr.exp(-exponent / (2 * sigma ** 2)) 28 | return K 29 | 30 | 31 | # functions to compute the marginal MMD with rbf kernel 32 | def rbf_mmd(Xs, Xt, sigma): 33 | K = _rbf_kernel(Xs, Xt, sigma) 34 | m = Xs.size(0) # assume Xs, Xt are same shape 35 | e = tr.cat((1 / m * tr.ones(m, 1), -1 / m * tr.ones(m, 1)), 0) 36 | M = e * e.T 37 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu()) 38 | loss = tr.trace(tmp).cuda() 39 | return loss 40 | 41 | 42 | # functions to compute rbf kernel JMMD 43 | def rbf_jmmd(Xs, Ys, Xt, Yt0, sigma): 44 | K = _rbf_kernel(Xs, Xt, sigma) 45 | n = K.size(0) 46 | m = Xs.size(0) # assume Xs, Xt are same shape 47 | e = tr.cat((1 / m * tr.ones(m, 1), -1 / m * tr.ones(m, 1)), 0) 48 | C = len(tr.unique(Ys)) 49 | M = e * e.T * C 50 | for c in tr.unique(Ys): 51 | e = tr.zeros(n, 1) 52 | e[:m][Ys == c] = 1 / len(Ys[Ys == c]) 53 | if len(Yt0[Yt0 == c]) == 0: 54 | e[m:][Yt0 == c] = 0 55 | else: 56 | e[m:][Yt0 == c] = -1 / len(Yt0[Yt0 == c]) 57 | M = M + e * e.T 58 | M = M / tr.norm(M, p='fro') # can reduce the training loss only for jmmd 59 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu()) 60 | loss = tr.trace(tmp).cuda() 61 | return loss 62 | 63 | 64 | # functions to compute rbf kernel JPMMD 65 | def rbf_jpmmd(Xs, Ys, Xt, Yt0, sigma): 66 | K = _rbf_kernel(Xs, Xt, sigma) 67 | n = K.size(0) 68 | m = Xs.size(0) # assume Xs, Xt are same shape 69 | M = 0 70 | for c in tr.unique(Ys): 71 | e = tr.zeros(n, 1) 72 | e[:m] = 1 / len(Ys) 73 | if len(Yt0[Yt0 == c]) == 0: 74 | e[m:] = 0 75 | else: 76 | e[m:] = -1 / len(Yt0) 77 | M = M + e * e.T 78 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu()) 79 | loss = tr.trace(tmp).cuda() 80 | return loss 81 | 82 | 83 | # functions to compute rbf kernel DJP-MMD 84 | def rbf_djpmmd(Xs, Ys, Xt, Yt0, sigma): 85 | K = _rbf_kernel(Xs, Xt, sigma) 86 | # K = _linear_kernel(Xs, Xt) # bad performance 87 | m = Xs.size(0) 88 | C = 10 # len(tr.unique(Ys)) 89 | 90 | # For transferability 91 | Ns = 1 / m * tr.zeros(m, C).scatter_(1, Ys.unsqueeze(1).cpu(), 1) 92 | Nt = tr.zeros(m, C) 93 | if len(tr.unique(Yt0)) == 1: 94 | Nt = 1 / m * tr.zeros(m, C).scatter_(1, Yt0.unsqueeze(1).cpu(), 1) 95 | Rmin_1 = tr.cat((tr.mm(Ns, Ns.T), tr.mm(-Ns, Nt.T)), 0) 96 | Rmin_2 = tr.cat((tr.mm(-Nt, Ns.T), tr.mm(Nt, Nt.T)), 0) 97 | Rmin = tr.cat((Rmin_1, Rmin_2), 1) 98 | 99 | # For discriminability 100 | Ms = tr.empty(m, (C - 1) * C) 101 | Mt = tr.empty(m, (C - 1) * C) 102 | for i in range(0, C): 103 | idx = tr.arange((C - 1) * i, (C - 1) * (i + 1)) 104 | Ms[:, idx] = Ns[:, i].repeat(C - 1, 1).T 105 | tmp = tr.arange(0, C) 106 | Mt[:, idx] = Nt[:, tmp[tmp != i]] 107 | Rmax_1 = tr.cat((tr.mm(Ms, Ms.T), tr.mm(-Ms, Mt.T)), 0) 108 | Rmax_2 = tr.cat((tr.mm(-Mt, Ms.T), tr.mm(Mt, Mt.T)), 0) 109 | Rmax = tr.cat((Rmax_1, Rmax_2), 1) 110 | M = Rmin - 0.1 * Rmax 111 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu()) 112 | loss = tr.trace(tmp.cuda()) 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /main_DaNN_DJP.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/7/4 19:18 3 | # @Author : wenzhang 4 | # @File : main_DaNN_DJP.py 5 | 6 | import numpy as np 7 | import torch as tr 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from tqdm import tqdm 11 | import djp_mmd, data_loader, DaNN 12 | import time 13 | 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | import os 18 | 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | DEVICE = tr.device('cuda' if tr.cuda.is_available() else 'cpu') 22 | 23 | # para of the network 24 | LEARNING_RATE = 0.001 # 0.001 25 | DROPOUT = 0.5 26 | N_EPOCH = 100 27 | BATCH_SIZE = [64, 64] # bathsize of source and target domain 28 | 29 | # para of the loss function 30 | # accommodate small values of MMD gradient compared to NNs for each iteration 31 | GAMMA = 1000 # 1000 more weight to transferability 32 | SIGMA = 1 # default 1 33 | 34 | 35 | # MMD, JMMD, JPMMD, DJP-MMD 36 | def mmd_loss(x_src, y_src, x_tar, y_pseudo, mmd_type): 37 | if mmd_type == 'mmd': 38 | return djp_mmd.rbf_mmd(x_src, x_tar, SIGMA) 39 | elif mmd_type == 'jmmd': 40 | return djp_mmd.rbf_jmmd(x_src, y_src, x_tar, y_pseudo, SIGMA) 41 | elif mmd_type == 'jpmmd': 42 | return djp_mmd.rbf_jpmmd(x_src, y_src, x_tar, y_pseudo, SIGMA) 43 | elif mmd_type == 'djpmmd': 44 | return djp_mmd.rbf_djpmmd(x_src, y_src, x_tar, y_pseudo, SIGMA) 45 | 46 | 47 | def model_train(model, optimizer, epoch, data_src, data_tar, y_pse, mmd_type): 48 | tmp_train_loss = 0 49 | correct = 0 50 | batch_j = 0 51 | criterion = nn.CrossEntropyLoss() 52 | list_src, list_tar = list(enumerate(data_src)), list(enumerate(data_tar)) 53 | 54 | # print('***********', len(list_src), len(list_tar)) 55 | for batch_id, (x_src, y_src) in enumerate(data_src): 56 | optimizer.zero_grad() 57 | x_src, y_src = x_src.detach().view(-1, 28 * 28).to(DEVICE), y_src.to(DEVICE) 58 | _, (x_tar, y_tar) = list_tar[batch_j] 59 | x_tar = x_tar.view(-1, 28 * 28).to(DEVICE) 60 | model.train() 61 | ypred, x_src_mmd, x_tar_mmd = model(x_src, x_tar) 62 | 63 | # print('x_src: ', x_src.shape, '\t x_tar', x_tar.shape) # both torch.Size([64, 784]) 64 | loss_ce = criterion(ypred, y_src) 65 | loss_mmd = mmd_loss(x_src_mmd, y_src, x_tar_mmd, y_pse[batch_id, :], mmd_type) 66 | pred = ypred.detach().max(1)[1] # get the index of the max log-probability 67 | 68 | # get pseudo labels of the target 69 | model.eval() 70 | pred_pse, _, _ = model(x_tar, x_tar) 71 | y_pse[batch_id, :] = pred_pse.detach().max(1)[1] 72 | 73 | # get training loss 74 | correct += pred.eq(y_src.detach().view_as(pred)).cpu().sum() 75 | loss = loss_ce + GAMMA * loss_mmd 76 | 77 | # error backward 78 | loss.backward() 79 | optimizer.step() 80 | tmp_train_loss += loss.detach() 81 | 82 | tmp_train_loss /= len(data_src) 83 | tmp_train_acc = correct * 100. / len(data_src.dataset) 84 | train_loss = tmp_train_loss.detach().cpu().numpy() 85 | train_acc = tmp_train_acc.numpy() 86 | 87 | tim = time.strftime("%H:%M:%S", time.localtime()) 88 | res_e = '{:s}, epoch: {}/{}, train loss: {:.4f}, train acc: {:.4f}'.format( 89 | tim, epoch, N_EPOCH, tmp_train_loss, tmp_train_acc) 90 | tqdm.write(res_e) 91 | return train_acc, train_loss, model 92 | 93 | 94 | def model_test(model, data_tar, epoch): 95 | tmp_test_loss = 0 96 | correct = 0 97 | criterion = nn.CrossEntropyLoss() 98 | with tr.no_grad(): 99 | for batch_id, (x_tar, y_tar) in enumerate(data_tar): 100 | x_tar, y_tar = x_tar.view(-1, 28 * 28).to(DEVICE), y_tar.to(DEVICE) 101 | model.eval() 102 | ypred, _, _ = model(x_tar, x_tar) 103 | loss = criterion(ypred, y_tar) 104 | pred = ypred.detach().max(1)[1] # get the index of the max log-probability 105 | correct += pred.eq(y_tar.detach().view_as(pred)).cpu().sum() 106 | tmp_test_loss += loss.detach() 107 | 108 | tmp_test_loss /= len(data_tar) 109 | tmp_test_acc = correct * 100. / len(data_tar.dataset) 110 | test_loss = tmp_test_loss.detach().cpu().numpy() 111 | test_acc = tmp_test_acc.numpy() 112 | 113 | res = 'test loss: {:.4f}, test acc: {:.4f}'.format(tmp_test_loss, tmp_test_acc) 114 | tqdm.write(res) 115 | return test_acc, test_loss 116 | 117 | 118 | def main(): 119 | rootdir = "/mnt/xxx/dataset/office_caltech_10/" 120 | tr.manual_seed(1) 121 | domain_str = ['webcam', 'dslr'] 122 | X_s = data_loader.load_train(root_dir=rootdir, domain=domain_str[0], batch_size=BATCH_SIZE[0]) 123 | X_t = data_loader.load_test(root_dir=rootdir, domain=domain_str[1], batch_size=BATCH_SIZE[1]) 124 | 125 | # train and test 126 | start = time.time() 127 | mmd_type = ['mmd', 'jmmd', 'jpmmd', 'djpmmd'] 128 | for mt in mmd_type: 129 | print('-' * 10 + domain_str[0] + ' --> ' + domain_str[1] + '-' * 10) 130 | print('MMD loss type: ' + mt + '\n') 131 | acc, loss = {}, {} 132 | train_acc = [] 133 | test_acc = [] 134 | train_loss = [] 135 | test_loss = [] 136 | y_pse = tr.zeros(14, 64).long().cuda() 137 | 138 | mdl = DaNN.DaNN(n_input=28 * 28, n_hidden=256, n_class=10) 139 | mdl = mdl.to(DEVICE) 140 | 141 | # optimization 142 | opt_Adam = optim.Adam(mdl.parameters(), lr=LEARNING_RATE) 143 | 144 | for ep in tqdm(range(1, N_EPOCH + 1)): 145 | tmp_train_acc, tmp_train_loss, mdl = \ 146 | model_train(model=mdl, optimizer=opt_Adam, epoch=ep, data_src=X_s, data_tar=X_t, y_pse=y_pse, 147 | mmd_type=mt) 148 | tmp_test_acc, tmp_test_loss = model_test(mdl, X_t, ep) 149 | train_acc.append(tmp_train_acc) 150 | test_acc.append(tmp_test_acc) 151 | train_loss.append(tmp_train_loss) 152 | test_loss.append(tmp_test_loss) 153 | acc['train'], acc['test'] = train_acc, test_acc 154 | loss['train'], loss['test'] = train_loss, test_loss 155 | 156 | # visualize 157 | plt.plot(acc['train'], label='train-' + mt) 158 | plt.plot(acc['test'], label='test-' + mt, ls='--') 159 | 160 | plt.title(domain_str[0] + ' to ' + domain_str[1]) 161 | plt.xticks(np.linspace(1, N_EPOCH, num=5, dtype=np.int8)) 162 | plt.xlim(1, N_EPOCH) 163 | plt.ylim(0, 100) 164 | plt.legend(loc='upper right') 165 | plt.xlabel("epochs") 166 | plt.ylabel("accuracy") 167 | plt.savefig(domain_str[0] + '_' + domain_str[1] + "_acc.jpg") 168 | plt.close() 169 | 170 | # time and save model 171 | end = time.time() 172 | print("Total run time: %.2f" % float(end - start)) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /webcam_dslr_acc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chamwen/DaNN_DJP/28d30cdac84407da6d885d66dce8ad7e9e5fb39c/webcam_dslr_acc.jpg --------------------------------------------------------------------------------