├── .gitignore ├── DANN ├── DANN_mnist_m.ipynb ├── experiment1 │ ├── generate_data.py │ └── main.py ├── 论文解读.md └── 论文解读_zhihu.md ├── DCN ├── DCN.PNG └── 论文解读.md ├── LICENSE ├── README.md ├── zhihu.py └── 说明.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | #VScode 3 | *.vscode/ 4 | #dataset 5 | */dataset 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | *.events.out.tfevents* 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /DANN/DANN_mnist_m.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DANN 领域迁移实验\n", 8 | "\n", 9 | "从mnist数据集迁移到mnist_m数据集, mnist_m数据集由mnist数据集和BSDS500数据集部分像素块混合而成。\n", 10 | "\n", 11 | "DANN 论文链接:https://arxiv.org/abs/1505.07818\n", 12 | "\n", 13 | "论文解释:https://zhuanlan.zhihu.com/p/122571123\n", 14 | "\n", 15 | "mnist_m数据链接:https://pan.baidu.com/s/1I5QE1NxJcvlFYWC8YHa4Bg \n", 16 | "提取码:ywz4\n", 17 | "在该笔记同级目录下创建dataset文件夹,将下载的压缩包放入并解压。" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import torch\n", 27 | "from torch import nn\n", 28 | "from torch.nn import functional as F\n", 29 | "from torch.optim import Adam\n", 30 | "from torch.utils.tensorboard import SummaryWriter\n", 31 | "from torch.utils.data import RandomSampler,Dataset,DataLoader\n", 32 | "\n", 33 | "import torchvision\n", 34 | "from torchvision import datasets, transforms\n", 35 | "from torchvision.utils import make_grid\n", 36 | "\n", 37 | "from PIL import Image\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "from tqdm.notebook import tqdm\n", 40 | "import numpy as np\n", 41 | "import shutil\n", 42 | "import os" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## 工具函数" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "def adjust_learning_rate(optimizer,epoch):\n", 59 | " lr=0.001*0.1**(epoch//10)\n", 60 | " for param_group in optimizer.param_groups:\n", 61 | " param_group['lr']=lr\n", 62 | " return lr\n", 63 | "\n", 64 | "def accuracy(output,target,topk=(1,)):\n", 65 | " maxk=max(topk)\n", 66 | " batch_size=target.size(0)\n", 67 | " _,pred=output.topk(maxk,1,True,True)\n", 68 | " pred=pred.t()\n", 69 | " correct=pred.eq(target.view(1,-1).expand_as(pred))\n", 70 | " res=[]\n", 71 | " for k in topk:\n", 72 | " correct_k=correct[:k].view(-1).float().sum(0)\n", 73 | " res.append(correct_k.mul_(100/batch_size))\n", 74 | " return res\n", 75 | "\n", 76 | "def matplotlib_imshow(img,one_channel=False):\n", 77 | " if one_channel:\n", 78 | " img=img.mean(dim=0)\n", 79 | " np_img=img.numpy()\n", 80 | " np_img=(np_img-np.min(np_img))/(np.max(np_img)-np.min(np_img))\n", 81 | " if one_channel:\n", 82 | " plt.imshow(np_img,cmap=\"Greys\")\n", 83 | " else:\n", 84 | " plt.imshow(np.transpose(np_img,(1,2,0)))\n", 85 | "\n", 86 | "class mnist_m(Dataset):\n", 87 | " def __init__(self,root,label_file):\n", 88 | " super(mnist_m,self).__init__()\n", 89 | " self.transform=transforms.Compose([\n", 90 | " transforms.Resize(image_size),\n", 91 | " transforms.ToTensor(),\n", 92 | " transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])\n", 93 | " ])\n", 94 | " with open(label_file,\"r\") as f:\n", 95 | " self.imgs=[]\n", 96 | " self.labels=[]\n", 97 | " for line in f.readlines():\n", 98 | " line=line.strip(\"\\n\").split(\" \")\n", 99 | " img_name,label=line[0],int(line[1])\n", 100 | " img=Image.open(root+os.sep+img_name)\n", 101 | " self.imgs.append(self.transform(img.convert(\"RGB\")))\n", 102 | " self.labels.append(label)\n", 103 | " def __len__(self):\n", 104 | " return len(self.labels)\n", 105 | " def __getitem__(self,index):\n", 106 | " return self.imgs[index],self.labels[index]\n", 107 | " def __add__(self,other):\n", 108 | " pass\n", 109 | "class AverageMeter(object):\n", 110 | " def __init__(self):\n", 111 | " self.reset()\n", 112 | " def reset(self):\n", 113 | " self.val=0\n", 114 | " self.avg=0\n", 115 | " self.sum=0\n", 116 | " self.count=0\n", 117 | " def update(self,val,n=1):\n", 118 | " self.val=val\n", 119 | " self.sum+=val*n\n", 120 | " self.count+=n\n", 121 | " self.avg=self.sum/self.count" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Tensorboard" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 3, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "log_dir=\"minist_experiment_1\"\n", 138 | "remove_log_dir=True\n", 139 | "if remove_log_dir and os.path.exists(log_dir):\n", 140 | " shutil.rmtree(log_dir)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "## 读取展示数据" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 4, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "image_size=28\n", 157 | "batch_size=128\n", 158 | "transform = transforms.Compose([transforms.Resize(image_size),\n", 159 | " transforms.ToTensor(),\n", 160 | " transforms.Normalize(mean=[0.5],std=[0.5])])\n", 161 | "train_ds=datasets.MNIST(root=\"mnist\",train=True,transform=transform,download=True)\n", 162 | "test_ds=datasets.MNIST(root=\"mnist\",train=False,transform=transform,download=True)\n", 163 | "train_dl=DataLoader(train_ds,batch_size=batch_size,shuffle=True)\n", 164 | "test_dl=DataLoader(test_ds,batch_size=batch_size,shuffle=False)\n", 165 | "root_path=os.path.join(\"dataset\",\"mnist_m\")\n", 166 | "train_m_ds=mnist_m(os.path.join(root_path,\"mnist_m_train\"),os.path.join(root_path,\"mnist_m_train_labels.txt\"))\n", 167 | "test_m_ds=mnist_m(os.path.join(root_path,\"mnist_m_test\"),os.path.join(root_path,\"mnist_m_test_labels.txt\"))\n", 168 | "train_m_dl=DataLoader(train_m_ds,batch_size=batch_size,shuffle=True)\n", 169 | "test_m_dl=DataLoader(test_m_ds,batch_size=batch_size,shuffle=False)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 5, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAABOCAYAAAA5Hk1WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAV4ElEQVR4nO2deXRU1f3AP9+AEJFFwBpAAwFBLbacCApY0caCSFDBQNVIRVQssmiBsioWEVGxHDRJlc2yKR6pIi60ngpSBUNREVwBWQXCEpay/UQCJNzfH2/uzWSDQGbeZMj3c86cmXlv5r3v3Lnv++79bleMMSiKoijRR0ykBVAURVHODlXgiqIoUYoqcEVRlChFFbiiKEqUogpcURQlSlEFriiKEqWUSYGLSCcRWSciG0VkZKiEUhRFUU6PnG0cuIhUAtYDNwPbgRXAPcaYNaETT1EURSmJsozAWwMbjTGbjTHHgblA19CIpSiKopyOymX47iVAVtD77UCbwh8SkT5AH4AqVaq0qlevXhlOqSiKUvHYtm3bPmPMLwpvL4sCLxXGmGnANIBGjRqZESNGhPuUiqIo5xQDBgzYWtz2sijwHUB80PtLA9tKRf/+/ctw6vAyadIkIDpkhOiQMxpkhOiQMxpkhOiQMxpkPBVlsYGvAJqJSGMRqQKkAu+X4XiKoijKGXDWI3BjTK6IPAJ8CFQCZhhjVodMMkVRFOWUlMkGboz5APggRLIoiqIoZ4BmYiqKokQpqsB9Jisri6FDhzJ06FAqVarkXmdlZZ3+y4pSwXj66acREdq0aUObNm04dOhQpEUqV6gCVxRFiVLCHgceLk6ePAnAsWPHCmyfPXs2R44cAWDNmjWkpaXx+OOPA/DSSy9x/vnnAzBx4kT69evno8SwY8cOrr76ag4ePAiAiJCWlubk3rt3r6/ynC1r166lQ4cOfP311wD84hdF8gsiwiuvvAJA3759OXnyJOvWrQPg8ssvj6RYUYG9jk6cOEFmZiY7dngRwb169aJyZf/VhL1GMjIyiImJYeXKlQBs27aNX//6177LUxL79u0jNzeXL774AoCuXbsSE1PyuPiBBx4AYOrUqVSqVKnM548KBX7o0CHy8vIA+Oabb1i4cKH7g6dNm1bi9xISEhgyZAjTp08HoFatWtxwww0A/O53vwuz1Pls3erF4CclJXHgwAFExMlTtWpVAPbs2cPmzZtp1KgRQJn/3A0bNgBw4MABWrduXaZjFebzzz+nffv2IT1mWVm8eDF//vOfAdwFZNtZKR57DU2cOJH//Oc/gPffBrNjxw5Gjx7tu2zVqlUDoEuXLsyaNcv385+K7OxsXn31VcDTPydPnmTbtm2A1/dO1e/sb6lduzbjxo1z1//ZoiYURVGUKKVcj8C3b98OQGJiIgcOHCj19+wIbPr06Zx//vn07t0bgIsvvpjq1asD/kz7T5w4wdatW+nUqRNAEUdlYmIizzzzDADt2rWjWbNmbkZhZT5bFi9eDMAPP/wQshG4rVy5YcMG1q9fH5Jjhor169eTk5MTaTHYsmUL4I20/v3vfwOwYsUKAF5//XUA4uPjWbRoEQD3338/CQkJvsq4d+9e0tPTSU9PB+Do0aPuv23cuDF169Z1JoupU6fSr18/381kVapUcfKUN0aOHMmcOXPKdIwXX3yRvn37ctlll5XpOOVagdetWxeAuLi4Uyrwjh07us/Pnz/fTUuSkpLCLuOpGDZsGC+99FKJ+5csWeLs9SkpKcyfP5+vvvoqJOfOyMgA8tsmFPz0008APPfccwwcOLDc2L7XrFnDmDFj3PuWLVuycOFCLrjgAl/lWLZsGXfddRcAu3fvdkqxW7duZGVlce+997rP2n179+7l5ZdfDrtsOTk5jBs3DoDJkycXieawduUlS5aQm5tLXFyc+x2HDh3y/b+2N+NQXQ+h5Pbbby+gwBs0aMDQoUMBzzcXbAP/9NNPeeedd8ImS7lW4NbhOGvWLObNmwfAddddR/fu3d1n2rVrx3vvvQd4d+3s7Gw3sogUdqQ9Z84cguutp6Sk0L17d3chx8fH88tf/hKAESNGMG/ePM62PnthrM8glPTt29e9tnJHko0bNwLQuXNn9u/f77aPHz+eWrVq+SKDdaZv2bKFW2+91d3k7rjjDqcwmzVrRl5eHg8++CAAc+fOdd//zW9+44ucy5YtY/z48cXua968OUuXLgWgZs2a/O9///NFplNx4sQJwLs5B/PZZ5/RsGFD3/7f4khJSSnQ32JiYtzMvjAPP/ywu1asnRzgwQcfdP6usqA2cEVRlCilXI/ALddeey0tWrQAvFH28OHD+etf/wp4gf7WXgZQr149nnvuuYjICfmhguB5+UWEP/zhD4AX5rZmzRoX7paamuq87Q0aNCAmJobXXnsN8Oxs8fHxxZzh9OzcudOFgYWS4FHHzTffHPLjnyl///vfgfwZT7du3QC46aabfJPh448/BuCWW24B4O677wZgxowZBSIMMjMzC4y8rd07JSXFFzkLR3JcfvnlLhLrmWeeoWbNmm6fjZqKJDVq1ABg8ODBBcJ9+/XrR926dd1/HQliYmIKtNepWLVqFfv27SuyvWHDhiEJz4wKBQ4UuBhq167tXmdkZLjQwEiGjdk/6fnnn3f2+ri4OBo3buw6YJUqVUhMTCQxMbHE4/z8888ATJgwwdmxz5SFCxe644SKI0eO8N1337n31j8RKX7++WcmTJgAeBdU3bp1efrpp32VISMjg8GDBwNe3xs9ejS23n3h8LBBgwYVeP+Pf/wDyA+XCzeTJk3iuuuuA6BTp07ExcWV6CPYs2ePLzKVhj59+vierxEqMjMzSU9PL/ZaHDZsWEjOoSYURVGUKCVqRuDBDBo0yGU+vfPOO6xe7VWx/dWvfhUReXJzc50Xes6cOc7B8uGHH9K0aVPnkDkTfvzxx7OW5/vvv3evTzXaPxNGjRrFzp07AWjRokUBs5Wf2OSTrl0LLr86ZswYrrzySt/kmDJlCoMHD3Yj7dTUVB577DHOO+8895nc3FzASz7bsGGDc1BnZGRwzTXX+CYreCaJ0i5eYJN6yguFIzvKM0uXLmXIkCEArF69muPHjxfYb60Fofo9UanAq1Sp4uKlFy9e7C7mO+64g+uvv97ZFf0yqWzbtq1AWNFnn30G5Kdw22iaSNCmTZFlSkvFsWPHWLlypWtnO+UHTwHFxsaGRL4z5dNPPwXgv//9r9t25513cv/99/tyfhveZosspaamAp7NO5j9+/c7e7i1kz/88MMA/PGPf/RF1tNhI7sOHz7sbi4i4mLAAW699VaaNGkSEfksp8tu9JuDBw/y5ptvAvDBBwWraS9YsKCIrBdeeCEAr776Ku3atQMocKMvC1GpwAHq1KkDeKNcmyiTlpZGWlqau5i6d+9eYnhPKBkwYIC7AFJSUs669kbwSCNU4YR2xBqMHUmfPHmSJUuWuNH+8ePH+dvf/gZ4YYgXXHCBiyOPjY11M4lIhRCuWLGCXr16ufe333474DmH/bqh2PDM3bt3A15CBng+gnnz5rkb3fLlyzl8+DDgKUUR4aGHHgKI2OzF/n87d+5k9OjRBQYdNhzS9j/rQJ85c2bUjH79YNeuXSQlJbFp06ZSf8f2086dO4dcHv1nFEVRopSoHYFbWrdu7WzggwcP5q233nIJE5s2bWLYsGEuJCkcfPXVVyxdutRNm+68886zPlbwVLEsNtJq1aq543Tp0oUrrriiwP7ly5cD3ii/cuXKbpbSpk0bZ8u/4YYbSExMdJEK8fHxLms0EhmYBw8epG3btgW2NW3aFMDXjEtbZKxevXpkZ2e7mWDhaXPDhg3d1DkrK4u4uDhatmzpm5wWO2PYvn27y0zOysqiWrVqbpSdnJzMG2+8AeRn21r7/b/+9S969OgRksp55wrGmBJnyMXZ623hq4EDB4bMJ2WJegUOUL9+fcCLde3bty8dOnQAvPjWdevWFbDfhpqcnByOHTtGgwYNAM9meCbYC8WGDP7+978HcCVwz4axY8e6GguffPJJkf3NmjUDoEePHjRt2vSU9SasjS87O9tXJ2FhJk6cWOTCsCF7fmJNNZmZmbRt29aVAG7evDk9e/bkvvvuA7ybSs+ePQFPYUYiFC4vL8+V/A32hUyaNIn27du7PnL06FG+/fZbIL8aYXZ2NuCVP23SpIn7fiRKyxZWiosWLYpYHHj9+vVZsWIFb731FuCVqijJJDZ9+nSefPLJsMpzWhOKiMSLyMciskZEVovIwMD2OiKySEQ2BJ5rn+5YiqIoSugoze00FxhijFklIjWAlSKyCLgfWGyMGS8iI4GRgP9DoiBiY2NJSkpy073c3FzeffddV9i/sCkh1OcGzshpmpuby+TJkwEYPnw4CQkJjBo1Cii7o8s6+4KdfmfDP//5T/famqb8xGaU2ogJywMPPBDRYloJCQlulFocGzZs4N133wU805jfs5e8vDzS09MZPny429ajRw8A7rvvPmJjY12CyW233eYip6pWrcqECRPcyH3mzJn89re/dUW6Ro8eXaCPX3rppWH/LYWjUF555RVXvMwW3fKTWrVqOYf0qRgyZEjYR+CnVeDGmF3ArsDr/xORtcAlQFcgKfCx2cAnREiB26iK+fPns3z5cmeWAC8N348VWex0uTRYpfT8888zadIkwFNINsW+vBKJaav1BdhMV5uyfqoqj+WBnJycAgtLJCcn+3JeG02SlpbGiBEjnP9n1qxZru1iY2PZunWrC2dcunSpq0Y4d+5crrzySrdCz6OPPsqMGTOYPXs2gAufA2jSpIkvZYWfeOIJV3bZYq+VJ554IuznP1tWrVoV9nOcURSKiCQAVwOfA3EB5Q6QDRR7KxSRPiLypYh8aR0kiqIoStkptUdCRKoDbwODjDGHg6c0xhgjIsW6ZY0x04BpAI0aNQpNcDM459HLL7/MzJkzgfwFICyVKlUiISEhrEkA1iNtiwX95S9/OeXn33jjDR599FHAW+7sT3/6E5AfT6wUxNblsKNZ67iMVCx1aYnUuo3W5DVixAiqV6/OggULAGjVqpUzJU6ZMoU5c+Zw9OhRwJvNWPOKLdJkM0xbtGhBWlqaK+EcPEv0q8/aQnaRIi8vz9UBuuqqq06bhGMX6yhLRFppKZUCF5Hz8JT368aY+YHNu0WkvjFml4jUB3ypgPPTTz+xYMECxo4dC1DsFM5WWRs/fjytWrUKqzw2ScPePMaOHetW06lRowarV69m6tSpgJdFuGXLFuf9T01NdQq8vGOMcVXq/MrMGzp0qDMJWCJ9MZeW4MJffhKcLp+bm+t8KocOHSpQYgFw/pfevXufNlnHpoDbZz/p3r27Sx6z9cHtQKl///4ulDPU2HVlx4wZ4yLZ9u/fX6wCtzfDL774wmXnWouDLVgWjmSz0ypw8Yav04G1xpgXgna9D/QCxgee3wu5dAGOHDniSobee++9Ja7S0bFjR5566imuvfZaK3u4RCqCjbcdO3asW0S5Tp06RS7k5ORklzn6yCOP+CZfWRGRIso0XAQ7Lq1iqVq1Kk8++aTvq+ycLZs3b47IeW2Z2uzsbHJycli2bJnbZxcSufnmm0lOTnZx6tGQaWmXBVy7di3gj8y2PEPwQs8vvvhisaVk7UxnyZIlBfROt27dXG2UcDiySzMCvx7oCXwnIl8Htj2Op7jfFJHewFbgrpBLpyiKopRIaaJQMoGShrLtQytOPkePHnU1lDMzM/nhhx+K/Vznzp0ZPXo04FXeC1WRmNJy1VVX0aFDBz766CO3zZpT7Ejy4osvBrxi9KezkZdnbJW69u3D9rcD+VPP4EUpEhISIpK4c7a0bt26SH0RP7CLWS9fvpxly5a5JLe7777bTeGjMavSmhptNEykKE3NeZvU17NnT5566qmwJj+Vq0xMu6L3s88+y0cffVTiyiDVqlVzDdm/f/+IOrRq1qzJvHnzXLpsYZv2uHHjXLhWpBdBKAuhKq5VUahfv74rb7x27Vp2797tywrrwQt6R3pR71BiTUOtWrUqUC0xnFi7d0ZGBi+88EKJn2vevLkzq3Ts2NFd7/bmGU7Kv/FLURRFKZZyNQJ/++23AZwT0NKyZUvuuecewKvF0KdPn4jVoy6O6tWrO+9/aYvmRws2fGzKlCm+nfOSSy4BvLoy1jkUjaSlpQFe8tHw4cNd8lEksgejHbtISrBDMdzYLNNnn32WG2+8EYCHHnqIffv2uazkLl26kJSU5EvZ6uIoVwrcemvtsxJ5rL3brwgUyC9HYFPRoxVbvP+uu+7izTff5KKLLgIgPT293MexK/lUrlyZ2267DeCU5RMiQblS4IpyLmHt0TNnzuSKK65wfpsxY8boKFwJCWoDVxRFiVJUgStKmLFJSLm5ueTm5uroWwkZ4md4WKNGjUw0xfIqiqKUBwYMGLDSGFNkmS4dgSuKokQpqsAVRVGiFF9NKCKyFzgC7PPtpNHBRWibFEbbpCjaJkWpKG3SyBhTZAkqXxU4gIh8WZwtpyKjbVIUbZOiaJsUpaK3iZpQFEVRohRV4IqiKFFKJBT4tAics7yjbVIUbZOiaJsUpUK3ie82cEVRFCU0qAlFURQlSlEFriiKEqX4psBFpJOIrBORjSIy0q/zljdEZIuIfCciX4vIl4FtdURkkYhsCDzXjrSc4UZEZojIHhH5Pmhbse0gHhmBvvOtiLSMnOTho4Q2GSMiOwL95WsR6Ry077FAm6wTkVsiI3V4EZF4EflYRNaIyGoRGRjYXqH7isUXBS4ilYCXgWSgOXCPiDT349zllJuMMYlB8asjgcXGmGbA4sD7c51ZQKdC20pqh2SgWeDRB5jsk4x+M4uibQLwYqC/JBpjPgAIXD+pwFWB70wKXGfnGrnAEGNMc6AtMCDw2yt6XwH8G4G3BjYaYzYbY44Dc4GuPp07GugK2NVaZwN3RFAWXzDGLAX2F9pcUjt0BV41Hp8BF4pI+Bcc9JkS2qQkugJzjTHHjDE/AhvxrrNzCmPMLmPMqsDr/wPWApdQwfuKxS8FfgmQFfR+e2BbRcQAC0VkpYj0CWyLM8bsCrzOBipqvdGS2qGi959HAuaAGUHmtQrXJiKSAFwNfI72FUCdmJGgnTGmJd5Ub4CI3Bi803hxnRU+tlPbwTEZuAxIBHYBEyMrTmQQkerA28AgY8zh4H0Vua/4pcB3APFB7y8NbKtwGGN2BJ73AO/gTXt322le4HlP5CSMKCW1Q4XtP8aY3caYPGPMSeAV8s0kFaZNROQ8POX9ujFmfmCz9hX8U+ArgGYi0lhEquA5X9736dzlBhG5QERq2NdAR+B7vLboFfhYL+C9yEgYcUpqh/eB+wIRBm2BQ0HT53OaQvbbFLz+Al6bpIpIVRFpjOe0+8Jv+cKNiAgwHVhrjHkhaJf2FQBjjC8PoDOwHtgEjPLrvOXpATQBvgk8Vtt2AOriedI3AB8BdSItqw9t8QaeSeAEnp2yd0ntAAheFNMm4DvgmkjL72ObvBb4zd/iKaf6QZ8fFWiTdUBypOUPU5u0wzOPfAt8HXh0ruh9xT40lV5RFCVKUSemoihKlKIKXFEUJUpRBa4oihKlqAJXFEWJUlSBK4qiRCmqwBVFUaIUVeCKoihRyv8D3i+SBazcutcAAAAASUVORK5CYII=\n", 180 | "text/plain": [ 181 | "
" 182 | ] 183 | }, 184 | "metadata": { 185 | "needs_background": "light" 186 | }, 187 | "output_type": "display_data" 188 | } 189 | ], 190 | "source": [ 191 | "%matplotlib inline\n", 192 | "\n", 193 | "writer=SummaryWriter(log_dir)\n", 194 | "show_images=[train_ds[i][0] for i in range(8)]\n", 195 | "show_labels=[train_ds[i][1] for i in range(8)]\n", 196 | "show_img_grid=make_grid(show_images)\n", 197 | "matplotlib_imshow(show_img_grid,one_channel=True)\n", 198 | "writer.add_image(\"mnist_images\",show_img_grid)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 6, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "image/png": "\n", 209 | "text/plain": [ 210 | "
" 211 | ] 212 | }, 213 | "metadata": { 214 | "needs_background": "light" 215 | }, 216 | "output_type": "display_data" 217 | } 218 | ], 219 | "source": [ 220 | "show_images=[train_m_ds[i][0] for i in range(8)]\n", 221 | "show_labels=[train_m_ds[i][1] for i in range(8)]\n", 222 | "show_img_grid=make_grid(show_images)\n", 223 | "matplotlib_imshow(show_img_grid,one_channel=False)\n", 224 | "writer.add_image(\"mnist_m_images\",show_img_grid)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "## 独立训练\n", 232 | "\n", 233 | "在源域上独立训练CNN模型" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 7, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "class CNN(nn.Module):\n", 243 | " def __init__(self,num_classes=10):\n", 244 | " super(CNN,self).__init__()\n", 245 | " self.features=nn.Sequential(\n", 246 | " nn.Conv2d(3,32,5),\n", 247 | " nn.ReLU(inplace=True),\n", 248 | " nn.MaxPool2d(2),\n", 249 | " nn.Conv2d(32,48,5),\n", 250 | " nn.ReLU(inplace=True),\n", 251 | " nn.MaxPool2d(2),\n", 252 | " )\n", 253 | " self.avgpool=nn.AdaptiveAvgPool2d((5,5))\n", 254 | " self.classifier=nn.Sequential(\n", 255 | " nn.Linear(48*5*5,100),\n", 256 | " nn.ReLU(inplace=True),\n", 257 | " nn.Linear(100,100),\n", 258 | " nn.ReLU(inplace=True),\n", 259 | " nn.Linear(100,num_classes)\n", 260 | " )\n", 261 | " def forward(self,x):\n", 262 | " x=x.expand(x.data.shape[0],3,image_size,image_size)\n", 263 | " x=self.features(x)\n", 264 | " x=self.avgpool(x)\n", 265 | " x=torch.flatten(x,1)\n", 266 | " x=self.classifier(x)\n", 267 | " return x" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "用一个5层的神经网络在mnist上使用Adam训练,准确率约为99.3%" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 8, 280 | "metadata": { 281 | "scrolled": false 282 | }, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "application/vnd.jupyter.widget-view+json": { 287 | "model_id": "bd6d0dd0ed3c40538cd30011a3925859", 288 | "version_major": 2, 289 | "version_minor": 0 290 | }, 291 | "text/plain": [ 292 | "HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))" 293 | ] 294 | }, 295 | "metadata": {}, 296 | "output_type": "display_data" 297 | }, 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "Epoch:0[200/469],Loss:[0.122,0.404],prec[96.0938,87.5273]\n", 303 | "Epoch:0[400/469],Loss:[0.078,0.253],prec[98.4375,92.1875]\n", 304 | "\n" 305 | ] 306 | }, 307 | { 308 | "data": { 309 | "application/vnd.jupyter.widget-view+json": { 310 | "model_id": "b93026984dea4dca9c7ff5c8202b5877", 311 | "version_major": 2, 312 | "version_minor": 0 313 | }, 314 | "text/plain": [ 315 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))" 316 | ] 317 | }, 318 | "metadata": {}, 319 | "output_type": "display_data" 320 | }, 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "\n", 326 | "Epoch:0,val,Loss:[0.046],prec[98.4600]\n" 327 | ] 328 | }, 329 | { 330 | "data": { 331 | "application/vnd.jupyter.widget-view+json": { 332 | "model_id": "380860fc0a614482bec94650028bf6de", 333 | "version_major": 2, 334 | "version_minor": 0 335 | }, 336 | "text/plain": [ 337 | "HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))" 338 | ] 339 | }, 340 | "metadata": {}, 341 | "output_type": "display_data" 342 | }, 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "Epoch:1[200/469],Loss:[0.072,0.060],prec[98.4375,98.1328]\n", 348 | "Epoch:1[400/469],Loss:[0.030,0.056],prec[99.2188,98.2031]\n", 349 | "\n" 350 | ] 351 | }, 352 | { 353 | "data": { 354 | "application/vnd.jupyter.widget-view+json": { 355 | "model_id": "a03c48666d584db8a95c7a2a98c4a5be", 356 | "version_major": 2, 357 | "version_minor": 0 358 | }, 359 | "text/plain": [ 360 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))" 361 | ] 362 | }, 363 | "metadata": {}, 364 | "output_type": "display_data" 365 | }, 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "\n", 371 | "Epoch:1,val,Loss:[0.045],prec[98.5700]\n" 372 | ] 373 | }, 374 | { 375 | "data": { 376 | "application/vnd.jupyter.widget-view+json": { 377 | "model_id": "bdd19d40c46d4d64a0ee17a34ea60b8b", 378 | "version_major": 2, 379 | "version_minor": 0 380 | }, 381 | "text/plain": [ 382 | "HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))" 383 | ] 384 | }, 385 | "metadata": {}, 386 | "output_type": "display_data" 387 | }, 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "Epoch:2[200/469],Loss:[0.033,0.040],prec[99.2188,98.7461]\n", 393 | "Epoch:2[400/469],Loss:[0.006,0.039],prec[100.0000,98.7422]\n", 394 | "\n" 395 | ] 396 | }, 397 | { 398 | "data": { 399 | "application/vnd.jupyter.widget-view+json": { 400 | "model_id": "a6ec2ac21a824e80a662f240be052fd6", 401 | "version_major": 2, 402 | "version_minor": 0 403 | }, 404 | "text/plain": [ 405 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))" 406 | ] 407 | }, 408 | "metadata": {}, 409 | "output_type": "display_data" 410 | }, 411 | { 412 | "name": "stdout", 413 | "output_type": "stream", 414 | "text": [ 415 | "\n", 416 | "Epoch:2,val,Loss:[0.033],prec[98.9600]\n" 417 | ] 418 | }, 419 | { 420 | "data": { 421 | "application/vnd.jupyter.widget-view+json": { 422 | "model_id": "8fe64844aed149f59f3c8a9bbfefb207", 423 | "version_major": 2, 424 | "version_minor": 0 425 | }, 426 | "text/plain": [ 427 | "HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))" 428 | ] 429 | }, 430 | "metadata": {}, 431 | "output_type": "display_data" 432 | }, 433 | { 434 | "name": "stdout", 435 | "output_type": "stream", 436 | "text": [ 437 | "Epoch:3[200/469],Loss:[0.083,0.030],prec[98.4375,99.0273]\n", 438 | "Epoch:3[400/469],Loss:[0.020,0.031],prec[99.2188,99.0312]\n", 439 | "\n" 440 | ] 441 | }, 442 | { 443 | "data": { 444 | "application/vnd.jupyter.widget-view+json": { 445 | "model_id": "0e9a47d31add47c0b6c62ace5d57579f", 446 | "version_major": 2, 447 | "version_minor": 0 448 | }, 449 | "text/plain": [ 450 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))" 451 | ] 452 | }, 453 | "metadata": {}, 454 | "output_type": "display_data" 455 | }, 456 | { 457 | "name": "stdout", 458 | "output_type": "stream", 459 | "text": [ 460 | "\n", 461 | "Epoch:3,val,Loss:[0.035],prec[98.7400]\n" 462 | ] 463 | }, 464 | { 465 | "data": { 466 | "application/vnd.jupyter.widget-view+json": { 467 | "model_id": "c216c63d9602417da6b23999813d3aa5", 468 | "version_major": 2, 469 | "version_minor": 0 470 | }, 471 | "text/plain": [ 472 | "HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))" 473 | ] 474 | }, 475 | "metadata": {}, 476 | "output_type": "display_data" 477 | }, 478 | { 479 | "name": "stdout", 480 | "output_type": "stream", 481 | "text": [ 482 | "Epoch:4[200/469],Loss:[0.015,0.022],prec[100.0000,99.2773]\n", 483 | "Epoch:4[400/469],Loss:[0.024,0.024],prec[99.2188,99.2168]\n", 484 | "\n" 485 | ] 486 | }, 487 | { 488 | "data": { 489 | "application/vnd.jupyter.widget-view+json": { 490 | "model_id": "61dbb83eac534c888df7a599c42b27c9", 491 | "version_major": 2, 492 | "version_minor": 0 493 | }, 494 | "text/plain": [ 495 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))" 496 | ] 497 | }, 498 | "metadata": {}, 499 | "output_type": "display_data" 500 | }, 501 | { 502 | "name": "stdout", 503 | "output_type": "stream", 504 | "text": [ 505 | "\n", 506 | "Epoch:4,val,Loss:[0.030],prec[99.1400]\n" 507 | ] 508 | } 509 | ], 510 | "source": [ 511 | "cnn_model=CNN()\n", 512 | "optimizer=Adam(cnn_model.parameters(),lr=0.001)\n", 513 | "Loss=nn.CrossEntropyLoss()\n", 514 | "epochs=5\n", 515 | "train_loss=AverageMeter()\n", 516 | "test_loss=AverageMeter()\n", 517 | "test_top1=AverageMeter()\n", 518 | "train_top1=AverageMeter()\n", 519 | "train_cnt=AverageMeter()\n", 520 | "print_freq=200\n", 521 | "cnn_model.cuda()\n", 522 | "for epoch in range(epochs):\n", 523 | " lr=adjust_learning_rate(optimizer,epoch)\n", 524 | " writer.add_scalar(\"lr\",lr,epoch)\n", 525 | " train_loss.reset()\n", 526 | " train_top1.reset()\n", 527 | " train_cnt.reset()\n", 528 | " test_top1.reset()\n", 529 | " test_loss.reset()\n", 530 | " for images,labels in tqdm(train_dl):\n", 531 | " images=images.cuda()\n", 532 | " labels=labels.cuda()\n", 533 | " optimizer.zero_grad()\n", 534 | " predict=cnn_model(images)\n", 535 | " losses=Loss(predict,labels)\n", 536 | " train_loss.update(losses.data,images.size(0))\n", 537 | " top1=accuracy(predict.data,labels,topk=(1,))[0]\n", 538 | " train_top1.update(top1,images.size(0))\n", 539 | " train_cnt.update(images.size(0),1)\n", 540 | " losses.backward()\n", 541 | " optimizer.step()\n", 542 | " if train_cnt.count%print_freq==0:\n", 543 | " print(\"Epoch:{}[{}/{}],Loss:[{:.3f},{:.3f}],prec[{:.4f},{:.4f}]\".format(epoch,train_cnt.count,len(train_dl),train_loss.val,train_loss.avg,\n", 544 | " train_top1.val,train_top1.avg))\n", 545 | " for images,labels in tqdm(test_dl):\n", 546 | " images=images.cuda()\n", 547 | " labels=labels.cuda()\n", 548 | " predict=cnn_model(images)\n", 549 | " losses=Loss(predict,labels)\n", 550 | " test_loss.update(losses.data,images.size(0))\n", 551 | " top1=accuracy(predict.data,labels,topk=(1,))[0]\n", 552 | " test_top1.update(top1,images.size(0))\n", 553 | " print(\"Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}]\".format(epoch,test_loss.avg,test_top1.avg))\n", 554 | " writer.add_scalar(\"train_loss\",train_loss.avg,epoch)\n", 555 | " writer.add_scalar(\"test_loss\",test_loss.avg,epoch) \n", 556 | " writer.add_scalar(\"train_top1\",train_top1.avg,epoch)\n", 557 | " writer.add_scalar(\"test_top1\",test_top1.avg,epoch)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": {}, 563 | "source": [ 564 | "## 直接迁移\n", 565 | "\n", 566 | "直接用mnist数据集训练的网络识别mnist_m数据集,准确率约为58%.可以看作领域适应方法准确率的下界。" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 9, 572 | "metadata": {}, 573 | "outputs": [ 574 | { 575 | "data": { 576 | "application/vnd.jupyter.widget-view+json": { 577 | "model_id": "d4cf549dbf414ee89a14fb28cfd371e0", 578 | "version_major": 2, 579 | "version_minor": 0 580 | }, 581 | "text/plain": [ 582 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 583 | ] 584 | }, 585 | "metadata": {}, 586 | "output_type": "display_data" 587 | }, 588 | { 589 | "name": "stdout", 590 | "output_type": "stream", 591 | "text": [ 592 | "\n", 593 | "Epoch:4,val,Loss:[1.303],prec[57.8491]\n" 594 | ] 595 | } 596 | ], 597 | "source": [ 598 | "test_m_top1=AverageMeter()\n", 599 | "test_m_loss=AverageMeter()\n", 600 | "for images,labels in tqdm(test_m_dl):\n", 601 | " images=images.cuda()\n", 602 | " labels=labels.cuda()\n", 603 | " predict=cnn_model(images)\n", 604 | " losses=Loss(predict,labels)\n", 605 | " test_m_loss.update(losses.data,images.size(0))\n", 606 | " top1=accuracy(predict.data,labels,topk=(1,))[0]\n", 607 | " test_m_top1.update(top1,images.size(0))\n", 608 | "print(\"Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}]\".format(epoch,test_m_loss.avg,test_m_top1.avg))" 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": {}, 614 | "source": [ 615 | "## 直接训练\n", 616 | "\n", 617 | "直接使用mnist_m训练,准确率约为96%,可以看坐领域适应方法准确率的上界。" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 10, 623 | "metadata": {}, 624 | "outputs": [ 625 | { 626 | "data": { 627 | "application/vnd.jupyter.widget-view+json": { 628 | "model_id": "7d7af153dbf542418882c328fc8069d8", 629 | "version_major": 2, 630 | "version_minor": 0 631 | }, 632 | "text/plain": [ 633 | "HBox(children=(FloatProgress(value=0.0, max=461.0), HTML(value='')))" 634 | ] 635 | }, 636 | "metadata": {}, 637 | "output_type": "display_data" 638 | }, 639 | { 640 | "name": "stdout", 641 | "output_type": "stream", 642 | "text": [ 643 | "Epoch:0[100/469],Loss:[0.259,0.659],prec[90.6250,79.5469]\n", 644 | "Epoch:0[200/469],Loss:[0.235,0.488],prec[91.4062,84.7852]\n", 645 | "Epoch:0[300/469],Loss:[0.142,0.405],prec[95.3125,87.3620]\n", 646 | "Epoch:0[400/469],Loss:[0.209,0.359],prec[94.5312,88.8027]\n", 647 | "\n" 648 | ] 649 | }, 650 | { 651 | "data": { 652 | "application/vnd.jupyter.widget-view+json": { 653 | "model_id": "133018a21de24b27abf290e9769c5a28", 654 | "version_major": 2, 655 | "version_minor": 0 656 | }, 657 | "text/plain": [ 658 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 659 | ] 660 | }, 661 | "metadata": {}, 662 | "output_type": "display_data" 663 | }, 664 | { 665 | "name": "stdout", 666 | "output_type": "stream", 667 | "text": [ 668 | "\n", 669 | "Epoch:0,val,Loss:[0.172],prec[94.7339]\n" 670 | ] 671 | }, 672 | { 673 | "data": { 674 | "application/vnd.jupyter.widget-view+json": { 675 | "model_id": "a68e98886b5f49148ffd18e5de966e5c", 676 | "version_major": 2, 677 | "version_minor": 0 678 | }, 679 | "text/plain": [ 680 | "HBox(children=(FloatProgress(value=0.0, max=461.0), HTML(value='')))" 681 | ] 682 | }, 683 | "metadata": {}, 684 | "output_type": "display_data" 685 | }, 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "Epoch:1[100/469],Loss:[0.107,0.163],prec[96.0938,95.0156]\n", 691 | "Epoch:1[200/469],Loss:[0.126,0.157],prec[96.0938,95.1055]\n", 692 | "Epoch:1[300/469],Loss:[0.120,0.153],prec[96.8750,95.1927]\n", 693 | "Epoch:1[400/469],Loss:[0.096,0.149],prec[96.8750,95.2930]\n", 694 | "\n" 695 | ] 696 | }, 697 | { 698 | "data": { 699 | "application/vnd.jupyter.widget-view+json": { 700 | "model_id": "55c2ddea1e8e4190a45d452dd79f4341", 701 | "version_major": 2, 702 | "version_minor": 0 703 | }, 704 | "text/plain": [ 705 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 706 | ] 707 | }, 708 | "metadata": {}, 709 | "output_type": "display_data" 710 | }, 711 | { 712 | "name": "stdout", 713 | "output_type": "stream", 714 | "text": [ 715 | "\n", 716 | "Epoch:1,val,Loss:[0.145],prec[95.4561]\n" 717 | ] 718 | }, 719 | { 720 | "data": { 721 | "application/vnd.jupyter.widget-view+json": { 722 | "model_id": "fbf787029fb44ff98ec5dea5ec565664", 723 | "version_major": 2, 724 | "version_minor": 0 725 | }, 726 | "text/plain": [ 727 | "HBox(children=(FloatProgress(value=0.0, max=461.0), HTML(value='')))" 728 | ] 729 | }, 730 | "metadata": {}, 731 | "output_type": "display_data" 732 | }, 733 | { 734 | "name": "stdout", 735 | "output_type": "stream", 736 | "text": [ 737 | "Epoch:2[100/469],Loss:[0.052,0.108],prec[99.2188,96.4062]\n", 738 | "Epoch:2[200/469],Loss:[0.108,0.112],prec[96.8750,96.3359]\n", 739 | "Epoch:2[300/469],Loss:[0.152,0.111],prec[93.7500,96.4141]\n", 740 | "Epoch:2[400/469],Loss:[0.191,0.112],prec[92.9688,96.3887]\n", 741 | "\n" 742 | ] 743 | }, 744 | { 745 | "data": { 746 | "application/vnd.jupyter.widget-view+json": { 747 | "model_id": "b363c236417742eb9b40b378ff5f0d09", 748 | "version_major": 2, 749 | "version_minor": 0 750 | }, 751 | "text/plain": [ 752 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 753 | ] 754 | }, 755 | "metadata": {}, 756 | "output_type": "display_data" 757 | }, 758 | { 759 | "name": "stdout", 760 | "output_type": "stream", 761 | "text": [ 762 | "\n", 763 | "Epoch:2,val,Loss:[0.125],prec[95.9671]\n" 764 | ] 765 | }, 766 | { 767 | "data": { 768 | "application/vnd.jupyter.widget-view+json": { 769 | "model_id": "7b1405be408847a7b8753faf3067bf24", 770 | "version_major": 2, 771 | "version_minor": 0 772 | }, 773 | "text/plain": [ 774 | "HBox(children=(FloatProgress(value=0.0, max=461.0), HTML(value='')))" 775 | ] 776 | }, 777 | "metadata": {}, 778 | "output_type": "display_data" 779 | }, 780 | { 781 | "name": "stdout", 782 | "output_type": "stream", 783 | "text": [ 784 | "Epoch:3[100/469],Loss:[0.076,0.080],prec[97.6562,97.4062]\n", 785 | "Epoch:3[200/469],Loss:[0.157,0.084],prec[95.3125,97.3438]\n", 786 | "Epoch:3[300/469],Loss:[0.051,0.086],prec[100.0000,97.2448]\n", 787 | "Epoch:3[400/469],Loss:[0.116,0.087],prec[95.3125,97.1992]\n", 788 | "\n" 789 | ] 790 | }, 791 | { 792 | "data": { 793 | "application/vnd.jupyter.widget-view+json": { 794 | "model_id": "bb986906f2db4266a7844bf4f93345a3", 795 | "version_major": 2, 796 | "version_minor": 0 797 | }, 798 | "text/plain": [ 799 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 800 | ] 801 | }, 802 | "metadata": {}, 803 | "output_type": "display_data" 804 | }, 805 | { 806 | "name": "stdout", 807 | "output_type": "stream", 808 | "text": [ 809 | "\n", 810 | "Epoch:3,val,Loss:[0.114],prec[96.5337]\n" 811 | ] 812 | }, 813 | { 814 | "data": { 815 | "application/vnd.jupyter.widget-view+json": { 816 | "model_id": "b0eb87f0bb764515aa7d02847df79f76", 817 | "version_major": 2, 818 | "version_minor": 0 819 | }, 820 | "text/plain": [ 821 | "HBox(children=(FloatProgress(value=0.0, max=461.0), HTML(value='')))" 822 | ] 823 | }, 824 | "metadata": {}, 825 | "output_type": "display_data" 826 | }, 827 | { 828 | "name": "stdout", 829 | "output_type": "stream", 830 | "text": [ 831 | "Epoch:4[100/469],Loss:[0.073,0.069],prec[97.6562,97.6797]\n", 832 | "Epoch:4[200/469],Loss:[0.050,0.071],prec[98.4375,97.6055]\n", 833 | "Epoch:4[300/469],Loss:[0.089,0.073],prec[97.6562,97.5911]\n", 834 | "Epoch:4[400/469],Loss:[0.083,0.073],prec[96.8750,97.5703]\n", 835 | "\n" 836 | ] 837 | }, 838 | { 839 | "data": { 840 | "application/vnd.jupyter.widget-view+json": { 841 | "model_id": "3a4a7c73ba864aaf9ca9d5a0be195c2f", 842 | "version_major": 2, 843 | "version_minor": 0 844 | }, 845 | "text/plain": [ 846 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 847 | ] 848 | }, 849 | "metadata": {}, 850 | "output_type": "display_data" 851 | }, 852 | { 853 | "name": "stdout", 854 | "output_type": "stream", 855 | "text": [ 856 | "\n", 857 | "Epoch:4,val,Loss:[0.117],prec[96.3337]\n" 858 | ] 859 | } 860 | ], 861 | "source": [ 862 | "train_loss=AverageMeter()\n", 863 | "test_loss=AverageMeter()\n", 864 | "test_top1=AverageMeter()\n", 865 | "train_top1=AverageMeter()\n", 866 | "train_cnt=AverageMeter()\n", 867 | "print_freq=100\n", 868 | "cnn_model.cuda()\n", 869 | "epochs=5\n", 870 | "for epoch in range(epochs):\n", 871 | " lr=adjust_learning_rate(optimizer,epoch)\n", 872 | " writer.add_scalar(\"lr\",lr,epoch)\n", 873 | " train_loss.reset()\n", 874 | " train_top1.reset()\n", 875 | " train_cnt.reset()\n", 876 | " test_top1.reset()\n", 877 | " test_loss.reset()\n", 878 | " for images,labels in tqdm(train_m_dl):\n", 879 | " images=images.cuda()\n", 880 | " labels=labels.cuda()\n", 881 | " optimizer.zero_grad()\n", 882 | " predict=cnn_model(images)\n", 883 | " losses=Loss(predict,labels)\n", 884 | " train_loss.update(losses.data,images.size(0))\n", 885 | " top1=accuracy(predict.data,labels,topk=(1,))[0]\n", 886 | " train_top1.update(top1,images.size(0))\n", 887 | " train_cnt.update(images.size(0),1)\n", 888 | " losses.backward()\n", 889 | " optimizer.step()\n", 890 | " if train_cnt.count%print_freq==0:\n", 891 | " print(\"Epoch:{}[{}/{}],Loss:[{:.3f},{:.3f}],prec[{:.4f},{:.4f}]\".format(epoch,train_cnt.count,len(train_dl),train_loss.val,train_loss.avg,\n", 892 | " train_top1.val,train_top1.avg))\n", 893 | " for images,labels in tqdm(test_m_dl):\n", 894 | " images=images.cuda()\n", 895 | " labels=labels.cuda()\n", 896 | " predict=cnn_model(images)\n", 897 | " losses=Loss(predict,labels)\n", 898 | " test_loss.update(losses.data,images.size(0))\n", 899 | " top1=accuracy(predict.data,labels,topk=(1,))[0]\n", 900 | " test_top1.update(top1,images.size(0))\n", 901 | " print(\"Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}]\".format(epoch,test_loss.avg,test_top1.avg))\n", 902 | " writer.add_scalar(\"train_loss\",train_loss.avg,epoch)\n", 903 | " writer.add_scalar(\"test_loss\",test_loss.avg,epoch) \n", 904 | " writer.add_scalar(\"train_top1\",train_top1.avg,epoch)\n", 905 | " writer.add_scalar(\"test_top1\",test_top1.avg,epoch)" 906 | ] 907 | }, 908 | { 909 | "cell_type": "markdown", 910 | "metadata": {}, 911 | "source": [ 912 | "## GRL\n", 913 | "\n", 914 | "梯度反转层,这一层正向表现为恒等变换,反向传播是改变梯度的符号,alpha用来平衡域损失的权重。" 915 | ] 916 | }, 917 | { 918 | "cell_type": "code", 919 | "execution_count": 11, 920 | "metadata": {}, 921 | "outputs": [], 922 | "source": [ 923 | "from torch.autograd import Function\n", 924 | "\n", 925 | "class GRL(Function):\n", 926 | " @staticmethod\n", 927 | " def forward(ctx, x, alpha):\n", 928 | " ctx.alpha = alpha\n", 929 | " return x.view_as(x)\n", 930 | "\n", 931 | " @staticmethod\n", 932 | " def backward(ctx, grad_output):\n", 933 | " output = grad_output.neg() * ctx.alpha\n", 934 | " return output, None" 935 | ] 936 | }, 937 | { 938 | "cell_type": "markdown", 939 | "metadata": {}, 940 | "source": [ 941 | "## DANN" 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": 14, 947 | "metadata": {}, 948 | "outputs": [], 949 | "source": [ 950 | "class DANN(nn.Module):\n", 951 | " def __init__(self,num_classes=10):\n", 952 | " super(DANN,self).__init__()\n", 953 | " self.features=nn.Sequential(\n", 954 | " nn.Conv2d(3,32,5),\n", 955 | " nn.ReLU(inplace=True),\n", 956 | " nn.MaxPool2d(2),\n", 957 | " nn.Conv2d(32,48,5),\n", 958 | " nn.ReLU(inplace=True),\n", 959 | " nn.MaxPool2d(2),\n", 960 | " )\n", 961 | " self.avgpool=nn.AdaptiveAvgPool2d((5,5))\n", 962 | " self.task_classifier=nn.Sequential(\n", 963 | " nn.Linear(48*5*5,100),\n", 964 | " nn.ReLU(inplace=True),\n", 965 | " nn.Linear(100,100),\n", 966 | " nn.ReLU(inplace=True),\n", 967 | " nn.Linear(100,num_classes)\n", 968 | " )\n", 969 | " self.domain_classifier=nn.Sequential(\n", 970 | " nn.Linear(48*5*5,100),\n", 971 | " nn.ReLU(inplace=True),\n", 972 | " nn.Linear(100,2)\n", 973 | " )\n", 974 | " self.GRL=GRL()\n", 975 | " def forward(self,x,alpha):\n", 976 | " x = x.expand(x.data.shape[0], 3, image_size,image_size)\n", 977 | " x=self.features(x)\n", 978 | " x=self.avgpool(x)\n", 979 | " x=torch.flatten(x,1)\n", 980 | " task_predict=self.task_classifier(x)\n", 981 | " x=GRL.apply(x,alpha)\n", 982 | " domain_predict=self.domain_classifier(x)\n", 983 | " return task_predict,domain_predict" 984 | ] 985 | }, 986 | { 987 | "cell_type": "markdown", 988 | "metadata": {}, 989 | "source": [ 990 | "## 领域迁移训练\n", 991 | "\n", 992 | "使用DANN进行领域迁移训练,使用mnist上的有标签数据和mnist_m上的无标签数据,准确率约为83%." 993 | ] 994 | }, 995 | { 996 | "cell_type": "code", 997 | "execution_count": 16, 998 | "metadata": { 999 | "scrolled": false 1000 | }, 1001 | "outputs": [ 1002 | { 1003 | "name": "stdout", 1004 | "output_type": "stream", 1005 | "text": [ 1006 | "Epoch:0[200/469],Loss:[0.977,1.207],domain loss:[0.834,0.554],label loss:[0.045,0.404],prec[98.4375,87.3945],alpha:0.02131873182952404\n", 1007 | "Epoch:0[400/469],Loss:[0.588,1.019],domain loss:[0.354,0.536],label loss:[0.063,0.250],prec[99.2188,92.2266],alpha:0.04261809214949608\n" 1008 | ] 1009 | }, 1010 | { 1011 | "data": { 1012 | "application/vnd.jupyter.widget-view+json": { 1013 | "model_id": "0b3f23f4b5bb4c249dbdcbb3e9595258", 1014 | "version_major": 2, 1015 | "version_minor": 0 1016 | }, 1017 | "text/plain": [ 1018 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1019 | ] 1020 | }, 1021 | "metadata": {}, 1022 | "output_type": "display_data" 1023 | }, 1024 | { 1025 | "name": "stdout", 1026 | "output_type": "stream", 1027 | "text": [ 1028 | "\n", 1029 | "Epoch:0,val,Loss:[1.235],prec[61.9376],domain_acc[77.7247]\n", 1030 | "Epoch:1[200/469],Loss:[0.590,0.574],domain loss:[0.387,0.353],label loss:[0.070,0.059],prec[97.6562,98.1211],alpha:0.07120127230882645\n", 1031 | "Epoch:1[400/469],Loss:[0.680,0.744],domain loss:[0.427,0.443],label loss:[0.134,0.070],prec[96.8750,97.8359],alpha:0.09237977862358093\n" 1032 | ] 1033 | }, 1034 | { 1035 | "data": { 1036 | "application/vnd.jupyter.widget-view+json": { 1037 | "model_id": "6c8d733a371445c6bd653912a945286a", 1038 | "version_major": 2, 1039 | "version_minor": 0 1040 | }, 1041 | "text/plain": [ 1042 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1043 | ] 1044 | }, 1045 | "metadata": {}, 1046 | "output_type": "display_data" 1047 | }, 1048 | { 1049 | "name": "stdout", 1050 | "output_type": "stream", 1051 | "text": [ 1052 | "\n", 1053 | "Epoch:1,val,Loss:[1.270],prec[59.6600],domain_acc[79.5578]\n", 1054 | "Epoch:2[200/469],Loss:[1.079,0.687],domain loss:[0.667,0.402],label loss:[0.050,0.062],prec[98.4375,98.2266],alpha:0.1207301989197731\n", 1055 | "Epoch:2[400/469],Loss:[0.588,0.640],domain loss:[0.327,0.374],label loss:[0.033,0.061],prec[97.6562,98.1700],alpha:0.14168426394462585\n" 1056 | ] 1057 | }, 1058 | { 1059 | "data": { 1060 | "application/vnd.jupyter.widget-view+json": { 1061 | "model_id": "99df9a1ee493413da93bbb581e56cab4", 1062 | "version_major": 2, 1063 | "version_minor": 0 1064 | }, 1065 | "text/plain": [ 1066 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1067 | ] 1068 | }, 1069 | "metadata": {}, 1070 | "output_type": "display_data" 1071 | }, 1072 | { 1073 | "name": "stdout", 1074 | "output_type": "stream", 1075 | "text": [ 1076 | "\n", 1077 | "Epoch:2,val,Loss:[2.051],prec[53.3163],domain_acc[81.1169]\n", 1078 | "Epoch:3[200/469],Loss:[0.423,0.677],domain loss:[0.228,0.364],label loss:[0.047,0.083],prec[98.4375,97.3984],alpha:0.16966524720191956\n", 1079 | "Epoch:3[400/469],Loss:[0.420,0.608],domain loss:[0.196,0.332],label loss:[0.021,0.068],prec[99.2188,97.9044],alpha:0.19029566645622253\n" 1080 | ] 1081 | }, 1082 | { 1083 | "data": { 1084 | "application/vnd.jupyter.widget-view+json": { 1085 | "model_id": "02ff509001c84b7998096b721d36f571", 1086 | "version_major": 2, 1087 | "version_minor": 0 1088 | }, 1089 | "text/plain": [ 1090 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1091 | ] 1092 | }, 1093 | "metadata": {}, 1094 | "output_type": "display_data" 1095 | }, 1096 | { 1097 | "name": "stdout", 1098 | "output_type": "stream", 1099 | "text": [ 1100 | "\n", 1101 | "Epoch:3,val,Loss:[1.395],prec[58.5602],domain_acc[81.7604]\n", 1102 | "Epoch:4[200/469],Loss:[0.604,0.528],domain loss:[0.254,0.301],label loss:[0.053,0.046],prec[97.6562,98.5703],alpha:0.21777768433094025\n", 1103 | "Epoch:4[400/469],Loss:[0.607,0.526],domain loss:[0.304,0.297],label loss:[0.155,0.048],prec[98.4375,98.5098],alpha:0.23799148201942444\n" 1104 | ] 1105 | }, 1106 | { 1107 | "data": { 1108 | "application/vnd.jupyter.widget-view+json": { 1109 | "model_id": "679149ae337647249497c4f09cbdde0f", 1110 | "version_major": 2, 1111 | "version_minor": 0 1112 | }, 1113 | "text/plain": [ 1114 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1115 | ] 1116 | }, 1117 | "metadata": {}, 1118 | "output_type": "display_data" 1119 | }, 1120 | { 1121 | "name": "stdout", 1122 | "output_type": "stream", 1123 | "text": [ 1124 | "\n", 1125 | "Epoch:4,val,Loss:[1.358],prec[59.6711],domain_acc[83.5107]\n", 1126 | "Epoch:5[200/469],Loss:[0.468,0.484],domain loss:[0.149,0.268],label loss:[0.016,0.044],prec[99.2188,98.6211],alpha:0.2648544907569885\n", 1127 | "Epoch:5[400/469],Loss:[0.566,0.581],domain loss:[0.423,0.316],label loss:[0.009,0.056],prec[100.0000,98.2715],alpha:0.2845664620399475\n" 1128 | ] 1129 | }, 1130 | { 1131 | "data": { 1132 | "application/vnd.jupyter.widget-view+json": { 1133 | "model_id": "48fa42e13ba643f7aa428f9a22bc286c", 1134 | "version_major": 2, 1135 | "version_minor": 0 1136 | }, 1137 | "text/plain": [ 1138 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1139 | ] 1140 | }, 1141 | "metadata": {}, 1142 | "output_type": "display_data" 1143 | }, 1144 | { 1145 | "name": "stdout", 1146 | "output_type": "stream", 1147 | "text": [ 1148 | "\n", 1149 | "Epoch:5,val,Loss:[1.644],prec[58.7268],domain_acc[85.1683]\n", 1150 | "Epoch:6[200/469],Loss:[0.435,0.690],domain loss:[0.172,0.349],label loss:[0.049,0.093],prec[99.2188,97.3672],alpha:0.31070175766944885\n", 1151 | "Epoch:6[400/469],Loss:[0.636,0.584],domain loss:[0.350,0.305],label loss:[0.057,0.072],prec[99.2188,97.9668],alpha:0.3298357427120209\n" 1152 | ] 1153 | }, 1154 | { 1155 | "data": { 1156 | "application/vnd.jupyter.widget-view+json": { 1157 | "model_id": "952dbd0109994ccebce5bf781fed7742", 1158 | "version_major": 2, 1159 | "version_minor": 0 1160 | }, 1161 | "text/plain": [ 1162 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1163 | ] 1164 | }, 1165 | "metadata": {}, 1166 | "output_type": "display_data" 1167 | }, 1168 | { 1169 | "name": "stdout", 1170 | "output_type": "stream", 1171 | "text": [ 1172 | "\n", 1173 | "Epoch:6,val,Loss:[1.433],prec[59.3823],domain_acc[86.1031]\n", 1174 | "Epoch:7[200/469],Loss:[0.778,0.570],domain loss:[0.297,0.295],label loss:[0.076,0.045],prec[97.6562,98.6719],alpha:0.35514748096466064\n", 1175 | "Epoch:7[400/469],Loss:[0.679,0.569],domain loss:[0.349,0.299],label loss:[0.050,0.053],prec[98.4375,98.4337],alpha:0.37363728880882263\n" 1176 | ] 1177 | }, 1178 | { 1179 | "data": { 1180 | "application/vnd.jupyter.widget-view+json": { 1181 | "model_id": "0463245fdfb64a498553dc0124279758", 1182 | "version_major": 2, 1183 | "version_minor": 0 1184 | }, 1185 | "text/plain": [ 1186 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1187 | ] 1188 | }, 1189 | "metadata": {}, 1190 | "output_type": "display_data" 1191 | }, 1192 | { 1193 | "name": "stdout", 1194 | "output_type": "stream", 1195 | "text": [ 1196 | "\n", 1197 | "Epoch:7,val,Loss:[2.182],prec[55.7605],domain_acc[84.8308]\n", 1198 | "Epoch:8[200/469],Loss:[0.832,0.692],domain loss:[0.316,0.347],label loss:[0.125,0.074],prec[95.3125,97.7735],alpha:0.39804354310035706\n", 1199 | "Epoch:8[400/469],Loss:[0.788,0.682],domain loss:[0.350,0.330],label loss:[0.109,0.068],prec[97.6562,98.0137],alpha:0.41583359241485596\n" 1200 | ] 1201 | }, 1202 | { 1203 | "data": { 1204 | "application/vnd.jupyter.widget-view+json": { 1205 | "model_id": "853191f7b5574952ab5c39f10007933f", 1206 | "version_major": 2, 1207 | "version_minor": 0 1208 | }, 1209 | "text/plain": [ 1210 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1211 | ] 1212 | }, 1213 | "metadata": {}, 1214 | "output_type": "display_data" 1215 | }, 1216 | { 1217 | "name": "stdout", 1218 | "output_type": "stream", 1219 | "text": [ 1220 | "\n", 1221 | "Epoch:8,val,Loss:[2.493],prec[56.3937],domain_acc[84.7523]\n", 1222 | "Epoch:9[200/469],Loss:[0.866,0.985],domain loss:[0.449,0.411],label loss:[0.182,0.110],prec[95.3125,97.0898],alpha:0.4392668306827545\n", 1223 | "Epoch:9[400/469],Loss:[1.133,1.028],domain loss:[0.527,0.439],label loss:[0.245,0.114],prec[95.3125,96.8827],alpha:0.456312358379364\n" 1224 | ] 1225 | }, 1226 | { 1227 | "data": { 1228 | "application/vnd.jupyter.widget-view+json": { 1229 | "model_id": "596f9abff1c74697b0890a984de9e463", 1230 | "version_major": 2, 1231 | "version_minor": 0 1232 | }, 1233 | "text/plain": [ 1234 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1235 | ] 1236 | }, 1237 | "metadata": {}, 1238 | "output_type": "display_data" 1239 | }, 1240 | { 1241 | "name": "stdout", 1242 | "output_type": "stream", 1243 | "text": [ 1244 | "\n", 1245 | "Epoch:9,val,Loss:[2.505],prec[59.6600],domain_acc[84.4939]\n", 1246 | "Epoch:10[200/469],Loss:[1.085,0.837],domain loss:[0.300,0.374],label loss:[0.079,0.093],prec[96.8750,97.3711],alpha:0.47871965169906616\n", 1247 | "Epoch:10[400/469],Loss:[0.775,0.984],domain loss:[0.246,0.428],label loss:[0.181,0.106],prec[95.3125,97.0684],alpha:0.49498671293258667\n" 1248 | ] 1249 | }, 1250 | { 1251 | "data": { 1252 | "application/vnd.jupyter.widget-view+json": { 1253 | "model_id": "3f51722366404308956f6804cc94507c", 1254 | "version_major": 2, 1255 | "version_minor": 0 1256 | }, 1257 | "text/plain": [ 1258 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1259 | ] 1260 | }, 1261 | "metadata": {}, 1262 | "output_type": "display_data" 1263 | }, 1264 | { 1265 | "name": "stdout", 1266 | "output_type": "stream", 1267 | "text": [ 1268 | "\n", 1269 | "Epoch:10,val,Loss:[2.768],prec[58.6379],domain_acc[85.3067]\n", 1270 | "Epoch:11[200/469],Loss:[1.212,1.070],domain loss:[0.511,0.469],label loss:[0.136,0.118],prec[96.0938,96.5274],alpha:0.5163294672966003\n", 1271 | "Epoch:11[400/469],Loss:[0.963,1.081],domain loss:[0.340,0.461],label loss:[0.037,0.107],prec[99.2188,96.8262],alpha:0.5317944884300232\n" 1272 | ] 1273 | }, 1274 | { 1275 | "data": { 1276 | "application/vnd.jupyter.widget-view+json": { 1277 | "model_id": "af12310a82f34b20a566188317772eda", 1278 | "version_major": 2, 1279 | "version_minor": 0 1280 | }, 1281 | "text/plain": [ 1282 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1283 | ] 1284 | }, 1285 | "metadata": {}, 1286 | "output_type": "display_data" 1287 | }, 1288 | { 1289 | "name": "stdout", 1290 | "output_type": "stream", 1291 | "text": [ 1292 | "\n", 1293 | "Epoch:11,val,Loss:[3.159],prec[58.7379],domain_acc[84.8026]\n", 1294 | "Epoch:12[200/469],Loss:[1.140,1.120],domain loss:[0.550,0.465],label loss:[0.093,0.098],prec[98.4375,96.9648],alpha:0.5520477890968323\n", 1295 | "Epoch:12[400/469],Loss:[1.227,1.088],domain loss:[0.678,0.459],label loss:[0.034,0.095],prec[99.2188,97.0781],alpha:0.5666970610618591\n" 1296 | ] 1297 | }, 1298 | { 1299 | "data": { 1300 | "application/vnd.jupyter.widget-view+json": { 1301 | "model_id": "e04fced654dd4a0388988b2d8ee80f51", 1302 | "version_major": 2, 1303 | "version_minor": 0 1304 | }, 1305 | "text/plain": [ 1306 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1307 | ] 1308 | }, 1309 | "metadata": {}, 1310 | "output_type": "display_data" 1311 | }, 1312 | { 1313 | "name": "stdout", 1314 | "output_type": "stream", 1315 | "text": [ 1316 | "\n", 1317 | "Epoch:12,val,Loss:[3.164],prec[59.0823],domain_acc[84.4445]\n", 1318 | "Epoch:13[200/469],Loss:[1.187,1.179],domain loss:[0.584,0.492],label loss:[0.025,0.092],prec[100.0000,97.0586],alpha:0.5858488082885742\n", 1319 | "Epoch:13[400/469],Loss:[1.255,1.212],domain loss:[0.401,0.512],label loss:[0.142,0.096],prec[94.5312,96.9726],alpha:0.5996778011322021\n" 1320 | ] 1321 | }, 1322 | { 1323 | "data": { 1324 | "application/vnd.jupyter.widget-view+json": { 1325 | "model_id": "34e592e5e60f4cbdb0e0ddb28be25023", 1326 | "version_major": 2, 1327 | "version_minor": 0 1328 | }, 1329 | "text/plain": [ 1330 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1331 | ] 1332 | }, 1333 | "metadata": {}, 1334 | "output_type": "display_data" 1335 | }, 1336 | { 1337 | "name": "stdout", 1338 | "output_type": "stream", 1339 | "text": [ 1340 | "\n", 1341 | "Epoch:13,val,Loss:[2.182],prec[60.5155],domain_acc[84.3859]\n", 1342 | "Epoch:14[200/469],Loss:[1.176,1.073],domain loss:[0.594,0.531],label loss:[0.116,0.088],prec[96.0938,97.3906],alpha:0.617727518081665\n", 1343 | "Epoch:14[400/469],Loss:[1.090,1.080],domain loss:[0.574,0.528],label loss:[0.044,0.080],prec[97.6562,97.5937],alpha:0.6307399272918701\n" 1344 | ] 1345 | }, 1346 | { 1347 | "data": { 1348 | "application/vnd.jupyter.widget-view+json": { 1349 | "model_id": "f0a03a7136024c39b1a9237cef3523fc", 1350 | "version_major": 2, 1351 | "version_minor": 0 1352 | }, 1353 | "text/plain": [ 1354 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1355 | ] 1356 | }, 1357 | "metadata": {}, 1358 | "output_type": "display_data" 1359 | }, 1360 | { 1361 | "name": "stdout", 1362 | "output_type": "stream", 1363 | "text": [ 1364 | "\n", 1365 | "Epoch:14,val,Loss:[1.604],prec[63.7151],domain_acc[83.0374]\n", 1366 | "Epoch:15[200/469],Loss:[1.137,1.140],domain loss:[0.466,0.541],label loss:[0.039,0.071],prec[99.2188,97.8750],alpha:0.6476975083351135\n", 1367 | "Epoch:15[400/469],Loss:[1.162,1.169],domain loss:[0.476,0.546],label loss:[0.074,0.074],prec[97.6562,97.7735],alpha:0.6599041819572449\n" 1368 | ] 1369 | }, 1370 | { 1371 | "data": { 1372 | "application/vnd.jupyter.widget-view+json": { 1373 | "model_id": "7779de894ee1451aa22c156804a90b00", 1374 | "version_major": 2, 1375 | "version_minor": 0 1376 | }, 1377 | "text/plain": [ 1378 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1379 | ] 1380 | }, 1381 | "metadata": {}, 1382 | "output_type": "display_data" 1383 | }, 1384 | { 1385 | "name": "stdout", 1386 | "output_type": "stream", 1387 | "text": [ 1388 | "\n", 1389 | "Epoch:15,val,Loss:[1.526],prec[63.5929],domain_acc[81.8013]\n", 1390 | "Epoch:16[200/469],Loss:[1.116,1.225],domain loss:[0.450,0.581],label loss:[0.057,0.088],prec[97.6562,97.4453],alpha:0.6757887601852417\n", 1391 | "Epoch:16[400/469],Loss:[1.375,1.208],domain loss:[0.455,0.596],label loss:[0.057,0.080],prec[96.8750,97.6153],alpha:0.6872069239616394\n" 1392 | ] 1393 | }, 1394 | { 1395 | "data": { 1396 | "application/vnd.jupyter.widget-view+json": { 1397 | "model_id": "246422088edf4ce18a05e34c11bc96f2", 1398 | "version_major": 2, 1399 | "version_minor": 0 1400 | }, 1401 | "text/plain": [ 1402 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1403 | ] 1404 | }, 1405 | "metadata": {}, 1406 | "output_type": "display_data" 1407 | }, 1408 | { 1409 | "name": "stdout", 1410 | "output_type": "stream", 1411 | "text": [ 1412 | "\n", 1413 | "Epoch:16,val,Loss:[2.131],prec[63.8151],domain_acc[81.1773]\n", 1414 | "Epoch:17[200/469],Loss:[1.303,1.195],domain loss:[0.506,0.599],label loss:[0.081,0.061],prec[96.8750,98.1289],alpha:0.7020451426506042\n", 1415 | "Epoch:17[400/469],Loss:[1.180,1.224],domain loss:[0.581,0.595],label loss:[0.040,0.064],prec[98.4375,98.0879],alpha:0.7126971483230591\n" 1416 | ] 1417 | }, 1418 | { 1419 | "data": { 1420 | "application/vnd.jupyter.widget-view+json": { 1421 | "model_id": "2b4785d3018641788028ee6e854a80f6", 1422 | "version_major": 2, 1423 | "version_minor": 0 1424 | }, 1425 | "text/plain": [ 1426 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1427 | ] 1428 | }, 1429 | "metadata": {}, 1430 | "output_type": "display_data" 1431 | }, 1432 | { 1433 | "name": "stdout", 1434 | "output_type": "stream", 1435 | "text": [ 1436 | "\n", 1437 | "Epoch:17,val,Loss:[1.716],prec[63.7596],domain_acc[80.7886]\n", 1438 | "Epoch:18[200/469],Loss:[1.187,1.161],domain loss:[0.540,0.569],label loss:[0.056,0.063],prec[97.6562,98.0937],alpha:0.7265222072601318\n", 1439 | "Epoch:18[400/469],Loss:[1.231,1.177],domain loss:[0.569,0.583],label loss:[0.080,0.067],prec[96.8750,97.9551],alpha:0.7364346385002136\n" 1440 | ] 1441 | }, 1442 | { 1443 | "data": { 1444 | "application/vnd.jupyter.widget-view+json": { 1445 | "model_id": "e8797e87c42c4036869eee2ae57dcee0", 1446 | "version_major": 2, 1447 | "version_minor": 0 1448 | }, 1449 | "text/plain": [ 1450 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1451 | ] 1452 | }, 1453 | "metadata": {}, 1454 | "output_type": "display_data" 1455 | }, 1456 | { 1457 | "name": "stdout", 1458 | "output_type": "stream", 1459 | "text": [ 1460 | "\n", 1461 | "Epoch:18,val,Loss:[1.434],prec[66.9037],domain_acc[79.5607]\n", 1462 | "Epoch:19[200/469],Loss:[1.329,1.167],domain loss:[0.533,0.592],label loss:[0.099,0.069],prec[97.6562,97.8828],alpha:0.7492846250534058\n", 1463 | "Epoch:19[400/469],Loss:[1.335,1.199],domain loss:[0.673,0.604],label loss:[0.091,0.070],prec[97.6562,97.8516],alpha:0.7584874629974365\n" 1464 | ] 1465 | }, 1466 | { 1467 | "data": { 1468 | "application/vnd.jupyter.widget-view+json": { 1469 | "model_id": "0ccf062de4694390b030b7d3d1ceaab6", 1470 | "version_major": 2, 1471 | "version_minor": 0 1472 | }, 1473 | "text/plain": [ 1474 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1475 | ] 1476 | }, 1477 | "metadata": {}, 1478 | "output_type": "display_data" 1479 | }, 1480 | { 1481 | "name": "stdout", 1482 | "output_type": "stream", 1483 | "text": [ 1484 | "\n", 1485 | "Epoch:19,val,Loss:[1.422],prec[66.1704],domain_acc[78.2974]\n", 1486 | "Epoch:20[200/469],Loss:[1.086,1.225],domain loss:[0.452,0.589],label loss:[0.025,0.070],prec[99.2188,97.8164],alpha:0.7704044580459595\n", 1487 | "Epoch:20[400/469],Loss:[1.204,1.218],domain loss:[0.556,0.596],label loss:[0.056,0.066],prec[97.6562,97.9688],alpha:0.7789300084114075\n" 1488 | ] 1489 | }, 1490 | { 1491 | "data": { 1492 | "application/vnd.jupyter.widget-view+json": { 1493 | "model_id": "c0f0c456ee5f4025bf0e00ddaa89de7d", 1494 | "version_major": 2, 1495 | "version_minor": 0 1496 | }, 1497 | "text/plain": [ 1498 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1499 | ] 1500 | }, 1501 | "metadata": {}, 1502 | "output_type": "display_data" 1503 | }, 1504 | { 1505 | "name": "stdout", 1506 | "output_type": "stream", 1507 | "text": [ 1508 | "\n", 1509 | "Epoch:20,val,Loss:[1.460],prec[67.2814],domain_acc[77.7977]\n", 1510 | "Epoch:21[200/469],Loss:[1.236,1.264],domain loss:[0.524,0.613],label loss:[0.107,0.082],prec[97.6562,97.5078],alpha:0.7899587750434875\n", 1511 | "Epoch:21[400/469],Loss:[1.133,1.229],domain loss:[0.574,0.602],label loss:[0.077,0.071],prec[97.6562,97.8496],alpha:0.7978411316871643\n" 1512 | ] 1513 | }, 1514 | { 1515 | "data": { 1516 | "application/vnd.jupyter.widget-view+json": { 1517 | "model_id": "a11e7d9797c8475f83b739fc3470f79b", 1518 | "version_major": 2, 1519 | "version_minor": 0 1520 | }, 1521 | "text/plain": [ 1522 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1523 | ] 1524 | }, 1525 | "metadata": {}, 1526 | "output_type": "display_data" 1527 | }, 1528 | { 1529 | "name": "stdout", 1530 | "output_type": "stream", 1531 | "text": [ 1532 | "\n", 1533 | "Epoch:21,val,Loss:[1.560],prec[67.6814],domain_acc[77.4611]\n", 1534 | "Epoch:22[200/469],Loss:[1.147,1.206],domain loss:[0.509,0.580],label loss:[0.048,0.069],prec[99.2188,97.8867],alpha:0.8080282211303711\n", 1535 | "Epoch:22[400/469],Loss:[1.253,1.209],domain loss:[0.543,0.587],label loss:[0.050,0.064],prec[97.6562,98.0177],alpha:0.8153024911880493\n" 1536 | ] 1537 | }, 1538 | { 1539 | "data": { 1540 | "application/vnd.jupyter.widget-view+json": { 1541 | "model_id": "bc599ec130664aa9bfa95ac985b38357", 1542 | "version_major": 2, 1543 | "version_minor": 0 1544 | }, 1545 | "text/plain": [ 1546 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1547 | ] 1548 | }, 1549 | "metadata": {}, 1550 | "output_type": "display_data" 1551 | }, 1552 | { 1553 | "name": "stdout", 1554 | "output_type": "stream", 1555 | "text": [ 1556 | "\n", 1557 | "Epoch:22,val,Loss:[1.462],prec[69.1256],domain_acc[77.3204]\n", 1558 | "Epoch:23[200/469],Loss:[1.171,1.212],domain loss:[0.594,0.602],label loss:[0.065,0.061],prec[98.4375,98.1055],alpha:0.8246954679489136\n", 1559 | "Epoch:23[400/469],Loss:[1.253,1.242],domain loss:[0.606,0.615],label loss:[0.084,0.062],prec[96.0938,98.0821],alpha:0.8313970565795898\n" 1560 | ] 1561 | }, 1562 | { 1563 | "data": { 1564 | "application/vnd.jupyter.widget-view+json": { 1565 | "model_id": "50cc42409a9e47d8a42007f048e8c9d2", 1566 | "version_major": 2, 1567 | "version_minor": 0 1568 | }, 1569 | "text/plain": [ 1570 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1571 | ] 1572 | }, 1573 | "metadata": {}, 1574 | "output_type": "display_data" 1575 | }, 1576 | { 1577 | "name": "stdout", 1578 | "output_type": "stream", 1579 | "text": [ 1580 | "\n", 1581 | "Epoch:23,val,Loss:[1.503],prec[68.2035],domain_acc[76.7405]\n", 1582 | "Epoch:24[200/469],Loss:[1.364,1.294],domain loss:[0.612,0.623],label loss:[0.036,0.065],prec[98.4375,97.9492],alpha:0.8400437235832214\n", 1583 | "Epoch:24[400/469],Loss:[1.162,1.266],domain loss:[0.560,0.614],label loss:[0.016,0.067],prec[100.0000,97.9180],alpha:0.8462079763412476\n" 1584 | ] 1585 | }, 1586 | { 1587 | "data": { 1588 | "application/vnd.jupyter.widget-view+json": { 1589 | "model_id": "04403e5cf52b4f9c901128ef075cc5a4", 1590 | "version_major": 2, 1591 | "version_minor": 0 1592 | }, 1593 | "text/plain": [ 1594 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1595 | ] 1596 | }, 1597 | "metadata": {}, 1598 | "output_type": "display_data" 1599 | }, 1600 | { 1601 | "name": "stdout", 1602 | "output_type": "stream", 1603 | "text": [ 1604 | "\n", 1605 | "Epoch:24,val,Loss:[1.255],prec[70.6477],domain_acc[76.1213]\n", 1606 | "Epoch:25[200/469],Loss:[1.339,1.272],domain loss:[0.603,0.616],label loss:[0.040,0.057],prec[98.4375,98.1484],alpha:0.8541555404663086\n", 1607 | "Epoch:25[400/469],Loss:[1.317,1.279],domain loss:[0.699,0.620],label loss:[0.068,0.064],prec[96.8750,97.9337],alpha:0.859817385673523\n" 1608 | ] 1609 | }, 1610 | { 1611 | "data": { 1612 | "application/vnd.jupyter.widget-view+json": { 1613 | "model_id": "e170060c8ccf4ca89dbc2ef39b089866", 1614 | "version_major": 2, 1615 | "version_minor": 0 1616 | }, 1617 | "text/plain": [ 1618 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1619 | ] 1620 | }, 1621 | "metadata": {}, 1622 | "output_type": "display_data" 1623 | }, 1624 | { 1625 | "name": "stdout", 1626 | "output_type": "stream", 1627 | "text": [ 1628 | "\n", 1629 | "Epoch:25,val,Loss:[1.403],prec[68.1702],domain_acc[74.8118]\n", 1630 | "Epoch:26[200/469],Loss:[1.318,1.301],domain loss:[0.689,0.652],label loss:[0.079,0.067],prec[96.8750,98.0039],alpha:0.8671122789382935\n", 1631 | "Epoch:26[400/469],Loss:[1.306,1.289],domain loss:[0.633,0.638],label loss:[0.062,0.063],prec[97.6562,98.1134],alpha:0.8723058104515076\n" 1632 | ] 1633 | }, 1634 | { 1635 | "data": { 1636 | "application/vnd.jupyter.widget-view+json": { 1637 | "model_id": "8ed59f80e01644bbaa7d150313f30e23", 1638 | "version_major": 2, 1639 | "version_minor": 0 1640 | }, 1641 | "text/plain": [ 1642 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1643 | ] 1644 | }, 1645 | "metadata": {}, 1646 | "output_type": "display_data" 1647 | }, 1648 | { 1649 | "name": "stdout", 1650 | "output_type": "stream", 1651 | "text": [ 1652 | "\n", 1653 | "Epoch:26,val,Loss:[1.301],prec[70.3700],domain_acc[73.7046]\n", 1654 | "Epoch:27[200/469],Loss:[1.378,1.292],domain loss:[0.676,0.638],label loss:[0.098,0.056],prec[97.6562,98.2617],alpha:0.8789930939674377\n", 1655 | "Epoch:27[400/469],Loss:[1.378,1.323],domain loss:[0.705,0.640],label loss:[0.103,0.065],prec[96.0938,97.9454],alpha:0.8837512135505676\n" 1656 | ] 1657 | }, 1658 | { 1659 | "data": { 1660 | "application/vnd.jupyter.widget-view+json": { 1661 | "model_id": "76f22c3d3b4b4d4fa7b65d273092f498", 1662 | "version_major": 2, 1663 | "version_minor": 0 1664 | }, 1665 | "text/plain": [ 1666 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1667 | ] 1668 | }, 1669 | "metadata": {}, 1670 | "output_type": "display_data" 1671 | }, 1672 | { 1673 | "name": "stdout", 1674 | "output_type": "stream", 1675 | "text": [ 1676 | "\n", 1677 | "Epoch:27,val,Loss:[1.325],prec[69.5145],domain_acc[72.2828]\n", 1678 | "Epoch:28[200/469],Loss:[1.366,1.303],domain loss:[0.734,0.665],label loss:[0.111,0.062],prec[96.8750,98.0976],alpha:0.8898743987083435\n", 1679 | "Epoch:28[400/469],Loss:[1.413,1.328],domain loss:[0.796,0.649],label loss:[0.141,0.066],prec[96.0938,98.0234],alpha:0.8942286968231201\n" 1680 | ] 1681 | }, 1682 | { 1683 | "data": { 1684 | "application/vnd.jupyter.widget-view+json": { 1685 | "model_id": "a6daf4fe1c244bec84f1c050f2286716", 1686 | "version_major": 2, 1687 | "version_minor": 0 1688 | }, 1689 | "text/plain": [ 1690 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1691 | ] 1692 | }, 1693 | "metadata": {}, 1694 | "output_type": "display_data" 1695 | }, 1696 | { 1697 | "name": "stdout", 1698 | "output_type": "stream", 1699 | "text": [ 1700 | "\n", 1701 | "Epoch:28,val,Loss:[1.496],prec[67.2148],domain_acc[71.9464]\n", 1702 | "Epoch:29[200/469],Loss:[1.325,1.364],domain loss:[0.698,0.676],label loss:[0.025,0.055],prec[99.2188,98.2813],alpha:0.8998293280601501\n", 1703 | "Epoch:29[400/469],Loss:[1.473,1.358],domain loss:[0.803,0.664],label loss:[0.022,0.057],prec[99.2188,98.1955],alpha:0.9038100838661194\n" 1704 | ] 1705 | }, 1706 | { 1707 | "data": { 1708 | "application/vnd.jupyter.widget-view+json": { 1709 | "model_id": "84da01902ff74f8c87d5ed0a27e787d0", 1710 | "version_major": 2, 1711 | "version_minor": 0 1712 | }, 1713 | "text/plain": [ 1714 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1715 | ] 1716 | }, 1717 | "metadata": {}, 1718 | "output_type": "display_data" 1719 | }, 1720 | { 1721 | "name": "stdout", 1722 | "output_type": "stream", 1723 | "text": [ 1724 | "\n", 1725 | "Epoch:29,val,Loss:[1.179],prec[71.0588],domain_acc[71.7839]\n", 1726 | "Epoch:30[200/469],Loss:[1.335,1.354],domain loss:[0.708,0.659],label loss:[0.076,0.047],prec[98.4375,98.5352],alpha:0.9089277386665344\n", 1727 | "Epoch:30[400/469],Loss:[1.339,1.345],domain loss:[0.717,0.660],label loss:[0.052,0.053],prec[99.2188,98.3848],alpha:0.9125635623931885\n" 1728 | ] 1729 | }, 1730 | { 1731 | "data": { 1732 | "application/vnd.jupyter.widget-view+json": { 1733 | "model_id": "62101943ee1f40b1810583a3ad6bc047", 1734 | "version_major": 2, 1735 | "version_minor": 0 1736 | }, 1737 | "text/plain": [ 1738 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1739 | ] 1740 | }, 1741 | "metadata": {}, 1742 | "output_type": "display_data" 1743 | }, 1744 | { 1745 | "name": "stdout", 1746 | "output_type": "stream", 1747 | "text": [ 1748 | "\n", 1749 | "Epoch:30,val,Loss:[1.310],prec[71.6476],domain_acc[71.8949]\n", 1750 | "Epoch:31[200/469],Loss:[1.433,1.362],domain loss:[0.734,0.641],label loss:[0.038,0.050],prec[98.4375,98.3359],alpha:0.917235791683197\n", 1751 | "Epoch:31[400/469],Loss:[1.322,1.374],domain loss:[0.742,0.649],label loss:[0.006,0.055],prec[100.0000,98.3087],alpha:0.9205537438392639\n" 1752 | ] 1753 | }, 1754 | { 1755 | "data": { 1756 | "application/vnd.jupyter.widget-view+json": { 1757 | "model_id": "6b19291389ce490881bf164ba22a3133", 1758 | "version_major": 2, 1759 | "version_minor": 0 1760 | }, 1761 | "text/plain": [ 1762 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1763 | ] 1764 | }, 1765 | "metadata": {}, 1766 | "output_type": "display_data" 1767 | }, 1768 | { 1769 | "name": "stdout", 1770 | "output_type": "stream", 1771 | "text": [ 1772 | "\n", 1773 | "Epoch:31,val,Loss:[1.133],prec[73.9362],domain_acc[70.8841]\n", 1774 | "Epoch:32[200/469],Loss:[1.366,1.353],domain loss:[0.647,0.679],label loss:[0.087,0.052],prec[97.6562,98.3711],alpha:0.9248157739639282\n", 1775 | "Epoch:32[400/469],Loss:[1.395,1.352],domain loss:[0.747,0.663],label loss:[0.025,0.055],prec[99.2188,98.3204],alpha:0.9278412461280823\n" 1776 | ] 1777 | }, 1778 | { 1779 | "data": { 1780 | "application/vnd.jupyter.widget-view+json": { 1781 | "model_id": "c0fae860bc6e4eaeb393368a51ef6266", 1782 | "version_major": 2, 1783 | "version_minor": 0 1784 | }, 1785 | "text/plain": [ 1786 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1787 | ] 1788 | }, 1789 | "metadata": {}, 1790 | "output_type": "display_data" 1791 | }, 1792 | { 1793 | "name": "stdout", 1794 | "output_type": "stream", 1795 | "text": [ 1796 | "\n", 1797 | "Epoch:32,val,Loss:[1.078],prec[73.2807],domain_acc[70.5430]\n", 1798 | "Epoch:33[200/469],Loss:[1.245,1.317],domain loss:[0.634,0.649],label loss:[0.003,0.057],prec[100.0000,98.2539],alpha:0.9317262768745422\n", 1799 | "Epoch:33[400/469],Loss:[1.364,1.321],domain loss:[0.676,0.649],label loss:[0.025,0.059],prec[99.2188,98.2443],alpha:0.9344831109046936\n" 1800 | ] 1801 | }, 1802 | { 1803 | "data": { 1804 | "application/vnd.jupyter.widget-view+json": { 1805 | "model_id": "e2c2f7c596fc4ccc8402d1960a14d229", 1806 | "version_major": 2, 1807 | "version_minor": 0 1808 | }, 1809 | "text/plain": [ 1810 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1811 | ] 1812 | }, 1813 | "metadata": {}, 1814 | "output_type": "display_data" 1815 | }, 1816 | { 1817 | "name": "stdout", 1818 | "output_type": "stream", 1819 | "text": [ 1820 | "\n", 1821 | "Epoch:33,val,Loss:[1.039],prec[73.8362],domain_acc[70.1007]\n", 1822 | "Epoch:34[200/469],Loss:[1.439,1.376],domain loss:[0.764,0.661],label loss:[0.028,0.052],prec[99.2188,98.4571],alpha:0.9380220174789429\n", 1823 | "Epoch:34[400/469],Loss:[1.385,1.362],domain loss:[0.660,0.660],label loss:[0.096,0.056],prec[97.6562,98.3126],alpha:0.9405325055122375\n" 1824 | ] 1825 | }, 1826 | { 1827 | "data": { 1828 | "application/vnd.jupyter.widget-view+json": { 1829 | "model_id": "381fdec774654f5c953eb0f792ee7d63", 1830 | "version_major": 2, 1831 | "version_minor": 0 1832 | }, 1833 | "text/plain": [ 1834 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1835 | ] 1836 | }, 1837 | "metadata": {}, 1838 | "output_type": "display_data" 1839 | }, 1840 | { 1841 | "name": "stdout", 1842 | "output_type": "stream", 1843 | "text": [ 1844 | "\n", 1845 | "Epoch:34,val,Loss:[1.322],prec[72.0142],domain_acc[70.0262]\n", 1846 | "Epoch:35[200/469],Loss:[1.346,1.400],domain loss:[0.574,0.678],label loss:[0.050,0.065],prec[97.6562,98.0196],alpha:0.9437541365623474\n", 1847 | "Epoch:35[400/469],Loss:[1.473,1.391],domain loss:[0.654,0.661],label loss:[0.147,0.062],prec[96.8750,98.1114],alpha:0.9460389018058777\n" 1848 | ] 1849 | }, 1850 | { 1851 | "data": { 1852 | "application/vnd.jupyter.widget-view+json": { 1853 | "model_id": "1796a0eef9a947ec8d66b374ac1474ae", 1854 | "version_major": 2, 1855 | "version_minor": 0 1856 | }, 1857 | "text/plain": [ 1858 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1859 | ] 1860 | }, 1861 | "metadata": {}, 1862 | "output_type": "display_data" 1863 | }, 1864 | { 1865 | "name": "stdout", 1866 | "output_type": "stream", 1867 | "text": [ 1868 | "\n", 1869 | "Epoch:35,val,Loss:[1.067],prec[74.5695],domain_acc[69.7401]\n", 1870 | "Epoch:36[200/469],Loss:[1.369,1.369],domain loss:[0.673,0.655],label loss:[0.096,0.051],prec[98.4375,98.3906],alpha:0.9489700794219971\n", 1871 | "Epoch:36[400/469],Loss:[1.351,1.377],domain loss:[0.649,0.665],label loss:[0.017,0.052],prec[100.0000,98.3888],alpha:0.9510483145713806\n" 1872 | ] 1873 | }, 1874 | { 1875 | "data": { 1876 | "application/vnd.jupyter.widget-view+json": { 1877 | "model_id": "ed37add197654d7ab977f34bca23c1cc", 1878 | "version_major": 2, 1879 | "version_minor": 0 1880 | }, 1881 | "text/plain": [ 1882 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1883 | ] 1884 | }, 1885 | "metadata": {}, 1886 | "output_type": "display_data" 1887 | }, 1888 | { 1889 | "name": "stdout", 1890 | "output_type": "stream", 1891 | "text": [ 1892 | "\n", 1893 | "Epoch:36,val,Loss:[1.088],prec[74.8361],domain_acc[69.7823]\n", 1894 | "Epoch:37[200/469],Loss:[1.310,1.405],domain loss:[0.613,0.672],label loss:[0.001,0.054],prec[100.0000,98.3867],alpha:0.9537138342857361\n", 1895 | "Epoch:37[400/469],Loss:[1.429,1.407],domain loss:[0.674,0.679],label loss:[0.071,0.052],prec[99.2188,98.4082],alpha:0.9556032419204712\n" 1896 | ] 1897 | }, 1898 | { 1899 | "data": { 1900 | "application/vnd.jupyter.widget-view+json": { 1901 | "model_id": "59aeb7bb310f4dfbb89f34ae70876e1c", 1902 | "version_major": 2, 1903 | "version_minor": 0 1904 | }, 1905 | "text/plain": [ 1906 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1907 | ] 1908 | }, 1909 | "metadata": {}, 1910 | "output_type": "display_data" 1911 | }, 1912 | { 1913 | "name": "stdout", 1914 | "output_type": "stream", 1915 | "text": [ 1916 | "\n", 1917 | "Epoch:37,val,Loss:[1.119],prec[75.4583],domain_acc[69.6753]\n", 1918 | "Epoch:38[200/469],Loss:[1.348,1.359],domain loss:[0.700,0.651],label loss:[0.009,0.045],prec[100.0000,98.6133],alpha:0.958026111125946\n", 1919 | "Epoch:38[400/469],Loss:[1.480,1.368],domain loss:[0.653,0.666],label loss:[0.076,0.046],prec[97.6562,98.5723],alpha:0.9597431421279907\n" 1920 | ] 1921 | }, 1922 | { 1923 | "data": { 1924 | "application/vnd.jupyter.widget-view+json": { 1925 | "model_id": "3d83cc126a33433588ca51f2cb69235c", 1926 | "version_major": 2, 1927 | "version_minor": 0 1928 | }, 1929 | "text/plain": [ 1930 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1931 | ] 1932 | }, 1933 | "metadata": {}, 1934 | "output_type": "display_data" 1935 | }, 1936 | { 1937 | "name": "stdout", 1938 | "output_type": "stream", 1939 | "text": [ 1940 | "\n", 1941 | "Epoch:38,val,Loss:[1.032],prec[75.5694],domain_acc[69.7851]\n", 1942 | "Epoch:39[200/469],Loss:[1.395,1.401],domain loss:[0.688,0.683],label loss:[0.086,0.050],prec[97.6562,98.5117],alpha:0.9619444608688354\n", 1943 | "Epoch:39[400/469],Loss:[1.435,1.398],domain loss:[0.684,0.681],label loss:[0.091,0.050],prec[96.8750,98.4728],alpha:0.9635041356086731\n" 1944 | ] 1945 | }, 1946 | { 1947 | "data": { 1948 | "application/vnd.jupyter.widget-view+json": { 1949 | "model_id": "23ad23b45db84fb6b12d6ee1f8b0ec02", 1950 | "version_major": 2, 1951 | "version_minor": 0 1952 | }, 1953 | "text/plain": [ 1954 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1955 | ] 1956 | }, 1957 | "metadata": {}, 1958 | "output_type": "display_data" 1959 | }, 1960 | { 1961 | "name": "stdout", 1962 | "output_type": "stream", 1963 | "text": [ 1964 | "\n", 1965 | "Epoch:39,val,Loss:[1.097],prec[75.5249],domain_acc[69.7114]\n", 1966 | "Epoch:40[200/469],Loss:[1.443,1.339],domain loss:[0.658,0.649],label loss:[0.103,0.057],prec[96.0938,98.2422],alpha:0.965503454208374\n", 1967 | "Epoch:40[400/469],Loss:[1.379,1.353],domain loss:[0.656,0.662],label loss:[0.056,0.055],prec[97.6562,98.3165],alpha:0.9669197797775269\n" 1968 | ] 1969 | }, 1970 | { 1971 | "data": { 1972 | "application/vnd.jupyter.widget-view+json": { 1973 | "model_id": "4a59a35f26024dec8717caeeebb27cce", 1974 | "version_major": 2, 1975 | "version_minor": 0 1976 | }, 1977 | "text/plain": [ 1978 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 1979 | ] 1980 | }, 1981 | "metadata": {}, 1982 | "output_type": "display_data" 1983 | }, 1984 | { 1985 | "name": "stdout", 1986 | "output_type": "stream", 1987 | "text": [ 1988 | "\n", 1989 | "Epoch:40,val,Loss:[0.865],prec[78.0802],domain_acc[69.5400]\n", 1990 | "Epoch:41[200/469],Loss:[1.373,1.357],domain loss:[0.680,0.660],label loss:[0.088,0.045],prec[98.4375,98.5274],alpha:0.9687349200248718\n", 1991 | "Epoch:41[400/469],Loss:[1.394,1.371],domain loss:[0.720,0.662],label loss:[0.057,0.049],prec[96.8750,98.4669],alpha:0.970020592212677\n" 1992 | ] 1993 | }, 1994 | { 1995 | "data": { 1996 | "application/vnd.jupyter.widget-view+json": { 1997 | "model_id": "ce9f93deea264f6eb860b22473503444", 1998 | "version_major": 2, 1999 | "version_minor": 0 2000 | }, 2001 | "text/plain": [ 2002 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2003 | ] 2004 | }, 2005 | "metadata": {}, 2006 | "output_type": "display_data" 2007 | }, 2008 | { 2009 | "name": "stdout", 2010 | "output_type": "stream", 2011 | "text": [ 2012 | "\n", 2013 | "Epoch:41,val,Loss:[1.078],prec[76.8026],domain_acc[69.6777]\n", 2014 | "Epoch:42[200/469],Loss:[1.388,1.379],domain loss:[0.692,0.665],label loss:[0.031,0.044],prec[99.2188,98.5899],alpha:0.971668004989624\n", 2015 | "Epoch:42[400/469],Loss:[1.337,1.374],domain loss:[0.611,0.664],label loss:[0.029,0.044],prec[98.4375,98.6700],alpha:0.9728347659111023\n" 2016 | ] 2017 | }, 2018 | { 2019 | "data": { 2020 | "application/vnd.jupyter.widget-view+json": { 2021 | "model_id": "751fd951e1f740309902dd2decce7aed", 2022 | "version_major": 2, 2023 | "version_minor": 0 2024 | }, 2025 | "text/plain": [ 2026 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2027 | ] 2028 | }, 2029 | "metadata": {}, 2030 | "output_type": "display_data" 2031 | }, 2032 | { 2033 | "name": "stdout", 2034 | "output_type": "stream", 2035 | "text": [ 2036 | "\n", 2037 | "Epoch:42,val,Loss:[0.972],prec[76.1915],domain_acc[69.5295]\n", 2038 | "Epoch:43[200/469],Loss:[1.393,1.359],domain loss:[0.705,0.668],label loss:[0.048,0.047],prec[98.4375,98.5898],alpha:0.9743295311927795\n", 2039 | "Epoch:43[400/469],Loss:[1.407,1.379],domain loss:[0.633,0.673],label loss:[0.065,0.050],prec[98.4375,98.4435],alpha:0.9753880500793457\n" 2040 | ] 2041 | }, 2042 | { 2043 | "data": { 2044 | "application/vnd.jupyter.widget-view+json": { 2045 | "model_id": "2fa9a0752eb34fd1b6d2fe487ddac337", 2046 | "version_major": 2, 2047 | "version_minor": 0 2048 | }, 2049 | "text/plain": [ 2050 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2051 | ] 2052 | }, 2053 | "metadata": {}, 2054 | "output_type": "display_data" 2055 | }, 2056 | { 2057 | "name": "stdout", 2058 | "output_type": "stream", 2059 | "text": [ 2060 | "\n", 2061 | "Epoch:43,val,Loss:[0.850],prec[77.3803],domain_acc[69.6466]\n", 2062 | "Epoch:44[200/469],Loss:[1.440,1.383],domain loss:[0.661,0.670],label loss:[0.060,0.044],prec[96.8750,98.6211],alpha:0.9767439961433411\n", 2063 | "Epoch:44[400/469],Loss:[1.424,1.380],domain loss:[0.641,0.667],label loss:[0.072,0.050],prec[98.4375,98.4454],alpha:0.9777040481567383\n" 2064 | ] 2065 | }, 2066 | { 2067 | "data": { 2068 | "application/vnd.jupyter.widget-view+json": { 2069 | "model_id": "db0519c3db214db7909fcb61b18cd7dc", 2070 | "version_major": 2, 2071 | "version_minor": 0 2072 | }, 2073 | "text/plain": [ 2074 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2075 | ] 2076 | }, 2077 | "metadata": {}, 2078 | "output_type": "display_data" 2079 | }, 2080 | { 2081 | "name": "stdout", 2082 | "output_type": "stream", 2083 | "text": [ 2084 | "\n", 2085 | "Epoch:44,val,Loss:[0.961],prec[74.6139],domain_acc[69.9088]\n", 2086 | "Epoch:45[200/469],Loss:[1.290,1.385],domain loss:[0.677,0.680],label loss:[0.024,0.040],prec[100.0000,98.6875],alpha:0.9789338111877441\n", 2087 | "Epoch:45[400/469],Loss:[1.395,1.394],domain loss:[0.682,0.672],label loss:[0.040,0.044],prec[99.2188,98.6173],alpha:0.9798043966293335\n" 2088 | ] 2089 | }, 2090 | { 2091 | "data": { 2092 | "application/vnd.jupyter.widget-view+json": { 2093 | "model_id": "8efca1e67880409bad1047b1575a7d1f", 2094 | "version_major": 2, 2095 | "version_minor": 0 2096 | }, 2097 | "text/plain": [ 2098 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2099 | ] 2100 | }, 2101 | "metadata": {}, 2102 | "output_type": "display_data" 2103 | }, 2104 | { 2105 | "name": "stdout", 2106 | "output_type": "stream", 2107 | "text": [ 2108 | "\n", 2109 | "Epoch:45,val,Loss:[1.007],prec[77.3692],domain_acc[70.1354]\n", 2110 | "Epoch:46[200/469],Loss:[1.347,1.399],domain loss:[0.709,0.683],label loss:[0.020,0.035],prec[99.2188,98.9259],alpha:0.9809194207191467\n", 2111 | "Epoch:46[400/469],Loss:[1.384,1.396],domain loss:[0.620,0.676],label loss:[0.005,0.038],prec[100.0000,98.8009],alpha:0.9817086458206177\n" 2112 | ] 2113 | }, 2114 | { 2115 | "data": { 2116 | "application/vnd.jupyter.widget-view+json": { 2117 | "model_id": "743246d9cdd44a9baff7cd812bb5ba32", 2118 | "version_major": 2, 2119 | "version_minor": 0 2120 | }, 2121 | "text/plain": [ 2122 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2123 | ] 2124 | }, 2125 | "metadata": {}, 2126 | "output_type": "display_data" 2127 | }, 2128 | { 2129 | "name": "stdout", 2130 | "output_type": "stream", 2131 | "text": [ 2132 | "\n", 2133 | "Epoch:46,val,Loss:[1.063],prec[75.3916],domain_acc[69.5994]\n", 2134 | "Epoch:47[200/469],Loss:[1.356,1.400],domain loss:[0.640,0.686],label loss:[0.005,0.039],prec[100.0000,98.7422],alpha:0.9827194809913635\n", 2135 | "Epoch:47[400/469],Loss:[1.326,1.395],domain loss:[0.675,0.682],label loss:[0.020,0.040],prec[99.2188,98.7462],alpha:0.9834349155426025\n" 2136 | ] 2137 | }, 2138 | { 2139 | "data": { 2140 | "application/vnd.jupyter.widget-view+json": { 2141 | "model_id": "e698252b64f44318ab0c91383b290b85", 2142 | "version_major": 2, 2143 | "version_minor": 0 2144 | }, 2145 | "text/plain": [ 2146 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2147 | ] 2148 | }, 2149 | "metadata": {}, 2150 | "output_type": "display_data" 2151 | }, 2152 | { 2153 | "name": "stdout", 2154 | "output_type": "stream", 2155 | "text": [ 2156 | "\n", 2157 | "Epoch:47,val,Loss:[1.183],prec[75.9471],domain_acc[69.7166]\n", 2158 | "Epoch:48[200/469],Loss:[1.408,1.402],domain loss:[0.653,0.668],label loss:[0.072,0.037],prec[96.8750,98.8985],alpha:0.9843510389328003\n", 2159 | "Epoch:48[400/469],Loss:[1.388,1.400],domain loss:[0.690,0.680],label loss:[0.028,0.038],prec[99.2188,98.8537],alpha:0.9849994778633118\n" 2160 | ] 2161 | }, 2162 | { 2163 | "data": { 2164 | "application/vnd.jupyter.widget-view+json": { 2165 | "model_id": "8d75cff6e9b84dbeb1875f95d1b4b8b6", 2166 | "version_major": 2, 2167 | "version_minor": 0 2168 | }, 2169 | "text/plain": [ 2170 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2171 | ] 2172 | }, 2173 | "metadata": {}, 2174 | "output_type": "display_data" 2175 | }, 2176 | { 2177 | "name": "stdout", 2178 | "output_type": "stream", 2179 | "text": [ 2180 | "\n", 2181 | "Epoch:48,val,Loss:[1.001],prec[74.5695],domain_acc[69.6422]\n", 2182 | "Epoch:49[200/469],Loss:[1.408,1.399],domain loss:[0.704,0.683],label loss:[0.015,0.041],prec[100.0000,98.7031],alpha:0.9858297109603882\n", 2183 | "Epoch:49[400/469],Loss:[1.462,1.399],domain loss:[0.521,0.680],label loss:[0.075,0.041],prec[98.4375,98.7384],alpha:0.9864172339439392\n" 2184 | ] 2185 | }, 2186 | { 2187 | "data": { 2188 | "application/vnd.jupyter.widget-view+json": { 2189 | "model_id": "d9c37a7dc2ac4552b930e8bc01976b7b", 2190 | "version_major": 2, 2191 | "version_minor": 0 2192 | }, 2193 | "text/plain": [ 2194 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2195 | ] 2196 | }, 2197 | "metadata": {}, 2198 | "output_type": "display_data" 2199 | }, 2200 | { 2201 | "name": "stdout", 2202 | "output_type": "stream", 2203 | "text": [ 2204 | "\n", 2205 | "Epoch:49,val,Loss:[0.931],prec[77.0581],domain_acc[68.3082]\n", 2206 | "Epoch:50[200/469],Loss:[1.458,1.377],domain loss:[0.686,0.704],label loss:[0.077,0.034],prec[98.4375,98.9336],alpha:0.9871695041656494\n", 2207 | "Epoch:50[400/469],Loss:[1.392,1.396],domain loss:[0.658,0.690],label loss:[0.033,0.035],prec[99.2188,98.8575],alpha:0.9877018928527832\n" 2208 | ] 2209 | }, 2210 | { 2211 | "data": { 2212 | "application/vnd.jupyter.widget-view+json": { 2213 | "model_id": "75a5b26456be41cf8bfe887a15ef32fe", 2214 | "version_major": 2, 2215 | "version_minor": 0 2216 | }, 2217 | "text/plain": [ 2218 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2219 | ] 2220 | }, 2221 | "metadata": {}, 2222 | "output_type": "display_data" 2223 | }, 2224 | { 2225 | "name": "stdout", 2226 | "output_type": "stream", 2227 | "text": [ 2228 | "\n", 2229 | "Epoch:50,val,Loss:[1.146],prec[75.4472],domain_acc[68.4752]\n", 2230 | "Epoch:51[200/469],Loss:[1.377,1.402],domain loss:[0.651,0.685],label loss:[0.009,0.030],prec[100.0000,99.0196],alpha:0.988383412361145\n", 2231 | "Epoch:51[400/469],Loss:[1.443,1.405],domain loss:[0.708,0.687],label loss:[0.087,0.033],prec[99.2188,98.9768],alpha:0.9888656735420227\n" 2232 | ] 2233 | }, 2234 | { 2235 | "data": { 2236 | "application/vnd.jupyter.widget-view+json": { 2237 | "model_id": "717d755e43ee4691b572ba3681a76a28", 2238 | "version_major": 2, 2239 | "version_minor": 0 2240 | }, 2241 | "text/plain": [ 2242 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2243 | ] 2244 | }, 2245 | "metadata": {}, 2246 | "output_type": "display_data" 2247 | }, 2248 | { 2249 | "name": "stdout", 2250 | "output_type": "stream", 2251 | "text": [ 2252 | "\n", 2253 | "Epoch:51,val,Loss:[1.015],prec[77.7914],domain_acc[68.4930]\n", 2254 | "Epoch:52[200/469],Loss:[1.310,1.381],domain loss:[0.632,0.679],label loss:[0.002,0.041],prec[100.0000,98.7461],alpha:0.9894830584526062\n", 2255 | "Epoch:52[400/469],Loss:[1.392,1.391],domain loss:[0.706,0.681],label loss:[0.038,0.038],prec[99.2188,98.8615],alpha:0.989919900894165\n" 2256 | ] 2257 | }, 2258 | { 2259 | "data": { 2260 | "application/vnd.jupyter.widget-view+json": { 2261 | "model_id": "cf8a242d8f42413fb9837a375de203b1", 2262 | "version_major": 2, 2263 | "version_minor": 0 2264 | }, 2265 | "text/plain": [ 2266 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2267 | ] 2268 | }, 2269 | "metadata": {}, 2270 | "output_type": "display_data" 2271 | }, 2272 | { 2273 | "name": "stdout", 2274 | "output_type": "stream", 2275 | "text": [ 2276 | "\n", 2277 | "Epoch:52,val,Loss:[0.988],prec[77.6691],domain_acc[68.6668]\n", 2278 | "Epoch:53[200/469],Loss:[1.410,1.397],domain loss:[0.738,0.686],label loss:[0.030,0.025],prec[99.2188,99.2032],alpha:0.9904791116714478\n", 2279 | "Epoch:53[400/469],Loss:[1.373,1.397],domain loss:[0.686,0.685],label loss:[0.019,0.031],prec[99.2188,99.0295],alpha:0.9908747673034668\n" 2280 | ] 2281 | }, 2282 | { 2283 | "data": { 2284 | "application/vnd.jupyter.widget-view+json": { 2285 | "model_id": "28f937493dfb474dbef80ca2695ca7fa", 2286 | "version_major": 2, 2287 | "version_minor": 0 2288 | }, 2289 | "text/plain": [ 2290 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2291 | ] 2292 | }, 2293 | "metadata": {}, 2294 | "output_type": "display_data" 2295 | }, 2296 | { 2297 | "name": "stdout", 2298 | "output_type": "stream", 2299 | "text": [ 2300 | "\n", 2301 | "Epoch:53,val,Loss:[1.069],prec[75.2805],domain_acc[68.4910]\n", 2302 | "Epoch:54[200/469],Loss:[1.453,1.401],domain loss:[0.719,0.685],label loss:[0.064,0.035],prec[98.4375,98.8438],alpha:0.9913812279701233\n", 2303 | "Epoch:54[400/469],Loss:[1.480,1.400],domain loss:[0.754,0.681],label loss:[0.097,0.040],prec[96.8750,98.7131],alpha:0.9917395710945129\n" 2304 | ] 2305 | }, 2306 | { 2307 | "data": { 2308 | "application/vnd.jupyter.widget-view+json": { 2309 | "model_id": "e8f7c1654f1848b5a2bbf12590124665", 2310 | "version_major": 2, 2311 | "version_minor": 0 2312 | }, 2313 | "text/plain": [ 2314 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2315 | ] 2316 | }, 2317 | "metadata": {}, 2318 | "output_type": "display_data" 2319 | }, 2320 | { 2321 | "name": "stdout", 2322 | "output_type": "stream", 2323 | "text": [ 2324 | "\n", 2325 | "Epoch:54,val,Loss:[1.096],prec[75.6583],domain_acc[68.5237]\n", 2326 | "Epoch:55[200/469],Loss:[1.344,1.401],domain loss:[0.675,0.678],label loss:[0.019,0.035],prec[99.2188,98.9336],alpha:0.9921982288360596\n", 2327 | "Epoch:55[400/469],Loss:[1.430,1.401],domain loss:[0.656,0.682],label loss:[0.046,0.035],prec[97.6562,98.9611],alpha:0.9925227165222168\n" 2328 | ] 2329 | }, 2330 | { 2331 | "data": { 2332 | "application/vnd.jupyter.widget-view+json": { 2333 | "model_id": "eb2cf35e20cc4238a0e88be117d0474c", 2334 | "version_major": 2, 2335 | "version_minor": 0 2336 | }, 2337 | "text/plain": [ 2338 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2339 | ] 2340 | }, 2341 | "metadata": {}, 2342 | "output_type": "display_data" 2343 | }, 2344 | { 2345 | "name": "stdout", 2346 | "output_type": "stream", 2347 | "text": [ 2348 | "\n", 2349 | "Epoch:55,val,Loss:[0.854],prec[78.0358],domain_acc[67.6546]\n", 2350 | "Epoch:56[200/469],Loss:[1.389,1.368],domain loss:[0.619,0.678],label loss:[0.037,0.028],prec[98.4375,99.1484],alpha:0.9929380416870117\n", 2351 | "Epoch:56[400/469],Loss:[1.428,1.387],domain loss:[0.670,0.681],label loss:[0.055,0.030],prec[97.6562,99.0392],alpha:0.9932318925857544\n" 2352 | ] 2353 | }, 2354 | { 2355 | "data": { 2356 | "application/vnd.jupyter.widget-view+json": { 2357 | "model_id": "06f7461b92244f899858041d49468775", 2358 | "version_major": 2, 2359 | "version_minor": 0 2360 | }, 2361 | "text/plain": [ 2362 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2363 | ] 2364 | }, 2365 | "metadata": {}, 2366 | "output_type": "display_data" 2367 | }, 2368 | { 2369 | "name": "stdout", 2370 | "output_type": "stream", 2371 | "text": [ 2372 | "\n", 2373 | "Epoch:56,val,Loss:[0.887],prec[76.8803],domain_acc[67.8876]\n", 2374 | "Epoch:57[200/469],Loss:[1.395,1.386],domain loss:[0.659,0.680],label loss:[0.051,0.031],prec[97.6562,99.0156],alpha:0.9936079382896423\n", 2375 | "Epoch:57[400/469],Loss:[1.306,1.392],domain loss:[0.588,0.675],label loss:[0.044,0.037],prec[97.6562,98.8556],alpha:0.9938739538192749\n" 2376 | ] 2377 | }, 2378 | { 2379 | "data": { 2380 | "application/vnd.jupyter.widget-view+json": { 2381 | "model_id": "2fabcdae921c4590a13465262fbc7d04", 2382 | "version_major": 2, 2383 | "version_minor": 0 2384 | }, 2385 | "text/plain": [ 2386 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2387 | ] 2388 | }, 2389 | "metadata": {}, 2390 | "output_type": "display_data" 2391 | }, 2392 | { 2393 | "name": "stdout", 2394 | "output_type": "stream", 2395 | "text": [ 2396 | "\n", 2397 | "Epoch:57,val,Loss:[0.955],prec[76.1138],domain_acc[67.2845]\n", 2398 | "Epoch:58[200/469],Loss:[1.356,1.386],domain loss:[0.668,0.685],label loss:[0.017,0.037],prec[99.2188,98.8789],alpha:0.9942144751548767\n", 2399 | "Epoch:58[400/469],Loss:[1.368,1.395],domain loss:[0.637,0.686],label loss:[0.007,0.042],prec[100.0000,98.7540],alpha:0.9944553375244141\n" 2400 | ] 2401 | }, 2402 | { 2403 | "data": { 2404 | "application/vnd.jupyter.widget-view+json": { 2405 | "model_id": "185bfcf21a5d4caca7df9408a76da7d5", 2406 | "version_major": 2, 2407 | "version_minor": 0 2408 | }, 2409 | "text/plain": [ 2410 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2411 | ] 2412 | }, 2413 | "metadata": {}, 2414 | "output_type": "display_data" 2415 | }, 2416 | { 2417 | "name": "stdout", 2418 | "output_type": "stream", 2419 | "text": [ 2420 | "\n", 2421 | "Epoch:58,val,Loss:[1.009],prec[76.2693],domain_acc[66.7252]\n", 2422 | "Epoch:59[200/469],Loss:[1.336,1.382],domain loss:[0.621,0.679],label loss:[0.009,0.030],prec[100.0000,99.0860],alpha:0.9947636127471924\n", 2423 | "Epoch:59[400/469],Loss:[1.442,1.386],domain loss:[0.716,0.682],label loss:[0.098,0.031],prec[96.0938,98.9963],alpha:0.9949816465377808\n" 2424 | ] 2425 | }, 2426 | { 2427 | "data": { 2428 | "application/vnd.jupyter.widget-view+json": { 2429 | "model_id": "122df8c3d8f84f1bae68eb0afb9f0210", 2430 | "version_major": 2, 2431 | "version_minor": 0 2432 | }, 2433 | "text/plain": [ 2434 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2435 | ] 2436 | }, 2437 | "metadata": {}, 2438 | "output_type": "display_data" 2439 | }, 2440 | { 2441 | "name": "stdout", 2442 | "output_type": "stream", 2443 | "text": [ 2444 | "\n", 2445 | "Epoch:59,val,Loss:[1.003],prec[77.1914],domain_acc[66.6594]\n", 2446 | "Epoch:60[200/469],Loss:[1.270,1.497],domain loss:[0.593,0.691],label loss:[0.047,0.128],prec[98.4375,96.3789],alpha:0.9952607154846191\n", 2447 | "Epoch:60[400/469],Loss:[1.309,1.406],domain loss:[0.609,0.666],label loss:[0.069,0.093],prec[99.2188,97.3869],alpha:0.9954581260681152\n" 2448 | ] 2449 | }, 2450 | { 2451 | "data": { 2452 | "application/vnd.jupyter.widget-view+json": { 2453 | "model_id": "8a4b5e577f90464ca3c6fcf1dd438be2", 2454 | "version_major": 2, 2455 | "version_minor": 0 2456 | }, 2457 | "text/plain": [ 2458 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2459 | ] 2460 | }, 2461 | "metadata": {}, 2462 | "output_type": "display_data" 2463 | }, 2464 | { 2465 | "name": "stdout", 2466 | "output_type": "stream", 2467 | "text": [ 2468 | "\n", 2469 | "Epoch:60,val,Loss:[1.226],prec[76.1249],domain_acc[65.8433]\n", 2470 | "Epoch:61[200/469],Loss:[1.397,1.365],domain loss:[0.680,0.682],label loss:[0.019,0.040],prec[99.2188,98.8047],alpha:0.9957107305526733\n", 2471 | "Epoch:61[400/469],Loss:[1.417,1.376],domain loss:[0.691,0.673],label loss:[0.046,0.040],prec[97.6562,98.7735],alpha:0.99588942527771\n" 2472 | ] 2473 | }, 2474 | { 2475 | "data": { 2476 | "application/vnd.jupyter.widget-view+json": { 2477 | "model_id": "90699f1479a946b3914f214096922a76", 2478 | "version_major": 2, 2479 | "version_minor": 0 2480 | }, 2481 | "text/plain": [ 2482 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2483 | ] 2484 | }, 2485 | "metadata": {}, 2486 | "output_type": "display_data" 2487 | }, 2488 | { 2489 | "name": "stdout", 2490 | "output_type": "stream", 2491 | "text": [ 2492 | "\n", 2493 | "Epoch:61,val,Loss:[0.785],prec[78.3135],domain_acc[65.3096]\n", 2494 | "Epoch:62[200/469],Loss:[1.357,1.382],domain loss:[0.637,0.682],label loss:[0.008,0.037],prec[100.0000,98.9414],alpha:0.9961181282997131\n", 2495 | "Epoch:62[400/469],Loss:[1.434,1.387],domain loss:[0.689,0.681],label loss:[0.047,0.036],prec[98.4375,98.8986],alpha:0.9962798953056335\n" 2496 | ] 2497 | }, 2498 | { 2499 | "data": { 2500 | "application/vnd.jupyter.widget-view+json": { 2501 | "model_id": "a7131fbd67894a359681a4d610ff9c3a", 2502 | "version_major": 2, 2503 | "version_minor": 0 2504 | }, 2505 | "text/plain": [ 2506 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2507 | ] 2508 | }, 2509 | "metadata": {}, 2510 | "output_type": "display_data" 2511 | }, 2512 | { 2513 | "name": "stdout", 2514 | "output_type": "stream", 2515 | "text": [ 2516 | "\n", 2517 | "Epoch:62,val,Loss:[0.867],prec[78.3135],domain_acc[65.4645]\n", 2518 | "Epoch:63[200/469],Loss:[1.364,1.394],domain loss:[0.688,0.687],label loss:[0.007,0.026],prec[100.0000,99.1915],alpha:0.9964869022369385\n", 2519 | "Epoch:63[400/469],Loss:[1.405,1.397],domain loss:[0.641,0.683],label loss:[0.006,0.028],prec[100.0000,99.1154],alpha:0.9966332912445068\n" 2520 | ] 2521 | }, 2522 | { 2523 | "data": { 2524 | "application/vnd.jupyter.widget-view+json": { 2525 | "model_id": "8cb19c2831ca49e9a6e7dada323c9523", 2526 | "version_major": 2, 2527 | "version_minor": 0 2528 | }, 2529 | "text/plain": [ 2530 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2531 | ] 2532 | }, 2533 | "metadata": {}, 2534 | "output_type": "display_data" 2535 | }, 2536 | { 2537 | "name": "stdout", 2538 | "output_type": "stream", 2539 | "text": [ 2540 | "\n", 2541 | "Epoch:63,val,Loss:[0.790],prec[79.5800],domain_acc[65.6970]\n", 2542 | "Epoch:64[200/469],Loss:[1.436,1.387],domain loss:[0.702,0.676],label loss:[0.099,0.024],prec[98.4375,99.2696],alpha:0.9968206882476807\n", 2543 | "Epoch:64[400/469],Loss:[1.397,1.393],domain loss:[0.609,0.684],label loss:[0.032,0.029],prec[99.2188,99.0939],alpha:0.9969531893730164\n" 2544 | ] 2545 | }, 2546 | { 2547 | "data": { 2548 | "application/vnd.jupyter.widget-view+json": { 2549 | "model_id": "404f11e67dd24c6c802b594653ae5ea7", 2550 | "version_major": 2, 2551 | "version_minor": 0 2552 | }, 2553 | "text/plain": [ 2554 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2555 | ] 2556 | }, 2557 | "metadata": {}, 2558 | "output_type": "display_data" 2559 | }, 2560 | { 2561 | "name": "stdout", 2562 | "output_type": "stream", 2563 | "text": [ 2564 | "\n", 2565 | "Epoch:64,val,Loss:[0.901],prec[78.1691],domain_acc[65.8332]\n", 2566 | "Epoch:65[200/469],Loss:[1.374,1.396],domain loss:[0.671,0.680],label loss:[0.027,0.036],prec[98.4375,98.8477],alpha:0.9971228241920471\n", 2567 | "Epoch:65[400/469],Loss:[1.376,1.396],domain loss:[0.672,0.688],label loss:[0.026,0.029],prec[99.2188,99.0724],alpha:0.9972427487373352\n" 2568 | ] 2569 | }, 2570 | { 2571 | "data": { 2572 | "application/vnd.jupyter.widget-view+json": { 2573 | "model_id": "1d9a6da8694f47ddbf9fe95770c32737", 2574 | "version_major": 2, 2575 | "version_minor": 0 2576 | }, 2577 | "text/plain": [ 2578 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2579 | ] 2580 | }, 2581 | "metadata": {}, 2582 | "output_type": "display_data" 2583 | }, 2584 | { 2585 | "name": "stdout", 2586 | "output_type": "stream", 2587 | "text": [ 2588 | "\n", 2589 | "Epoch:65,val,Loss:[0.950],prec[77.6025],domain_acc[66.1125]\n", 2590 | "Epoch:66[200/469],Loss:[1.402,1.388],domain loss:[0.632,0.676],label loss:[0.050,0.025],prec[98.4375,99.1719],alpha:0.9973962306976318\n", 2591 | "Epoch:66[400/469],Loss:[1.427,1.393],domain loss:[0.701,0.683],label loss:[0.031,0.028],prec[99.2188,99.0822],alpha:0.9975048303604126\n" 2592 | ] 2593 | }, 2594 | { 2595 | "data": { 2596 | "application/vnd.jupyter.widget-view+json": { 2597 | "model_id": "d48c52a1f896423a927cc77bf3b4ef99", 2598 | "version_major": 2, 2599 | "version_minor": 0 2600 | }, 2601 | "text/plain": [ 2602 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2603 | ] 2604 | }, 2605 | "metadata": {}, 2606 | "output_type": "display_data" 2607 | }, 2608 | { 2609 | "name": "stdout", 2610 | "output_type": "stream", 2611 | "text": [ 2612 | "\n", 2613 | "Epoch:66,val,Loss:[0.899],prec[78.4468],domain_acc[66.2366]\n", 2614 | "Epoch:67[200/469],Loss:[1.457,1.396],domain loss:[0.718,0.691],label loss:[0.081,0.026],prec[99.2188,99.1328],alpha:0.9976437091827393\n", 2615 | "Epoch:67[400/469],Loss:[1.398,1.399],domain loss:[0.689,0.687],label loss:[0.013,0.025],prec[99.2188,99.1818],alpha:0.9977419972419739\n" 2616 | ] 2617 | }, 2618 | { 2619 | "data": { 2620 | "application/vnd.jupyter.widget-view+json": { 2621 | "model_id": "cba6daac69bb4193a21f009726d2b848", 2622 | "version_major": 2, 2623 | "version_minor": 0 2624 | }, 2625 | "text/plain": [ 2626 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2627 | ] 2628 | }, 2629 | "metadata": {}, 2630 | "output_type": "display_data" 2631 | }, 2632 | { 2633 | "name": "stdout", 2634 | "output_type": "stream", 2635 | "text": [ 2636 | "\n", 2637 | "Epoch:67,val,Loss:[0.867],prec[79.5578],domain_acc[65.6435]\n", 2638 | "Epoch:68[200/469],Loss:[1.405,1.388],domain loss:[0.697,0.687],label loss:[0.055,0.024],prec[98.4375,99.2266],alpha:0.9978677034378052\n", 2639 | "Epoch:68[400/469],Loss:[1.365,1.386],domain loss:[0.635,0.686],label loss:[0.035,0.025],prec[99.2188,99.1682],alpha:0.9979566335678101\n" 2640 | ] 2641 | }, 2642 | { 2643 | "data": { 2644 | "application/vnd.jupyter.widget-view+json": { 2645 | "model_id": "45face23047d4da69e78621fdd7bac60", 2646 | "version_major": 2, 2647 | "version_minor": 0 2648 | }, 2649 | "text/plain": [ 2650 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2651 | ] 2652 | }, 2653 | "metadata": {}, 2654 | "output_type": "display_data" 2655 | }, 2656 | { 2657 | "name": "stdout", 2658 | "output_type": "stream", 2659 | "text": [ 2660 | "\n", 2661 | "Epoch:68,val,Loss:[0.813],prec[79.6911],domain_acc[65.3914]\n", 2662 | "Epoch:69[200/469],Loss:[1.357,1.388],domain loss:[0.700,0.685],label loss:[0.019,0.028],prec[99.2188,99.1445],alpha:0.9980704188346863\n", 2663 | "Epoch:69[400/469],Loss:[1.396,1.390],domain loss:[0.639,0.686],label loss:[0.027,0.028],prec[99.2188,99.1329],alpha:0.9981509447097778\n" 2664 | ] 2665 | }, 2666 | { 2667 | "data": { 2668 | "application/vnd.jupyter.widget-view+json": { 2669 | "model_id": "b94700945ef94b2ba48ef44919257175", 2670 | "version_major": 2, 2671 | "version_minor": 0 2672 | }, 2673 | "text/plain": [ 2674 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2675 | ] 2676 | }, 2677 | "metadata": {}, 2678 | "output_type": "display_data" 2679 | }, 2680 | { 2681 | "name": "stdout", 2682 | "output_type": "stream", 2683 | "text": [ 2684 | "\n", 2685 | "Epoch:69,val,Loss:[0.873],prec[78.9579],domain_acc[65.4430]\n", 2686 | "Epoch:70[200/469],Loss:[1.300,1.391],domain loss:[0.680,0.679],label loss:[0.012,0.032],prec[100.0000,99.0469],alpha:0.9982538819313049\n", 2687 | "Epoch:70[400/469],Loss:[1.335,1.384],domain loss:[0.706,0.676],label loss:[0.007,0.036],prec[100.0000,98.9396],alpha:0.9983267188072205\n" 2688 | ] 2689 | }, 2690 | { 2691 | "data": { 2692 | "application/vnd.jupyter.widget-view+json": { 2693 | "model_id": "5f1967e32d874edf89164ff89defa039", 2694 | "version_major": 2, 2695 | "version_minor": 0 2696 | }, 2697 | "text/plain": [ 2698 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2699 | ] 2700 | }, 2701 | "metadata": {}, 2702 | "output_type": "display_data" 2703 | }, 2704 | { 2705 | "name": "stdout", 2706 | "output_type": "stream", 2707 | "text": [ 2708 | "\n", 2709 | "Epoch:70,val,Loss:[0.893],prec[78.1802],domain_acc[65.6498]\n", 2710 | "Epoch:71[200/469],Loss:[1.329,1.361],domain loss:[0.673,0.663],label loss:[0.051,0.040],prec[99.2188,98.7930],alpha:0.9984199404716492\n", 2711 | "Epoch:71[400/469],Loss:[1.400,1.362],domain loss:[0.656,0.666],label loss:[0.042,0.039],prec[98.4375,98.8165],alpha:0.9984858632087708\n" 2712 | ] 2713 | }, 2714 | { 2715 | "data": { 2716 | "application/vnd.jupyter.widget-view+json": { 2717 | "model_id": "13062029e85648b6b156266527c82d02", 2718 | "version_major": 2, 2719 | "version_minor": 0 2720 | }, 2721 | "text/plain": [ 2722 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2723 | ] 2724 | }, 2725 | "metadata": {}, 2726 | "output_type": "display_data" 2727 | }, 2728 | { 2729 | "name": "stdout", 2730 | "output_type": "stream", 2731 | "text": [ 2732 | "\n", 2733 | "Epoch:71,val,Loss:[0.854],prec[79.0023],domain_acc[65.8379]\n", 2734 | "Epoch:72[200/469],Loss:[1.333,1.375],domain loss:[0.730,0.678],label loss:[0.006,0.037],prec[100.0000,98.9063],alpha:0.9985702037811279\n", 2735 | "Epoch:72[400/469],Loss:[1.376,1.378],domain loss:[0.673,0.674],label loss:[0.009,0.040],prec[100.0000,98.8087],alpha:0.9986298084259033\n" 2736 | ] 2737 | }, 2738 | { 2739 | "data": { 2740 | "application/vnd.jupyter.widget-view+json": { 2741 | "model_id": "956ce65bfd934c30b88b5f51f90f8f27", 2742 | "version_major": 2, 2743 | "version_minor": 0 2744 | }, 2745 | "text/plain": [ 2746 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2747 | ] 2748 | }, 2749 | "metadata": {}, 2750 | "output_type": "display_data" 2751 | }, 2752 | { 2753 | "name": "stdout", 2754 | "output_type": "stream", 2755 | "text": [ 2756 | "\n", 2757 | "Epoch:72,val,Loss:[0.760],prec[80.7021],domain_acc[65.6793]\n", 2758 | "Epoch:73[200/469],Loss:[1.437,1.385],domain loss:[0.700,0.681],label loss:[0.075,0.035],prec[97.6562,98.9024],alpha:0.9987061619758606\n", 2759 | "Epoch:73[400/469],Loss:[1.447,1.381],domain loss:[0.768,0.674],label loss:[0.043,0.041],prec[99.2188,98.7443],alpha:0.9987601637840271\n" 2760 | ] 2761 | }, 2762 | { 2763 | "data": { 2764 | "application/vnd.jupyter.widget-view+json": { 2765 | "model_id": "946d1105c35845acad9f9afe85e9a93c", 2766 | "version_major": 2, 2767 | "version_minor": 0 2768 | }, 2769 | "text/plain": [ 2770 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2771 | ] 2772 | }, 2773 | "metadata": {}, 2774 | "output_type": "display_data" 2775 | }, 2776 | { 2777 | "name": "stdout", 2778 | "output_type": "stream", 2779 | "text": [ 2780 | "\n", 2781 | "Epoch:73,val,Loss:[1.020],prec[76.4248],domain_acc[65.5121]\n", 2782 | "Epoch:74[200/469],Loss:[1.376,1.382],domain loss:[0.783,0.686],label loss:[0.005,0.036],prec[100.0000,98.8008],alpha:0.9988292455673218\n", 2783 | "Epoch:74[400/469],Loss:[1.359,1.377],domain loss:[0.609,0.676],label loss:[0.080,0.038],prec[98.4375,98.8127],alpha:0.9988780617713928\n" 2784 | ] 2785 | }, 2786 | { 2787 | "data": { 2788 | "application/vnd.jupyter.widget-view+json": { 2789 | "model_id": "d0a916d8694245b3a005733a4e4d5fc7", 2790 | "version_major": 2, 2791 | "version_minor": 0 2792 | }, 2793 | "text/plain": [ 2794 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2795 | ] 2796 | }, 2797 | "metadata": {}, 2798 | "output_type": "display_data" 2799 | }, 2800 | { 2801 | "name": "stdout", 2802 | "output_type": "stream", 2803 | "text": [ 2804 | "\n", 2805 | "Epoch:74,val,Loss:[0.997],prec[76.9803],domain_acc[65.4960]\n", 2806 | "Epoch:75[200/469],Loss:[1.444,1.358],domain loss:[0.660,0.675],label loss:[0.033,0.034],prec[98.4375,98.9571],alpha:0.9989405870437622\n", 2807 | "Epoch:75[400/469],Loss:[1.429,1.362],domain loss:[0.699,0.666],label loss:[0.063,0.032],prec[98.4375,99.0275],alpha:0.9989847540855408\n" 2808 | ] 2809 | }, 2810 | { 2811 | "data": { 2812 | "application/vnd.jupyter.widget-view+json": { 2813 | "model_id": "45051f203db4447ba68d11951aed4e6b", 2814 | "version_major": 2, 2815 | "version_minor": 0 2816 | }, 2817 | "text/plain": [ 2818 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2819 | ] 2820 | }, 2821 | "metadata": {}, 2822 | "output_type": "display_data" 2823 | }, 2824 | { 2825 | "name": "stdout", 2826 | "output_type": "stream", 2827 | "text": [ 2828 | "\n", 2829 | "Epoch:75,val,Loss:[0.864],prec[79.2467],domain_acc[65.3131]\n", 2830 | "Epoch:76[200/469],Loss:[1.359,1.377],domain loss:[0.662,0.679],label loss:[0.022,0.034],prec[99.2188,98.9375],alpha:0.9990413188934326\n", 2831 | "Epoch:76[400/469],Loss:[1.353,1.385],domain loss:[0.653,0.682],label loss:[0.010,0.037],prec[99.2188,98.8810],alpha:0.9990813732147217\n" 2832 | ] 2833 | }, 2834 | { 2835 | "data": { 2836 | "application/vnd.jupyter.widget-view+json": { 2837 | "model_id": "44c8d9a3a6a14d56a2caf74a96c4b48a", 2838 | "version_major": 2, 2839 | "version_minor": 0 2840 | }, 2841 | "text/plain": [ 2842 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2843 | ] 2844 | }, 2845 | "metadata": {}, 2846 | "output_type": "display_data" 2847 | }, 2848 | { 2849 | "name": "stdout", 2850 | "output_type": "stream", 2851 | "text": [ 2852 | "\n", 2853 | "Epoch:76,val,Loss:[1.154],prec[75.8027],domain_acc[65.4373]\n", 2854 | "Epoch:77[200/469],Loss:[1.382,1.377],domain loss:[0.690,0.679],label loss:[0.040,0.032],prec[99.2188,98.9570],alpha:0.999132513999939\n", 2855 | "Epoch:77[400/469],Loss:[1.426,1.386],domain loss:[0.677,0.684],label loss:[0.049,0.033],prec[98.4375,98.9728],alpha:0.9991687536239624\n" 2856 | ] 2857 | }, 2858 | { 2859 | "data": { 2860 | "application/vnd.jupyter.widget-view+json": { 2861 | "model_id": "15564fbab8c54a97a97462b184e84870", 2862 | "version_major": 2, 2863 | "version_minor": 0 2864 | }, 2865 | "text/plain": [ 2866 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2867 | ] 2868 | }, 2869 | "metadata": {}, 2870 | "output_type": "display_data" 2871 | }, 2872 | { 2873 | "name": "stdout", 2874 | "output_type": "stream", 2875 | "text": [ 2876 | "\n", 2877 | "Epoch:77,val,Loss:[0.947],prec[78.2580],domain_acc[65.4437]\n", 2878 | "Epoch:78[200/469],Loss:[1.463,1.382],domain loss:[0.666,0.675],label loss:[0.069,0.033],prec[96.8750,99.0469],alpha:0.9992150664329529\n", 2879 | "Epoch:78[400/469],Loss:[1.367,1.382],domain loss:[0.702,0.677],label loss:[0.002,0.031],prec[100.0000,99.0432],alpha:0.9992477893829346\n" 2880 | ] 2881 | }, 2882 | { 2883 | "data": { 2884 | "application/vnd.jupyter.widget-view+json": { 2885 | "model_id": "446fa989b5fa4d16848063f5f05a6aff", 2886 | "version_major": 2, 2887 | "version_minor": 0 2888 | }, 2889 | "text/plain": [ 2890 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2891 | ] 2892 | }, 2893 | "metadata": {}, 2894 | "output_type": "display_data" 2895 | }, 2896 | { 2897 | "name": "stdout", 2898 | "output_type": "stream", 2899 | "text": [ 2900 | "\n", 2901 | "Epoch:78,val,Loss:[1.023],prec[77.1137],domain_acc[65.3608]\n", 2902 | "Epoch:79[200/469],Loss:[1.536,1.374],domain loss:[0.686,0.667],label loss:[0.162,0.035],prec[97.6562,98.9766],alpha:0.9992896914482117\n", 2903 | "Epoch:79[400/469],Loss:[1.404,1.374],domain loss:[0.604,0.674],label loss:[0.054,0.034],prec[98.4375,98.9474],alpha:0.9993193745613098\n" 2904 | ] 2905 | }, 2906 | { 2907 | "data": { 2908 | "application/vnd.jupyter.widget-view+json": { 2909 | "model_id": "32047388425d44d7a7977af5d469b9a1", 2910 | "version_major": 2, 2911 | "version_minor": 0 2912 | }, 2913 | "text/plain": [ 2914 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2915 | ] 2916 | }, 2917 | "metadata": {}, 2918 | "output_type": "display_data" 2919 | }, 2920 | { 2921 | "name": "stdout", 2922 | "output_type": "stream", 2923 | "text": [ 2924 | "\n", 2925 | "Epoch:79,val,Loss:[0.971],prec[77.9580],domain_acc[65.3259]\n", 2926 | "Epoch:80[200/469],Loss:[1.495,1.364],domain loss:[0.709,0.660],label loss:[0.039,0.032],prec[99.2188,99.0274],alpha:0.999357283115387\n", 2927 | "Epoch:80[400/469],Loss:[1.362,1.371],domain loss:[0.709,0.667],label loss:[0.004,0.036],prec[100.0000,98.8966],alpha:0.9993841052055359\n" 2928 | ] 2929 | }, 2930 | { 2931 | "data": { 2932 | "application/vnd.jupyter.widget-view+json": { 2933 | "model_id": "653e04cf24bc4b39b9c789b485ef8b1e", 2934 | "version_major": 2, 2935 | "version_minor": 0 2936 | }, 2937 | "text/plain": [ 2938 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2939 | ] 2940 | }, 2941 | "metadata": {}, 2942 | "output_type": "display_data" 2943 | }, 2944 | { 2945 | "name": "stdout", 2946 | "output_type": "stream", 2947 | "text": [ 2948 | "\n", 2949 | "Epoch:80,val,Loss:[0.902],prec[78.4357],domain_acc[65.2594]\n", 2950 | "Epoch:81[200/469],Loss:[1.364,1.356],domain loss:[0.604,0.656],label loss:[0.039,0.040],prec[98.4375,98.7539],alpha:0.9994184374809265\n", 2951 | "Epoch:81[400/469],Loss:[1.325,1.351],domain loss:[0.713,0.659],label loss:[0.009,0.036],prec[100.0000,98.8459],alpha:0.9994426965713501\n" 2952 | ] 2953 | }, 2954 | { 2955 | "data": { 2956 | "application/vnd.jupyter.widget-view+json": { 2957 | "model_id": "37cfbfbc29384a45b1868b3aed0216e8", 2958 | "version_major": 2, 2959 | "version_minor": 0 2960 | }, 2961 | "text/plain": [ 2962 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2963 | ] 2964 | }, 2965 | "metadata": {}, 2966 | "output_type": "display_data" 2967 | }, 2968 | { 2969 | "name": "stdout", 2970 | "output_type": "stream", 2971 | "text": [ 2972 | "\n", 2973 | "Epoch:81,val,Loss:[0.820],prec[79.7023],domain_acc[65.2285]\n", 2974 | "Epoch:82[200/469],Loss:[1.362,1.366],domain loss:[0.671,0.674],label loss:[0.043,0.037],prec[98.4375,98.8555],alpha:0.9994737505912781\n", 2975 | "Epoch:82[400/469],Loss:[1.322,1.373],domain loss:[0.675,0.672],label loss:[0.035,0.040],prec[99.2188,98.7677],alpha:0.9994957447052002\n" 2976 | ] 2977 | }, 2978 | { 2979 | "data": { 2980 | "application/vnd.jupyter.widget-view+json": { 2981 | "model_id": "764882d0b8ea44c0969ced0a8ef8af0d", 2982 | "version_major": 2, 2983 | "version_minor": 0 2984 | }, 2985 | "text/plain": [ 2986 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 2987 | ] 2988 | }, 2989 | "metadata": {}, 2990 | "output_type": "display_data" 2991 | }, 2992 | { 2993 | "name": "stdout", 2994 | "output_type": "stream", 2995 | "text": [ 2996 | "\n", 2997 | "Epoch:82,val,Loss:[0.930],prec[78.3580],domain_acc[65.1620]\n", 2998 | "Epoch:83[200/469],Loss:[1.405,1.348],domain loss:[0.659,0.657],label loss:[0.045,0.038],prec[97.6562,98.8945],alpha:0.9995238184928894\n", 2999 | "Epoch:83[400/469],Loss:[1.356,1.364],domain loss:[0.645,0.670],label loss:[0.008,0.041],prec[100.0000,98.7911],alpha:0.9995437264442444\n" 3000 | ] 3001 | }, 3002 | { 3003 | "data": { 3004 | "application/vnd.jupyter.widget-view+json": { 3005 | "model_id": "b90d988b990a45b792b413121bc23893", 3006 | "version_major": 2, 3007 | "version_minor": 0 3008 | }, 3009 | "text/plain": [ 3010 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3011 | ] 3012 | }, 3013 | "metadata": {}, 3014 | "output_type": "display_data" 3015 | }, 3016 | { 3017 | "name": "stdout", 3018 | "output_type": "stream", 3019 | "text": [ 3020 | "\n", 3021 | "Epoch:83,val,Loss:[0.979],prec[77.4247],domain_acc[65.1898]\n", 3022 | "Epoch:84[200/469],Loss:[1.335,1.372],domain loss:[0.679,0.669],label loss:[0.036,0.038],prec[99.2188,98.7734],alpha:0.9995691180229187\n", 3023 | "Epoch:84[400/469],Loss:[1.263,1.377],domain loss:[0.619,0.669],label loss:[0.006,0.048],prec[100.0000,98.5899],alpha:0.9995871186256409\n" 3024 | ] 3025 | }, 3026 | { 3027 | "data": { 3028 | "application/vnd.jupyter.widget-view+json": { 3029 | "model_id": "d8417f67f6d74c1690e30e589a0f28ed", 3030 | "version_major": 2, 3031 | "version_minor": 0 3032 | }, 3033 | "text/plain": [ 3034 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3035 | ] 3036 | }, 3037 | "metadata": {}, 3038 | "output_type": "display_data" 3039 | }, 3040 | { 3041 | "name": "stdout", 3042 | "output_type": "stream", 3043 | "text": [ 3044 | "\n", 3045 | "Epoch:84,val,Loss:[0.691],prec[81.2243],domain_acc[65.2196]\n", 3046 | "Epoch:85[200/469],Loss:[1.285,1.403],domain loss:[0.634,0.684],label loss:[0.014,0.048],prec[100.0000,98.6211],alpha:0.9996101260185242\n", 3047 | "Epoch:85[400/469],Loss:[1.361,1.395],domain loss:[0.666,0.676],label loss:[0.042,0.051],prec[99.2188,98.5098],alpha:0.9996263980865479\n" 3048 | ] 3049 | }, 3050 | { 3051 | "data": { 3052 | "application/vnd.jupyter.widget-view+json": { 3053 | "model_id": "0d097072bbe04a66bd493c78c7a42388", 3054 | "version_major": 2, 3055 | "version_minor": 0 3056 | }, 3057 | "text/plain": [ 3058 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3059 | ] 3060 | }, 3061 | "metadata": {}, 3062 | "output_type": "display_data" 3063 | }, 3064 | { 3065 | "name": "stdout", 3066 | "output_type": "stream", 3067 | "text": [ 3068 | "\n", 3069 | "Epoch:85,val,Loss:[0.906],prec[77.6136],domain_acc[65.2616]\n", 3070 | "Epoch:86[200/469],Loss:[1.423,1.373],domain loss:[0.752,0.662],label loss:[0.048,0.034],prec[98.4375,98.9531],alpha:0.9996472001075745\n", 3071 | "Epoch:86[400/469],Loss:[1.276,1.374],domain loss:[0.674,0.667],label loss:[0.008,0.038],prec[100.0000,98.8243],alpha:0.999661922454834\n" 3072 | ] 3073 | }, 3074 | { 3075 | "data": { 3076 | "application/vnd.jupyter.widget-view+json": { 3077 | "model_id": "62617f89be7c407082e4d942b630ff81", 3078 | "version_major": 2, 3079 | "version_minor": 0 3080 | }, 3081 | "text/plain": [ 3082 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3083 | ] 3084 | }, 3085 | "metadata": {}, 3086 | "output_type": "display_data" 3087 | }, 3088 | { 3089 | "name": "stdout", 3090 | "output_type": "stream", 3091 | "text": [ 3092 | "\n", 3093 | "Epoch:86,val,Loss:[0.973],prec[78.2135],domain_acc[65.2598]\n", 3094 | "Epoch:87[200/469],Loss:[1.466,1.367],domain loss:[0.668,0.681],label loss:[0.060,0.030],prec[98.4375,99.1094],alpha:0.9996808171272278\n", 3095 | "Epoch:87[400/469],Loss:[1.342,1.375],domain loss:[0.646,0.678],label loss:[0.023,0.032],prec[99.2188,99.0392],alpha:0.9996941089630127\n" 3096 | ] 3097 | }, 3098 | { 3099 | "data": { 3100 | "application/vnd.jupyter.widget-view+json": { 3101 | "model_id": "70133f986c8c47bea735811fa2073cd5", 3102 | "version_major": 2, 3103 | "version_minor": 0 3104 | }, 3105 | "text/plain": [ 3106 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3107 | ] 3108 | }, 3109 | "metadata": {}, 3110 | "output_type": "display_data" 3111 | }, 3112 | { 3113 | "name": "stdout", 3114 | "output_type": "stream", 3115 | "text": [ 3116 | "\n", 3117 | "Epoch:87,val,Loss:[0.862],prec[79.2467],domain_acc[65.5208]\n", 3118 | "Epoch:88[200/469],Loss:[1.309,1.361],domain loss:[0.670,0.666],label loss:[0.002,0.029],prec[100.0000,99.0899],alpha:0.9997111558914185\n", 3119 | "Epoch:88[400/469],Loss:[1.382,1.365],domain loss:[0.699,0.670],label loss:[0.050,0.033],prec[99.2188,98.9571],alpha:0.9997231960296631\n" 3120 | ] 3121 | }, 3122 | { 3123 | "data": { 3124 | "application/vnd.jupyter.widget-view+json": { 3125 | "model_id": "ad58ad4b8a3943ab9be436bd31dc9af2", 3126 | "version_major": 2, 3127 | "version_minor": 0 3128 | }, 3129 | "text/plain": [ 3130 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3131 | ] 3132 | }, 3133 | "metadata": {}, 3134 | "output_type": "display_data" 3135 | }, 3136 | { 3137 | "name": "stdout", 3138 | "output_type": "stream", 3139 | "text": [ 3140 | "\n", 3141 | "Epoch:88,val,Loss:[0.873],prec[79.3134],domain_acc[65.5149]\n", 3142 | "Epoch:89[200/469],Loss:[1.396,1.363],domain loss:[0.728,0.671],label loss:[0.025,0.033],prec[99.2188,98.9649],alpha:0.9997386336326599\n", 3143 | "Epoch:89[400/469],Loss:[1.380,1.367],domain loss:[0.740,0.669],label loss:[0.033,0.035],prec[98.4375,98.9025],alpha:0.9997495412826538\n" 3144 | ] 3145 | }, 3146 | { 3147 | "data": { 3148 | "application/vnd.jupyter.widget-view+json": { 3149 | "model_id": "8b8dcc6f40094ede824a792d446811fd", 3150 | "version_major": 2, 3151 | "version_minor": 0 3152 | }, 3153 | "text/plain": [ 3154 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3155 | ] 3156 | }, 3157 | "metadata": {}, 3158 | "output_type": "display_data" 3159 | }, 3160 | { 3161 | "name": "stdout", 3162 | "output_type": "stream", 3163 | "text": [ 3164 | "\n", 3165 | "Epoch:89,val,Loss:[0.965],prec[77.7691],domain_acc[65.4788]\n", 3166 | "Epoch:90[200/469],Loss:[1.369,1.447],domain loss:[0.689,0.692],label loss:[0.042,0.091],prec[98.4375,97.8516],alpha:0.9997634887695312\n", 3167 | "Epoch:90[400/469],Loss:[1.484,1.404],domain loss:[0.628,0.680],label loss:[0.123,0.069],prec[96.8750,98.2248],alpha:0.999773383140564\n" 3168 | ] 3169 | }, 3170 | { 3171 | "data": { 3172 | "application/vnd.jupyter.widget-view+json": { 3173 | "model_id": "3f8a36d7fe1a40be83e0f0491d267b3e", 3174 | "version_major": 2, 3175 | "version_minor": 0 3176 | }, 3177 | "text/plain": [ 3178 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3179 | ] 3180 | }, 3181 | "metadata": {}, 3182 | "output_type": "display_data" 3183 | }, 3184 | { 3185 | "name": "stdout", 3186 | "output_type": "stream", 3187 | "text": [ 3188 | "\n", 3189 | "Epoch:90,val,Loss:[0.894],prec[78.2913],domain_acc[65.5167]\n", 3190 | "Epoch:91[200/469],Loss:[1.318,1.337],domain loss:[0.655,0.655],label loss:[0.080,0.043],prec[98.4375,98.7071],alpha:0.9997860193252563\n", 3191 | "Epoch:91[400/469],Loss:[1.217,1.349],domain loss:[0.517,0.654],label loss:[0.026,0.056],prec[99.2188,98.4826],alpha:0.9997949600219727\n" 3192 | ] 3193 | }, 3194 | { 3195 | "data": { 3196 | "application/vnd.jupyter.widget-view+json": { 3197 | "model_id": "12bb71af000a4449bb7beec21ffde87a", 3198 | "version_major": 2, 3199 | "version_minor": 0 3200 | }, 3201 | "text/plain": [ 3202 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3203 | ] 3204 | }, 3205 | "metadata": {}, 3206 | "output_type": "display_data" 3207 | }, 3208 | { 3209 | "name": "stdout", 3210 | "output_type": "stream", 3211 | "text": [ 3212 | "\n", 3213 | "Epoch:91,val,Loss:[1.070],prec[76.6359],domain_acc[65.2487]\n", 3214 | "Epoch:92[200/469],Loss:[1.336,1.331],domain loss:[0.750,0.685],label loss:[0.038,0.035],prec[99.2188,98.8243],alpha:0.9998064041137695\n", 3215 | "Epoch:92[400/469],Loss:[1.452,1.349],domain loss:[0.637,0.669],label loss:[0.054,0.038],prec[99.2188,98.8205],alpha:0.9998144507408142\n" 3216 | ] 3217 | }, 3218 | { 3219 | "data": { 3220 | "application/vnd.jupyter.widget-view+json": { 3221 | "model_id": "cc9a7e9f37b34e9594d342d767e057cb", 3222 | "version_major": 2, 3223 | "version_minor": 0 3224 | }, 3225 | "text/plain": [ 3226 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3227 | ] 3228 | }, 3229 | "metadata": {}, 3230 | "output_type": "display_data" 3231 | }, 3232 | { 3233 | "name": "stdout", 3234 | "output_type": "stream", 3235 | "text": [ 3236 | "\n", 3237 | "Epoch:92,val,Loss:[0.914],prec[78.7024],domain_acc[65.2878]\n", 3238 | "Epoch:93[200/469],Loss:[1.397,1.375],domain loss:[0.672,0.673],label loss:[0.033,0.032],prec[99.2188,98.9805],alpha:0.9998248219490051\n", 3239 | "Epoch:93[400/469],Loss:[1.362,1.374],domain loss:[0.587,0.669],label loss:[0.017,0.032],prec[99.2188,99.0197],alpha:0.9998320937156677\n" 3240 | ] 3241 | }, 3242 | { 3243 | "data": { 3244 | "application/vnd.jupyter.widget-view+json": { 3245 | "model_id": "f21a033fd9754e90933c772be66522fe", 3246 | "version_major": 2, 3247 | "version_minor": 0 3248 | }, 3249 | "text/plain": [ 3250 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3251 | ] 3252 | }, 3253 | "metadata": {}, 3254 | "output_type": "display_data" 3255 | }, 3256 | { 3257 | "name": "stdout", 3258 | "output_type": "stream", 3259 | "text": [ 3260 | "\n", 3261 | "Epoch:93,val,Loss:[0.766],prec[81.4798],domain_acc[65.0748]\n", 3262 | "Epoch:94[200/469],Loss:[1.348,1.386],domain loss:[0.710,0.697],label loss:[0.028,0.024],prec[98.4375,99.2578],alpha:0.9998414516448975\n", 3263 | "Epoch:94[400/469],Loss:[1.477,1.380],domain loss:[0.746,0.682],label loss:[0.016,0.029],prec[99.2188,99.0763],alpha:0.9998480677604675\n" 3264 | ] 3265 | }, 3266 | { 3267 | "data": { 3268 | "application/vnd.jupyter.widget-view+json": { 3269 | "model_id": "c19172b9819540c9b8efa78690237a08", 3270 | "version_major": 2, 3271 | "version_minor": 0 3272 | }, 3273 | "text/plain": [ 3274 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3275 | ] 3276 | }, 3277 | "metadata": {}, 3278 | "output_type": "display_data" 3279 | }, 3280 | { 3281 | "name": "stdout", 3282 | "output_type": "stream", 3283 | "text": [ 3284 | "\n", 3285 | "Epoch:94,val,Loss:[0.768],prec[81.1021],domain_acc[65.1909]\n", 3286 | "Epoch:95[200/469],Loss:[1.321,1.361],domain loss:[0.658,0.672],label loss:[0.009,0.027],prec[100.0000,99.1211],alpha:0.9998565316200256\n", 3287 | "Epoch:95[400/469],Loss:[1.347,1.373],domain loss:[0.745,0.678],label loss:[0.021,0.033],prec[99.2188,98.9611],alpha:0.999862551689148\n" 3288 | ] 3289 | }, 3290 | { 3291 | "data": { 3292 | "application/vnd.jupyter.widget-view+json": { 3293 | "model_id": "bcfdd02f29cd484bab08bc47a8681bee", 3294 | "version_major": 2, 3295 | "version_minor": 0 3296 | }, 3297 | "text/plain": [ 3298 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3299 | ] 3300 | }, 3301 | "metadata": {}, 3302 | "output_type": "display_data" 3303 | }, 3304 | { 3305 | "name": "stdout", 3306 | "output_type": "stream", 3307 | "text": [ 3308 | "\n", 3309 | "Epoch:95,val,Loss:[0.907],prec[78.7357],domain_acc[65.2230]\n", 3310 | "Epoch:96[200/469],Loss:[1.330,1.350],domain loss:[0.639,0.667],label loss:[0.084,0.035],prec[99.2188,98.9063],alpha:0.9998701810836792\n", 3311 | "Epoch:96[400/469],Loss:[1.307,1.358],domain loss:[0.583,0.666],label loss:[0.010,0.038],prec[99.2188,98.8381],alpha:0.9998756051063538\n" 3312 | ] 3313 | }, 3314 | { 3315 | "data": { 3316 | "application/vnd.jupyter.widget-view+json": { 3317 | "model_id": "9956822ea2304751bc4d50c8da44a1e9", 3318 | "version_major": 2, 3319 | "version_minor": 0 3320 | }, 3321 | "text/plain": [ 3322 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3323 | ] 3324 | }, 3325 | "metadata": {}, 3326 | "output_type": "display_data" 3327 | }, 3328 | { 3329 | "name": "stdout", 3330 | "output_type": "stream", 3331 | "text": [ 3332 | "\n", 3333 | "Epoch:96,val,Loss:[0.920],prec[78.0469],domain_acc[65.1908]\n", 3334 | "Epoch:97[200/469],Loss:[1.431,1.390],domain loss:[0.698,0.675],label loss:[0.130,0.059],prec[97.6562,98.2735],alpha:0.9998825788497925\n", 3335 | "Epoch:97[400/469],Loss:[1.409,1.375],domain loss:[0.713,0.675],label loss:[0.012,0.052],prec[99.2188,98.4317],alpha:0.9998874664306641\n" 3336 | ] 3337 | }, 3338 | { 3339 | "data": { 3340 | "application/vnd.jupyter.widget-view+json": { 3341 | "model_id": "0f54c00c369a4b168c3072062709cf0a", 3342 | "version_major": 2, 3343 | "version_minor": 0 3344 | }, 3345 | "text/plain": [ 3346 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3347 | ] 3348 | }, 3349 | "metadata": {}, 3350 | "output_type": "display_data" 3351 | }, 3352 | { 3353 | "name": "stdout", 3354 | "output_type": "stream", 3355 | "text": [ 3356 | "\n", 3357 | "Epoch:97,val,Loss:[0.678],prec[81.9020],domain_acc[65.1945]\n", 3358 | "Epoch:98[200/469],Loss:[1.335,1.341],domain loss:[0.721,0.676],label loss:[0.021,0.044],prec[99.2188,98.6758],alpha:0.9998937249183655\n", 3359 | "Epoch:98[400/469],Loss:[1.369,1.345],domain loss:[0.672,0.656],label loss:[0.025,0.042],prec[99.2188,98.7579],alpha:0.9998981952667236\n" 3360 | ] 3361 | }, 3362 | { 3363 | "data": { 3364 | "application/vnd.jupyter.widget-view+json": { 3365 | "model_id": "a35cc82653be4418b5dbfb04a02d6c6f", 3366 | "version_major": 2, 3367 | "version_minor": 0 3368 | }, 3369 | "text/plain": [ 3370 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3371 | ] 3372 | }, 3373 | "metadata": {}, 3374 | "output_type": "display_data" 3375 | }, 3376 | { 3377 | "name": "stdout", 3378 | "output_type": "stream", 3379 | "text": [ 3380 | "\n", 3381 | "Epoch:98,val,Loss:[0.805],prec[80.4911],domain_acc[65.1767]\n", 3382 | "Epoch:99[200/469],Loss:[1.340,1.367],domain loss:[0.604,0.682],label loss:[0.014,0.044],prec[99.2188,98.6680],alpha:0.9999038577079773\n", 3383 | "Epoch:99[400/469],Loss:[1.351,1.372],domain loss:[0.641,0.674],label loss:[0.075,0.044],prec[98.4375,98.6486],alpha:0.9999078512191772\n" 3384 | ] 3385 | }, 3386 | { 3387 | "data": { 3388 | "application/vnd.jupyter.widget-view+json": { 3389 | "model_id": "aad24d1a484041cbbd823349853f346a", 3390 | "version_major": 2, 3391 | "version_minor": 0 3392 | }, 3393 | "text/plain": [ 3394 | "HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))" 3395 | ] 3396 | }, 3397 | "metadata": {}, 3398 | "output_type": "display_data" 3399 | }, 3400 | { 3401 | "name": "stdout", 3402 | "output_type": "stream", 3403 | "text": [ 3404 | "\n", 3405 | "Epoch:99,val,Loss:[0.737],prec[80.9354],domain_acc[65.1802]\n" 3406 | ] 3407 | } 3408 | ], 3409 | "source": [ 3410 | "train_loss=AverageMeter()\n", 3411 | "train_domain_loss=AverageMeter()\n", 3412 | "train_task_loss=AverageMeter()\n", 3413 | "test_loss=AverageMeter()\n", 3414 | "test_top1=AverageMeter()\n", 3415 | "test_domain_acc=AverageMeter()\n", 3416 | "train_top1=AverageMeter()\n", 3417 | "train_cnt=AverageMeter()\n", 3418 | "\n", 3419 | "print_freq=200\n", 3420 | "domain_model=DANN()\n", 3421 | "domain_model.cuda()\n", 3422 | "domain_loss=nn.CrossEntropyLoss()\n", 3423 | "task_loss=nn.CrossEntropyLoss()\n", 3424 | "lr=0.001\n", 3425 | "optimizer=Adam(domain_model.parameters(),lr=lr)\n", 3426 | "epochs=100\n", 3427 | "for epoch in range(epochs):\n", 3428 | "\n", 3429 | " #lr=adjust_learning_rate(optimizer,epoch)\n", 3430 | " writer.add_scalar(\"lr\",lr,epoch)\n", 3431 | " train_loss.reset()\n", 3432 | " train_domain_loss.reset()\n", 3433 | " train_task_loss.reset()\n", 3434 | " train_top1.reset()\n", 3435 | " train_cnt.reset()\n", 3436 | " test_top1.reset()\n", 3437 | " test_loss.reset()\n", 3438 | " for source,target in zip(train_dl,train_m_dl):\n", 3439 | " train_cnt.update(images.size(0),1)\n", 3440 | " p = float(train_cnt.count + epoch * len(train_dl)) / (epochs *len(train_dl))\n", 3441 | " alpha = torch.tensor(2. / (1. + np.exp(-10 * p)) - 1)\n", 3442 | " src_imgs=source[0].cuda()\n", 3443 | " src_labels=source[1].cuda()\n", 3444 | " dst_imgs=target[0].cuda()\n", 3445 | " optimizer.zero_grad()\n", 3446 | " \n", 3447 | " src_predict,src_domains=domain_model(src_imgs,alpha)\n", 3448 | " src_label_loss=task_loss(src_predict,src_labels)\n", 3449 | " src_domain_loss=domain_loss(src_domains,torch.ones(len(src_domains)).long().cuda())\n", 3450 | " \n", 3451 | " _,dst_domains=domain_model(dst_imgs,alpha)\n", 3452 | " dst_domain_loss=domain_loss(dst_domains,torch.zeros(len(dst_domains)).long().cuda())\n", 3453 | " \n", 3454 | " losses=src_label_loss+src_domain_loss+dst_domain_loss\n", 3455 | " \n", 3456 | " train_loss.update(losses.data,images.size(0))\n", 3457 | " train_domain_loss.update(dst_domain_loss.data,images.size(0))\n", 3458 | " train_task_loss.update(src_label_loss.data,images.size(0))\n", 3459 | " top1=accuracy(src_predict.data,src_labels,topk=(1,))[0]\n", 3460 | " train_top1.update(top1,images.size(0))\n", 3461 | " \n", 3462 | " losses.backward()\n", 3463 | " optimizer.step()\n", 3464 | " if train_cnt.count%print_freq==0:\n", 3465 | " print(\"Epoch:{}[{}/{}],Loss:[{:.3f},{:.3f}],domain loss:[{:.3f},{:.3f}],label loss:[{:.3f},{:.3f}],prec[{:.4f},{:.4f}],alpha:{}\".format(\n", 3466 | " epoch,train_cnt.count,len(train_dl),train_loss.val,train_loss.avg,\n", 3467 | " train_domain_loss.val,train_domain_loss.avg,\n", 3468 | " train_task_loss.val,train_task_loss.avg,train_top1.val,train_top1.avg,alpha))\n", 3469 | " for images,labels in tqdm(test_m_dl):\n", 3470 | " images=images.cuda()\n", 3471 | " labels=labels.cuda()\n", 3472 | " predicts,domains=domain_model(images,0)\n", 3473 | " losses=task_loss(predicts,labels)\n", 3474 | " test_loss.update(losses.data,images.size(0))\n", 3475 | " top1=accuracy(predicts.data,labels,topk=(1,))[0]\n", 3476 | " domain_acc=accuracy(domains.data,torch.zeros(len(domains)).long().cuda(),topk=(1,))[0]\n", 3477 | " test_top1.update(top1,images.size(0))\n", 3478 | " test_domain_acc.update(domain_acc,images.size(0))\n", 3479 | " print(\"Epoch:{},val,Loss:[{:.3f}],prec[{:.4f}],domain_acc[{:.4f}]\".format(epoch,test_loss.avg,test_top1.avg,test_domain_acc.avg))\n", 3480 | " writer.add_scalar(\"train_loss\",train_loss.avg,epoch)\n", 3481 | " writer.add_scalar(\"test_loss\",test_loss.avg,epoch) \n", 3482 | " writer.add_scalar(\"train_top1\",train_top1.avg,epoch)\n", 3483 | " writer.add_scalar(\"test_top1\",test_top1.avg,epoch)\n", 3484 | " writer.add_scalar(\"test_domain\",test_domain_acc.avg,epoch)" 3485 | ] 3486 | }, 3487 | { 3488 | "cell_type": "code", 3489 | "execution_count": null, 3490 | "metadata": {}, 3491 | "outputs": [], 3492 | "source": [] 3493 | } 3494 | ], 3495 | "metadata": { 3496 | "kernelspec": { 3497 | "display_name": "Python 3", 3498 | "language": "python", 3499 | "name": "python3" 3500 | }, 3501 | "language_info": { 3502 | "codemirror_mode": { 3503 | "name": "ipython", 3504 | "version": 3 3505 | }, 3506 | "file_extension": ".py", 3507 | "mimetype": "text/x-python", 3508 | "name": "python", 3509 | "nbconvert_exporter": "python", 3510 | "pygments_lexer": "ipython3", 3511 | "version": "3.6.10" 3512 | } 3513 | }, 3514 | "nbformat": 4, 3515 | "nbformat_minor": 4 3516 | } 3517 | -------------------------------------------------------------------------------- /DANN/experiment1/generate_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import random 4 | def generate_data(l,N,): 5 | theta=np.random.random(N)*np.pi 6 | p1=l+np.random.random(N) 7 | p2=l-np.random.random(N) 8 | for i in range(N): 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | plt. -------------------------------------------------------------------------------- /DANN/experiment1/main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/DomainAdaptionPytorch/c06799caa150eab51804f472586634c748f56d79/DANN/experiment1/main.py -------------------------------------------------------------------------------- /DANN/论文解读.md: -------------------------------------------------------------------------------- 1 | # DANN 2 | 3 | ## 摘要 4 | 5 | 这篇文章主要提出了一种基于对抗的领域适应方法。主要理论依据是:为了达成好的领域迁移,必须基于一种无法区分源域和目标域的特征进行预测。这个主要处理的场景是源域上的数据带有标记,而目标域上的数据没有标记。在训练的过程中,这个方法的目标是获得: 6 | 7 | 1. 在源域上对于原任务可区分 8 | 2. 关于域之间的偏移不可区分 9 | 10 | 这篇文章主要通过在标准的分类网络中添加梯度反转层(gradient reversal layer)来实现。并且添加了这种层的神经网络依然可以通过反向传播和随机梯度下降进行训练。在原文章中,作者主要进行了文本情感分析和图像分类的实验,都取得了当时的SOTA的效果。同时作者还进行了行人重识别的实验。 11 | 12 | ## 原理 13 | 14 | 许多之前的领域适应都是使用固定的特征表示,领域适应与深度学习相结合,统一到一个训练过程中。文章的整体理论基础基于Ben-David的理论。这一理论认为,给定两个分布,$D_S^X,D_T^X$, 和一个假想的分类函数$\Eta$,则这两个分布的$\Eta divergence$可以表达为: 15 | 16 | $$ 17 | d_\Eta(D_S^x,D_T^x)=2sup_{\eta\in \Eta}\left|Pr_{x\sim D_S^x}[\eta(x)=1]-Pr_{x\sim D_T^x}[\eta(x)=1]\right| 18 | $$ 19 | 20 | 如果存在判别函数$\eta$能够完全区分源域和目标域中的数据,则$\Eta divergence$为2.反之,$\Eta$中最好的判别函数对源域和目标域区分度越差,则divergence越小。因为普遍情况下我们无法获得全部的源域数据和目标域数据,Ben-David给出了通过采样数据估计divergence的方法,其表达式为: 21 | 22 | $$ 23 | \hat{d_\Eta}(S,T)=2\left( 1-min_{\eta\in \Eta}\left[\frac{1}{n}\sum_{i=1}^nI[\eta(x_i)=0]+\frac{1}{n'}\sum_{i=n+1}^NI[\eta(x_i)=1]\right]\right) 24 | $$ 25 | 26 | 其中$1-n$为源域中的数据,$n-N$为目标域中的数据,即: 27 | 28 | $$ 29 | U=\left\{ (x_i,0)\right\}_{i=1}^{n}\cup\left\{ (x_i,1)\right\}_{i=n+1}^{N} 30 | $$ 31 | 32 | $I$为指示器函数,该函数在预测正确时为1,否则为0.即预测准确率越高,divergence越小。 33 | 34 | 然而实际计算divergence的却困难,因为这涉及到在函数空间内找到一个最优的$\eta$,所以大多数情况需要使用学习算法。如果使用$\epsilon$表示分类器的错误率,则divergence的估计可以表示为: 35 | 36 | $$ 37 | \hat{d_\Alpha}=2(1-2\epsilon) 38 | $$ 39 | 40 | Ben-David 同时证明了,$d_\Eta(D_S^x,D_T^x)$是$\hat{d_\Eta}(S,T)+O(d)+O(N)$的上界,其中d是函数$\Eta$的VC维维度数。因此,在目标域上的风险满足: 41 | 42 | $$ 43 | R_{D_T}(\eta)\leq R_S(\eta)+\sqrt{\frac{4}{n}(dlog\frac{2en}{d}+log\frac{4}{\delta})}+\hat{d_\Eta}(S,T)+4\sqrt{\frac{1}{n}(dlog\frac{2n}{d}+log\frac{4}{\delta})}+\beta 44 | $$ 45 | 46 | 其中 47 | $\beta\geq inf_{\eta*\in \Eta}[R_{D_S}(\eta*)+D_{D_T}(\eta*)]$ 48 | 表示分类函数在源域和目标域上分类误差的和的下届。这表明,在给定VC维的情况下,要减少在目标域上的损失,需要最小化 49 | $R_S(\eta)+\beta+\hat{d_\Eta}(S,T)$。然而在未知目标域标签的情况下,我们能做的只能是优化 50 | $R_S(\eta)+\hat{d_\Eta}(S,T)$ 51 | 这篇文章接下来的方法就是构建一种网络去优化上面这个目标函数。 52 | 53 | ## 网络结构 54 | 55 | 作者首先考虑了只有一个隐藏层的浅层神经网络,其输入为m维实向量,经过隐藏层后转化为一个D维的表示。即: 56 | 57 | $$ 58 | G_f(x;W,b)=sigm(Wx+b) 59 | $$ 60 | 61 | 其中: 62 | $$ 63 | sigm(a)=[\frac{1}{1+exp9-a_i}]_{i=1}^{|a|} 64 | $$ 65 | 在最后的预测层则采用softmax函数进行L分类。即$G_y:R^D\rightarrow [0,1]^L$,使用正确标签的对数可能性作为损失函数,即 66 | 67 | $$ 68 | L_y(G_y(G_f(x_i)),y_i)=log\frac{1}{G_y(G_f(x))_{y_i}} 69 | $$ 70 | 最终在网络中进行训练的时候也可以加上正则化项。因为中间隐藏层可以看做是一种表示,所以源域和目标域的样本表示可以分别写作: 71 | $$ 72 | S(G_f)=\left\{G_f(x)|x\in S\right\} 73 | $$ 74 | $$ 75 | T(G_f)=\left\{G_f(x)|x\in T\right\} 76 | $$ 77 | 78 | 根据上面的原理,源域和目标域的表示可以写作: 79 | $$ 80 | \hat{d_\Eta}(S(G_f),T(G_f))=2\left( 1-min_{\eta\in \Eta}\left[\frac{1}{n}\sum_{i=1}^nI[\eta(G_f(x_i))=0]+\frac{1}{n'}\sum_{i=n+1}^NI[\eta(G_f(x_i))=1]\right]\right) 81 | $$ 82 | 83 | 在这里为了优化divergence,只需要在网络中添加一个层,这个层来对源域和目标域进行分类,就可以表示$\eta$.可以使用逻辑回归的形式,损失函数同样也可使用对数损失: 84 | $$ 85 | G_d(G_f(x);u,z)=sigm(u^TG_f(x)+z) 86 | $$ 87 | 对于原网络来说,我们的目的是使网络特征表达尽可能差的被域分类器分类,所以原网络的正则化项就可以设计为: 88 | $$ 89 | R(W,b)=max_{u,z}-\left[ \frac{1}{n}\sum_{i=1}^{n}L_d^i(W,b,u,z)+\frac{1}{n'}\sum_{i=n+1}^{N}L_d^i(W,b,u,z)\right] 90 | $$ 91 | 这样,这个正则化项越小,则说明域分类器的最佳表现结果越差。最终目标函数就被导出了: 92 | 93 | $$ 94 | E(W,V,b,c,u,z)=\frac{1}{n}\sum_{i=1}^{n}L_y^i(W,b,V,c)-\lambda\left( \frac{1}{n}\sum_{i=1}^{n}L_d^i(W,b,u,z)+\frac{1}{n'}\sum_{i=n+1}^{N}L_d^i(W,b,u,z)\right) 95 | $$ 96 | 97 | 其中属于网络隐藏层和类别预测其的参数$W,V,b,c$希望目标函数最大,属于域分类器的参数$u,z$希望目标函数最小。这两者之间是一个对抗的过程,这也是网络名Domain Adversial Neural Network的来源。这样的目标函数通过常用的反向传播和随机梯度下降是无法训练的,所以作者选择加入一种新的层来解决这个问题。也就是这篇文章的主要技巧GRL。 98 | 99 | GRL层表现的很简单,在前向传播时,表现为恒等变换,在反向传播时,将梯度取反。这种函数在实际中不存在,但在神经网络框架中可以很方便的定制。例如在pytorch里面: 100 | 101 | ``` python 102 | from torch.autograd import Function 103 | class ReverseLayerF(Function): 104 | 105 | @staticmethod 106 | def forward(ctx, x, alpha): 107 | ctx.alpha = alpha 108 | 109 | return x.view_as(x) 110 | 111 | @staticmethod 112 | def backward(ctx, grad_output): 113 | output = grad_output.neg() * ctx.alpha 114 | 115 | return output, None 116 | ``` 117 | 118 | ## 实验(next) -------------------------------------------------------------------------------- /DANN/论文解读_zhihu.md: -------------------------------------------------------------------------------- 1 | # DANN 2 | 3 | ## 摘要 4 | 5 | 这篇文章主要提出了一种基于对抗的领域适应方法。主要理论依据是:为了达成好的领域迁移,必须基于一种无法区分源域和目标域的特征进行预测。这个主要处理的场景是源域上的数据带有标记,而目标域上的数据没有标记。在训练的过程中,这个方法的目标是获得: 6 | 7 | 1. 在源域上对于原任务可区分 8 | 2. 关于域之间的偏移不可区分 9 | 10 | 这篇文章主要通过在标准的分类网络中添加梯度反转层(gradient reversal layer)来实现。并且添加了这种层的神经网络依然可以通过反向传播和随机梯度下降进行训练。在原文章中,作者主要进行了文本情感分析和图像分类的实验,都取得了当时的SOTA的效果。同时作者还进行了行人重识别的实验。 11 | 12 | ## 原理 13 | 14 | 许多之前的领域适应都是使用固定的特征表示,领域适应与深度学习相结合,统一到一个训练过程中。文章的整体理论基础基于Ben-David的理论。这一理论认为,给定两个分布, 15 | D_S^X,D_T^X 16 | , 和一个假想的分类函数 17 | H 18 | ,则这两个分布的 19 | H divergence 20 | 可以表达为: 21 | 22 | 23 | d_H(D_S^x,D_T^x)=2sup_{\eta\in H}\left|Pr_{x\sim D_S^x}[\eta(x)=1]-Pr_{x\sim D_T^x}[\eta(x)=1]\right| 24 | 25 | 26 | 如果存在判别函数 27 | \eta 28 | 能够完全区分源域和目标域中的数据,则 29 | H divergence 30 | 为2.反之, 31 | H 32 | 中最好的判别函数对源域和目标域区分度越差,则divergence越小。因为普遍情况下我们无法获得全部的源域数据和目标域数据,Ben-David给出了通过采样数据估计divergence的方法,其表达式为: 33 | 34 | 35 | \hat{d_H}(S,T)=2\left( 1-min_{\eta\in H}\left[\frac{1}{n}\sum_{i=1}^nI[\eta(x_i)=0]+\frac{1}{n'}\sum_{i=n+1}^NI[\eta(x_i)=1]\right]\right) 36 | 37 | 38 | 其中 39 | 1-n 40 | 为源域中的数据, 41 | n-N 42 | 为目标域中的数据,即: 43 | 44 | 45 | U=\left\{ (x_i,0)\right\}_{i=1}^{n}\cup\left\{ (x_i,1)\right\}_{i=n+1}^{N} 46 | 47 | 48 | 49 | I 50 | 为指示器函数,该函数在预测正确时为1,否则为0.即预测准确率越高,divergence越小。 51 | 52 | 然而实际计算divergence的却困难,因为这涉及到在函数空间内找到一个最优的 53 | \eta 54 | ,所以大多数情况需要使用学习算法。如果使用 55 | \epsilon 56 | 表示分类器的错误率,则divergence的估计可以表示为: 57 | 58 | 59 | \hat{d_A}=2(1-2\epsilon) 60 | 61 | 62 | Ben-David 同时证明了, 63 | d_H(D_S^x,D_T^x) 64 | 是 65 | \hat{d_H}(S,T)+O(d)+O(N) 66 | 的上界,其中d是函数 67 | H 68 | 的VC维维度数。因此,在目标域上的风险满足: 69 | 70 | 71 | R_{D_T}(\eta)\leq R_S(\eta)+\sqrt{\frac{4}{n}(dlog\frac{2en}{d}+log\frac{4}{\delta})}+\hat{d_H}(S,T)+4\sqrt{\frac{1}{n}(dlog\frac{2n}{d}+log\frac{4}{\delta})}+\beta 72 | 73 | 74 | 其中 75 | 76 | \beta\geq inf_{\eta*\in H}[R_{D_S}(\eta*)+D_{D_T}(\eta*)] 77 | 78 | 表示分类函数在源域和目标域上分类误差的和的下届。这表明,在给定VC维的情况下,要减少在目标域上的损失,需要最小化 79 | 80 | R_S(\eta)+\beta+\hat{d_H}(S,T) 81 | 。然而在未知目标域标签的情况下,我们能做的只能是优化 82 | 83 | R_S(\eta)+\hat{d_H}(S,T) 84 | 85 | 这篇文章接下来的方法就是构建一种网络去优化上面这个目标函数。 86 | 87 | ## 网络结构 88 | 89 | 作者首先考虑了只有一个隐藏层的浅层神经网络,其输入为m维实向量,经过隐藏层后转化为一个D维的表示。即: 90 | 91 | 92 | G_f(x;W,b)=sigm(Wx+b) 93 | 94 | 95 | 其中: 96 | 97 | sigm(a)=[\frac{1}{1+exp9-a_i}]_{i=1}^{|a|} 98 | 99 | 在最后的预测层则采用softmax函数进行L分类。即 100 | G_y:R^D\rightarrow [0,1]^L 101 | ,使用正确标签的对数可能性作为损失函数,即 102 | 103 | 104 | L_y(G_y(G_f(x_i)),y_i)=log\frac{1}{G_y(G_f(x))_{y_i}} 105 | 106 | 最终在网络中进行训练的时候也可以加上正则化项。因为中间隐藏层可以看做是一种表示,所以源域和目标域的样本表示可以分别写作: 107 | 108 | S(G_f)=\left\{G_f(x)|x\in S\right\} 109 | 110 | 111 | T(G_f)=\left\{G_f(x)|x\in T\right\} 112 | 113 | 114 | 根据上面的原理,源域和目标域的表示可以写作: 115 | 116 | \hat{d_H}(S(G_f),T(G_f))=2\left( 1-min_{\eta\in H}\left[\frac{1}{n}\sum_{i=1}^nI[\eta(G_f(x_i))=0]+\frac{1}{n'}\sum_{i=n+1}^NI[\eta(G_f(x_i))=1]\right]\right) 117 | 118 | 119 | 在这里为了优化divergence,只需要在网络中添加一个层,这个层来对源域和目标域进行分类,就可以表示 120 | \eta 121 | .可以使用逻辑回归的形式,损失函数同样也可使用对数损失: 122 | 123 | G_d(G_f(x);u,z)=sigm(u^TG_f(x)+z) 124 | 125 | 对于原网络来说,我们的目的是使网络特征表达尽可能差的被域分类器分类,所以原网络的正则化项就可以设计为: 126 | 127 | R(W,b)=max_{u,z}-\left[ \frac{1}{n}\sum_{i=1}^{n}L_d^i(W,b,u,z)+\frac{1}{n'}\sum_{i=n+1}^{N}L_d^i(W,b,u,z)\right] 128 | 129 | 这样,这个正则化项越小,则说明域分类器的最佳表现结果越差。最终目标函数就被导出了: 130 | 131 | 132 | E(W,V,b,c,u,z)=\frac{1}{n}\sum_{i=1}^{n}L_y^i(W,b,V,c)-\lambda\left( \frac{1}{n}\sum_{i=1}^{n}L_d^i(W,b,u,z)+\frac{1}{n'}\sum_{i=n+1}^{N}L_d^i(W,b,u,z)\right) 133 | 134 | 135 | 其中属于网络隐藏层和类别预测其的参数 136 | W,V,b,c 137 | 希望目标函数最大,属于域分类器的参数 138 | u,z 139 | 希望目标函数最小。这两者之间是一个对抗的过程,这也是网络名Domain Adversial Neural Network的来源。这样的目标函数通过常用的反向传播和随机梯度下降是无法训练的,所以作者选择加入一种新的层来解决这个问题。也就是这篇文章的主要技巧GRL。 140 | 141 | GRL层表现的很简单,在前向传播时,表现为恒等变换,在反向传播时,将梯度取反。这种函数在实际中不存在,但在神经网络框架中可以很方便的定制。例如在pytorch里面: 142 | 143 | ``` python 144 | from torch.autograd import Function 145 | class ReverseLayerF(Function): 146 | 147 | @staticmethod 148 | def forward(ctx, x, alpha): 149 | ctx.alpha = alpha 150 | 151 | return x.view_as(x) 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | output = grad_output.neg() * ctx.alpha 156 | 157 | return output, None 158 | ``` 159 | 160 | ## 实验(next) -------------------------------------------------------------------------------- /DCN/DCN.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamer-AP/DomainAdaptionPytorch/c06799caa150eab51804f472586634c748f56d79/DCN/DCN.PNG -------------------------------------------------------------------------------- /DCN/论文解读.md: -------------------------------------------------------------------------------- 1 | # DDC 2 | 3 | 原论文: 4 | 5 | ## 摘要 6 | 7 | (文章发表于2014年)在深度模型的fintune过程中,需要大量的数据,对很多应用来说并不现实。作者提出了一种包含适应层和领域损失的网络结构,能够同时保证语义上有意义和领域上无损失。并找到了一种domain confusion度量能够决定适应层的维度和位置。文章提出的网络在视觉任务上取得了当时最好的效果。 8 | 9 | ## 介绍 10 | 11 | 数据偏差是图像识别领域的常见问题,作者指出在此之前解决这些问题的方法都是较浅层的模型,或者使用fintune的方法。但如果目标域只有很少的数据,进行fintune是有问题的。**作者认为,进行领域不变性的优化可以认为是同时进行标签的预测任务和寻找域之间更相似的表示。** 这也是作者这篇文章的主要启发点,同时优化分类损失和域混淆损失。 12 | 13 | 基于Alex Net, 作者提出了新的网络结构。![DCN网络结构](DCN.png) 14 | 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 cryan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DomainAdaptionPytorch 2 | 3 | Code for recent domain adaption paper implemented by pytorch. 4 | 5 | [中文版说明](https://github.com/streamer-AP/DomainAdaptionPytorch/blob/master/说明.md) 6 | 7 | ## paper list 8 | 9 | * [Domain-Adversarial Training of Neural Networks(DANN)](https://arxiv.org/abs/1505.07818), [code](https://github.com/streamer-AP/DomainAdaptionPytorch/blob/master/DANN) 10 | 11 | -------------------------------------------------------------------------------- /zhihu.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import sys 4 | import os 5 | 6 | def replace(file_name, output_file_name): 7 | if os.path.exists(file_name): 8 | pattern1 = r"\$\$\n*([\s\S]*?)\n*\$\$" 9 | new_pattern1 = r'\n\1\n' 10 | pattern2 = r"\$\n*(.*?)\n*\$" 11 | new_pattern2 =r'\n\1\n' 12 | f = open(file_name, 'r',encoding="utf-8") 13 | f_output = open(output_file_name, 'w',encoding="utf-8") 14 | all_lines = f.read() 15 | new_lines1 = re.sub(pattern1, new_pattern1, all_lines) 16 | new_lines2 = re.sub(pattern2, new_pattern2, new_lines1) 17 | new_lines2=new_lines2.replace(r"\Eta","H") 18 | new_lines2=new_lines2.replace(r"\Alpha","A") 19 | 20 | f_output.write(new_lines2) 21 | f.close() 22 | f_output.close() 23 | 24 | if __name__ == '__main__': 25 | 26 | if len(sys.argv) < 2: 27 | print("need file name") 28 | sys.exit(1) 29 | file_name = sys.argv[1] 30 | file_name_pre = file_name.split(".")[0] 31 | output_file_name = file_name_pre + "_zhihu.md" 32 | replace(file_name, output_file_name) -------------------------------------------------------------------------------- /说明.md: -------------------------------------------------------------------------------- 1 | # 领域适应代码及实现 2 | 3 | 对近年来的领域适应论文的实现 4 | 5 | [English Readme](https://github.com/streamer-AP/DomainAdaptionPytorch/blob/master/README.md) 6 | 7 | ## 论文列表 8 | 9 | * [Domain-Adversarial Training of Neural Networks(DANN)](https://arxiv.org/abs/1505.07818), [代码](https://github.com/streamer-AP/DomainAdaptionPytorch/blob/master/DANN) --------------------------------------------------------------------------------