├── DaNN.py
├── README.md
├── data_loader.py
├── djp_mmd.py
├── main_DaNN_DJP.py
└── webcam_dslr_acc.jpg
/DaNN.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class DaNN(nn.Module):
7 | def __init__(self, n_input=28 * 28, n_hidden=256, n_class=10):
8 | super(DaNN, self).__init__()
9 | # single layer feedforward neural network
10 | self.layer_input = nn.Linear(n_input, n_hidden)
11 | self.dropout = nn.Dropout(p=0.5)
12 | self.relu = nn.ReLU()
13 | self.layer_hidden = nn.Linear(n_hidden, n_class)
14 |
15 | # the sequence of network is defined by forward
16 | def forward(self, src, tar):
17 | x_src = self.layer_input(src)
18 | x_tar = self.layer_input(tar)
19 | x_src = self.dropout(x_src)
20 | x_tar = self.dropout(x_tar)
21 | x_src_mmd = self.relu(x_src)
22 | x_tar_mmd = self.relu(x_tar)
23 | y_src = self.layer_hidden(x_src_mmd)
24 | return y_src, x_src_mmd, x_tar_mmd
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Domain Adaptive Neural Networks with DJP-MMD
2 |
3 | This repository contains codes of the DJP-MMD metric proposed in IJCNN 2020. We later extended its to deep neural networks [Domain Adaptive Neural Networks (DaNN)](https://link.springer.com/chapter/10.1007/978-3-319-13560-1_76) by replacing the marginal MMD in DaNN. Considering this work has not been published, if you are interested in this method, please cite the original paper.
4 |
5 | ## Requirements
6 |
7 | - [PyTorch](https://pytorch.org/) (version >= 0.4.1)
8 | - [scikit-learn](https://scikit-learn.org/stable/)
9 |
10 | ## Experiments
11 |
12 | We perform the DJP-MMD in Domain Adaptive Neural Networks in [ Office-Caltech10](https://github.com/jindongwang/transferlearning/tree/master/data#office-caltech10) raw images, and this new metric shows better convergence speed and accuracy.
13 |
14 |
15 |

16 |
17 |
18 | ## Citation
19 |
20 | This code is corresponding to our [paper](https://ieeexplore.ieee.org/document/9207365) below:
21 |
22 | ```
23 | @Inproceedings{wenz20djpmmd,
24 | title={Discriminative Joint Probability Maximum Mean Discrepancy ({DJP-MMD}) for Domain Adaptation},
25 | author={Zhang, Wen and Wu, Dongrui},
26 | booktitle={Proc. Int'l Joint Conf. on Neural Networks},
27 | year={2020},
28 | month=jul,
29 | pages={1--8},
30 | address={Glasgow, UK}
31 | }
32 | ```
33 |
34 | Please cite our paper if you like or use our work for your research, thanks!
35 |
36 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from torchvision import datasets, transforms
5 |
6 |
7 | def load_train(root_dir, domain, batch_size):
8 | transform = transforms.Compose([
9 | transforms.Grayscale(),
10 | transforms.Resize([28, 28]),
11 | transforms.ToTensor(),
12 | transforms.Normalize([0.5], [0.5]),
13 | ])
14 | image_folder = datasets.ImageFolder(root=root_dir + domain, transform=transform)
15 | data_loader = torch.utils.data.DataLoader(dataset=image_folder, batch_size=batch_size,
16 | shuffle=True, num_workers=2, drop_last=True)
17 | return data_loader
18 |
19 |
20 | def load_test(root_dir, domain, batch_size):
21 | transform = transforms.Compose([
22 | transforms.Grayscale(),
23 | transforms.Resize([28, 28]),
24 | transforms.ToTensor(),
25 | transforms.Normalize([0.5], [0.5]),
26 | ]
27 | )
28 | image_folder = datasets.ImageFolder(
29 | root=root_dir + domain,
30 | transform=transform
31 | )
32 | data_loader = torch.utils.data.DataLoader(dataset=image_folder, batch_size=batch_size,
33 | shuffle=False, num_workers=2)
34 | return data_loader
35 |
--------------------------------------------------------------------------------
/djp_mmd.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/7/4 19:18
3 | # @Author : wenzhang
4 | # @File : djp_mmd.py
5 |
6 |
7 | import torch as tr
8 |
9 |
10 | def _primal_kernel(Xs, Xt):
11 | Z = tr.cat((Xs.T, Xt.T), 1) # Xs / Xt: batch_size * k
12 | return Z
13 |
14 |
15 | def _linear_kernel(Xs, Xt):
16 | Z = tr.cat((Xs, Xt), 0) # Xs / Xt: batch_size * k
17 | K = tr.mm(Z, Z.T)
18 | return K
19 |
20 |
21 | def _rbf_kernel(Xs, Xt, sigma):
22 | Z = tr.cat((Xs, Xt), 0)
23 | ZZT = tr.mm(Z, Z.T)
24 | diag_ZZT = tr.diag(ZZT).unsqueeze(1)
25 | Z_norm_sqr = diag_ZZT.expand_as(ZZT)
26 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.T
27 | K = tr.exp(-exponent / (2 * sigma ** 2))
28 | return K
29 |
30 |
31 | # functions to compute the marginal MMD with rbf kernel
32 | def rbf_mmd(Xs, Xt, sigma):
33 | K = _rbf_kernel(Xs, Xt, sigma)
34 | m = Xs.size(0) # assume Xs, Xt are same shape
35 | e = tr.cat((1 / m * tr.ones(m, 1), -1 / m * tr.ones(m, 1)), 0)
36 | M = e * e.T
37 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu())
38 | loss = tr.trace(tmp).cuda()
39 | return loss
40 |
41 |
42 | # functions to compute rbf kernel JMMD
43 | def rbf_jmmd(Xs, Ys, Xt, Yt0, sigma):
44 | K = _rbf_kernel(Xs, Xt, sigma)
45 | n = K.size(0)
46 | m = Xs.size(0) # assume Xs, Xt are same shape
47 | e = tr.cat((1 / m * tr.ones(m, 1), -1 / m * tr.ones(m, 1)), 0)
48 | C = len(tr.unique(Ys))
49 | M = e * e.T * C
50 | for c in tr.unique(Ys):
51 | e = tr.zeros(n, 1)
52 | e[:m][Ys == c] = 1 / len(Ys[Ys == c])
53 | if len(Yt0[Yt0 == c]) == 0:
54 | e[m:][Yt0 == c] = 0
55 | else:
56 | e[m:][Yt0 == c] = -1 / len(Yt0[Yt0 == c])
57 | M = M + e * e.T
58 | M = M / tr.norm(M, p='fro') # can reduce the training loss only for jmmd
59 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu())
60 | loss = tr.trace(tmp).cuda()
61 | return loss
62 |
63 |
64 | # functions to compute rbf kernel JPMMD
65 | def rbf_jpmmd(Xs, Ys, Xt, Yt0, sigma):
66 | K = _rbf_kernel(Xs, Xt, sigma)
67 | n = K.size(0)
68 | m = Xs.size(0) # assume Xs, Xt are same shape
69 | M = 0
70 | for c in tr.unique(Ys):
71 | e = tr.zeros(n, 1)
72 | e[:m] = 1 / len(Ys)
73 | if len(Yt0[Yt0 == c]) == 0:
74 | e[m:] = 0
75 | else:
76 | e[m:] = -1 / len(Yt0)
77 | M = M + e * e.T
78 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu())
79 | loss = tr.trace(tmp).cuda()
80 | return loss
81 |
82 |
83 | # functions to compute rbf kernel DJP-MMD
84 | def rbf_djpmmd(Xs, Ys, Xt, Yt0, sigma):
85 | K = _rbf_kernel(Xs, Xt, sigma)
86 | # K = _linear_kernel(Xs, Xt) # bad performance
87 | m = Xs.size(0)
88 | C = 10 # len(tr.unique(Ys))
89 |
90 | # For transferability
91 | Ns = 1 / m * tr.zeros(m, C).scatter_(1, Ys.unsqueeze(1).cpu(), 1)
92 | Nt = tr.zeros(m, C)
93 | if len(tr.unique(Yt0)) == 1:
94 | Nt = 1 / m * tr.zeros(m, C).scatter_(1, Yt0.unsqueeze(1).cpu(), 1)
95 | Rmin_1 = tr.cat((tr.mm(Ns, Ns.T), tr.mm(-Ns, Nt.T)), 0)
96 | Rmin_2 = tr.cat((tr.mm(-Nt, Ns.T), tr.mm(Nt, Nt.T)), 0)
97 | Rmin = tr.cat((Rmin_1, Rmin_2), 1)
98 |
99 | # For discriminability
100 | Ms = tr.empty(m, (C - 1) * C)
101 | Mt = tr.empty(m, (C - 1) * C)
102 | for i in range(0, C):
103 | idx = tr.arange((C - 1) * i, (C - 1) * (i + 1))
104 | Ms[:, idx] = Ns[:, i].repeat(C - 1, 1).T
105 | tmp = tr.arange(0, C)
106 | Mt[:, idx] = Nt[:, tmp[tmp != i]]
107 | Rmax_1 = tr.cat((tr.mm(Ms, Ms.T), tr.mm(-Ms, Mt.T)), 0)
108 | Rmax_2 = tr.cat((tr.mm(-Mt, Ms.T), tr.mm(Mt, Mt.T)), 0)
109 | Rmax = tr.cat((Rmax_1, Rmax_2), 1)
110 | M = Rmin - 0.1 * Rmax
111 | tmp = tr.mm(tr.mm(K.cpu(), M.cpu()), K.T.cpu())
112 | loss = tr.trace(tmp.cuda())
113 |
114 | return loss
115 |
--------------------------------------------------------------------------------
/main_DaNN_DJP.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/7/4 19:18
3 | # @Author : wenzhang
4 | # @File : main_DaNN_DJP.py
5 |
6 | import numpy as np
7 | import torch as tr
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from tqdm import tqdm
11 | import djp_mmd, data_loader, DaNN
12 | import time
13 |
14 | import matplotlib as mpl
15 | import matplotlib.pyplot as plt
16 |
17 | import os
18 |
19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
21 | DEVICE = tr.device('cuda' if tr.cuda.is_available() else 'cpu')
22 |
23 | # para of the network
24 | LEARNING_RATE = 0.001 # 0.001
25 | DROPOUT = 0.5
26 | N_EPOCH = 100
27 | BATCH_SIZE = [64, 64] # bathsize of source and target domain
28 |
29 | # para of the loss function
30 | # accommodate small values of MMD gradient compared to NNs for each iteration
31 | GAMMA = 1000 # 1000 more weight to transferability
32 | SIGMA = 1 # default 1
33 |
34 |
35 | # MMD, JMMD, JPMMD, DJP-MMD
36 | def mmd_loss(x_src, y_src, x_tar, y_pseudo, mmd_type):
37 | if mmd_type == 'mmd':
38 | return djp_mmd.rbf_mmd(x_src, x_tar, SIGMA)
39 | elif mmd_type == 'jmmd':
40 | return djp_mmd.rbf_jmmd(x_src, y_src, x_tar, y_pseudo, SIGMA)
41 | elif mmd_type == 'jpmmd':
42 | return djp_mmd.rbf_jpmmd(x_src, y_src, x_tar, y_pseudo, SIGMA)
43 | elif mmd_type == 'djpmmd':
44 | return djp_mmd.rbf_djpmmd(x_src, y_src, x_tar, y_pseudo, SIGMA)
45 |
46 |
47 | def model_train(model, optimizer, epoch, data_src, data_tar, y_pse, mmd_type):
48 | tmp_train_loss = 0
49 | correct = 0
50 | batch_j = 0
51 | criterion = nn.CrossEntropyLoss()
52 | list_src, list_tar = list(enumerate(data_src)), list(enumerate(data_tar))
53 |
54 | # print('***********', len(list_src), len(list_tar))
55 | for batch_id, (x_src, y_src) in enumerate(data_src):
56 | optimizer.zero_grad()
57 | x_src, y_src = x_src.detach().view(-1, 28 * 28).to(DEVICE), y_src.to(DEVICE)
58 | _, (x_tar, y_tar) = list_tar[batch_j]
59 | x_tar = x_tar.view(-1, 28 * 28).to(DEVICE)
60 | model.train()
61 | ypred, x_src_mmd, x_tar_mmd = model(x_src, x_tar)
62 |
63 | # print('x_src: ', x_src.shape, '\t x_tar', x_tar.shape) # both torch.Size([64, 784])
64 | loss_ce = criterion(ypred, y_src)
65 | loss_mmd = mmd_loss(x_src_mmd, y_src, x_tar_mmd, y_pse[batch_id, :], mmd_type)
66 | pred = ypred.detach().max(1)[1] # get the index of the max log-probability
67 |
68 | # get pseudo labels of the target
69 | model.eval()
70 | pred_pse, _, _ = model(x_tar, x_tar)
71 | y_pse[batch_id, :] = pred_pse.detach().max(1)[1]
72 |
73 | # get training loss
74 | correct += pred.eq(y_src.detach().view_as(pred)).cpu().sum()
75 | loss = loss_ce + GAMMA * loss_mmd
76 |
77 | # error backward
78 | loss.backward()
79 | optimizer.step()
80 | tmp_train_loss += loss.detach()
81 |
82 | tmp_train_loss /= len(data_src)
83 | tmp_train_acc = correct * 100. / len(data_src.dataset)
84 | train_loss = tmp_train_loss.detach().cpu().numpy()
85 | train_acc = tmp_train_acc.numpy()
86 |
87 | tim = time.strftime("%H:%M:%S", time.localtime())
88 | res_e = '{:s}, epoch: {}/{}, train loss: {:.4f}, train acc: {:.4f}'.format(
89 | tim, epoch, N_EPOCH, tmp_train_loss, tmp_train_acc)
90 | tqdm.write(res_e)
91 | return train_acc, train_loss, model
92 |
93 |
94 | def model_test(model, data_tar, epoch):
95 | tmp_test_loss = 0
96 | correct = 0
97 | criterion = nn.CrossEntropyLoss()
98 | with tr.no_grad():
99 | for batch_id, (x_tar, y_tar) in enumerate(data_tar):
100 | x_tar, y_tar = x_tar.view(-1, 28 * 28).to(DEVICE), y_tar.to(DEVICE)
101 | model.eval()
102 | ypred, _, _ = model(x_tar, x_tar)
103 | loss = criterion(ypred, y_tar)
104 | pred = ypred.detach().max(1)[1] # get the index of the max log-probability
105 | correct += pred.eq(y_tar.detach().view_as(pred)).cpu().sum()
106 | tmp_test_loss += loss.detach()
107 |
108 | tmp_test_loss /= len(data_tar)
109 | tmp_test_acc = correct * 100. / len(data_tar.dataset)
110 | test_loss = tmp_test_loss.detach().cpu().numpy()
111 | test_acc = tmp_test_acc.numpy()
112 |
113 | res = 'test loss: {:.4f}, test acc: {:.4f}'.format(tmp_test_loss, tmp_test_acc)
114 | tqdm.write(res)
115 | return test_acc, test_loss
116 |
117 |
118 | def main():
119 | rootdir = "/mnt/xxx/dataset/office_caltech_10/"
120 | tr.manual_seed(1)
121 | domain_str = ['webcam', 'dslr']
122 | X_s = data_loader.load_train(root_dir=rootdir, domain=domain_str[0], batch_size=BATCH_SIZE[0])
123 | X_t = data_loader.load_test(root_dir=rootdir, domain=domain_str[1], batch_size=BATCH_SIZE[1])
124 |
125 | # train and test
126 | start = time.time()
127 | mmd_type = ['mmd', 'jmmd', 'jpmmd', 'djpmmd']
128 | for mt in mmd_type:
129 | print('-' * 10 + domain_str[0] + ' --> ' + domain_str[1] + '-' * 10)
130 | print('MMD loss type: ' + mt + '\n')
131 | acc, loss = {}, {}
132 | train_acc = []
133 | test_acc = []
134 | train_loss = []
135 | test_loss = []
136 | y_pse = tr.zeros(14, 64).long().cuda()
137 |
138 | mdl = DaNN.DaNN(n_input=28 * 28, n_hidden=256, n_class=10)
139 | mdl = mdl.to(DEVICE)
140 |
141 | # optimization
142 | opt_Adam = optim.Adam(mdl.parameters(), lr=LEARNING_RATE)
143 |
144 | for ep in tqdm(range(1, N_EPOCH + 1)):
145 | tmp_train_acc, tmp_train_loss, mdl = \
146 | model_train(model=mdl, optimizer=opt_Adam, epoch=ep, data_src=X_s, data_tar=X_t, y_pse=y_pse,
147 | mmd_type=mt)
148 | tmp_test_acc, tmp_test_loss = model_test(mdl, X_t, ep)
149 | train_acc.append(tmp_train_acc)
150 | test_acc.append(tmp_test_acc)
151 | train_loss.append(tmp_train_loss)
152 | test_loss.append(tmp_test_loss)
153 | acc['train'], acc['test'] = train_acc, test_acc
154 | loss['train'], loss['test'] = train_loss, test_loss
155 |
156 | # visualize
157 | plt.plot(acc['train'], label='train-' + mt)
158 | plt.plot(acc['test'], label='test-' + mt, ls='--')
159 |
160 | plt.title(domain_str[0] + ' to ' + domain_str[1])
161 | plt.xticks(np.linspace(1, N_EPOCH, num=5, dtype=np.int8))
162 | plt.xlim(1, N_EPOCH)
163 | plt.ylim(0, 100)
164 | plt.legend(loc='upper right')
165 | plt.xlabel("epochs")
166 | plt.ylabel("accuracy")
167 | plt.savefig(domain_str[0] + '_' + domain_str[1] + "_acc.jpg")
168 | plt.close()
169 |
170 | # time and save model
171 | end = time.time()
172 | print("Total run time: %.2f" % float(end - start))
173 |
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
--------------------------------------------------------------------------------
/webcam_dslr_acc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chamwen/DaNN_DJP/28d30cdac84407da6d885d66dce8ad7e9e5fb39c/webcam_dslr_acc.jpg
--------------------------------------------------------------------------------