├── DANN_torch.ipynb ├── MDAN_torch.ipynb ├── README.md └── images ├── 48,1000,d_acc.png ├── 48,1000,loss.png ├── 48,1000,p_acc.png ├── DC_DA_10.png ├── DC_original_10.png ├── dann.png ├── init └── visulization.png /DANN_torch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 10 | "\n", 11 | "import torchvision\n", 12 | "import torch.optim as optim\n", 13 | "\n", 14 | "import numpy as np\n", 15 | "\n", 16 | "from models import models\n", 17 | "from sklearn.manifold import TSNE\n", 18 | "\n", 19 | "import argparse, sys, os\n", 20 | "\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import torch.nn.functional as F\n", 24 | "from torch.autograd import Variable\n", 25 | "import torch.nn.init as init\n", 26 | "from torch.utils.data import DataLoader\n", 27 | "from torchvision import datasets, transforms\n", 28 | "\n", 29 | "import time\n", 30 | "from collections import Counter\n", 31 | "import matplotlib.pyplot as plt " 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "(23758, 62, 5) (23758,)\n", 44 | "(23758, 62, 5) (23758,)\n", 45 | "Counter({1: 8190, 2: 7840, 0: 7728})\n", 46 | "(47516, 62, 5) (47516,)\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "data2 = np.load(r'C:\\Users\\dingk\\Desktop\\experiment\\DL-model with MMD\\ExtractedFeatures\\3session012\\session2.npy')\n", 52 | "label2 = np.load(r'C:\\Users\\dingk\\Desktop\\experiment\\DL-model with MMD\\ExtractedFeatures\\3session012\\label2.npy')\n", 53 | "data1 = np.load(r'C:\\Users\\dingk\\Desktop\\experiment\\DL-model with MMD\\ExtractedFeatures\\3session012\\session1.npy')\n", 54 | "label1 = np.load(r'C:\\Users\\dingk\\Desktop\\experiment\\DL-model with MMD\\ExtractedFeatures\\3session012\\label1.npy')\n", 55 | "data3 = np.load(r'C:\\Users\\dingk\\Desktop\\experiment\\DL-model with MMD\\ExtractedFeatures\\3session012\\session3.npy')\n", 56 | "label3 = np.load(r'C:\\Users\\dingk\\Desktop\\experiment\\DL-model with MMD\\ExtractedFeatures\\3session012\\label3.npy')\n", 57 | "\n", 58 | "print(data2.shape,label2.shape)\n", 59 | "print(data1.shape,label1.shape)\n", 60 | "\n", 61 | "print(Counter(label1))\n", 62 | "\n", 63 | "train_data = np.concatenate((data1,data3),axis = 0)\n", 64 | "label_data = np.concatenate((label1,label3),axis = 0)\n", 65 | "\n", 66 | "test = data2\n", 67 | "label_test = label2\n", 68 | "print(train_data.shape,label_data.shape)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def get_train_loader(data,label,batch_size,shuffle=True):\n", 78 | " \"\"\"\n", 79 | " Get train dataloader of source domain or target domain\n", 80 | " :return: dataloader\n", 81 | " \"\"\"\n", 82 | " tensor_x = torch.Tensor(data) # transform to torch tensor\n", 83 | " tensor_y = torch.Tensor(label)\n", 84 | "\n", 85 | " my_dataset = torch.utils.data.TensorDataset(tensor_x,tensor_y) # create your datset\n", 86 | " my_dataloader = torch.utils.data.DataLoader(my_dataset,batch_size=batch_size,shuffle=True,drop_last=True,\n", 87 | " num_workers=8, pin_memory=True) # create your dataloader\n", 88 | " \n", 89 | " return my_dataloader\n", 90 | "\n", 91 | "def get_test_loader(data,label,batch_size,shuffle=True):\n", 92 | " \"\"\"\n", 93 | " Get test dataloader of source domain or target domain\n", 94 | " :return: dataloader\n", 95 | " \"\"\"\n", 96 | " tensor_x = torch.Tensor(data) # transform to torch tensor\n", 97 | " tensor_y = torch.Tensor(label)\n", 98 | "\n", 99 | " my_dataset = torch.utils.data.TensorDataset(tensor_x,tensor_y) # create your datset\n", 100 | " my_dataloader = torch.utils.data.DataLoader(my_dataset,batch_size=batch_size,shuffle=True,drop_last=True,\n", 101 | " num_workers=8, pin_memory=True) # create your dataloader\n", 102 | " \n", 103 | " return my_dataloader\n", 104 | "\n", 105 | "\n", 106 | "def optimizer_scheduler(optimizer, p):\n", 107 | " \"\"\"\n", 108 | " Adjust the learning rate of optimizer\n", 109 | " :param optimizer: optimizer for updating parameters\n", 110 | " :param p: a variable for adjusting learning rate\n", 111 | " :return: optimizer\n", 112 | " \"\"\"\n", 113 | " for param_group in optimizer.param_groups:\n", 114 | " param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75\n", 115 | "\n", 116 | " return optimizer" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "class GradReverse(torch.autograd.Function):\n", 126 | " \"\"\"\n", 127 | " Extension of grad reverse layer\n", 128 | " \"\"\"\n", 129 | " @staticmethod\n", 130 | " def forward(ctx, x, constant):\n", 131 | " ctx.constant = constant\n", 132 | " return x.view_as(x)\n", 133 | "\n", 134 | " @staticmethod\n", 135 | " def backward(ctx, grad_output):\n", 136 | " grad_output = grad_output.neg() * ctx.constant\n", 137 | " return grad_output, None\n", 138 | "\n", 139 | " def grad_reverse(x, constant):\n", 140 | " return GradReverse.apply(x, constant)\n", 141 | "\n", 142 | "class Extractor(nn.Module):\n", 143 | "\n", 144 | " def __init__(self):\n", 145 | " super(Extractor, self).__init__()\n", 146 | " self.conv1 = nn.Conv1d(5, 64, 5, 1) \n", 147 | " self.bn1 = nn.BatchNorm1d(64)\n", 148 | " self.conv2 = nn.Conv1d(64, 64, 5, 1) \n", 149 | " self.bn2 = nn.BatchNorm1d(64)\n", 150 | " self.conv3 = nn.Conv1d(64, 64, 5, 1) \n", 151 | " self.bn3 = nn.BatchNorm1d(64)\n", 152 | " self.pool = nn.AvgPool1d(5, stride=5)\n", 153 | "# self.pool = nn.MaxPool1d(5, stride=5)\n", 154 | " \n", 155 | "\n", 156 | " def forward(self, input):\n", 157 | " x = input.permute(0,2,1)\n", 158 | "# print('after permute ',x.shape) # 64,5,62\n", 159 | " x = F.relu(self.bn1(self.conv1(x))) \n", 160 | "# print('after CNN1 ',x.shape) # 64, 64, 58\n", 161 | " x = F.relu(self.bn2(self.conv2(x))) \n", 162 | "# print('after CNN2 ',x.shape) #64, 64, 54\n", 163 | "# x = F.avg_pool1d(x,kernel_size=5) \n", 164 | " x = F.relu(self.bn3(self.conv3(x))) \n", 165 | " x = self.pool(x)\n", 166 | " \n", 167 | "# print('before fc ',x.shape) #64, 64, 10\n", 168 | " x = x.view(-1, 64 * 10) \n", 169 | " \n", 170 | " return x\n", 171 | "\n", 172 | "\n", 173 | "class Class_classifier(nn.Module):\n", 174 | "\n", 175 | " def __init__(self):\n", 176 | " super(Class_classifier, self).__init__()\n", 177 | " # self.fc1 = nn.Linear(50 * 4 * 4, 100)\n", 178 | " # self.bn1 = nn.BatchNorm1d(100)\n", 179 | " # self.fc2 = nn.Linear(100, 100)\n", 180 | " # self.bn2 = nn.BatchNorm1d(100)\n", 181 | " # self.fc3 = nn.Linear(100, 10)\n", 182 | " self.fc1 = nn.Linear(64 * 10, 128)\n", 183 | " self.fc2 = nn.Linear(128, 128)\n", 184 | " self.fc3 = nn.Linear(128, 3)\n", 185 | "\n", 186 | " def forward(self, input):\n", 187 | " # logits = F.relu(self.bn1(self.fc1(input)))\n", 188 | " # logits = self.fc2(F.dropout(logits))\n", 189 | " # logits = F.relu(self.bn2(logits))\n", 190 | " # logits = self.fc3(logits)\n", 191 | " logits = F.relu(self.fc1(input))\n", 192 | " logits = self.fc2(F.dropout(logits))\n", 193 | " logits = F.relu(logits)\n", 194 | " logits = self.fc3(logits)\n", 195 | "\n", 196 | " return F.log_softmax(logits, 1)\n", 197 | "\n", 198 | "class Domain_classifier(nn.Module):\n", 199 | "\n", 200 | " def __init__(self):\n", 201 | " super(Domain_classifier, self).__init__()\n", 202 | " # self.fc1 = nn.Linear(50 * 4 * 4, 100)\n", 203 | " # self.bn1 = nn.BatchNorm1d(100)\n", 204 | " # self.fc2 = nn.Linear(100, 2)\n", 205 | " self.fc1 = nn.Linear(64 * 10, 128)\n", 206 | " self.fc2 = nn.Linear(128, 2)\n", 207 | "\n", 208 | " def forward(self, input, constant):\n", 209 | " input = GradReverse.grad_reverse(input, constant)\n", 210 | " # logits = F.relu(self.bn1(self.fc1(input)))\n", 211 | " # logits = F.log_softmax(self.fc2(logits), 1)\n", 212 | " logits = F.relu(self.fc1(input))\n", 213 | " logits = F.log_softmax(self.fc2(logits), 1)\n", 214 | "\n", 215 | " return logits\n" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 5, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "def train(training_mode, feature_extractor, class_classifier, domain_classifier, class_criterion, domain_criterion,\n", 225 | " source_dataloader, target_dataloader, optimizer, epoch):\n", 226 | " \"\"\"\n", 227 | " Execute target domain adaptation\n", 228 | " :param training_mode:\n", 229 | " :param feature_extractor:\n", 230 | " :param class_classifier:\n", 231 | " :param domain_classifier:\n", 232 | " :param class_criterion:\n", 233 | " :param domain_criterion:\n", 234 | " :param source_dataloader:\n", 235 | " :param target_dataloader:\n", 236 | " :param optimizer:\n", 237 | " :return:\n", 238 | " \"\"\"\n", 239 | "\n", 240 | " # setup models\n", 241 | " feature_extractor.train()\n", 242 | " class_classifier.train()\n", 243 | " domain_classifier.train()\n", 244 | "\n", 245 | " # steps\n", 246 | " start_steps = epoch * len(source_dataloader)\n", 247 | " total_steps = 10 * len(source_dataloader)\n", 248 | "\n", 249 | " for batch_idx, (sdata, tdata) in enumerate(zip(source_dataloader, target_dataloader)):\n", 250 | "\n", 251 | " if training_mode == 'dann':\n", 252 | " # setup hyperparameters\n", 253 | " p = float(batch_idx + start_steps) / total_steps\n", 254 | " constant = 2. / (1. + np.exp(-gamma * p)) - 1\n", 255 | "\n", 256 | " # prepare the data\n", 257 | " input1, label1 = sdata\n", 258 | " input2, label2 = tdata\n", 259 | " size = min((input1.shape[0], input2.shape[0]))\n", 260 | " input1, label1 = input1[0:size, :, :], label1[0:size]\n", 261 | " input2, label2 = input2[0:size, :, :], label2[0:size]\n", 262 | " \n", 263 | " input1, label1 = Variable(input1.cuda()), Variable(label1.cuda().long())\n", 264 | " input2, label2 = Variable(input2.cuda()), Variable(label2.cuda().long())\n", 265 | "\n", 266 | "\n", 267 | " # setup optimizer\n", 268 | " optimizer = optimizer_scheduler(optimizer, p)\n", 269 | " optimizer.zero_grad()\n", 270 | "\n", 271 | " # prepare domain labels\n", 272 | " \n", 273 | " source_labels = Variable(torch.zeros((input1.size()[0])).type(torch.LongTensor).cuda())\n", 274 | " target_labels = Variable(torch.ones((input2.size()[0])).type(torch.LongTensor).cuda())\n", 275 | "\n", 276 | " # compute the output of source domain and target domain\n", 277 | " src_feature = feature_extractor(input1)\n", 278 | " tgt_feature = feature_extractor(input2)\n", 279 | "\n", 280 | " # compute the class loss of src_feature\n", 281 | " class_preds = class_classifier(src_feature)\n", 282 | " class_loss = class_criterion(class_preds, label1)\n", 283 | "\n", 284 | " # compute the domain loss of src_feature and target_feature\n", 285 | " tgt_preds = domain_classifier(tgt_feature, constant)\n", 286 | " src_preds = domain_classifier(src_feature, constant)\n", 287 | " tgt_loss = domain_criterion(tgt_preds, target_labels)\n", 288 | " src_loss = domain_criterion(src_preds, source_labels)\n", 289 | " domain_loss = tgt_loss + src_loss\n", 290 | "\n", 291 | " loss = class_loss + theta * domain_loss\n", 292 | " loss.backward()\n", 293 | " optimizer.step()\n", 294 | "\n", 295 | " # print loss\n", 296 | " if (batch_idx + 1) % 100 == 0:\n", 297 | " print('[{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tClass Loss: {:.6f}\\tDomain Loss: {:.6f}'.format(\n", 298 | " batch_idx * len(input2), len(target_dataloader.dataset),\n", 299 | " 100. * batch_idx / len(target_dataloader), loss.item(), class_loss.item(),\n", 300 | " domain_loss.item()\n", 301 | " ))\n", 302 | " \n", 303 | " total_loss.append(loss.item())\n", 304 | " c_loss.append( class_loss.item())\n", 305 | " d_loss.append(domain_loss.item())\n", 306 | "\n", 307 | "\n", 308 | " elif training_mode == 'source':\n", 309 | " # prepare the data\n", 310 | " input1, label1 = sdata\n", 311 | " size = input1.shape[0]\n", 312 | " input1, label1 = input1[0:size, :, :], label1[0:size]\n", 313 | " input1, label1 = Variable(input1.cuda()), Variable(label1.cuda().long())\n", 314 | " \n", 315 | "\n", 316 | " # setup optimizer\n", 317 | " optimizer = optim.SGD(list(feature_extractor.parameters())+list(class_classifier.parameters()), lr=0.01, momentum=0.9)\n", 318 | " optimizer.zero_grad()\n", 319 | "\n", 320 | " # compute the output of source domain and target domain\n", 321 | " src_feature = feature_extractor(input1)\n", 322 | "\n", 323 | " # compute the class loss of src_feature\n", 324 | " class_preds = class_classifier(src_feature)\n", 325 | " class_loss = class_criterion(class_preds, label1)\n", 326 | "\n", 327 | " class_loss.backward()\n", 328 | " optimizer.step()\n", 329 | "\n", 330 | " # print loss\n", 331 | " if (batch_idx + 1) % 10 == 0:\n", 332 | " print('[{}/{} ({:.0f}%)]\\tClass Loss: {:.6f}'.format(\n", 333 | " batch_idx * len(input1), len(source_dataloader.dataset),\n", 334 | " 100. * batch_idx / len(source_dataloader), class_loss.item()\n", 335 | " ))\n", 336 | "\n", 337 | " elif training_mode == 'target':\n", 338 | " # prepare the data\n", 339 | " input2, label2 = tdata\n", 340 | " size = input2.shape[0]\n", 341 | " input2, label2 = input2[0:size, :, :], label2[0:size]\n", 342 | " input2, label2 = Variable(input2.cuda()), Variable(label2.cuda())\n", 343 | "\n", 344 | " # setup optimizer\n", 345 | " optimizer = optim.SGD(list(feature_extractor.parameters()) + list(class_classifier.parameters()), lr=0.01,\n", 346 | " momentum=0.9)\n", 347 | " optimizer.zero_grad()\n", 348 | "\n", 349 | " # compute the output of source domain and target domain\n", 350 | " tgt_feature = feature_extractor(input2)\n", 351 | "\n", 352 | " # compute the class loss of src_feature\n", 353 | " class_preds = class_classifier(tgt_feature)\n", 354 | " class_loss = class_criterion(class_preds, label2)\n", 355 | "\n", 356 | " class_loss.backward()\n", 357 | " optimizer.step()\n", 358 | "\n", 359 | " # print loss\n", 360 | " if (batch_idx + 1) % 10 == 0:\n", 361 | " print('[{}/{} ({:.0f}%)]\\tClass Loss: {:.6f}'.format(\n", 362 | " batch_idx * len(input2), len(target_dataloader.dataset),\n", 363 | " 100. * batch_idx / len(target_dataloader), class_loss.item()\n", 364 | " ))" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 6, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "def test(feature_extractor, class_classifier, domain_classifier, source_dataloader, target_dataloader):\n", 374 | " \"\"\"\n", 375 | " Test the performance of the model\n", 376 | " :param feature_extractor: network used to extract feature from target samples\n", 377 | " :param class_classifier: network used to predict labels\n", 378 | " :param domain_classifier: network used to predict domain\n", 379 | " :param source_dataloader: test dataloader of source domain\n", 380 | " :param target_dataloader: test dataloader of target domain\n", 381 | " :return: None\n", 382 | " \"\"\"\n", 383 | " # setup the network\n", 384 | " feature_extractor.eval()\n", 385 | " class_classifier.eval()\n", 386 | " domain_classifier.eval()\n", 387 | " source_correct = 0.0\n", 388 | " target_correct = 0.0\n", 389 | " domain_correct = 0.0\n", 390 | " tgt_correct = 0.0\n", 391 | " src_correct = 0.0\n", 392 | "\n", 393 | " for batch_idx, sdata in enumerate(source_dataloader):\n", 394 | " # setup hyperparameters\n", 395 | " p = float(batch_idx) / len(source_dataloader)\n", 396 | " constant = 2. / (1. + np.exp(-10 * p)) - 1.\n", 397 | "\n", 398 | " input1, label1 = sdata\n", 399 | " input1, label1 = Variable(input1.cuda()), Variable(label1.cuda().long())\n", 400 | " src_labels = Variable(torch.zeros((input1.size()[0])).type(torch.LongTensor).cuda())\n", 401 | "\n", 402 | " output1 = class_classifier(feature_extractor(input1))\n", 403 | " pred1 = output1.data.max(1, keepdim = True)[1]\n", 404 | " source_correct += pred1.eq(label1.data.view_as(pred1)).cpu().sum()\n", 405 | "\n", 406 | " src_preds = domain_classifier(feature_extractor(input1), constant)\n", 407 | " src_preds = src_preds.data.max(1, keepdim= True)[1]\n", 408 | " src_correct += src_preds.eq(src_labels.data.view_as(src_preds)).cpu().sum()\n", 409 | "\n", 410 | " for batch_idx, tdata in enumerate(target_dataloader):\n", 411 | " # setup hyperparameters\n", 412 | " p = float(batch_idx) / len(source_dataloader)\n", 413 | " constant = 2. / (1. + np.exp(-10 * p)) - 1\n", 414 | "\n", 415 | " input2, label2 = tdata\n", 416 | " input2, label2 = Variable(input2.cuda()), Variable(label2.cuda().long())\n", 417 | " tgt_labels = Variable(torch.ones((input2.size()[0])).type(torch.LongTensor).cuda())\n", 418 | "\n", 419 | " output2 = class_classifier(feature_extractor(input2))\n", 420 | " pred2 = output2.data.max(1, keepdim=True)[1]\n", 421 | " target_correct += pred2.eq(label2.data.view_as(pred2)).cpu().sum()\n", 422 | "\n", 423 | " tgt_preds = domain_classifier(feature_extractor(input2), constant)\n", 424 | " tgt_preds = tgt_preds.data.max(1, keepdim=True)[1]\n", 425 | " tgt_correct += tgt_preds.eq(tgt_labels.data.view_as(tgt_preds)).cpu().sum()\n", 426 | "\n", 427 | " domain_correct = tgt_correct + src_correct\n", 428 | "\n", 429 | " print('\\nSource Accuracy: {}/{} ({:.4f}%)\\nTarget Accuracy: {}/{} ({:.4f}%)\\n'\n", 430 | " 'Domain Accuracy: {}/{} ({:.4f}%)\\n'.\n", 431 | " format(\n", 432 | " source_correct, len(source_dataloader.dataset), 100. * float(source_correct) / len(source_dataloader.dataset),\n", 433 | " target_correct, len(target_dataloader.dataset), 100. * float(target_correct) / len(target_dataloader.dataset),\n", 434 | " domain_correct, len(source_dataloader.dataset) + len(target_dataloader.dataset),\n", 435 | " 100. * float(domain_correct) / (len(source_dataloader.dataset) + len(target_dataloader.dataset))\n", 436 | " ))\n", 437 | " acc_list1.append(100. * float(source_correct) / len(source_dataloader.dataset))\n", 438 | " acc_list2.append(100. * float(target_correct) / len(target_dataloader.dataset))\n", 439 | " acc_list3.append(100. * float(domain_correct) / (len(source_dataloader.dataset) + len(target_dataloader.dataset)))\n" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 7, 445 | "metadata": { 446 | "scrolled": true 447 | }, 448 | "outputs": [ 449 | { 450 | "name": "stdout", 451 | "output_type": "stream", 452 | "text": [ 453 | "Epoch: 0\n", 454 | "[9900/23758 (42%)]\tLoss: 1.790843\tClass Loss: 0.649918\tDomain Loss: 1.140924\n", 455 | "[19900/23758 (84%)]\tLoss: 1.928555\tClass Loss: 0.359018\tDomain Loss: 1.569537\n", 456 | "\n", 457 | "Source Accuracy: 15021.0/23758 (63.2250%)\n", 458 | "Target Accuracy: 12669.0/23758 (53.3252%)\n", 459 | "Domain Accuracy: 27316.0/47516 (57.4880%)\n", 460 | "\n", 461 | "Epoch: 1\n", 462 | "[9900/23758 (42%)]\tLoss: 1.417862\tClass Loss: 0.219204\tDomain Loss: 1.198658\n", 463 | "[19900/23758 (84%)]\tLoss: 1.479155\tClass Loss: 0.186296\tDomain Loss: 1.292859\n", 464 | "\n", 465 | "Source Accuracy: 7711.0/23758 (32.4564%)\n", 466 | "Target Accuracy: 7712.0/23758 (32.4606%)\n", 467 | "Domain Accuracy: 23700.0/47516 (49.8779%)\n", 468 | "\n", 469 | "Epoch: 2\n", 470 | "[9900/23758 (42%)]\tLoss: 1.485815\tClass Loss: 0.096645\tDomain Loss: 1.389169\n", 471 | "[19900/23758 (84%)]\tLoss: 1.428321\tClass Loss: 0.046468\tDomain Loss: 1.381853\n", 472 | "\n", 473 | "Source Accuracy: 16596.0/23758 (69.8544%)\n", 474 | "Target Accuracy: 14933.0/23758 (62.8546%)\n", 475 | "Domain Accuracy: 31245.0/47516 (65.7568%)\n", 476 | "\n", 477 | "Epoch: 3\n", 478 | "[9900/23758 (42%)]\tLoss: 1.589309\tClass Loss: 0.079787\tDomain Loss: 1.509522\n", 479 | "[19900/23758 (84%)]\tLoss: 1.663020\tClass Loss: 0.186908\tDomain Loss: 1.476113\n", 480 | "\n", 481 | "Source Accuracy: 15765.0/23758 (66.3566%)\n", 482 | "Target Accuracy: 12356.0/23758 (52.0077%)\n", 483 | "Domain Accuracy: 27057.0/47516 (56.9429%)\n", 484 | "\n", 485 | "Epoch: 4\n", 486 | "[9900/23758 (42%)]\tLoss: 1.474490\tClass Loss: 0.119472\tDomain Loss: 1.355018\n", 487 | "[19900/23758 (84%)]\tLoss: 1.438168\tClass Loss: 0.040553\tDomain Loss: 1.397614\n", 488 | "\n", 489 | "Source Accuracy: 14429.0/23758 (60.7332%)\n", 490 | "Target Accuracy: 12994.0/23758 (54.6932%)\n", 491 | "Domain Accuracy: 29714.0/47516 (62.5347%)\n", 492 | "\n", 493 | "Epoch: 5\n", 494 | "[9900/23758 (42%)]\tLoss: 1.412591\tClass Loss: 0.033921\tDomain Loss: 1.378671\n", 495 | "[19900/23758 (84%)]\tLoss: 1.417555\tClass Loss: 0.042087\tDomain Loss: 1.375468\n", 496 | "\n", 497 | "Source Accuracy: 18065.0/23758 (76.0375%)\n", 498 | "Target Accuracy: 16876.0/23758 (71.0329%)\n", 499 | "Domain Accuracy: 27996.0/47516 (58.9191%)\n", 500 | "\n", 501 | "Epoch: 6\n", 502 | "[9900/23758 (42%)]\tLoss: 1.475151\tClass Loss: 0.048026\tDomain Loss: 1.427125\n", 503 | "[19900/23758 (84%)]\tLoss: 1.398542\tClass Loss: 0.026189\tDomain Loss: 1.372354\n", 504 | "\n", 505 | "Source Accuracy: 17915.0/23758 (75.4062%)\n", 506 | "Target Accuracy: 13788.0/23758 (58.0352%)\n", 507 | "Domain Accuracy: 23785.0/47516 (50.0568%)\n", 508 | "\n", 509 | "Epoch: 7\n", 510 | "[9900/23758 (42%)]\tLoss: 1.411167\tClass Loss: 0.009974\tDomain Loss: 1.401193\n", 511 | "[19900/23758 (84%)]\tLoss: 1.398504\tClass Loss: 0.016725\tDomain Loss: 1.381778\n", 512 | "\n", 513 | "Source Accuracy: 20377.0/23758 (85.7690%)\n", 514 | "Target Accuracy: 16865.0/23758 (70.9866%)\n", 515 | "Domain Accuracy: 26668.0/47516 (56.1243%)\n", 516 | "\n", 517 | "Epoch: 8\n", 518 | "[9900/23758 (42%)]\tLoss: 1.388465\tClass Loss: 0.004175\tDomain Loss: 1.384290\n", 519 | "[19900/23758 (84%)]\tLoss: 1.407421\tClass Loss: 0.005365\tDomain Loss: 1.402056\n", 520 | "\n", 521 | "Source Accuracy: 12963.0/23758 (54.5627%)\n", 522 | "Target Accuracy: 13614.0/23758 (57.3028%)\n", 523 | "Domain Accuracy: 25071.0/47516 (52.7633%)\n", 524 | "\n", 525 | "Epoch: 9\n", 526 | "[9900/23758 (42%)]\tLoss: 1.401463\tClass Loss: 0.010243\tDomain Loss: 1.391219\n", 527 | "[19900/23758 (84%)]\tLoss: 1.403482\tClass Loss: 0.030464\tDomain Loss: 1.373018\n", 528 | "\n", 529 | "Source Accuracy: 23295.0/23758 (98.0512%)\n", 530 | "Target Accuracy: 17519.0/23758 (73.7394%)\n", 531 | "Domain Accuracy: 24944.0/47516 (52.4960%)\n", 532 | "\n", 533 | "Epoch: 10\n", 534 | "[9900/23758 (42%)]\tLoss: 1.401866\tClass Loss: 0.021802\tDomain Loss: 1.380064\n", 535 | "[19900/23758 (84%)]\tLoss: 1.424058\tClass Loss: 0.006759\tDomain Loss: 1.417299\n", 536 | "\n", 537 | "Source Accuracy: 22754.0/23758 (95.7741%)\n", 538 | "Target Accuracy: 17867.0/23758 (75.2041%)\n", 539 | "Domain Accuracy: 27500.0/47516 (57.8752%)\n", 540 | "\n", 541 | "Epoch: 11\n", 542 | "[9900/23758 (42%)]\tLoss: 1.412953\tClass Loss: 0.012205\tDomain Loss: 1.400748\n", 543 | "[19900/23758 (84%)]\tLoss: 1.375268\tClass Loss: 0.006671\tDomain Loss: 1.368596\n", 544 | "\n", 545 | "Source Accuracy: 16110.0/23758 (67.8087%)\n", 546 | "Target Accuracy: 14437.0/23758 (60.7669%)\n", 547 | "Domain Accuracy: 24550.0/47516 (51.6668%)\n", 548 | "\n", 549 | "Epoch: 12\n", 550 | "[9900/23758 (42%)]\tLoss: 1.438763\tClass Loss: 0.037560\tDomain Loss: 1.401203\n", 551 | "[19900/23758 (84%)]\tLoss: 1.407230\tClass Loss: 0.006404\tDomain Loss: 1.400826\n", 552 | "\n", 553 | "Source Accuracy: 9335.0/23758 (39.2920%)\n", 554 | "Target Accuracy: 9608.0/23758 (40.4411%)\n", 555 | "Domain Accuracy: 24567.0/47516 (51.7026%)\n", 556 | "\n", 557 | "Epoch: 13\n", 558 | "[9900/23758 (42%)]\tLoss: 1.376336\tClass Loss: 0.003661\tDomain Loss: 1.372675\n", 559 | "[19900/23758 (84%)]\tLoss: 1.390265\tClass Loss: 0.007209\tDomain Loss: 1.383056\n", 560 | "\n", 561 | "Source Accuracy: 22854.0/23758 (96.1950%)\n", 562 | "Target Accuracy: 15857.0/23758 (66.7438%)\n", 563 | "Domain Accuracy: 25662.0/47516 (54.0071%)\n", 564 | "\n", 565 | "Epoch: 14\n", 566 | "[9900/23758 (42%)]\tLoss: 1.391270\tClass Loss: 0.001315\tDomain Loss: 1.389955\n", 567 | "[19900/23758 (84%)]\tLoss: 1.374239\tClass Loss: 0.002301\tDomain Loss: 1.371938\n", 568 | "\n", 569 | "Source Accuracy: 22773.0/23758 (95.8540%)\n", 570 | "Target Accuracy: 18244.0/23758 (76.7910%)\n", 571 | "Domain Accuracy: 27106.0/47516 (57.0460%)\n", 572 | "\n", 573 | "Epoch: 15\n", 574 | "[9900/23758 (42%)]\tLoss: 1.388542\tClass Loss: 0.001410\tDomain Loss: 1.387131\n", 575 | "[19900/23758 (84%)]\tLoss: 1.385997\tClass Loss: 0.000551\tDomain Loss: 1.385446\n", 576 | "\n", 577 | "Source Accuracy: 23134.0/23758 (97.3735%)\n", 578 | "Target Accuracy: 17127.0/23758 (72.0894%)\n", 579 | "Domain Accuracy: 26987.0/47516 (56.7956%)\n", 580 | "\n", 581 | "Epoch: 16\n", 582 | "[9900/23758 (42%)]\tLoss: 1.381949\tClass Loss: 0.001925\tDomain Loss: 1.380024\n", 583 | "[19900/23758 (84%)]\tLoss: 1.374534\tClass Loss: 0.001234\tDomain Loss: 1.373300\n", 584 | "\n", 585 | "Source Accuracy: 23559.0/23758 (99.1624%)\n", 586 | "Target Accuracy: 17963.0/23758 (75.6082%)\n", 587 | "Domain Accuracy: 26129.0/47516 (54.9899%)\n", 588 | "\n", 589 | "Epoch: 17\n", 590 | "[9900/23758 (42%)]\tLoss: 1.395145\tClass Loss: 0.001361\tDomain Loss: 1.393784\n", 591 | "[19900/23758 (84%)]\tLoss: 1.394285\tClass Loss: 0.007991\tDomain Loss: 1.386294\n", 592 | "\n", 593 | "Source Accuracy: 21263.0/23758 (89.4983%)\n", 594 | "Target Accuracy: 15727.0/23758 (66.1966%)\n", 595 | "Domain Accuracy: 22819.0/47516 (48.0238%)\n", 596 | "\n", 597 | "Epoch: 18\n", 598 | "[9900/23758 (42%)]\tLoss: 1.384957\tClass Loss: 0.006425\tDomain Loss: 1.378532\n", 599 | "[19900/23758 (84%)]\tLoss: 1.374072\tClass Loss: 0.005166\tDomain Loss: 1.368906\n", 600 | "\n", 601 | "Source Accuracy: 23132.0/23758 (97.3651%)\n", 602 | "Target Accuracy: 17796.0/23758 (74.9053%)\n", 603 | "Domain Accuracy: 27080.0/47516 (56.9913%)\n", 604 | "\n", 605 | "Epoch: 19\n", 606 | "[9900/23758 (42%)]\tLoss: 1.399660\tClass Loss: 0.009474\tDomain Loss: 1.390186\n", 607 | "[19900/23758 (84%)]\tLoss: 1.385671\tClass Loss: 0.001255\tDomain Loss: 1.384416\n", 608 | "\n", 609 | "Source Accuracy: 23505.0/23758 (98.9351%)\n", 610 | "Target Accuracy: 17417.0/23758 (73.3100%)\n", 611 | "Domain Accuracy: 21798.0/47516 (45.8751%)\n", 612 | "\n", 613 | "Epoch: 20\n", 614 | "[9900/23758 (42%)]\tLoss: 1.377080\tClass Loss: 0.001352\tDomain Loss: 1.375727\n", 615 | "[19900/23758 (84%)]\tLoss: 1.385924\tClass Loss: 0.001943\tDomain Loss: 1.383981\n", 616 | "\n", 617 | "Source Accuracy: 23234.0/23758 (97.7944%)\n", 618 | "Target Accuracy: 17842.0/23758 (75.0989%)\n", 619 | "Domain Accuracy: 24946.0/47516 (52.5002%)\n", 620 | "\n", 621 | "Epoch: 21\n", 622 | "[9900/23758 (42%)]\tLoss: 1.393147\tClass Loss: 0.004352\tDomain Loss: 1.388795\n", 623 | "[19900/23758 (84%)]\tLoss: 1.390955\tClass Loss: 0.001478\tDomain Loss: 1.389477\n", 624 | "\n", 625 | "Source Accuracy: 23225.0/23758 (97.7565%)\n", 626 | "Target Accuracy: 17484.0/23758 (73.5921%)\n", 627 | "Domain Accuracy: 24765.0/47516 (52.1193%)\n", 628 | "\n", 629 | "Epoch: 22\n", 630 | "[9900/23758 (42%)]\tLoss: 1.376348\tClass Loss: 0.000912\tDomain Loss: 1.375436\n", 631 | "[19900/23758 (84%)]\tLoss: 1.415650\tClass Loss: 0.005795\tDomain Loss: 1.409856\n", 632 | "\n", 633 | "Source Accuracy: 23591.0/23758 (99.2971%)\n", 634 | "Target Accuracy: 16991.0/23758 (71.5170%)\n", 635 | "Domain Accuracy: 21509.0/47516 (45.2669%)\n", 636 | "\n", 637 | "Epoch: 23\n", 638 | "[9900/23758 (42%)]\tLoss: 1.384212\tClass Loss: 0.001536\tDomain Loss: 1.382676\n", 639 | "[19900/23758 (84%)]\tLoss: 1.391564\tClass Loss: 0.004375\tDomain Loss: 1.387188\n", 640 | "\n", 641 | "Source Accuracy: 23528.0/23758 (99.0319%)\n", 642 | "Target Accuracy: 17153.0/23758 (72.1988%)\n", 643 | "Domain Accuracy: 28963.0/47516 (60.9542%)\n", 644 | "\n", 645 | "Epoch: 24\n", 646 | "[9900/23758 (42%)]\tLoss: 1.404573\tClass Loss: 0.001363\tDomain Loss: 1.403210\n", 647 | "[19900/23758 (84%)]\tLoss: 1.370133\tClass Loss: 0.001086\tDomain Loss: 1.369047\n", 648 | "\n", 649 | "Source Accuracy: 23483.0/23758 (98.8425%)\n", 650 | "Target Accuracy: 17730.0/23758 (74.6275%)\n", 651 | "Domain Accuracy: 30929.0/47516 (65.0918%)\n", 652 | "\n", 653 | "Epoch: 25\n", 654 | "[9900/23758 (42%)]\tLoss: 1.416145\tClass Loss: 0.000760\tDomain Loss: 1.415386\n", 655 | "[19900/23758 (84%)]\tLoss: 1.350325\tClass Loss: 0.001310\tDomain Loss: 1.349015\n", 656 | "\n", 657 | "Source Accuracy: 21940.0/23758 (92.3478%)\n", 658 | "Target Accuracy: 16160.0/23758 (68.0192%)\n", 659 | "Domain Accuracy: 21827.0/47516 (45.9361%)\n", 660 | "\n", 661 | "Epoch: 26\n", 662 | "[9900/23758 (42%)]\tLoss: 1.389704\tClass Loss: 0.001939\tDomain Loss: 1.387766\n", 663 | "[19900/23758 (84%)]\tLoss: 1.374779\tClass Loss: 0.001999\tDomain Loss: 1.372780\n", 664 | "\n", 665 | "Source Accuracy: 23417.0/23758 (98.5647%)\n", 666 | "Target Accuracy: 16904.0/23758 (71.1508%)\n", 667 | "Domain Accuracy: 24467.0/47516 (51.4921%)\n", 668 | "\n", 669 | "Epoch: 27\n", 670 | "[9900/23758 (42%)]\tLoss: 1.394412\tClass Loss: 0.001201\tDomain Loss: 1.393211\n", 671 | "[19900/23758 (84%)]\tLoss: 1.355685\tClass Loss: 0.001147\tDomain Loss: 1.354538\n", 672 | "\n", 673 | "Source Accuracy: 23488.0/23758 (98.8635%)\n", 674 | "Target Accuracy: 17162.0/23758 (72.2367%)\n", 675 | "Domain Accuracy: 24211.0/47516 (50.9534%)\n", 676 | "\n", 677 | "Epoch: 28\n" 678 | ] 679 | }, 680 | { 681 | "name": "stdout", 682 | "output_type": "stream", 683 | "text": [ 684 | "[9900/23758 (42%)]\tLoss: 1.423513\tClass Loss: 0.011672\tDomain Loss: 1.411841\n", 685 | "[19900/23758 (84%)]\tLoss: 1.341217\tClass Loss: 0.001062\tDomain Loss: 1.340155\n", 686 | "\n", 687 | "Source Accuracy: 22660.0/23758 (95.3784%)\n", 688 | "Target Accuracy: 17877.0/23758 (75.2462%)\n", 689 | "Domain Accuracy: 28519.0/47516 (60.0198%)\n", 690 | "\n", 691 | "Epoch: 29\n", 692 | "[9900/23758 (42%)]\tLoss: 1.412507\tClass Loss: 0.001725\tDomain Loss: 1.410782\n", 693 | "[19900/23758 (84%)]\tLoss: 1.366369\tClass Loss: 0.000886\tDomain Loss: 1.365483\n", 694 | "\n", 695 | "Source Accuracy: 23681.0/23758 (99.6759%)\n", 696 | "Target Accuracy: 17921.0/23758 (75.4314%)\n", 697 | "Domain Accuracy: 28139.0/47516 (59.2201%)\n", 698 | "\n", 699 | "Epoch: 30\n", 700 | "[9900/23758 (42%)]\tLoss: 1.400881\tClass Loss: 0.002656\tDomain Loss: 1.398225\n", 701 | "[19900/23758 (84%)]\tLoss: 1.383933\tClass Loss: 0.000655\tDomain Loss: 1.383278\n", 702 | "\n", 703 | "Source Accuracy: 22285.0/23758 (93.8000%)\n", 704 | "Target Accuracy: 18028.0/23758 (75.8818%)\n", 705 | "Domain Accuracy: 27237.0/47516 (57.3217%)\n", 706 | "\n", 707 | "Epoch: 31\n", 708 | "[9900/23758 (42%)]\tLoss: 1.405074\tClass Loss: 0.004948\tDomain Loss: 1.400125\n", 709 | "[19900/23758 (84%)]\tLoss: 1.379068\tClass Loss: 0.000869\tDomain Loss: 1.378199\n", 710 | "\n", 711 | "Source Accuracy: 23670.0/23758 (99.6296%)\n", 712 | "Target Accuracy: 16784.0/23758 (70.6457%)\n", 713 | "Domain Accuracy: 27311.0/47516 (57.4775%)\n", 714 | "\n", 715 | "Epoch: 32\n", 716 | "[9900/23758 (42%)]\tLoss: 1.398373\tClass Loss: 0.001787\tDomain Loss: 1.396586\n", 717 | "[19900/23758 (84%)]\tLoss: 1.374377\tClass Loss: 0.001199\tDomain Loss: 1.373178\n", 718 | "\n", 719 | "Source Accuracy: 23646.0/23758 (99.5286%)\n", 720 | "Target Accuracy: 17712.0/23758 (74.5517%)\n", 721 | "Domain Accuracy: 25043.0/47516 (52.7044%)\n", 722 | "\n", 723 | "Epoch: 33\n", 724 | "[9900/23758 (42%)]\tLoss: 1.401699\tClass Loss: 0.014343\tDomain Loss: 1.387356\n", 725 | "[19900/23758 (84%)]\tLoss: 1.391921\tClass Loss: 0.001369\tDomain Loss: 1.390552\n", 726 | "\n", 727 | "Source Accuracy: 23583.0/23758 (99.2634%)\n", 728 | "Target Accuracy: 17449.0/23758 (73.4447%)\n", 729 | "Domain Accuracy: 19275.0/47516 (40.5653%)\n", 730 | "\n", 731 | "Epoch: 34\n", 732 | "[9900/23758 (42%)]\tLoss: 1.387967\tClass Loss: 0.000460\tDomain Loss: 1.387507\n", 733 | "[19900/23758 (84%)]\tLoss: 1.379375\tClass Loss: 0.000867\tDomain Loss: 1.378509\n", 734 | "\n", 735 | "Source Accuracy: 23684.0/23758 (99.6885%)\n", 736 | "Target Accuracy: 17109.0/23758 (72.0136%)\n", 737 | "Domain Accuracy: 26634.0/47516 (56.0527%)\n", 738 | "\n", 739 | "Epoch: 35\n", 740 | "[9900/23758 (42%)]\tLoss: 1.386999\tClass Loss: 0.000604\tDomain Loss: 1.386395\n", 741 | "[19900/23758 (84%)]\tLoss: 1.392844\tClass Loss: 0.000761\tDomain Loss: 1.392083\n", 742 | "\n", 743 | "Source Accuracy: 23688.0/23758 (99.7054%)\n", 744 | "Target Accuracy: 17309.0/23758 (72.8555%)\n", 745 | "Domain Accuracy: 23075.0/47516 (48.5626%)\n", 746 | "\n", 747 | "Epoch: 36\n", 748 | "[9900/23758 (42%)]\tLoss: 1.389037\tClass Loss: 0.000959\tDomain Loss: 1.388077\n", 749 | "[19900/23758 (84%)]\tLoss: 1.386272\tClass Loss: 0.000531\tDomain Loss: 1.385741\n", 750 | "\n", 751 | "Source Accuracy: 23630.0/23758 (99.4612%)\n", 752 | "Target Accuracy: 18431.0/23758 (77.5781%)\n", 753 | "Domain Accuracy: 23826.0/47516 (50.1431%)\n", 754 | "\n", 755 | "Epoch: 37\n", 756 | "[9900/23758 (42%)]\tLoss: 1.392360\tClass Loss: 0.000903\tDomain Loss: 1.391458\n", 757 | "[19900/23758 (84%)]\tLoss: 1.388736\tClass Loss: 0.005508\tDomain Loss: 1.383228\n", 758 | "\n", 759 | "Source Accuracy: 23673.0/23758 (99.6422%)\n", 760 | "Target Accuracy: 17266.0/23758 (72.6745%)\n", 761 | "Domain Accuracy: 26610.0/47516 (56.0022%)\n", 762 | "\n", 763 | "Epoch: 38\n", 764 | "[9900/23758 (42%)]\tLoss: 1.397081\tClass Loss: 0.016118\tDomain Loss: 1.380963\n", 765 | "[19900/23758 (84%)]\tLoss: 1.387939\tClass Loss: 0.000906\tDomain Loss: 1.387033\n", 766 | "\n", 767 | "Source Accuracy: 23594.0/23758 (99.3097%)\n", 768 | "Target Accuracy: 17257.0/23758 (72.6366%)\n", 769 | "Domain Accuracy: 18903.0/47516 (39.7824%)\n", 770 | "\n", 771 | "Epoch: 39\n", 772 | "[9900/23758 (42%)]\tLoss: 1.379895\tClass Loss: 0.000647\tDomain Loss: 1.379248\n", 773 | "[19900/23758 (84%)]\tLoss: 1.396873\tClass Loss: 0.011238\tDomain Loss: 1.385635\n", 774 | "\n", 775 | "Source Accuracy: 23570.0/23758 (99.2087%)\n", 776 | "Target Accuracy: 17669.0/23758 (74.3707%)\n", 777 | "Domain Accuracy: 23236.0/47516 (48.9014%)\n", 778 | "\n", 779 | "Epoch: 40\n", 780 | "[9900/23758 (42%)]\tLoss: 1.385690\tClass Loss: 0.000763\tDomain Loss: 1.384927\n", 781 | "[19900/23758 (84%)]\tLoss: 1.392005\tClass Loss: 0.001265\tDomain Loss: 1.390740\n", 782 | "\n", 783 | "Source Accuracy: 23636.0/23758 (99.4865%)\n", 784 | "Target Accuracy: 17944.0/23758 (75.5282%)\n", 785 | "Domain Accuracy: 22567.0/47516 (47.4935%)\n", 786 | "\n", 787 | "Epoch: 41\n", 788 | "[9900/23758 (42%)]\tLoss: 1.386314\tClass Loss: 0.000209\tDomain Loss: 1.386104\n", 789 | "[19900/23758 (84%)]\tLoss: 1.384973\tClass Loss: 0.001330\tDomain Loss: 1.383643\n", 790 | "\n", 791 | "Source Accuracy: 23690.0/23758 (99.7138%)\n", 792 | "Target Accuracy: 18011.0/23758 (75.8103%)\n", 793 | "Domain Accuracy: 24049.0/47516 (50.6124%)\n", 794 | "\n", 795 | "Epoch: 42\n", 796 | "[9900/23758 (42%)]\tLoss: 1.392823\tClass Loss: 0.000770\tDomain Loss: 1.392054\n", 797 | "[19900/23758 (84%)]\tLoss: 1.382176\tClass Loss: 0.000783\tDomain Loss: 1.381393\n", 798 | "\n", 799 | "Source Accuracy: 23552.0/23758 (99.1329%)\n", 800 | "Target Accuracy: 17667.0/23758 (74.3623%)\n", 801 | "Domain Accuracy: 24942.0/47516 (52.4918%)\n", 802 | "\n", 803 | "Epoch: 43\n", 804 | "[9900/23758 (42%)]\tLoss: 1.404949\tClass Loss: 0.002852\tDomain Loss: 1.402097\n", 805 | "[19900/23758 (84%)]\tLoss: 1.385293\tClass Loss: 0.001881\tDomain Loss: 1.383412\n", 806 | "\n", 807 | "Source Accuracy: 23639.0/23758 (99.4991%)\n", 808 | "Target Accuracy: 17278.0/23758 (72.7250%)\n", 809 | "Domain Accuracy: 28033.0/47516 (58.9970%)\n", 810 | "\n", 811 | "Epoch: 44\n", 812 | "[9900/23758 (42%)]\tLoss: 1.391430\tClass Loss: 0.001058\tDomain Loss: 1.390372\n", 813 | "[19900/23758 (84%)]\tLoss: 1.390473\tClass Loss: 0.003863\tDomain Loss: 1.386610\n", 814 | "\n", 815 | "Source Accuracy: 23656.0/23758 (99.5707%)\n", 816 | "Target Accuracy: 17903.0/23758 (75.3557%)\n", 817 | "Domain Accuracy: 23413.0/47516 (49.2739%)\n", 818 | "\n", 819 | "Epoch: 45\n", 820 | "[9900/23758 (42%)]\tLoss: 1.385846\tClass Loss: 0.000672\tDomain Loss: 1.385174\n", 821 | "[19900/23758 (84%)]\tLoss: 1.392202\tClass Loss: 0.003446\tDomain Loss: 1.388756\n", 822 | "\n", 823 | "Source Accuracy: 23418.0/23758 (98.5689%)\n", 824 | "Target Accuracy: 17416.0/23758 (73.3058%)\n", 825 | "Domain Accuracy: 22124.0/47516 (46.5612%)\n", 826 | "\n", 827 | "Epoch: 46\n", 828 | "[9900/23758 (42%)]\tLoss: 1.387516\tClass Loss: 0.000902\tDomain Loss: 1.386614\n", 829 | "[19900/23758 (84%)]\tLoss: 1.386322\tClass Loss: 0.000267\tDomain Loss: 1.386055\n", 830 | "\n", 831 | "Source Accuracy: 23692.0/23758 (99.7222%)\n", 832 | "Target Accuracy: 18133.0/23758 (76.3238%)\n", 833 | "Domain Accuracy: 26866.0/47516 (56.5410%)\n", 834 | "\n", 835 | "Epoch: 47\n", 836 | "[9900/23758 (42%)]\tLoss: 1.387192\tClass Loss: 0.001346\tDomain Loss: 1.385846\n", 837 | "[19900/23758 (84%)]\tLoss: 1.386524\tClass Loss: 0.000356\tDomain Loss: 1.386169\n", 838 | "\n", 839 | "Source Accuracy: 23696.0/23758 (99.7390%)\n", 840 | "Target Accuracy: 18438.0/23758 (77.6075%)\n", 841 | "Domain Accuracy: 23383.0/47516 (49.2108%)\n", 842 | "\n", 843 | "Epoch: 48\n", 844 | "[9900/23758 (42%)]\tLoss: 1.390566\tClass Loss: 0.001955\tDomain Loss: 1.388611\n", 845 | "[19900/23758 (84%)]\tLoss: 1.387532\tClass Loss: 0.000610\tDomain Loss: 1.386922\n", 846 | "\n", 847 | "Source Accuracy: 23689.0/23758 (99.7096%)\n", 848 | "Target Accuracy: 18116.0/23758 (76.2522%)\n", 849 | "Domain Accuracy: 22719.0/47516 (47.8134%)\n", 850 | "\n", 851 | "Epoch: 49\n", 852 | "[9900/23758 (42%)]\tLoss: 1.388918\tClass Loss: 0.000496\tDomain Loss: 1.388422\n", 853 | "[19900/23758 (84%)]\tLoss: 1.393028\tClass Loss: 0.006346\tDomain Loss: 1.386682\n", 854 | "\n", 855 | "Source Accuracy: 23692.0/23758 (99.7222%)\n", 856 | "Target Accuracy: 18427.0/23758 (77.5612%)\n", 857 | "Domain Accuracy: 21416.0/47516 (45.0711%)\n", 858 | "\n", 859 | "Epoch: 50\n", 860 | "[9900/23758 (42%)]\tLoss: 1.383915\tClass Loss: 0.000870\tDomain Loss: 1.383045\n", 861 | "[19900/23758 (84%)]\tLoss: 1.389495\tClass Loss: 0.003497\tDomain Loss: 1.385998\n", 862 | "\n", 863 | "Source Accuracy: 23699.0/23758 (99.7517%)\n", 864 | "Target Accuracy: 18496.0/23758 (77.8517%)\n", 865 | "Domain Accuracy: 23885.0/47516 (50.2673%)\n", 866 | "\n", 867 | "Epoch: 51\n", 868 | "[9900/23758 (42%)]\tLoss: 1.385185\tClass Loss: 0.000137\tDomain Loss: 1.385048\n", 869 | "[19900/23758 (84%)]\tLoss: 1.385996\tClass Loss: 0.000896\tDomain Loss: 1.385101\n", 870 | "\n", 871 | "Source Accuracy: 23677.0/23758 (99.6591%)\n", 872 | "Target Accuracy: 18105.0/23758 (76.2059%)\n", 873 | "Domain Accuracy: 25450.0/47516 (53.5609%)\n", 874 | "\n", 875 | "Epoch: 52\n", 876 | "[9900/23758 (42%)]\tLoss: 1.390782\tClass Loss: 0.001514\tDomain Loss: 1.389268\n", 877 | "[19900/23758 (84%)]\tLoss: 1.387977\tClass Loss: 0.002324\tDomain Loss: 1.385653\n", 878 | "\n", 879 | "Source Accuracy: 23677.0/23758 (99.6591%)\n", 880 | "Target Accuracy: 18470.0/23758 (77.7422%)\n", 881 | "Domain Accuracy: 19504.0/47516 (41.0472%)\n", 882 | "\n", 883 | "Epoch: 53\n", 884 | "[9900/23758 (42%)]\tLoss: 1.383165\tClass Loss: 0.000447\tDomain Loss: 1.382718\n", 885 | "[19900/23758 (84%)]\tLoss: 1.386243\tClass Loss: 0.000968\tDomain Loss: 1.385275\n", 886 | "\n", 887 | "Source Accuracy: 23470.0/23758 (98.7878%)\n", 888 | "Target Accuracy: 18412.0/23758 (77.4981%)\n", 889 | "Domain Accuracy: 21540.0/47516 (45.3321%)\n", 890 | "\n", 891 | "Epoch: 54\n", 892 | "[9900/23758 (42%)]\tLoss: 1.388162\tClass Loss: 0.000381\tDomain Loss: 1.387782\n", 893 | "[19900/23758 (84%)]\tLoss: 1.385729\tClass Loss: 0.000937\tDomain Loss: 1.384793\n", 894 | "\n", 895 | "Source Accuracy: 23632.0/23758 (99.4697%)\n", 896 | "Target Accuracy: 17982.0/23758 (75.6882%)\n", 897 | "Domain Accuracy: 26433.0/47516 (55.6297%)\n", 898 | "\n", 899 | "Epoch: 55\n", 900 | "[9900/23758 (42%)]\tLoss: 1.390172\tClass Loss: 0.000103\tDomain Loss: 1.390068\n", 901 | "[19900/23758 (84%)]\tLoss: 1.397405\tClass Loss: 0.000566\tDomain Loss: 1.396839\n", 902 | "\n", 903 | "Source Accuracy: 23667.0/23758 (99.6170%)\n", 904 | "Target Accuracy: 17792.0/23758 (74.8885%)\n", 905 | "Domain Accuracy: 21291.0/47516 (44.8081%)\n", 906 | "\n", 907 | "Epoch: 56\n" 908 | ] 909 | }, 910 | { 911 | "name": "stdout", 912 | "output_type": "stream", 913 | "text": [ 914 | "[9900/23758 (42%)]\tLoss: 1.423285\tClass Loss: 0.013461\tDomain Loss: 1.409823\n", 915 | "[19900/23758 (84%)]\tLoss: 1.385545\tClass Loss: 0.004150\tDomain Loss: 1.381396\n", 916 | "\n", 917 | "Source Accuracy: 23320.0/23758 (98.1564%)\n", 918 | "Target Accuracy: 16557.0/23758 (69.6902%)\n", 919 | "Domain Accuracy: 24780.0/47516 (52.1509%)\n", 920 | "\n", 921 | "Epoch: 57\n", 922 | "[9900/23758 (42%)]\tLoss: 1.397056\tClass Loss: 0.001016\tDomain Loss: 1.396040\n", 923 | "[19900/23758 (84%)]\tLoss: 1.388483\tClass Loss: 0.005153\tDomain Loss: 1.383330\n", 924 | "\n", 925 | "Source Accuracy: 23690.0/23758 (99.7138%)\n", 926 | "Target Accuracy: 17582.0/23758 (74.0045%)\n", 927 | "Domain Accuracy: 23527.0/47516 (49.5138%)\n", 928 | "\n", 929 | "Epoch: 58\n", 930 | "[9900/23758 (42%)]\tLoss: 1.379666\tClass Loss: 0.000552\tDomain Loss: 1.379114\n", 931 | "[19900/23758 (84%)]\tLoss: 1.377842\tClass Loss: 0.000836\tDomain Loss: 1.377006\n", 932 | "\n", 933 | "Source Accuracy: 23648.0/23758 (99.5370%)\n", 934 | "Target Accuracy: 17720.0/23758 (74.5854%)\n", 935 | "Domain Accuracy: 24101.0/47516 (50.7219%)\n", 936 | "\n", 937 | "Epoch: 59\n", 938 | "[9900/23758 (42%)]\tLoss: 1.393548\tClass Loss: 0.004813\tDomain Loss: 1.388736\n", 939 | "[19900/23758 (84%)]\tLoss: 1.389772\tClass Loss: 0.001039\tDomain Loss: 1.388732\n", 940 | "\n", 941 | "Source Accuracy: 23697.0/23758 (99.7432%)\n", 942 | "Target Accuracy: 17617.0/23758 (74.1519%)\n", 943 | "Domain Accuracy: 22155.0/47516 (46.6264%)\n", 944 | "\n", 945 | "Epoch: 60\n", 946 | "[9900/23758 (42%)]\tLoss: 1.391455\tClass Loss: 0.001419\tDomain Loss: 1.390036\n", 947 | "[19900/23758 (84%)]\tLoss: 1.386772\tClass Loss: 0.001127\tDomain Loss: 1.385644\n", 948 | "\n", 949 | "Source Accuracy: 23696.0/23758 (99.7390%)\n", 950 | "Target Accuracy: 17131.0/23758 (72.1062%)\n", 951 | "Domain Accuracy: 22442.0/47516 (47.2304%)\n", 952 | "\n", 953 | "Epoch: 61\n", 954 | "[9900/23758 (42%)]\tLoss: 1.384464\tClass Loss: 0.001719\tDomain Loss: 1.382745\n", 955 | "[19900/23758 (84%)]\tLoss: 1.382799\tClass Loss: 0.000321\tDomain Loss: 1.382478\n", 956 | "\n", 957 | "Source Accuracy: 23599.0/23758 (99.3308%)\n", 958 | "Target Accuracy: 16576.0/23758 (69.7702%)\n", 959 | "Domain Accuracy: 25357.0/47516 (53.3652%)\n", 960 | "\n", 961 | "Epoch: 62\n", 962 | "[9900/23758 (42%)]\tLoss: 1.386681\tClass Loss: 0.001861\tDomain Loss: 1.384820\n", 963 | "[19900/23758 (84%)]\tLoss: 1.389572\tClass Loss: 0.000776\tDomain Loss: 1.388795\n", 964 | "\n", 965 | "Source Accuracy: 23689.0/23758 (99.7096%)\n", 966 | "Target Accuracy: 17520.0/23758 (73.7436%)\n", 967 | "Domain Accuracy: 23726.0/47516 (49.9327%)\n", 968 | "\n", 969 | "Epoch: 63\n", 970 | "[9900/23758 (42%)]\tLoss: 1.383709\tClass Loss: 0.000419\tDomain Loss: 1.383290\n", 971 | "[19900/23758 (84%)]\tLoss: 1.390180\tClass Loss: 0.002337\tDomain Loss: 1.387843\n", 972 | "\n", 973 | "Source Accuracy: 23688.0/23758 (99.7054%)\n", 974 | "Target Accuracy: 17496.0/23758 (73.6426%)\n", 975 | "Domain Accuracy: 26305.0/47516 (55.3603%)\n", 976 | "\n", 977 | "Epoch: 64\n", 978 | "[9900/23758 (42%)]\tLoss: 1.388847\tClass Loss: 0.000854\tDomain Loss: 1.387993\n", 979 | "[19900/23758 (84%)]\tLoss: 1.386437\tClass Loss: 0.001053\tDomain Loss: 1.385384\n", 980 | "\n", 981 | "Source Accuracy: 23694.0/23758 (99.7306%)\n", 982 | "Target Accuracy: 18211.0/23758 (76.6521%)\n", 983 | "Domain Accuracy: 24766.0/47516 (52.1214%)\n", 984 | "\n", 985 | "Epoch: 65\n", 986 | "[9900/23758 (42%)]\tLoss: 1.387684\tClass Loss: 0.000930\tDomain Loss: 1.386754\n", 987 | "[19900/23758 (84%)]\tLoss: 1.392374\tClass Loss: 0.000418\tDomain Loss: 1.391956\n", 988 | "\n", 989 | "Source Accuracy: 23651.0/23758 (99.5496%)\n", 990 | "Target Accuracy: 18397.0/23758 (77.4350%)\n", 991 | "Domain Accuracy: 21347.0/47516 (44.9259%)\n", 992 | "\n", 993 | "Epoch: 66\n", 994 | "[9900/23758 (42%)]\tLoss: 1.391102\tClass Loss: 0.001445\tDomain Loss: 1.389657\n", 995 | "[19900/23758 (84%)]\tLoss: 1.388028\tClass Loss: 0.000639\tDomain Loss: 1.387389\n", 996 | "\n", 997 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 998 | "Target Accuracy: 18348.0/23758 (77.2287%)\n", 999 | "Domain Accuracy: 23836.0/47516 (50.1642%)\n", 1000 | "\n", 1001 | "Epoch: 67\n", 1002 | "[9900/23758 (42%)]\tLoss: 1.389322\tClass Loss: 0.001797\tDomain Loss: 1.387525\n", 1003 | "[19900/23758 (84%)]\tLoss: 1.390065\tClass Loss: 0.000240\tDomain Loss: 1.389825\n", 1004 | "\n", 1005 | "Source Accuracy: 23695.0/23758 (99.7348%)\n", 1006 | "Target Accuracy: 18101.0/23758 (76.1891%)\n", 1007 | "Domain Accuracy: 21122.0/47516 (44.4524%)\n", 1008 | "\n", 1009 | "Epoch: 68\n", 1010 | "[9900/23758 (42%)]\tLoss: 1.388727\tClass Loss: 0.000581\tDomain Loss: 1.388146\n", 1011 | "[19900/23758 (84%)]\tLoss: 1.385180\tClass Loss: 0.000080\tDomain Loss: 1.385100\n", 1012 | "\n", 1013 | "Source Accuracy: 23687.0/23758 (99.7012%)\n", 1014 | "Target Accuracy: 17998.0/23758 (75.7555%)\n", 1015 | "Domain Accuracy: 21351.0/47516 (44.9343%)\n", 1016 | "\n", 1017 | "Epoch: 69\n", 1018 | "[9900/23758 (42%)]\tLoss: 1.387236\tClass Loss: 0.001813\tDomain Loss: 1.385423\n", 1019 | "[19900/23758 (84%)]\tLoss: 1.385226\tClass Loss: 0.000419\tDomain Loss: 1.384807\n", 1020 | "\n", 1021 | "Source Accuracy: 23694.0/23758 (99.7306%)\n", 1022 | "Target Accuracy: 18222.0/23758 (76.6984%)\n", 1023 | "Domain Accuracy: 24598.0/47516 (51.7678%)\n", 1024 | "\n", 1025 | "Epoch: 70\n", 1026 | "[9900/23758 (42%)]\tLoss: 1.386982\tClass Loss: 0.000377\tDomain Loss: 1.386604\n", 1027 | "[19900/23758 (84%)]\tLoss: 1.389788\tClass Loss: 0.000276\tDomain Loss: 1.389512\n", 1028 | "\n", 1029 | "Source Accuracy: 23677.0/23758 (99.6591%)\n", 1030 | "Target Accuracy: 18449.0/23758 (77.6538%)\n", 1031 | "Domain Accuracy: 22756.0/47516 (47.8912%)\n", 1032 | "\n", 1033 | "Epoch: 71\n", 1034 | "[9900/23758 (42%)]\tLoss: 1.385960\tClass Loss: 0.000379\tDomain Loss: 1.385581\n", 1035 | "[19900/23758 (84%)]\tLoss: 1.385307\tClass Loss: 0.000323\tDomain Loss: 1.384984\n", 1036 | "\n", 1037 | "Source Accuracy: 23685.0/23758 (99.6927%)\n", 1038 | "Target Accuracy: 17950.0/23758 (75.5535%)\n", 1039 | "Domain Accuracy: 23700.0/47516 (49.8779%)\n", 1040 | "\n", 1041 | "Epoch: 72\n", 1042 | "[9900/23758 (42%)]\tLoss: 1.388822\tClass Loss: 0.002081\tDomain Loss: 1.386740\n", 1043 | "[19900/23758 (84%)]\tLoss: 1.383132\tClass Loss: 0.000107\tDomain Loss: 1.383025\n", 1044 | "\n", 1045 | "Source Accuracy: 23695.0/23758 (99.7348%)\n", 1046 | "Target Accuracy: 17788.0/23758 (74.8716%)\n", 1047 | "Domain Accuracy: 26696.0/47516 (56.1832%)\n", 1048 | "\n", 1049 | "Epoch: 73\n", 1050 | "[9900/23758 (42%)]\tLoss: 1.381318\tClass Loss: 0.000290\tDomain Loss: 1.381028\n", 1051 | "[19900/23758 (84%)]\tLoss: 1.386827\tClass Loss: 0.001851\tDomain Loss: 1.384976\n", 1052 | "\n", 1053 | "Source Accuracy: 23699.0/23758 (99.7517%)\n", 1054 | "Target Accuracy: 17854.0/23758 (75.1494%)\n", 1055 | "Domain Accuracy: 24211.0/47516 (50.9534%)\n", 1056 | "\n", 1057 | "Epoch: 74\n", 1058 | "[9900/23758 (42%)]\tLoss: 1.394059\tClass Loss: 0.006660\tDomain Loss: 1.387399\n", 1059 | "[19900/23758 (84%)]\tLoss: 1.385212\tClass Loss: 0.000113\tDomain Loss: 1.385099\n", 1060 | "\n", 1061 | "Source Accuracy: 23688.0/23758 (99.7054%)\n", 1062 | "Target Accuracy: 17941.0/23758 (75.5156%)\n", 1063 | "Domain Accuracy: 23144.0/47516 (48.7078%)\n", 1064 | "\n", 1065 | "Epoch: 75\n", 1066 | "[9900/23758 (42%)]\tLoss: 1.388199\tClass Loss: 0.000258\tDomain Loss: 1.387941\n", 1067 | "[19900/23758 (84%)]\tLoss: 1.388942\tClass Loss: 0.000262\tDomain Loss: 1.388680\n", 1068 | "\n", 1069 | "Source Accuracy: 23696.0/23758 (99.7390%)\n", 1070 | "Target Accuracy: 17787.0/23758 (74.8674%)\n", 1071 | "Domain Accuracy: 19796.0/47516 (41.6618%)\n", 1072 | "\n", 1073 | "Epoch: 76\n", 1074 | "[9900/23758 (42%)]\tLoss: 1.381775\tClass Loss: 0.000100\tDomain Loss: 1.381675\n", 1075 | "[19900/23758 (84%)]\tLoss: 1.384601\tClass Loss: 0.000455\tDomain Loss: 1.384146\n", 1076 | "\n", 1077 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1078 | "Target Accuracy: 18020.0/23758 (75.8481%)\n", 1079 | "Domain Accuracy: 27055.0/47516 (56.9387%)\n", 1080 | "\n", 1081 | "Epoch: 77\n", 1082 | "[9900/23758 (42%)]\tLoss: 1.392953\tClass Loss: 0.000437\tDomain Loss: 1.392516\n", 1083 | "[19900/23758 (84%)]\tLoss: 1.387936\tClass Loss: 0.000265\tDomain Loss: 1.387671\n", 1084 | "\n", 1085 | "Source Accuracy: 23700.0/23758 (99.7559%)\n", 1086 | "Target Accuracy: 18073.0/23758 (76.0712%)\n", 1087 | "Domain Accuracy: 25901.0/47516 (54.5101%)\n", 1088 | "\n", 1089 | "Epoch: 78\n", 1090 | "[9900/23758 (42%)]\tLoss: 1.385173\tClass Loss: 0.000475\tDomain Loss: 1.384698\n", 1091 | "[19900/23758 (84%)]\tLoss: 1.382721\tClass Loss: 0.000118\tDomain Loss: 1.382603\n", 1092 | "\n", 1093 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1094 | "Target Accuracy: 17794.0/23758 (74.8969%)\n", 1095 | "Domain Accuracy: 19768.0/47516 (41.6028%)\n", 1096 | "\n", 1097 | "Epoch: 79\n", 1098 | "[9900/23758 (42%)]\tLoss: 1.388034\tClass Loss: 0.000330\tDomain Loss: 1.387705\n", 1099 | "[19900/23758 (84%)]\tLoss: 1.388049\tClass Loss: 0.001137\tDomain Loss: 1.386912\n", 1100 | "\n", 1101 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1102 | "Target Accuracy: 17905.0/23758 (75.3641%)\n", 1103 | "Domain Accuracy: 24916.0/47516 (52.4371%)\n", 1104 | "\n", 1105 | "Epoch: 80\n", 1106 | "[9900/23758 (42%)]\tLoss: 1.385367\tClass Loss: 0.000155\tDomain Loss: 1.385212\n", 1107 | "[19900/23758 (84%)]\tLoss: 1.387840\tClass Loss: 0.001047\tDomain Loss: 1.386793\n", 1108 | "\n", 1109 | "Source Accuracy: 23689.0/23758 (99.7096%)\n", 1110 | "Target Accuracy: 17780.0/23758 (74.8379%)\n", 1111 | "Domain Accuracy: 26826.0/47516 (56.4568%)\n", 1112 | "\n", 1113 | "Epoch: 81\n", 1114 | "[9900/23758 (42%)]\tLoss: 1.386810\tClass Loss: 0.000156\tDomain Loss: 1.386654\n", 1115 | "[19900/23758 (84%)]\tLoss: 1.386763\tClass Loss: 0.000430\tDomain Loss: 1.386333\n", 1116 | "\n", 1117 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1118 | "Target Accuracy: 17850.0/23758 (75.1326%)\n", 1119 | "Domain Accuracy: 24431.0/47516 (51.4164%)\n", 1120 | "\n", 1121 | "Epoch: 82\n", 1122 | "[9900/23758 (42%)]\tLoss: 1.384042\tClass Loss: 0.000143\tDomain Loss: 1.383898\n", 1123 | "[19900/23758 (84%)]\tLoss: 1.390027\tClass Loss: 0.000313\tDomain Loss: 1.389714\n", 1124 | "\n", 1125 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1126 | "Target Accuracy: 17824.0/23758 (75.0232%)\n", 1127 | "Domain Accuracy: 20699.0/47516 (43.5622%)\n", 1128 | "\n", 1129 | "Epoch: 83\n", 1130 | "[9900/23758 (42%)]\tLoss: 1.388460\tClass Loss: 0.000152\tDomain Loss: 1.388307\n", 1131 | "[19900/23758 (84%)]\tLoss: 1.384058\tClass Loss: 0.000475\tDomain Loss: 1.383582\n", 1132 | "\n", 1133 | "Source Accuracy: 23695.0/23758 (99.7348%)\n", 1134 | "Target Accuracy: 18037.0/23758 (75.9197%)\n", 1135 | "Domain Accuracy: 28904.0/47516 (60.8300%)\n", 1136 | "\n", 1137 | "Epoch: 84\n" 1138 | ] 1139 | }, 1140 | { 1141 | "name": "stdout", 1142 | "output_type": "stream", 1143 | "text": [ 1144 | "[9900/23758 (42%)]\tLoss: 1.386890\tClass Loss: 0.000391\tDomain Loss: 1.386499\n", 1145 | "[19900/23758 (84%)]\tLoss: 1.386707\tClass Loss: 0.000136\tDomain Loss: 1.386571\n", 1146 | "\n", 1147 | "Source Accuracy: 23681.0/23758 (99.6759%)\n", 1148 | "Target Accuracy: 18251.0/23758 (76.8204%)\n", 1149 | "Domain Accuracy: 24435.0/47516 (51.4248%)\n", 1150 | "\n", 1151 | "Epoch: 85\n", 1152 | "[9900/23758 (42%)]\tLoss: 1.388055\tClass Loss: 0.000140\tDomain Loss: 1.387915\n", 1153 | "[19900/23758 (84%)]\tLoss: 1.387120\tClass Loss: 0.000065\tDomain Loss: 1.387055\n", 1154 | "\n", 1155 | "Source Accuracy: 23700.0/23758 (99.7559%)\n", 1156 | "Target Accuracy: 17882.0/23758 (75.2673%)\n", 1157 | "Domain Accuracy: 23750.0/47516 (49.9832%)\n", 1158 | "\n", 1159 | "Epoch: 86\n", 1160 | "[9900/23758 (42%)]\tLoss: 1.385075\tClass Loss: 0.000114\tDomain Loss: 1.384961\n", 1161 | "[19900/23758 (84%)]\tLoss: 1.387286\tClass Loss: 0.000068\tDomain Loss: 1.387218\n", 1162 | "\n", 1163 | "Source Accuracy: 23700.0/23758 (99.7559%)\n", 1164 | "Target Accuracy: 17782.0/23758 (74.8464%)\n", 1165 | "Domain Accuracy: 21224.0/47516 (44.6671%)\n", 1166 | "\n", 1167 | "Epoch: 87\n", 1168 | "[9900/23758 (42%)]\tLoss: 1.387931\tClass Loss: 0.000085\tDomain Loss: 1.387846\n", 1169 | "[19900/23758 (84%)]\tLoss: 1.388863\tClass Loss: 0.000272\tDomain Loss: 1.388591\n", 1170 | "\n", 1171 | "Source Accuracy: 23695.0/23758 (99.7348%)\n", 1172 | "Target Accuracy: 18178.0/23758 (76.5132%)\n", 1173 | "Domain Accuracy: 25371.0/47516 (53.3946%)\n", 1174 | "\n", 1175 | "Epoch: 88\n", 1176 | "[9900/23758 (42%)]\tLoss: 1.383432\tClass Loss: 0.000168\tDomain Loss: 1.383264\n", 1177 | "[19900/23758 (84%)]\tLoss: 1.390494\tClass Loss: 0.000377\tDomain Loss: 1.390117\n", 1178 | "\n", 1179 | "Source Accuracy: 23696.0/23758 (99.7390%)\n", 1180 | "Target Accuracy: 18351.0/23758 (77.2414%)\n", 1181 | "Domain Accuracy: 23158.0/47516 (48.7373%)\n", 1182 | "\n", 1183 | "Epoch: 89\n", 1184 | "[9900/23758 (42%)]\tLoss: 1.387791\tClass Loss: 0.000102\tDomain Loss: 1.387690\n", 1185 | "[19900/23758 (84%)]\tLoss: 1.387672\tClass Loss: 0.000246\tDomain Loss: 1.387427\n", 1186 | "\n", 1187 | "Source Accuracy: 23700.0/23758 (99.7559%)\n", 1188 | "Target Accuracy: 18127.0/23758 (76.2985%)\n", 1189 | "Domain Accuracy: 23331.0/47516 (49.1014%)\n", 1190 | "\n", 1191 | "Epoch: 90\n", 1192 | "[9900/23758 (42%)]\tLoss: 1.382092\tClass Loss: 0.001599\tDomain Loss: 1.380493\n", 1193 | "[19900/23758 (84%)]\tLoss: 1.388073\tClass Loss: 0.000203\tDomain Loss: 1.387871\n", 1194 | "\n", 1195 | "Source Accuracy: 23696.0/23758 (99.7390%)\n", 1196 | "Target Accuracy: 18096.0/23758 (76.1680%)\n", 1197 | "Domain Accuracy: 22155.0/47516 (46.6264%)\n", 1198 | "\n", 1199 | "Epoch: 91\n", 1200 | "[9900/23758 (42%)]\tLoss: 1.389501\tClass Loss: 0.000175\tDomain Loss: 1.389326\n", 1201 | "[19900/23758 (84%)]\tLoss: 1.388921\tClass Loss: 0.000348\tDomain Loss: 1.388572\n", 1202 | "\n", 1203 | "Source Accuracy: 23700.0/23758 (99.7559%)\n", 1204 | "Target Accuracy: 18031.0/23758 (75.8944%)\n", 1205 | "Domain Accuracy: 26206.0/47516 (55.1519%)\n", 1206 | "\n", 1207 | "Epoch: 92\n", 1208 | "[9900/23758 (42%)]\tLoss: 1.387808\tClass Loss: 0.000209\tDomain Loss: 1.387600\n", 1209 | "[19900/23758 (84%)]\tLoss: 1.384339\tClass Loss: 0.000166\tDomain Loss: 1.384173\n", 1210 | "\n", 1211 | "Source Accuracy: 23699.0/23758 (99.7517%)\n", 1212 | "Target Accuracy: 18029.0/23758 (75.8860%)\n", 1213 | "Domain Accuracy: 26219.0/47516 (55.1793%)\n", 1214 | "\n", 1215 | "Epoch: 93\n", 1216 | "[9900/23758 (42%)]\tLoss: 1.389283\tClass Loss: 0.000151\tDomain Loss: 1.389132\n", 1217 | "[19900/23758 (84%)]\tLoss: 1.388346\tClass Loss: 0.000238\tDomain Loss: 1.388108\n", 1218 | "\n", 1219 | "Source Accuracy: 23699.0/23758 (99.7517%)\n", 1220 | "Target Accuracy: 18213.0/23758 (76.6605%)\n", 1221 | "Domain Accuracy: 21317.0/47516 (44.8628%)\n", 1222 | "\n", 1223 | "Epoch: 94\n", 1224 | "[9900/23758 (42%)]\tLoss: 1.384834\tClass Loss: 0.000087\tDomain Loss: 1.384746\n", 1225 | "[19900/23758 (84%)]\tLoss: 1.385881\tClass Loss: 0.000071\tDomain Loss: 1.385810\n", 1226 | "\n", 1227 | "Source Accuracy: 23699.0/23758 (99.7517%)\n", 1228 | "Target Accuracy: 18220.0/23758 (76.6900%)\n", 1229 | "Domain Accuracy: 22597.0/47516 (47.5566%)\n", 1230 | "\n", 1231 | "Epoch: 95\n", 1232 | "[9900/23758 (42%)]\tLoss: 1.384257\tClass Loss: 0.000177\tDomain Loss: 1.384079\n", 1233 | "[19900/23758 (84%)]\tLoss: 1.389578\tClass Loss: 0.000259\tDomain Loss: 1.389319\n", 1234 | "\n", 1235 | "Source Accuracy: 23690.0/23758 (99.7138%)\n", 1236 | "Target Accuracy: 18204.0/23758 (76.6226%)\n", 1237 | "Domain Accuracy: 20446.0/47516 (43.0297%)\n", 1238 | "\n", 1239 | "Epoch: 96\n", 1240 | "[9900/23758 (42%)]\tLoss: 1.388124\tClass Loss: 0.000282\tDomain Loss: 1.387842\n", 1241 | "[19900/23758 (84%)]\tLoss: 1.381596\tClass Loss: 0.000195\tDomain Loss: 1.381401\n", 1242 | "\n", 1243 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1244 | "Target Accuracy: 18120.0/23758 (76.2690%)\n", 1245 | "Domain Accuracy: 27461.0/47516 (57.7932%)\n", 1246 | "\n", 1247 | "Epoch: 97\n", 1248 | "[9900/23758 (42%)]\tLoss: 1.387494\tClass Loss: 0.000096\tDomain Loss: 1.387398\n", 1249 | "[19900/23758 (84%)]\tLoss: 1.389289\tClass Loss: 0.000118\tDomain Loss: 1.389171\n", 1250 | "\n", 1251 | "Source Accuracy: 23699.0/23758 (99.7517%)\n", 1252 | "Target Accuracy: 18043.0/23758 (75.9449%)\n", 1253 | "Domain Accuracy: 23975.0/47516 (50.4567%)\n", 1254 | "\n", 1255 | "Epoch: 98\n", 1256 | "[9900/23758 (42%)]\tLoss: 1.388064\tClass Loss: 0.000264\tDomain Loss: 1.387800\n", 1257 | "[19900/23758 (84%)]\tLoss: 1.388845\tClass Loss: 0.003562\tDomain Loss: 1.385282\n", 1258 | "\n", 1259 | "Source Accuracy: 23697.0/23758 (99.7432%)\n", 1260 | "Target Accuracy: 18204.0/23758 (76.6226%)\n", 1261 | "Domain Accuracy: 26050.0/47516 (54.8236%)\n", 1262 | "\n", 1263 | "Epoch: 99\n", 1264 | "[9900/23758 (42%)]\tLoss: 1.382862\tClass Loss: 0.000112\tDomain Loss: 1.382750\n", 1265 | "[19900/23758 (84%)]\tLoss: 1.386950\tClass Loss: 0.000206\tDomain Loss: 1.386744\n", 1266 | "\n", 1267 | "Source Accuracy: 23698.0/23758 (99.7475%)\n", 1268 | "Target Accuracy: 18063.0/23758 (76.0291%)\n", 1269 | "Domain Accuracy: 20906.0/47516 (43.9978%)\n", 1270 | "\n", 1271 | "total run time: (min) 18.27643547852834\n" 1272 | ] 1273 | } 1274 | ], 1275 | "source": [ 1276 | "def main():\n", 1277 | "\n", 1278 | " # prepare the source data and target data\n", 1279 | " \n", 1280 | " src_train_dataloader = get_train_loader(data1,label1,batch_size=batch_size,shuffle=False)\n", 1281 | " src_test_dataloader = get_test_loader(data1,label1,batch_size=batch_size,shuffle=False)\n", 1282 | " tgt_train_dataloader = get_train_loader(data2,label2,batch_size=batch_size,shuffle=False)\n", 1283 | " tgt_test_dataloader = get_test_loader(data2,label2,batch_size=batch_size,shuffle=False)\n", 1284 | "\n", 1285 | " # init models\n", 1286 | " feature_extractor = Extractor()\n", 1287 | " class_classifier = Class_classifier()\n", 1288 | " domain_classifier = Domain_classifier()\n", 1289 | "\n", 1290 | " feature_extractor.cuda()\n", 1291 | " class_classifier.cuda()\n", 1292 | " domain_classifier.cuda()\n", 1293 | "\n", 1294 | " # init criterions\n", 1295 | " class_criterion = nn.NLLLoss()\n", 1296 | " domain_criterion = nn.NLLLoss()\n", 1297 | "\n", 1298 | " # init optimizer\n", 1299 | " optimizer = optim.SGD([\n", 1300 | " {'params': feature_extractor.parameters()},\n", 1301 | " {'params': class_classifier.parameters()},\n", 1302 | " {'params': domain_classifier.parameters()}\n", 1303 | " ], lr= 0.01, momentum= 0.9)\n", 1304 | "\n", 1305 | " for epoch in range(100):\n", 1306 | " \n", 1307 | " print('Epoch: {}'.format(epoch))\n", 1308 | " train('dann', feature_extractor, class_classifier, domain_classifier, class_criterion, domain_criterion,\n", 1309 | " src_train_dataloader, tgt_train_dataloader, optimizer, epoch)\n", 1310 | " test(feature_extractor, class_classifier, domain_classifier, src_test_dataloader, tgt_test_dataloader)\n", 1311 | "\n", 1312 | " \n", 1313 | "total_loss, d_loss, c_loss = [],[],[]\n", 1314 | "acc_list1, acc_list2, acc_list3 = [],[],[]\n", 1315 | "if __name__ == '__main__':\n", 1316 | " gamma = 10\n", 1317 | " theta = 1\n", 1318 | " batch_size = 100\n", 1319 | " time_start=time.time()\n", 1320 | " main()\n", 1321 | " time_end=time.time()\n", 1322 | " print('total run time: (min)',(time_end-time_start)/60.)" 1323 | ] 1324 | }, 1325 | { 1326 | "cell_type": "code", 1327 | "execution_count": 8, 1328 | "metadata": {}, 1329 | "outputs": [ 1330 | { 1331 | "data": { 1332 | "image/png": "\n", 1333 | "text/plain": [ 1334 | "
" 1335 | ] 1336 | }, 1337 | "metadata": { 1338 | "needs_background": "light" 1339 | }, 1340 | "output_type": "display_data" 1341 | } 1342 | ], 1343 | "source": [ 1344 | "plt.plot(range(len(total_loss)),total_loss,c='r',label='total_loss')\n", 1345 | "plt.plot(d_loss,c='b',label='domain_loss')\n", 1346 | "plt.plot(c_loss,c='y',label='clf_loss')\n", 1347 | "plt.title('target domain: session2')\n", 1348 | "plt.legend(loc='best')\n", 1349 | "plt.show()" 1350 | ] 1351 | }, 1352 | { 1353 | "cell_type": "code", 1354 | "execution_count": 9, 1355 | "metadata": {}, 1356 | "outputs": [ 1357 | { 1358 | "name": "stdout", 1359 | "output_type": "stream", 1360 | "text": [ 1361 | "max target accuracy: 77.8516710160788\n" 1362 | ] 1363 | }, 1364 | { 1365 | "data": { 1366 | "image/png": "\n", 1367 | "text/plain": [ 1368 | "
" 1369 | ] 1370 | }, 1371 | "metadata": { 1372 | "needs_background": "light" 1373 | }, 1374 | "output_type": "display_data" 1375 | } 1376 | ], 1377 | "source": [ 1378 | "print('max target accuracy: ',max(acc_list2))\n", 1379 | "plt.plot(range(len(acc_list1)),acc_list1,c='r',label='source_acc')\n", 1380 | "plt.plot(acc_list2,c='b',label='target_acc')\n", 1381 | "plt.plot(acc_list3,c='y',label='domain_acc')\n", 1382 | "plt.axhline(max(acc_list2),c='b',linestyle='--')\n", 1383 | "plt.title('target domain: session2')\n", 1384 | "plt.legend(loc='best')\n", 1385 | "plt.show()" 1386 | ] 1387 | }, 1388 | { 1389 | "cell_type": "code", 1390 | "execution_count": null, 1391 | "metadata": {}, 1392 | "outputs": [], 1393 | "source": [] 1394 | } 1395 | ], 1396 | "metadata": { 1397 | "kernelspec": { 1398 | "display_name": "pytorch", 1399 | "language": "python", 1400 | "name": "pytorch" 1401 | }, 1402 | "language_info": { 1403 | "codemirror_mode": { 1404 | "name": "ipython", 1405 | "version": 3 1406 | }, 1407 | "file_extension": ".py", 1408 | "mimetype": "text/x-python", 1409 | "name": "python", 1410 | "nbconvert_exporter": "python", 1411 | "pygments_lexer": "ipython3", 1412 | "version": "3.6.10" 1413 | } 1414 | }, 1415 | "nbformat": 4, 1416 | "nbformat_minor": 4 1417 | } 1418 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSDA 2 | ### Introduction 3 | 4 | Multiple Source Domain Adaptation with Adversarial Learning 5 | GAN is wildly applied in Domain Adaptation, which tries to align data with different distribution into the same. The idea of this research is quite similiar to this paper, with similiar architecture, [https://arxiv.org/abs/1705.09684 ]. 6 | 7 | The code refers code from this repo https://github.com/daoyuan98/MSDA and https://github.com/pumpikano/tf-dann . Actually, the DANN(single-source-DA) is used as a baseline method. 8 | 9 | The dataset is not provided here since it's too large. The dataset I used is called SEED, which is a EEG dataset for emotion classification developed by Shanghai JiaoTong University. It's an open-source dataset, you can refer here.[http://bcmi.sjtu.edu.cn/home/seed/] 10 | All code is provided by jupyter notebook with excuted records. 11 | ### Results 12 | 13 | 14 | ##### 1. The Classfication Accuracy, the Domain Accuracy and Loss.(epoch of 1000) 15 | ![](images/48,1000,p_acc.png)![](images/48,1000,d_acc.png)![](images/48,1000,loss.png) 16 | 17 | ##### 2. Comparision of DANN. (epoch of 100) 18 | ![](images/dann.png) 19 | ##### 3. Visulization of proposed method. 20 | Demision Decomposition by both S-tne and PCA into two demension space. 21 | ![](images/visulization.png) 22 | -------------------------------------------------------------------------------- /images/48,1000,d_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/48,1000,d_acc.png -------------------------------------------------------------------------------- /images/48,1000,loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/48,1000,loss.png -------------------------------------------------------------------------------- /images/48,1000,p_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/48,1000,p_acc.png -------------------------------------------------------------------------------- /images/DC_DA_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/DC_DA_10.png -------------------------------------------------------------------------------- /images/DC_original_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/DC_original_10.png -------------------------------------------------------------------------------- /images/dann.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/dann.png -------------------------------------------------------------------------------- /images/init: -------------------------------------------------------------------------------- 1 | init 2 | -------------------------------------------------------------------------------- /images/visulization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkmC/MSDA/2b997b40b37cc268c316a8ae3e98e80862c2ec4f/images/visulization.png --------------------------------------------------------------------------------