├── .gitignore ├── mnist-nn-modified.ipynb ├── mnist-visualize-arcface-loss.ipynb ├── mnist-visualize-arcface3_fc7-loss.ipynb ├── mnist-visualize-arcface4_fc7-loss.ipynb ├── mnist-visualize-arcface5_fc7-loss.ipynb ├── mnist-visualize-arcface6_fc7-loss.ipynb ├── mnist-visualize-arcface7_fc7-features-s64-m0_5.ipynb ├── mnist-visualize-arcface7_fc7-features.ipynb ├── mnist-visualize-center-loss.ipynb ├── mnist-visualize-center-loss2.ipynb ├── mnist-visualize-center-loss3.ipynb ├── mnist-visualize-center-loss4.ipynb ├── mnist-visualize-cosface.ipynb ├── mnist-visualize-cosface2.ipynb ├── mnist-visualize-cosface3.ipynb ├── mnist-visualize-cosface4.ipynb ├── mnist-visualize-cosface5.ipynb ├── mnist-visualize-cosface5_fc7.ipynb ├── mnist-visualize-cosface6-features-s64-m0_5.ipynb ├── mnist-visualize-cosface6-features.ipynb ├── mnist-visualize-softmax.ipynb ├── mnist-visualize-softmax_custom.ipynb ├── mnist.ipynb ├── mnist_arcface.py ├── mnist_arcface2_fc7.py ├── mnist_arcface3_fc7.py ├── mnist_arcface4_fc7.py ├── mnist_arcface5_fc7.py ├── mnist_arcface6_fc7.py ├── mnist_cnn-cosface.pt ├── mnist_cosface.py ├── mnist_cosface2.py ├── mnist_cosface3.py ├── mnist_cosface4.py ├── mnist_cosface5_fc7.py ├── mnist_loss-cosface.pt ├── mnist_softmax.py ├── mnist_softmax_custom.py ├── mnist_test.ipynb ├── playground.py ├── plot_to_gif.py └── test_arcface_mnist.py /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | data/ 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # data 30 | #data/datasets 31 | #data/datasets/lfw_mtcnnpy_160_2/ 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *,cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | 82 | # virtualenv 83 | .venv 84 | venv/ 85 | ENV/ 86 | -------------------------------------------------------------------------------- /mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "ohFs-uX44uZS" 8 | }, 9 | "source": [ 10 | "## Install Pytorch if needed" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 0, 16 | "metadata": { 17 | "colab": {}, 18 | "colab_type": "code", 19 | "id": "doPtfTAl4uZU" 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "# http://pytorch.org/\n", 24 | "\n", 25 | "\n", 26 | "# from os.path import exists\n", 27 | "# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", 28 | "# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", 29 | "# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n", 30 | "# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'\n", 31 | "\n", 32 | "# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision\n", 33 | "# import torch" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "colab_type": "text", 40 | "id": "Cq7PhAqo4uZX" 41 | }, 42 | "source": [ 43 | "## Import modules" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 126, 49 | "metadata": { 50 | "ExecuteTime": { 51 | "end_time": "2017-04-23T13:00:45.951502", 52 | "start_time": "2017-04-23T13:00:44.296636" 53 | }, 54 | "colab": { 55 | "base_uri": "https://localhost:8080/", 56 | "height": 34 57 | }, 58 | "colab_type": "code", 59 | "id": "5-jCEYf94uZX", 60 | "outputId": "a083a119-1f3a-4fe7-e643-06e87639153e" 61 | }, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "Pytorch version: 1.0.0\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "import argparse\n", 73 | "import torch\n", 74 | "import torch.nn as nn\n", 75 | "import torch.nn.functional as F\n", 76 | "import torch.optim as optim\n", 77 | "from torchvision import datasets, transforms\n", 78 | "\n", 79 | "print(\"Pytorch version: \" + str(torch.__version__))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 127, 85 | "metadata": { 86 | "colab": {}, 87 | "colab_type": "code", 88 | "id": "Z46xBTdm4uZc" 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "BATCH_SIZE = 64\n", 93 | "BATCH_SIZE_TEST = 1000\n", 94 | "EPOCHS = 10\n", 95 | "LOG_INTERVAL = 10" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "colab_type": "text", 102 | "id": "eMSEzd-l4uZf" 103 | }, 104 | "source": [ 105 | "## Model setup" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 129, 111 | "metadata": { 112 | "colab": {}, 113 | "colab_type": "code", 114 | "id": "TgKqoF-v4uZg" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "class Net(nn.Module):\n", 119 | " def __init__(self):\n", 120 | " super(Net, self).__init__()\n", 121 | " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", 122 | " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", 123 | " self.fc1 = nn.Linear(4*4*50, 500)\n", 124 | " self.fc2 = nn.Linear(500, 10)\n", 125 | "\n", 126 | " def forward(self, x):\n", 127 | " x = F.relu(self.conv1(x))\n", 128 | " x = F.max_pool2d(x, 2, 2)\n", 129 | " x = F.relu(self.conv2(x))\n", 130 | " x = F.max_pool2d(x, 2, 2)\n", 131 | " x = x.view(-1, 4*4*50)\n", 132 | " x = F.relu(self.fc1(x))\n", 133 | " x = self.fc2(x)\n", 134 | " return F.log_softmax(x, dim=1)\n", 135 | " \n", 136 | "\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 130, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "Net(\n", 148 | " (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", 149 | " (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n", 150 | " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", 151 | " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", 152 | ")" 153 | ] 154 | }, 155 | "execution_count": 130, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "model = Net()\n", 162 | "model.eval()\n" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": { 168 | "colab_type": "text", 169 | "id": "17zxLDwV4uZj" 170 | }, 171 | "source": [ 172 | "## Train setup" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 4, 178 | "metadata": { 179 | "colab": {}, 180 | "colab_type": "code", 181 | "id": "NdQVgFKk4uZl" 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "def train(model, device, train_loader, optimizer, epoch):\n", 186 | " model.train()\n", 187 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 188 | " data, target = data.to(device), target.to(device)\n", 189 | " optimizer.zero_grad()\n", 190 | " output = model(data)\n", 191 | " loss = F.nll_loss(output, target)\n", 192 | " loss.backward()\n", 193 | " optimizer.step()\n", 194 | " if batch_idx % LOG_INTERVAL == 0:\n", 195 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 196 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 197 | " 100. * batch_idx / len(train_loader), loss.item()))\n", 198 | "\n", 199 | "def test(model, device, test_loader):\n", 200 | " model.eval()\n", 201 | " test_loss = 0\n", 202 | " correct = 0\n", 203 | " with torch.no_grad():\n", 204 | " for data, target in test_loader:\n", 205 | " data, target = data.to(device), target.to(device)\n", 206 | " output = model(data)\n", 207 | " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", 208 | " pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 209 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 210 | "\n", 211 | " test_loss /= len(test_loader.dataset)\n", 212 | "\n", 213 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 214 | " test_loss, correct, len(test_loader.dataset),\n", 215 | " 100. * correct / len(test_loader.dataset)))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "ExecuteTime": { 222 | "end_time": "2017-04-23T13:00:44.295728", 223 | "start_time": "2017-04-23T13:00:44.293871" 224 | }, 225 | "colab_type": "text", 226 | "collapsed": true, 227 | "id": "h0FbNgzK4uZp" 228 | }, 229 | "source": [ 230 | "## Dataset setup" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 84, 236 | "metadata": { 237 | "colab": {}, 238 | "colab_type": "code", 239 | "id": "AtqNFTcA4uZp" 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "\n", 244 | "\n", 245 | "use_cuda = torch.cuda.is_available()\n", 246 | "\n", 247 | "torch.manual_seed(1)\n", 248 | "\n", 249 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 250 | "\n", 251 | "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n", 252 | "train_loader = torch.utils.data.DataLoader(\n", 253 | " datasets.MNIST('../data', train=True, download=True,\n", 254 | " transform=transforms.Compose([\n", 255 | " transforms.ToTensor(),\n", 256 | " transforms.Normalize((0.1307,), (0.3081,))\n", 257 | " ])),\n", 258 | " batch_size=BATCH_SIZE, shuffle=True, **kwargs)\n", 259 | "test_loader = torch.utils.data.DataLoader(\n", 260 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 261 | " transforms.ToTensor(),\n", 262 | " transforms.Normalize((0.1307,), (0.3081,))\n", 263 | " ])),\n", 264 | " batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs)\n", 265 | "\n", 266 | "\n" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": { 272 | "colab_type": "text", 273 | "id": "Rzz0r5To4uZs" 274 | }, 275 | "source": [ 276 | "## Train process" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 10, 282 | "metadata": { 283 | "colab": { 284 | "base_uri": "https://localhost:8080/", 285 | "height": 16507 286 | }, 287 | "colab_type": "code", 288 | "id": "_GFUR8jN4uZu", 289 | "outputId": "5ddebb92-b06d-4ee2-c1dd-81a527b98e83" 290 | }, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.300039\n", 297 | "Train Epoch: 1 [640/60000 (1%)]\tLoss: 2.213470\n", 298 | "Train Epoch: 1 [1280/60000 (2%)]\tLoss: 2.170460\n", 299 | "Train Epoch: 1 [1920/60000 (3%)]\tLoss: 2.076729\n", 300 | "Train Epoch: 1 [2560/60000 (4%)]\tLoss: 1.868197\n", 301 | "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 1.414118\n", 302 | "Train Epoch: 1 [3840/60000 (6%)]\tLoss: 1.000784\n", 303 | "Train Epoch: 1 [4480/60000 (7%)]\tLoss: 0.775921\n", 304 | "Train Epoch: 1 [5120/60000 (9%)]\tLoss: 0.460183\n", 305 | "Train Epoch: 1 [5760/60000 (10%)]\tLoss: 0.486091\n", 306 | "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.437529\n", 307 | "Train Epoch: 1 [7040/60000 (12%)]\tLoss: 0.408766\n", 308 | "Train Epoch: 1 [7680/60000 (13%)]\tLoss: 0.461533\n", 309 | "Train Epoch: 1 [8320/60000 (14%)]\tLoss: 0.429015\n", 310 | "Train Epoch: 1 [8960/60000 (15%)]\tLoss: 0.398074\n", 311 | "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.386383\n", 312 | "Train Epoch: 1 [10240/60000 (17%)]\tLoss: 0.298286\n", 313 | "Train Epoch: 1 [10880/60000 (18%)]\tLoss: 0.502210\n", 314 | "Train Epoch: 1 [11520/60000 (19%)]\tLoss: 0.522743\n", 315 | "Train Epoch: 1 [12160/60000 (20%)]\tLoss: 0.338115\n", 316 | "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.365848\n", 317 | "Train Epoch: 1 [13440/60000 (22%)]\tLoss: 0.448824\n", 318 | "Train Epoch: 1 [14080/60000 (23%)]\tLoss: 0.303673\n", 319 | "Train Epoch: 1 [14720/60000 (25%)]\tLoss: 0.357254\n", 320 | "Train Epoch: 1 [15360/60000 (26%)]\tLoss: 0.329915\n", 321 | "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.437525\n", 322 | "Train Epoch: 1 [16640/60000 (28%)]\tLoss: 0.364413\n", 323 | "Train Epoch: 1 [17280/60000 (29%)]\tLoss: 0.315719\n", 324 | "Train Epoch: 1 [17920/60000 (30%)]\tLoss: 0.200114\n", 325 | "Train Epoch: 1 [18560/60000 (31%)]\tLoss: 0.496398\n", 326 | "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.327925\n", 327 | "Train Epoch: 1 [19840/60000 (33%)]\tLoss: 0.118464\n", 328 | "Train Epoch: 1 [20480/60000 (34%)]\tLoss: 0.189515\n", 329 | "Train Epoch: 1 [21120/60000 (35%)]\tLoss: 0.140906\n", 330 | "Train Epoch: 1 [21760/60000 (36%)]\tLoss: 0.313298\n", 331 | "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.150065\n", 332 | "Train Epoch: 1 [23040/60000 (38%)]\tLoss: 0.288942\n", 333 | "Train Epoch: 1 [23680/60000 (39%)]\tLoss: 0.469057\n", 334 | "Train Epoch: 1 [24320/60000 (41%)]\tLoss: 0.215919\n", 335 | "Train Epoch: 1 [24960/60000 (42%)]\tLoss: 0.152756\n", 336 | "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.225649\n", 337 | "Train Epoch: 1 [26240/60000 (44%)]\tLoss: 0.262905\n", 338 | "Train Epoch: 1 [26880/60000 (45%)]\tLoss: 0.233346\n", 339 | "Train Epoch: 1 [27520/60000 (46%)]\tLoss: 0.262019\n", 340 | "Train Epoch: 1 [28160/60000 (47%)]\tLoss: 0.211576\n", 341 | "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.132476\n", 342 | "Train Epoch: 1 [29440/60000 (49%)]\tLoss: 0.278274\n", 343 | "Train Epoch: 1 [30080/60000 (50%)]\tLoss: 0.094361\n", 344 | "Train Epoch: 1 [30720/60000 (51%)]\tLoss: 0.128920\n", 345 | "Train Epoch: 1 [31360/60000 (52%)]\tLoss: 0.246514\n", 346 | "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.338069\n", 347 | "Train Epoch: 1 [32640/60000 (54%)]\tLoss: 0.153890\n", 348 | "Train Epoch: 1 [33280/60000 (55%)]\tLoss: 0.089979\n", 349 | "Train Epoch: 1 [33920/60000 (57%)]\tLoss: 0.145261\n", 350 | "Train Epoch: 1 [34560/60000 (58%)]\tLoss: 0.197968\n", 351 | "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.217369\n", 352 | "Train Epoch: 1 [35840/60000 (60%)]\tLoss: 0.063571\n", 353 | "Train Epoch: 1 [36480/60000 (61%)]\tLoss: 0.136347\n", 354 | "Train Epoch: 1 [37120/60000 (62%)]\tLoss: 0.114946\n", 355 | "Train Epoch: 1 [37760/60000 (63%)]\tLoss: 0.237716\n", 356 | "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.063597\n", 357 | "Train Epoch: 1 [39040/60000 (65%)]\tLoss: 0.105856\n", 358 | "Train Epoch: 1 [39680/60000 (66%)]\tLoss: 0.160526\n", 359 | "Train Epoch: 1 [40320/60000 (67%)]\tLoss: 0.109833\n", 360 | "Train Epoch: 1 [40960/60000 (68%)]\tLoss: 0.177891\n", 361 | "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.229265\n", 362 | "Train Epoch: 1 [42240/60000 (70%)]\tLoss: 0.074022\n", 363 | "Train Epoch: 1 [42880/60000 (71%)]\tLoss: 0.155059\n", 364 | "Train Epoch: 1 [43520/60000 (72%)]\tLoss: 0.278508\n", 365 | "Train Epoch: 1 [44160/60000 (74%)]\tLoss: 0.143101\n", 366 | "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.114906\n", 367 | "Train Epoch: 1 [45440/60000 (76%)]\tLoss: 0.120892\n", 368 | "Train Epoch: 1 [46080/60000 (77%)]\tLoss: 0.076811\n", 369 | "Train Epoch: 1 [46720/60000 (78%)]\tLoss: 0.193659\n", 370 | "Train Epoch: 1 [47360/60000 (79%)]\tLoss: 0.070295\n", 371 | "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.207994\n", 372 | "Train Epoch: 1 [48640/60000 (81%)]\tLoss: 0.135138\n", 373 | "Train Epoch: 1 [49280/60000 (82%)]\tLoss: 0.093357\n", 374 | "Train Epoch: 1 [49920/60000 (83%)]\tLoss: 0.106430\n", 375 | "Train Epoch: 1 [50560/60000 (84%)]\tLoss: 0.120083\n", 376 | "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.143429\n", 377 | "Train Epoch: 1 [51840/60000 (86%)]\tLoss: 0.067632\n", 378 | "Train Epoch: 1 [52480/60000 (87%)]\tLoss: 0.023525\n", 379 | "Train Epoch: 1 [53120/60000 (88%)]\tLoss: 0.262577\n", 380 | "Train Epoch: 1 [53760/60000 (90%)]\tLoss: 0.092559\n", 381 | "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.127233\n", 382 | "Train Epoch: 1 [55040/60000 (92%)]\tLoss: 0.190678\n", 383 | "Train Epoch: 1 [55680/60000 (93%)]\tLoss: 0.034510\n", 384 | "Train Epoch: 1 [56320/60000 (94%)]\tLoss: 0.036593\n", 385 | "Train Epoch: 1 [56960/60000 (95%)]\tLoss: 0.076275\n", 386 | "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.118792\n", 387 | "Train Epoch: 1 [58240/60000 (97%)]\tLoss: 0.192908\n", 388 | "Train Epoch: 1 [58880/60000 (98%)]\tLoss: 0.204256\n", 389 | "Train Epoch: 1 [59520/60000 (99%)]\tLoss: 0.064322\n", 390 | "\n", 391 | "Test set: Average loss: 0.1019, Accuracy: 9658/10000 (97%)\n", 392 | "\n", 393 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.146209\n", 394 | "Train Epoch: 2 [640/60000 (1%)]\tLoss: 0.120489\n", 395 | "Train Epoch: 2 [1280/60000 (2%)]\tLoss: 0.103908\n", 396 | "Train Epoch: 2 [1920/60000 (3%)]\tLoss: 0.067228\n", 397 | "Train Epoch: 2 [2560/60000 (4%)]\tLoss: 0.105147\n", 398 | "Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.115444\n", 399 | "Train Epoch: 2 [3840/60000 (6%)]\tLoss: 0.098371\n", 400 | "Train Epoch: 2 [4480/60000 (7%)]\tLoss: 0.089602\n", 401 | "Train Epoch: 2 [5120/60000 (9%)]\tLoss: 0.190875\n", 402 | "Train Epoch: 2 [5760/60000 (10%)]\tLoss: 0.095843\n", 403 | "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.098459\n", 404 | "Train Epoch: 2 [7040/60000 (12%)]\tLoss: 0.070366\n", 405 | "Train Epoch: 2 [7680/60000 (13%)]\tLoss: 0.079403\n", 406 | "Train Epoch: 2 [8320/60000 (14%)]\tLoss: 0.077454\n", 407 | "Train Epoch: 2 [8960/60000 (15%)]\tLoss: 0.117428\n", 408 | "Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.043413\n", 409 | "Train Epoch: 2 [10240/60000 (17%)]\tLoss: 0.012356\n", 410 | "Train Epoch: 2 [10880/60000 (18%)]\tLoss: 0.248271\n", 411 | "Train Epoch: 2 [11520/60000 (19%)]\tLoss: 0.145766\n", 412 | "Train Epoch: 2 [12160/60000 (20%)]\tLoss: 0.091724\n", 413 | "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.060102\n", 414 | "Train Epoch: 2 [13440/60000 (22%)]\tLoss: 0.053024\n", 415 | "Train Epoch: 2 [14080/60000 (23%)]\tLoss: 0.074615\n", 416 | "Train Epoch: 2 [14720/60000 (25%)]\tLoss: 0.057848\n", 417 | "Train Epoch: 2 [15360/60000 (26%)]\tLoss: 0.025632\n", 418 | "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.168881\n", 419 | "Train Epoch: 2 [16640/60000 (28%)]\tLoss: 0.059074\n", 420 | "Train Epoch: 2 [17280/60000 (29%)]\tLoss: 0.062707\n", 421 | "Train Epoch: 2 [17920/60000 (30%)]\tLoss: 0.137210\n", 422 | "Train Epoch: 2 [18560/60000 (31%)]\tLoss: 0.112047\n", 423 | "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.066597\n", 424 | "Train Epoch: 2 [19840/60000 (33%)]\tLoss: 0.097211\n", 425 | "Train Epoch: 2 [20480/60000 (34%)]\tLoss: 0.074622\n", 426 | "Train Epoch: 2 [21120/60000 (35%)]\tLoss: 0.117305\n", 427 | "Train Epoch: 2 [21760/60000 (36%)]\tLoss: 0.093575\n", 428 | "Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.275831\n", 429 | "Train Epoch: 2 [23040/60000 (38%)]\tLoss: 0.064055\n", 430 | "Train Epoch: 2 [23680/60000 (39%)]\tLoss: 0.029043\n", 431 | "Train Epoch: 2 [24320/60000 (41%)]\tLoss: 0.191024\n", 432 | "Train Epoch: 2 [24960/60000 (42%)]\tLoss: 0.109829\n", 433 | "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.118210\n", 434 | "Train Epoch: 2 [26240/60000 (44%)]\tLoss: 0.021921\n", 435 | "Train Epoch: 2 [26880/60000 (45%)]\tLoss: 0.072457\n", 436 | "Train Epoch: 2 [27520/60000 (46%)]\tLoss: 0.147058\n", 437 | "Train Epoch: 2 [28160/60000 (47%)]\tLoss: 0.051846\n", 438 | "Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.132390\n", 439 | "Train Epoch: 2 [29440/60000 (49%)]\tLoss: 0.198021\n", 440 | "Train Epoch: 2 [30080/60000 (50%)]\tLoss: 0.058813\n", 441 | "Train Epoch: 2 [30720/60000 (51%)]\tLoss: 0.131270\n", 442 | "Train Epoch: 2 [31360/60000 (52%)]\tLoss: 0.146507\n", 443 | "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.034004\n", 444 | "Train Epoch: 2 [32640/60000 (54%)]\tLoss: 0.128595\n", 445 | "Train Epoch: 2 [33280/60000 (55%)]\tLoss: 0.076431\n", 446 | "Train Epoch: 2 [33920/60000 (57%)]\tLoss: 0.136983\n", 447 | "Train Epoch: 2 [34560/60000 (58%)]\tLoss: 0.147081\n", 448 | "Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.028694\n", 449 | "Train Epoch: 2 [35840/60000 (60%)]\tLoss: 0.167012\n", 450 | "Train Epoch: 2 [36480/60000 (61%)]\tLoss: 0.028680\n", 451 | "Train Epoch: 2 [37120/60000 (62%)]\tLoss: 0.039678\n", 452 | "Train Epoch: 2 [37760/60000 (63%)]\tLoss: 0.060088\n", 453 | "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.034208\n", 454 | "Train Epoch: 2 [39040/60000 (65%)]\tLoss: 0.060537\n", 455 | "Train Epoch: 2 [39680/60000 (66%)]\tLoss: 0.136136\n", 456 | "Train Epoch: 2 [40320/60000 (67%)]\tLoss: 0.094202\n", 457 | "Train Epoch: 2 [40960/60000 (68%)]\tLoss: 0.077091\n", 458 | "Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.033310\n", 459 | "Train Epoch: 2 [42240/60000 (70%)]\tLoss: 0.037923\n", 460 | "Train Epoch: 2 [42880/60000 (71%)]\tLoss: 0.020340\n", 461 | "Train Epoch: 2 [43520/60000 (72%)]\tLoss: 0.030696\n", 462 | "Train Epoch: 2 [44160/60000 (74%)]\tLoss: 0.042431\n", 463 | "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.036715\n", 464 | "Train Epoch: 2 [45440/60000 (76%)]\tLoss: 0.150002\n", 465 | "Train Epoch: 2 [46080/60000 (77%)]\tLoss: 0.103709\n", 466 | "Train Epoch: 2 [46720/60000 (78%)]\tLoss: 0.135341\n", 467 | "Train Epoch: 2 [47360/60000 (79%)]\tLoss: 0.138806\n", 468 | "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.050912\n", 469 | "Train Epoch: 2 [48640/60000 (81%)]\tLoss: 0.052124\n", 470 | "Train Epoch: 2 [49280/60000 (82%)]\tLoss: 0.028371\n", 471 | "Train Epoch: 2 [49920/60000 (83%)]\tLoss: 0.070579\n", 472 | "Train Epoch: 2 [50560/60000 (84%)]\tLoss: 0.100813\n", 473 | "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.026794\n", 474 | "Train Epoch: 2 [51840/60000 (86%)]\tLoss: 0.039561\n", 475 | "Train Epoch: 2 [52480/60000 (87%)]\tLoss: 0.025829\n", 476 | "Train Epoch: 2 [53120/60000 (88%)]\tLoss: 0.040187\n", 477 | "Train Epoch: 2 [53760/60000 (90%)]\tLoss: 0.192775\n", 478 | "Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.061120\n", 479 | "Train Epoch: 2 [55040/60000 (92%)]\tLoss: 0.044438\n", 480 | "Train Epoch: 2 [55680/60000 (93%)]\tLoss: 0.021667\n", 481 | "Train Epoch: 2 [56320/60000 (94%)]\tLoss: 0.069259\n", 482 | "Train Epoch: 2 [56960/60000 (95%)]\tLoss: 0.083077\n", 483 | "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.038155\n", 484 | "Train Epoch: 2 [58240/60000 (97%)]\tLoss: 0.165688\n", 485 | "Train Epoch: 2 [58880/60000 (98%)]\tLoss: 0.034593\n", 486 | "Train Epoch: 2 [59520/60000 (99%)]\tLoss: 0.068206\n", 487 | "\n", 488 | "Test set: Average loss: 0.0605, Accuracy: 9828/10000 (98%)\n", 489 | "\n", 490 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.050979\n", 491 | "Train Epoch: 3 [640/60000 (1%)]\tLoss: 0.055319\n", 492 | "Train Epoch: 3 [1280/60000 (2%)]\tLoss: 0.033622\n", 493 | "Train Epoch: 3 [1920/60000 (3%)]\tLoss: 0.055642\n", 494 | "Train Epoch: 3 [2560/60000 (4%)]\tLoss: 0.027476\n", 495 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.124311\n", 496 | "Train Epoch: 3 [3840/60000 (6%)]\tLoss: 0.027328\n", 497 | "Train Epoch: 3 [4480/60000 (7%)]\tLoss: 0.153117\n", 498 | "Train Epoch: 3 [5120/60000 (9%)]\tLoss: 0.081922\n", 499 | "Train Epoch: 3 [5760/60000 (10%)]\tLoss: 0.016511\n", 500 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.097670\n", 501 | "Train Epoch: 3 [7040/60000 (12%)]\tLoss: 0.027996\n", 502 | "Train Epoch: 3 [7680/60000 (13%)]\tLoss: 0.010163\n", 503 | "Train Epoch: 3 [8320/60000 (14%)]\tLoss: 0.035425\n", 504 | "Train Epoch: 3 [8960/60000 (15%)]\tLoss: 0.043074\n", 505 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.017631\n", 506 | "Train Epoch: 3 [10240/60000 (17%)]\tLoss: 0.051241\n", 507 | "Train Epoch: 3 [10880/60000 (18%)]\tLoss: 0.070071\n", 508 | "Train Epoch: 3 [11520/60000 (19%)]\tLoss: 0.031208\n", 509 | "Train Epoch: 3 [12160/60000 (20%)]\tLoss: 0.036031\n", 510 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.032816\n", 511 | "Train Epoch: 3 [13440/60000 (22%)]\tLoss: 0.117220\n", 512 | "Train Epoch: 3 [14080/60000 (23%)]\tLoss: 0.031576\n", 513 | "Train Epoch: 3 [14720/60000 (25%)]\tLoss: 0.065685\n", 514 | "Train Epoch: 3 [15360/60000 (26%)]\tLoss: 0.024636\n", 515 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.080977\n", 516 | "Train Epoch: 3 [16640/60000 (28%)]\tLoss: 0.064801\n", 517 | "Train Epoch: 3 [17280/60000 (29%)]\tLoss: 0.064981\n", 518 | "Train Epoch: 3 [17920/60000 (30%)]\tLoss: 0.047142\n", 519 | "Train Epoch: 3 [18560/60000 (31%)]\tLoss: 0.022806\n", 520 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.067183\n", 521 | "Train Epoch: 3 [19840/60000 (33%)]\tLoss: 0.066337\n", 522 | "Train Epoch: 3 [20480/60000 (34%)]\tLoss: 0.093048\n", 523 | "Train Epoch: 3 [21120/60000 (35%)]\tLoss: 0.042808\n", 524 | "Train Epoch: 3 [21760/60000 (36%)]\tLoss: 0.032521\n", 525 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.104981\n", 526 | "Train Epoch: 3 [23040/60000 (38%)]\tLoss: 0.154438\n", 527 | "Train Epoch: 3 [23680/60000 (39%)]\tLoss: 0.034937\n", 528 | "Train Epoch: 3 [24320/60000 (41%)]\tLoss: 0.088588\n", 529 | "Train Epoch: 3 [24960/60000 (42%)]\tLoss: 0.040093\n", 530 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.034436\n", 531 | "Train Epoch: 3 [26240/60000 (44%)]\tLoss: 0.118453\n", 532 | "Train Epoch: 3 [26880/60000 (45%)]\tLoss: 0.024535\n", 533 | "Train Epoch: 3 [27520/60000 (46%)]\tLoss: 0.028956\n", 534 | "Train Epoch: 3 [28160/60000 (47%)]\tLoss: 0.059100\n", 535 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.063484\n", 536 | "Train Epoch: 3 [29440/60000 (49%)]\tLoss: 0.007458\n", 537 | "Train Epoch: 3 [30080/60000 (50%)]\tLoss: 0.082076\n", 538 | "Train Epoch: 3 [30720/60000 (51%)]\tLoss: 0.080864\n", 539 | "Train Epoch: 3 [31360/60000 (52%)]\tLoss: 0.091670\n", 540 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.028935\n", 541 | "Train Epoch: 3 [32640/60000 (54%)]\tLoss: 0.021896\n", 542 | "Train Epoch: 3 [33280/60000 (55%)]\tLoss: 0.026369\n", 543 | "Train Epoch: 3 [33920/60000 (57%)]\tLoss: 0.044139\n", 544 | "Train Epoch: 3 [34560/60000 (58%)]\tLoss: 0.041852\n", 545 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.075658\n", 546 | "Train Epoch: 3 [35840/60000 (60%)]\tLoss: 0.049239\n", 547 | "Train Epoch: 3 [36480/60000 (61%)]\tLoss: 0.027003\n", 548 | "Train Epoch: 3 [37120/60000 (62%)]\tLoss: 0.020415\n", 549 | "Train Epoch: 3 [37760/60000 (63%)]\tLoss: 0.052795\n", 550 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.022723\n", 551 | "Train Epoch: 3 [39040/60000 (65%)]\tLoss: 0.027924\n", 552 | "Train Epoch: 3 [39680/60000 (66%)]\tLoss: 0.076430\n", 553 | "Train Epoch: 3 [40320/60000 (67%)]\tLoss: 0.084330\n", 554 | "Train Epoch: 3 [40960/60000 (68%)]\tLoss: 0.014167\n", 555 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.142501\n", 556 | "Train Epoch: 3 [42240/60000 (70%)]\tLoss: 0.085138\n", 557 | "Train Epoch: 3 [42880/60000 (71%)]\tLoss: 0.030728\n", 558 | "Train Epoch: 3 [43520/60000 (72%)]\tLoss: 0.040884\n", 559 | "Train Epoch: 3 [44160/60000 (74%)]\tLoss: 0.037333\n", 560 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.065329\n", 561 | "Train Epoch: 3 [45440/60000 (76%)]\tLoss: 0.167014\n", 562 | "Train Epoch: 3 [46080/60000 (77%)]\tLoss: 0.037039\n", 563 | "Train Epoch: 3 [46720/60000 (78%)]\tLoss: 0.051752\n", 564 | "Train Epoch: 3 [47360/60000 (79%)]\tLoss: 0.133369\n", 565 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.120845\n", 566 | "Train Epoch: 3 [48640/60000 (81%)]\tLoss: 0.043990\n", 567 | "Train Epoch: 3 [49280/60000 (82%)]\tLoss: 0.056167\n", 568 | "Train Epoch: 3 [49920/60000 (83%)]\tLoss: 0.134476\n", 569 | "Train Epoch: 3 [50560/60000 (84%)]\tLoss: 0.009573\n", 570 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.046407\n", 571 | "Train Epoch: 3 [51840/60000 (86%)]\tLoss: 0.061389\n", 572 | "Train Epoch: 3 [52480/60000 (87%)]\tLoss: 0.033624\n", 573 | "Train Epoch: 3 [53120/60000 (88%)]\tLoss: 0.014880\n", 574 | "Train Epoch: 3 [53760/60000 (90%)]\tLoss: 0.028547\n", 575 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.060354\n", 576 | "Train Epoch: 3 [55040/60000 (92%)]\tLoss: 0.054741\n", 577 | "Train Epoch: 3 [55680/60000 (93%)]\tLoss: 0.018225\n", 578 | "Train Epoch: 3 [56320/60000 (94%)]\tLoss: 0.073061\n", 579 | "Train Epoch: 3 [56960/60000 (95%)]\tLoss: 0.005020\n", 580 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.015339\n", 581 | "Train Epoch: 3 [58240/60000 (97%)]\tLoss: 0.022689\n", 582 | "Train Epoch: 3 [58880/60000 (98%)]\tLoss: 0.036816\n", 583 | "Train Epoch: 3 [59520/60000 (99%)]\tLoss: 0.025764\n", 584 | "\n", 585 | "Test set: Average loss: 0.0558, Accuracy: 9812/10000 (98%)\n", 586 | "\n", 587 | "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.020510\n", 588 | "Train Epoch: 4 [640/60000 (1%)]\tLoss: 0.060883\n", 589 | "Train Epoch: 4 [1280/60000 (2%)]\tLoss: 0.049800\n", 590 | "Train Epoch: 4 [1920/60000 (3%)]\tLoss: 0.052840\n", 591 | "Train Epoch: 4 [2560/60000 (4%)]\tLoss: 0.037942\n", 592 | "Train Epoch: 4 [3200/60000 (5%)]\tLoss: 0.066400\n", 593 | "Train Epoch: 4 [3840/60000 (6%)]\tLoss: 0.009949\n", 594 | "Train Epoch: 4 [4480/60000 (7%)]\tLoss: 0.056832\n", 595 | "Train Epoch: 4 [5120/60000 (9%)]\tLoss: 0.029494\n", 596 | "Train Epoch: 4 [5760/60000 (10%)]\tLoss: 0.063136\n", 597 | "Train Epoch: 4 [6400/60000 (11%)]\tLoss: 0.024719\n", 598 | "Train Epoch: 4 [7040/60000 (12%)]\tLoss: 0.046360\n", 599 | "Train Epoch: 4 [7680/60000 (13%)]\tLoss: 0.024510\n", 600 | "Train Epoch: 4 [8320/60000 (14%)]\tLoss: 0.056534\n", 601 | "Train Epoch: 4 [8960/60000 (15%)]\tLoss: 0.045378\n", 602 | "Train Epoch: 4 [9600/60000 (16%)]\tLoss: 0.011533\n", 603 | "Train Epoch: 4 [10240/60000 (17%)]\tLoss: 0.055759\n", 604 | "Train Epoch: 4 [10880/60000 (18%)]\tLoss: 0.085540\n", 605 | "Train Epoch: 4 [11520/60000 (19%)]\tLoss: 0.081261\n", 606 | "Train Epoch: 4 [12160/60000 (20%)]\tLoss: 0.027479\n", 607 | "Train Epoch: 4 [12800/60000 (21%)]\tLoss: 0.075687\n", 608 | "Train Epoch: 4 [13440/60000 (22%)]\tLoss: 0.021930\n", 609 | "Train Epoch: 4 [14080/60000 (23%)]\tLoss: 0.026600\n", 610 | "Train Epoch: 4 [14720/60000 (25%)]\tLoss: 0.012663\n", 611 | "Train Epoch: 4 [15360/60000 (26%)]\tLoss: 0.031141\n", 612 | "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.223421\n", 613 | "Train Epoch: 4 [16640/60000 (28%)]\tLoss: 0.068093\n", 614 | "Train Epoch: 4 [17280/60000 (29%)]\tLoss: 0.033539\n", 615 | "Train Epoch: 4 [17920/60000 (30%)]\tLoss: 0.017097\n", 616 | "Train Epoch: 4 [18560/60000 (31%)]\tLoss: 0.023774\n", 617 | "Train Epoch: 4 [19200/60000 (32%)]\tLoss: 0.024617\n", 618 | "Train Epoch: 4 [19840/60000 (33%)]\tLoss: 0.011677\n", 619 | "Train Epoch: 4 [20480/60000 (34%)]\tLoss: 0.090776\n", 620 | "Train Epoch: 4 [21120/60000 (35%)]\tLoss: 0.012044\n", 621 | "Train Epoch: 4 [21760/60000 (36%)]\tLoss: 0.018728\n", 622 | "Train Epoch: 4 [22400/60000 (37%)]\tLoss: 0.029524\n", 623 | "Train Epoch: 4 [23040/60000 (38%)]\tLoss: 0.042491\n", 624 | "Train Epoch: 4 [23680/60000 (39%)]\tLoss: 0.020629\n", 625 | "Train Epoch: 4 [24320/60000 (41%)]\tLoss: 0.025224\n", 626 | "Train Epoch: 4 [24960/60000 (42%)]\tLoss: 0.005827\n", 627 | "Train Epoch: 4 [25600/60000 (43%)]\tLoss: 0.016418\n", 628 | "Train Epoch: 4 [26240/60000 (44%)]\tLoss: 0.027710\n", 629 | "Train Epoch: 4 [26880/60000 (45%)]\tLoss: 0.039013\n", 630 | "Train Epoch: 4 [27520/60000 (46%)]\tLoss: 0.016448\n", 631 | "Train Epoch: 4 [28160/60000 (47%)]\tLoss: 0.042803\n", 632 | "Train Epoch: 4 [28800/60000 (48%)]\tLoss: 0.026534\n", 633 | "Train Epoch: 4 [29440/60000 (49%)]\tLoss: 0.006225\n", 634 | "Train Epoch: 4 [30080/60000 (50%)]\tLoss: 0.057307\n", 635 | "Train Epoch: 4 [30720/60000 (51%)]\tLoss: 0.008768\n", 636 | "Train Epoch: 4 [31360/60000 (52%)]\tLoss: 0.035020\n", 637 | "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.015098\n", 638 | "Train Epoch: 4 [32640/60000 (54%)]\tLoss: 0.011490\n", 639 | "Train Epoch: 4 [33280/60000 (55%)]\tLoss: 0.124877\n", 640 | "Train Epoch: 4 [33920/60000 (57%)]\tLoss: 0.060335\n", 641 | "Train Epoch: 4 [34560/60000 (58%)]\tLoss: 0.087088\n", 642 | "Train Epoch: 4 [35200/60000 (59%)]\tLoss: 0.076318\n", 643 | "Train Epoch: 4 [35840/60000 (60%)]\tLoss: 0.029793\n", 644 | "Train Epoch: 4 [36480/60000 (61%)]\tLoss: 0.129605\n", 645 | "Train Epoch: 4 [37120/60000 (62%)]\tLoss: 0.008943\n", 646 | "Train Epoch: 4 [37760/60000 (63%)]\tLoss: 0.031335\n", 647 | "Train Epoch: 4 [38400/60000 (64%)]\tLoss: 0.010181\n", 648 | "Train Epoch: 4 [39040/60000 (65%)]\tLoss: 0.011609\n", 649 | "Train Epoch: 4 [39680/60000 (66%)]\tLoss: 0.053345\n", 650 | "Train Epoch: 4 [40320/60000 (67%)]\tLoss: 0.151075\n", 651 | "Train Epoch: 4 [40960/60000 (68%)]\tLoss: 0.011086\n", 652 | "Train Epoch: 4 [41600/60000 (69%)]\tLoss: 0.033610\n", 653 | "Train Epoch: 4 [42240/60000 (70%)]\tLoss: 0.006295\n", 654 | "Train Epoch: 4 [42880/60000 (71%)]\tLoss: 0.011702\n", 655 | "Train Epoch: 4 [43520/60000 (72%)]\tLoss: 0.025307\n", 656 | "Train Epoch: 4 [44160/60000 (74%)]\tLoss: 0.014591\n", 657 | "Train Epoch: 4 [44800/60000 (75%)]\tLoss: 0.068241\n", 658 | "Train Epoch: 4 [45440/60000 (76%)]\tLoss: 0.029382\n", 659 | "Train Epoch: 4 [46080/60000 (77%)]\tLoss: 0.041087\n", 660 | "Train Epoch: 4 [46720/60000 (78%)]\tLoss: 0.009940\n", 661 | "Train Epoch: 4 [47360/60000 (79%)]\tLoss: 0.074880\n", 662 | "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.024332\n", 663 | "Train Epoch: 4 [48640/60000 (81%)]\tLoss: 0.088959\n", 664 | "Train Epoch: 4 [49280/60000 (82%)]\tLoss: 0.018842\n", 665 | "Train Epoch: 4 [49920/60000 (83%)]\tLoss: 0.009611\n", 666 | "Train Epoch: 4 [50560/60000 (84%)]\tLoss: 0.018992\n", 667 | "Train Epoch: 4 [51200/60000 (85%)]\tLoss: 0.013367\n", 668 | "Train Epoch: 4 [51840/60000 (86%)]\tLoss: 0.018389\n", 669 | "Train Epoch: 4 [52480/60000 (87%)]\tLoss: 0.029065\n", 670 | "Train Epoch: 4 [53120/60000 (88%)]\tLoss: 0.089412\n", 671 | "Train Epoch: 4 [53760/60000 (90%)]\tLoss: 0.020477\n", 672 | "Train Epoch: 4 [54400/60000 (91%)]\tLoss: 0.021283\n", 673 | "Train Epoch: 4 [55040/60000 (92%)]\tLoss: 0.132218\n", 674 | "Train Epoch: 4 [55680/60000 (93%)]\tLoss: 0.011182\n", 675 | "Train Epoch: 4 [56320/60000 (94%)]\tLoss: 0.043813\n", 676 | "Train Epoch: 4 [56960/60000 (95%)]\tLoss: 0.037521\n", 677 | "Train Epoch: 4 [57600/60000 (96%)]\tLoss: 0.064621\n", 678 | "Train Epoch: 4 [58240/60000 (97%)]\tLoss: 0.077974\n", 679 | "Train Epoch: 4 [58880/60000 (98%)]\tLoss: 0.025912\n", 680 | "Train Epoch: 4 [59520/60000 (99%)]\tLoss: 0.032939\n", 681 | "\n", 682 | "Test set: Average loss: 0.0408, Accuracy: 9867/10000 (99%)\n", 683 | "\n", 684 | "Train Epoch: 5 [0/60000 (0%)]\tLoss: 0.010725\n", 685 | "Train Epoch: 5 [640/60000 (1%)]\tLoss: 0.008210\n", 686 | "Train Epoch: 5 [1280/60000 (2%)]\tLoss: 0.013886\n", 687 | "Train Epoch: 5 [1920/60000 (3%)]\tLoss: 0.012739\n", 688 | "Train Epoch: 5 [2560/60000 (4%)]\tLoss: 0.021050\n", 689 | "Train Epoch: 5 [3200/60000 (5%)]\tLoss: 0.020811\n", 690 | "Train Epoch: 5 [3840/60000 (6%)]\tLoss: 0.006079\n", 691 | "Train Epoch: 5 [4480/60000 (7%)]\tLoss: 0.048679\n", 692 | "Train Epoch: 5 [5120/60000 (9%)]\tLoss: 0.167226\n", 693 | "Train Epoch: 5 [5760/60000 (10%)]\tLoss: 0.001919\n", 694 | "Train Epoch: 5 [6400/60000 (11%)]\tLoss: 0.057660\n", 695 | "Train Epoch: 5 [7040/60000 (12%)]\tLoss: 0.046041\n", 696 | "Train Epoch: 5 [7680/60000 (13%)]\tLoss: 0.034117\n", 697 | "Train Epoch: 5 [8320/60000 (14%)]\tLoss: 0.011748\n", 698 | "Train Epoch: 5 [8960/60000 (15%)]\tLoss: 0.072014\n", 699 | "Train Epoch: 5 [9600/60000 (16%)]\tLoss: 0.021567\n", 700 | "Train Epoch: 5 [10240/60000 (17%)]\tLoss: 0.076995\n", 701 | "Train Epoch: 5 [10880/60000 (18%)]\tLoss: 0.016878\n", 702 | "Train Epoch: 5 [11520/60000 (19%)]\tLoss: 0.010503\n", 703 | "Train Epoch: 5 [12160/60000 (20%)]\tLoss: 0.011180\n", 704 | "Train Epoch: 5 [12800/60000 (21%)]\tLoss: 0.064112\n", 705 | "Train Epoch: 5 [13440/60000 (22%)]\tLoss: 0.071387\n", 706 | "Train Epoch: 5 [14080/60000 (23%)]\tLoss: 0.015708\n", 707 | "Train Epoch: 5 [14720/60000 (25%)]\tLoss: 0.014479\n", 708 | "Train Epoch: 5 [15360/60000 (26%)]\tLoss: 0.064526\n", 709 | "Train Epoch: 5 [16000/60000 (27%)]\tLoss: 0.032539\n", 710 | "Train Epoch: 5 [16640/60000 (28%)]\tLoss: 0.009585\n", 711 | "Train Epoch: 5 [17280/60000 (29%)]\tLoss: 0.042810\n", 712 | "Train Epoch: 5 [17920/60000 (30%)]\tLoss: 0.037855\n", 713 | "Train Epoch: 5 [18560/60000 (31%)]\tLoss: 0.006764\n", 714 | "Train Epoch: 5 [19200/60000 (32%)]\tLoss: 0.053102\n", 715 | "Train Epoch: 5 [19840/60000 (33%)]\tLoss: 0.032614\n", 716 | "Train Epoch: 5 [20480/60000 (34%)]\tLoss: 0.067056\n", 717 | "Train Epoch: 5 [21120/60000 (35%)]\tLoss: 0.005527\n", 718 | "Train Epoch: 5 [21760/60000 (36%)]\tLoss: 0.026380\n", 719 | "Train Epoch: 5 [22400/60000 (37%)]\tLoss: 0.018374\n", 720 | "Train Epoch: 5 [23040/60000 (38%)]\tLoss: 0.060447\n", 721 | "Train Epoch: 5 [23680/60000 (39%)]\tLoss: 0.008874\n", 722 | "Train Epoch: 5 [24320/60000 (41%)]\tLoss: 0.019862\n", 723 | "Train Epoch: 5 [24960/60000 (42%)]\tLoss: 0.026592\n", 724 | "Train Epoch: 5 [25600/60000 (43%)]\tLoss: 0.024289\n", 725 | "Train Epoch: 5 [26240/60000 (44%)]\tLoss: 0.110426\n", 726 | "Train Epoch: 5 [26880/60000 (45%)]\tLoss: 0.005285\n", 727 | "Train Epoch: 5 [27520/60000 (46%)]\tLoss: 0.058936\n", 728 | "Train Epoch: 5 [28160/60000 (47%)]\tLoss: 0.086281\n", 729 | "Train Epoch: 5 [28800/60000 (48%)]\tLoss: 0.004462\n", 730 | "Train Epoch: 5 [29440/60000 (49%)]\tLoss: 0.169852\n", 731 | "Train Epoch: 5 [30080/60000 (50%)]\tLoss: 0.069522\n", 732 | "Train Epoch: 5 [30720/60000 (51%)]\tLoss: 0.014954\n", 733 | "Train Epoch: 5 [31360/60000 (52%)]\tLoss: 0.026138\n", 734 | "Train Epoch: 5 [32000/60000 (53%)]\tLoss: 0.041306\n", 735 | "Train Epoch: 5 [32640/60000 (54%)]\tLoss: 0.010898\n", 736 | "Train Epoch: 5 [33280/60000 (55%)]\tLoss: 0.029110\n", 737 | "Train Epoch: 5 [33920/60000 (57%)]\tLoss: 0.105546\n", 738 | "Train Epoch: 5 [34560/60000 (58%)]\tLoss: 0.023982\n", 739 | "Train Epoch: 5 [35200/60000 (59%)]\tLoss: 0.027237\n", 740 | "Train Epoch: 5 [35840/60000 (60%)]\tLoss: 0.017833\n", 741 | "Train Epoch: 5 [36480/60000 (61%)]\tLoss: 0.045222\n", 742 | "Train Epoch: 5 [37120/60000 (62%)]\tLoss: 0.020116\n", 743 | "Train Epoch: 5 [37760/60000 (63%)]\tLoss: 0.177658\n", 744 | "Train Epoch: 5 [38400/60000 (64%)]\tLoss: 0.055709\n", 745 | "Train Epoch: 5 [39040/60000 (65%)]\tLoss: 0.026385\n", 746 | "Train Epoch: 5 [39680/60000 (66%)]\tLoss: 0.023140\n", 747 | "Train Epoch: 5 [40320/60000 (67%)]\tLoss: 0.064874\n", 748 | "Train Epoch: 5 [40960/60000 (68%)]\tLoss: 0.023232\n", 749 | "Train Epoch: 5 [41600/60000 (69%)]\tLoss: 0.027616\n", 750 | "Train Epoch: 5 [42240/60000 (70%)]\tLoss: 0.008820\n", 751 | "Train Epoch: 5 [42880/60000 (71%)]\tLoss: 0.008302\n", 752 | "Train Epoch: 5 [43520/60000 (72%)]\tLoss: 0.015962\n", 753 | "Train Epoch: 5 [44160/60000 (74%)]\tLoss: 0.004702\n", 754 | "Train Epoch: 5 [44800/60000 (75%)]\tLoss: 0.177704\n", 755 | "Train Epoch: 5 [45440/60000 (76%)]\tLoss: 0.016430\n", 756 | "Train Epoch: 5 [46080/60000 (77%)]\tLoss: 0.023789\n", 757 | "Train Epoch: 5 [46720/60000 (78%)]\tLoss: 0.019787\n", 758 | "Train Epoch: 5 [47360/60000 (79%)]\tLoss: 0.031269\n", 759 | "Train Epoch: 5 [48000/60000 (80%)]\tLoss: 0.032767\n", 760 | "Train Epoch: 5 [48640/60000 (81%)]\tLoss: 0.022449\n", 761 | "Train Epoch: 5 [49280/60000 (82%)]\tLoss: 0.017720\n", 762 | "Train Epoch: 5 [49920/60000 (83%)]\tLoss: 0.021786\n", 763 | "Train Epoch: 5 [50560/60000 (84%)]\tLoss: 0.011756\n", 764 | "Train Epoch: 5 [51200/60000 (85%)]\tLoss: 0.072087\n", 765 | "Train Epoch: 5 [51840/60000 (86%)]\tLoss: 0.104906\n", 766 | "Train Epoch: 5 [52480/60000 (87%)]\tLoss: 0.196716\n", 767 | "Train Epoch: 5 [53120/60000 (88%)]\tLoss: 0.011484\n", 768 | "Train Epoch: 5 [53760/60000 (90%)]\tLoss: 0.002492\n", 769 | "Train Epoch: 5 [54400/60000 (91%)]\tLoss: 0.026741\n", 770 | "Train Epoch: 5 [55040/60000 (92%)]\tLoss: 0.028591\n", 771 | "Train Epoch: 5 [55680/60000 (93%)]\tLoss: 0.015796\n", 772 | "Train Epoch: 5 [56320/60000 (94%)]\tLoss: 0.013676\n", 773 | "Train Epoch: 5 [56960/60000 (95%)]\tLoss: 0.016341\n", 774 | "Train Epoch: 5 [57600/60000 (96%)]\tLoss: 0.030674\n", 775 | "Train Epoch: 5 [58240/60000 (97%)]\tLoss: 0.015367\n", 776 | "Train Epoch: 5 [58880/60000 (98%)]\tLoss: 0.017187\n", 777 | "Train Epoch: 5 [59520/60000 (99%)]\tLoss: 0.015880\n", 778 | "\n", 779 | "Test set: Average loss: 0.0382, Accuracy: 9866/10000 (99%)\n", 780 | "\n", 781 | "Train Epoch: 6 [0/60000 (0%)]\tLoss: 0.129179\n", 782 | "Train Epoch: 6 [640/60000 (1%)]\tLoss: 0.052901\n", 783 | "Train Epoch: 6 [1280/60000 (2%)]\tLoss: 0.105881\n", 784 | "Train Epoch: 6 [1920/60000 (3%)]\tLoss: 0.035592\n", 785 | "Train Epoch: 6 [2560/60000 (4%)]\tLoss: 0.006246\n", 786 | "Train Epoch: 6 [3200/60000 (5%)]\tLoss: 0.047103\n", 787 | "Train Epoch: 6 [3840/60000 (6%)]\tLoss: 0.007077\n", 788 | "Train Epoch: 6 [4480/60000 (7%)]\tLoss: 0.024196\n", 789 | "Train Epoch: 6 [5120/60000 (9%)]\tLoss: 0.039204\n", 790 | "Train Epoch: 6 [5760/60000 (10%)]\tLoss: 0.006895\n", 791 | "Train Epoch: 6 [6400/60000 (11%)]\tLoss: 0.023380\n", 792 | "Train Epoch: 6 [7040/60000 (12%)]\tLoss: 0.014487\n", 793 | "Train Epoch: 6 [7680/60000 (13%)]\tLoss: 0.040916\n", 794 | "Train Epoch: 6 [8320/60000 (14%)]\tLoss: 0.011987\n", 795 | "Train Epoch: 6 [8960/60000 (15%)]\tLoss: 0.054450\n", 796 | "Train Epoch: 6 [9600/60000 (16%)]\tLoss: 0.022615\n", 797 | "Train Epoch: 6 [10240/60000 (17%)]\tLoss: 0.004019\n", 798 | "Train Epoch: 6 [10880/60000 (18%)]\tLoss: 0.060761\n", 799 | "Train Epoch: 6 [11520/60000 (19%)]\tLoss: 0.034817\n", 800 | "Train Epoch: 6 [12160/60000 (20%)]\tLoss: 0.015127\n", 801 | "Train Epoch: 6 [12800/60000 (21%)]\tLoss: 0.035487\n", 802 | "Train Epoch: 6 [13440/60000 (22%)]\tLoss: 0.015562\n", 803 | "Train Epoch: 6 [14080/60000 (23%)]\tLoss: 0.027827\n", 804 | "Train Epoch: 6 [14720/60000 (25%)]\tLoss: 0.006514\n", 805 | "Train Epoch: 6 [15360/60000 (26%)]\tLoss: 0.011376\n", 806 | "Train Epoch: 6 [16000/60000 (27%)]\tLoss: 0.016648\n", 807 | "Train Epoch: 6 [16640/60000 (28%)]\tLoss: 0.012704\n", 808 | "Train Epoch: 6 [17280/60000 (29%)]\tLoss: 0.012249\n", 809 | "Train Epoch: 6 [17920/60000 (30%)]\tLoss: 0.057439\n", 810 | "Train Epoch: 6 [18560/60000 (31%)]\tLoss: 0.023276\n", 811 | "Train Epoch: 6 [19200/60000 (32%)]\tLoss: 0.017810\n", 812 | "Train Epoch: 6 [19840/60000 (33%)]\tLoss: 0.032417\n", 813 | "Train Epoch: 6 [20480/60000 (34%)]\tLoss: 0.038640\n", 814 | "Train Epoch: 6 [21120/60000 (35%)]\tLoss: 0.002132\n", 815 | "Train Epoch: 6 [21760/60000 (36%)]\tLoss: 0.061984\n", 816 | "Train Epoch: 6 [22400/60000 (37%)]\tLoss: 0.029548\n", 817 | "Train Epoch: 6 [23040/60000 (38%)]\tLoss: 0.003669\n", 818 | "Train Epoch: 6 [23680/60000 (39%)]\tLoss: 0.018836\n", 819 | "Train Epoch: 6 [24320/60000 (41%)]\tLoss: 0.009828\n", 820 | "Train Epoch: 6 [24960/60000 (42%)]\tLoss: 0.023059\n", 821 | "Train Epoch: 6 [25600/60000 (43%)]\tLoss: 0.065207\n", 822 | "Train Epoch: 6 [26240/60000 (44%)]\tLoss: 0.048029\n", 823 | "Train Epoch: 6 [26880/60000 (45%)]\tLoss: 0.024769\n", 824 | "Train Epoch: 6 [27520/60000 (46%)]\tLoss: 0.006746\n", 825 | "Train Epoch: 6 [28160/60000 (47%)]\tLoss: 0.003532\n", 826 | "Train Epoch: 6 [28800/60000 (48%)]\tLoss: 0.007101\n", 827 | "Train Epoch: 6 [29440/60000 (49%)]\tLoss: 0.011812\n", 828 | "Train Epoch: 6 [30080/60000 (50%)]\tLoss: 0.010295\n", 829 | "Train Epoch: 6 [30720/60000 (51%)]\tLoss: 0.022183\n", 830 | "Train Epoch: 6 [31360/60000 (52%)]\tLoss: 0.030085\n", 831 | "Train Epoch: 6 [32000/60000 (53%)]\tLoss: 0.013581\n", 832 | "Train Epoch: 6 [32640/60000 (54%)]\tLoss: 0.039882\n", 833 | "Train Epoch: 6 [33280/60000 (55%)]\tLoss: 0.007422\n", 834 | "Train Epoch: 6 [33920/60000 (57%)]\tLoss: 0.013132\n", 835 | "Train Epoch: 6 [34560/60000 (58%)]\tLoss: 0.089966\n", 836 | "Train Epoch: 6 [35200/60000 (59%)]\tLoss: 0.059206\n", 837 | "Train Epoch: 6 [35840/60000 (60%)]\tLoss: 0.004116\n", 838 | "Train Epoch: 6 [36480/60000 (61%)]\tLoss: 0.012886\n", 839 | "Train Epoch: 6 [37120/60000 (62%)]\tLoss: 0.007551\n", 840 | "Train Epoch: 6 [37760/60000 (63%)]\tLoss: 0.031933\n", 841 | "Train Epoch: 6 [38400/60000 (64%)]\tLoss: 0.043634\n", 842 | "Train Epoch: 6 [39040/60000 (65%)]\tLoss: 0.004200\n", 843 | "Train Epoch: 6 [39680/60000 (66%)]\tLoss: 0.071981\n", 844 | "Train Epoch: 6 [40320/60000 (67%)]\tLoss: 0.008260\n", 845 | "Train Epoch: 6 [40960/60000 (68%)]\tLoss: 0.080835\n", 846 | "Train Epoch: 6 [41600/60000 (69%)]\tLoss: 0.010787\n", 847 | "Train Epoch: 6 [42240/60000 (70%)]\tLoss: 0.002323\n", 848 | "Train Epoch: 6 [42880/60000 (71%)]\tLoss: 0.009086\n", 849 | "Train Epoch: 6 [43520/60000 (72%)]\tLoss: 0.025457\n", 850 | "Train Epoch: 6 [44160/60000 (74%)]\tLoss: 0.054194\n", 851 | "Train Epoch: 6 [44800/60000 (75%)]\tLoss: 0.007660\n", 852 | "Train Epoch: 6 [45440/60000 (76%)]\tLoss: 0.082376\n", 853 | "Train Epoch: 6 [46080/60000 (77%)]\tLoss: 0.051475\n", 854 | "Train Epoch: 6 [46720/60000 (78%)]\tLoss: 0.015938\n", 855 | "Train Epoch: 6 [47360/60000 (79%)]\tLoss: 0.032405\n", 856 | "Train Epoch: 6 [48000/60000 (80%)]\tLoss: 0.027912\n", 857 | "Train Epoch: 6 [48640/60000 (81%)]\tLoss: 0.054042\n", 858 | "Train Epoch: 6 [49280/60000 (82%)]\tLoss: 0.012704\n", 859 | "Train Epoch: 6 [49920/60000 (83%)]\tLoss: 0.051956\n", 860 | "Train Epoch: 6 [50560/60000 (84%)]\tLoss: 0.014005\n", 861 | "Train Epoch: 6 [51200/60000 (85%)]\tLoss: 0.033439\n", 862 | "Train Epoch: 6 [51840/60000 (86%)]\tLoss: 0.006111\n", 863 | "Train Epoch: 6 [52480/60000 (87%)]\tLoss: 0.017442\n", 864 | "Train Epoch: 6 [53120/60000 (88%)]\tLoss: 0.021951\n", 865 | "Train Epoch: 6 [53760/60000 (90%)]\tLoss: 0.003026\n", 866 | "Train Epoch: 6 [54400/60000 (91%)]\tLoss: 0.027849\n", 867 | "Train Epoch: 6 [55040/60000 (92%)]\tLoss: 0.023539\n", 868 | "Train Epoch: 6 [55680/60000 (93%)]\tLoss: 0.019511\n", 869 | "Train Epoch: 6 [56320/60000 (94%)]\tLoss: 0.016041\n", 870 | "Train Epoch: 6 [56960/60000 (95%)]\tLoss: 0.023572\n", 871 | "Train Epoch: 6 [57600/60000 (96%)]\tLoss: 0.006037\n", 872 | "Train Epoch: 6 [58240/60000 (97%)]\tLoss: 0.029409\n", 873 | "Train Epoch: 6 [58880/60000 (98%)]\tLoss: 0.021295\n", 874 | "Train Epoch: 6 [59520/60000 (99%)]\tLoss: 0.002066\n", 875 | "\n", 876 | "Test set: Average loss: 0.0338, Accuracy: 9893/10000 (99%)\n", 877 | "\n", 878 | "Train Epoch: 7 [0/60000 (0%)]\tLoss: 0.030842\n", 879 | "Train Epoch: 7 [640/60000 (1%)]\tLoss: 0.018875\n", 880 | "Train Epoch: 7 [1280/60000 (2%)]\tLoss: 0.021560\n", 881 | "Train Epoch: 7 [1920/60000 (3%)]\tLoss: 0.036105\n", 882 | "Train Epoch: 7 [2560/60000 (4%)]\tLoss: 0.005390\n", 883 | "Train Epoch: 7 [3200/60000 (5%)]\tLoss: 0.006128\n", 884 | "Train Epoch: 7 [3840/60000 (6%)]\tLoss: 0.004727\n", 885 | "Train Epoch: 7 [4480/60000 (7%)]\tLoss: 0.002214\n", 886 | "Train Epoch: 7 [5120/60000 (9%)]\tLoss: 0.103841\n", 887 | "Train Epoch: 7 [5760/60000 (10%)]\tLoss: 0.063472\n", 888 | "Train Epoch: 7 [6400/60000 (11%)]\tLoss: 0.006702\n", 889 | "Train Epoch: 7 [7040/60000 (12%)]\tLoss: 0.012966\n", 890 | "Train Epoch: 7 [7680/60000 (13%)]\tLoss: 0.011896\n", 891 | "Train Epoch: 7 [8320/60000 (14%)]\tLoss: 0.004966\n", 892 | "Train Epoch: 7 [8960/60000 (15%)]\tLoss: 0.020521\n", 893 | "Train Epoch: 7 [9600/60000 (16%)]\tLoss: 0.053214\n", 894 | "Train Epoch: 7 [10240/60000 (17%)]\tLoss: 0.019352\n", 895 | "Train Epoch: 7 [10880/60000 (18%)]\tLoss: 0.002213\n", 896 | "Train Epoch: 7 [11520/60000 (19%)]\tLoss: 0.010669\n", 897 | "Train Epoch: 7 [12160/60000 (20%)]\tLoss: 0.006014\n", 898 | "Train Epoch: 7 [12800/60000 (21%)]\tLoss: 0.079333\n", 899 | "Train Epoch: 7 [13440/60000 (22%)]\tLoss: 0.006914\n", 900 | "Train Epoch: 7 [14080/60000 (23%)]\tLoss: 0.049042\n", 901 | "Train Epoch: 7 [14720/60000 (25%)]\tLoss: 0.014304\n", 902 | "Train Epoch: 7 [15360/60000 (26%)]\tLoss: 0.006564\n", 903 | "Train Epoch: 7 [16000/60000 (27%)]\tLoss: 0.037733\n", 904 | "Train Epoch: 7 [16640/60000 (28%)]\tLoss: 0.023643\n", 905 | "Train Epoch: 7 [17280/60000 (29%)]\tLoss: 0.038910\n", 906 | "Train Epoch: 7 [17920/60000 (30%)]\tLoss: 0.144011\n", 907 | "Train Epoch: 7 [18560/60000 (31%)]\tLoss: 0.035657\n", 908 | "Train Epoch: 7 [19200/60000 (32%)]\tLoss: 0.026567\n", 909 | "Train Epoch: 7 [19840/60000 (33%)]\tLoss: 0.034334\n", 910 | "Train Epoch: 7 [20480/60000 (34%)]\tLoss: 0.018104\n", 911 | "Train Epoch: 7 [21120/60000 (35%)]\tLoss: 0.044123\n", 912 | "Train Epoch: 7 [21760/60000 (36%)]\tLoss: 0.009869\n", 913 | "Train Epoch: 7 [22400/60000 (37%)]\tLoss: 0.026679\n", 914 | "Train Epoch: 7 [23040/60000 (38%)]\tLoss: 0.011242\n", 915 | "Train Epoch: 7 [23680/60000 (39%)]\tLoss: 0.010542\n", 916 | "Train Epoch: 7 [24320/60000 (41%)]\tLoss: 0.022665\n", 917 | "Train Epoch: 7 [24960/60000 (42%)]\tLoss: 0.004571\n", 918 | "Train Epoch: 7 [25600/60000 (43%)]\tLoss: 0.005623\n", 919 | "Train Epoch: 7 [26240/60000 (44%)]\tLoss: 0.010772\n", 920 | "Train Epoch: 7 [26880/60000 (45%)]\tLoss: 0.020150\n", 921 | "Train Epoch: 7 [27520/60000 (46%)]\tLoss: 0.062882\n", 922 | "Train Epoch: 7 [28160/60000 (47%)]\tLoss: 0.094671\n", 923 | "Train Epoch: 7 [28800/60000 (48%)]\tLoss: 0.022337\n", 924 | "Train Epoch: 7 [29440/60000 (49%)]\tLoss: 0.008545\n", 925 | "Train Epoch: 7 [30080/60000 (50%)]\tLoss: 0.013422\n", 926 | "Train Epoch: 7 [30720/60000 (51%)]\tLoss: 0.038039\n", 927 | "Train Epoch: 7 [31360/60000 (52%)]\tLoss: 0.002636\n", 928 | "Train Epoch: 7 [32000/60000 (53%)]\tLoss: 0.014368\n", 929 | "Train Epoch: 7 [32640/60000 (54%)]\tLoss: 0.122236\n", 930 | "Train Epoch: 7 [33280/60000 (55%)]\tLoss: 0.040202\n", 931 | "Train Epoch: 7 [33920/60000 (57%)]\tLoss: 0.003397\n", 932 | "Train Epoch: 7 [34560/60000 (58%)]\tLoss: 0.006169\n", 933 | "Train Epoch: 7 [35200/60000 (59%)]\tLoss: 0.022380\n", 934 | "Train Epoch: 7 [35840/60000 (60%)]\tLoss: 0.010948\n", 935 | "Train Epoch: 7 [36480/60000 (61%)]\tLoss: 0.008686\n", 936 | "Train Epoch: 7 [37120/60000 (62%)]\tLoss: 0.096220\n", 937 | "Train Epoch: 7 [37760/60000 (63%)]\tLoss: 0.018495\n", 938 | "Train Epoch: 7 [38400/60000 (64%)]\tLoss: 0.032429\n", 939 | "Train Epoch: 7 [39040/60000 (65%)]\tLoss: 0.162318\n", 940 | "Train Epoch: 7 [39680/60000 (66%)]\tLoss: 0.027695\n", 941 | "Train Epoch: 7 [40320/60000 (67%)]\tLoss: 0.014180\n", 942 | "Train Epoch: 7 [40960/60000 (68%)]\tLoss: 0.004929\n", 943 | "Train Epoch: 7 [41600/60000 (69%)]\tLoss: 0.047359\n", 944 | "Train Epoch: 7 [42240/60000 (70%)]\tLoss: 0.015176\n", 945 | "Train Epoch: 7 [42880/60000 (71%)]\tLoss: 0.117097\n", 946 | "Train Epoch: 7 [43520/60000 (72%)]\tLoss: 0.024032\n", 947 | "Train Epoch: 7 [44160/60000 (74%)]\tLoss: 0.053941\n", 948 | "Train Epoch: 7 [44800/60000 (75%)]\tLoss: 0.005273\n", 949 | "Train Epoch: 7 [45440/60000 (76%)]\tLoss: 0.003344\n", 950 | "Train Epoch: 7 [46080/60000 (77%)]\tLoss: 0.001879\n", 951 | "Train Epoch: 7 [46720/60000 (78%)]\tLoss: 0.029248\n", 952 | "Train Epoch: 7 [47360/60000 (79%)]\tLoss: 0.093488\n", 953 | "Train Epoch: 7 [48000/60000 (80%)]\tLoss: 0.025508\n", 954 | "Train Epoch: 7 [48640/60000 (81%)]\tLoss: 0.053003\n", 955 | "Train Epoch: 7 [49280/60000 (82%)]\tLoss: 0.040633\n", 956 | "Train Epoch: 7 [49920/60000 (83%)]\tLoss: 0.041014\n", 957 | "Train Epoch: 7 [50560/60000 (84%)]\tLoss: 0.012292\n", 958 | "Train Epoch: 7 [51200/60000 (85%)]\tLoss: 0.075577\n", 959 | "Train Epoch: 7 [51840/60000 (86%)]\tLoss: 0.033290\n", 960 | "Train Epoch: 7 [52480/60000 (87%)]\tLoss: 0.003156\n", 961 | "Train Epoch: 7 [53120/60000 (88%)]\tLoss: 0.103615\n", 962 | "Train Epoch: 7 [53760/60000 (90%)]\tLoss: 0.011532\n", 963 | "Train Epoch: 7 [54400/60000 (91%)]\tLoss: 0.004137\n", 964 | "Train Epoch: 7 [55040/60000 (92%)]\tLoss: 0.070434\n", 965 | "Train Epoch: 7 [55680/60000 (93%)]\tLoss: 0.034054\n", 966 | "Train Epoch: 7 [56320/60000 (94%)]\tLoss: 0.045853\n", 967 | "Train Epoch: 7 [56960/60000 (95%)]\tLoss: 0.003249\n", 968 | "Train Epoch: 7 [57600/60000 (96%)]\tLoss: 0.010081\n", 969 | "Train Epoch: 7 [58240/60000 (97%)]\tLoss: 0.049628\n", 970 | "Train Epoch: 7 [58880/60000 (98%)]\tLoss: 0.012825\n", 971 | "Train Epoch: 7 [59520/60000 (99%)]\tLoss: 0.034033\n", 972 | "\n", 973 | "Test set: Average loss: 0.0346, Accuracy: 9870/10000 (99%)\n", 974 | "\n", 975 | "Train Epoch: 8 [0/60000 (0%)]\tLoss: 0.004954\n", 976 | "Train Epoch: 8 [640/60000 (1%)]\tLoss: 0.005838\n", 977 | "Train Epoch: 8 [1280/60000 (2%)]\tLoss: 0.008048\n", 978 | "Train Epoch: 8 [1920/60000 (3%)]\tLoss: 0.012437\n", 979 | "Train Epoch: 8 [2560/60000 (4%)]\tLoss: 0.055547\n", 980 | "Train Epoch: 8 [3200/60000 (5%)]\tLoss: 0.031063\n", 981 | "Train Epoch: 8 [3840/60000 (6%)]\tLoss: 0.024472\n", 982 | "Train Epoch: 8 [4480/60000 (7%)]\tLoss: 0.005790\n", 983 | "Train Epoch: 8 [5120/60000 (9%)]\tLoss: 0.002776\n", 984 | "Train Epoch: 8 [5760/60000 (10%)]\tLoss: 0.009141\n", 985 | "Train Epoch: 8 [6400/60000 (11%)]\tLoss: 0.032461\n", 986 | "Train Epoch: 8 [7040/60000 (12%)]\tLoss: 0.003172\n", 987 | "Train Epoch: 8 [7680/60000 (13%)]\tLoss: 0.009444\n", 988 | "Train Epoch: 8 [8320/60000 (14%)]\tLoss: 0.001078\n", 989 | "Train Epoch: 8 [8960/60000 (15%)]\tLoss: 0.002386\n", 990 | "Train Epoch: 8 [9600/60000 (16%)]\tLoss: 0.061214\n", 991 | "Train Epoch: 8 [10240/60000 (17%)]\tLoss: 0.001249\n", 992 | "Train Epoch: 8 [10880/60000 (18%)]\tLoss: 0.006502\n", 993 | "Train Epoch: 8 [11520/60000 (19%)]\tLoss: 0.036804\n", 994 | "Train Epoch: 8 [12160/60000 (20%)]\tLoss: 0.006823\n", 995 | "Train Epoch: 8 [12800/60000 (21%)]\tLoss: 0.048486\n", 996 | "Train Epoch: 8 [13440/60000 (22%)]\tLoss: 0.024336\n", 997 | "Train Epoch: 8 [14080/60000 (23%)]\tLoss: 0.004238\n", 998 | "Train Epoch: 8 [14720/60000 (25%)]\tLoss: 0.055290\n", 999 | "Train Epoch: 8 [15360/60000 (26%)]\tLoss: 0.008394\n", 1000 | "Train Epoch: 8 [16000/60000 (27%)]\tLoss: 0.001927\n", 1001 | "Train Epoch: 8 [16640/60000 (28%)]\tLoss: 0.006265\n", 1002 | "Train Epoch: 8 [17280/60000 (29%)]\tLoss: 0.018604\n", 1003 | "Train Epoch: 8 [17920/60000 (30%)]\tLoss: 0.004004\n", 1004 | "Train Epoch: 8 [18560/60000 (31%)]\tLoss: 0.011987\n", 1005 | "Train Epoch: 8 [19200/60000 (32%)]\tLoss: 0.016144\n", 1006 | "Train Epoch: 8 [19840/60000 (33%)]\tLoss: 0.047296\n", 1007 | "Train Epoch: 8 [20480/60000 (34%)]\tLoss: 0.016332\n", 1008 | "Train Epoch: 8 [21120/60000 (35%)]\tLoss: 0.003445\n", 1009 | "Train Epoch: 8 [21760/60000 (36%)]\tLoss: 0.003513\n", 1010 | "Train Epoch: 8 [22400/60000 (37%)]\tLoss: 0.076792\n", 1011 | "Train Epoch: 8 [23040/60000 (38%)]\tLoss: 0.004755\n", 1012 | "Train Epoch: 8 [23680/60000 (39%)]\tLoss: 0.012643\n", 1013 | "Train Epoch: 8 [24320/60000 (41%)]\tLoss: 0.021600\n", 1014 | "Train Epoch: 8 [24960/60000 (42%)]\tLoss: 0.037292\n", 1015 | "Train Epoch: 8 [25600/60000 (43%)]\tLoss: 0.019060\n", 1016 | "Train Epoch: 8 [26240/60000 (44%)]\tLoss: 0.000584\n", 1017 | "Train Epoch: 8 [26880/60000 (45%)]\tLoss: 0.029867\n", 1018 | "Train Epoch: 8 [27520/60000 (46%)]\tLoss: 0.036785\n", 1019 | "Train Epoch: 8 [28160/60000 (47%)]\tLoss: 0.003309\n", 1020 | "Train Epoch: 8 [28800/60000 (48%)]\tLoss: 0.040396\n", 1021 | "Train Epoch: 8 [29440/60000 (49%)]\tLoss: 0.005590\n", 1022 | "Train Epoch: 8 [30080/60000 (50%)]\tLoss: 0.004038\n", 1023 | "Train Epoch: 8 [30720/60000 (51%)]\tLoss: 0.006873\n", 1024 | "Train Epoch: 8 [31360/60000 (52%)]\tLoss: 0.007503\n", 1025 | "Train Epoch: 8 [32000/60000 (53%)]\tLoss: 0.014477\n", 1026 | "Train Epoch: 8 [32640/60000 (54%)]\tLoss: 0.055405\n", 1027 | "Train Epoch: 8 [33280/60000 (55%)]\tLoss: 0.030415\n", 1028 | "Train Epoch: 8 [33920/60000 (57%)]\tLoss: 0.012079\n", 1029 | "Train Epoch: 8 [34560/60000 (58%)]\tLoss: 0.007796\n", 1030 | "Train Epoch: 8 [35200/60000 (59%)]\tLoss: 0.025386\n", 1031 | "Train Epoch: 8 [35840/60000 (60%)]\tLoss: 0.002304\n", 1032 | "Train Epoch: 8 [36480/60000 (61%)]\tLoss: 0.021220\n", 1033 | "Train Epoch: 8 [37120/60000 (62%)]\tLoss: 0.002316\n", 1034 | "Train Epoch: 8 [37760/60000 (63%)]\tLoss: 0.013260\n", 1035 | "Train Epoch: 8 [38400/60000 (64%)]\tLoss: 0.088766\n", 1036 | "Train Epoch: 8 [39040/60000 (65%)]\tLoss: 0.277053\n", 1037 | "Train Epoch: 8 [39680/60000 (66%)]\tLoss: 0.016274\n", 1038 | "Train Epoch: 8 [40320/60000 (67%)]\tLoss: 0.004528\n", 1039 | "Train Epoch: 8 [40960/60000 (68%)]\tLoss: 0.001379\n", 1040 | "Train Epoch: 8 [41600/60000 (69%)]\tLoss: 0.003366\n", 1041 | "Train Epoch: 8 [42240/60000 (70%)]\tLoss: 0.002621\n", 1042 | "Train Epoch: 8 [42880/60000 (71%)]\tLoss: 0.089309\n", 1043 | "Train Epoch: 8 [43520/60000 (72%)]\tLoss: 0.010128\n", 1044 | "Train Epoch: 8 [44160/60000 (74%)]\tLoss: 0.177295\n", 1045 | "Train Epoch: 8 [44800/60000 (75%)]\tLoss: 0.007664\n", 1046 | "Train Epoch: 8 [45440/60000 (76%)]\tLoss: 0.014377\n", 1047 | "Train Epoch: 8 [46080/60000 (77%)]\tLoss: 0.013430\n", 1048 | "Train Epoch: 8 [46720/60000 (78%)]\tLoss: 0.015274\n", 1049 | "Train Epoch: 8 [47360/60000 (79%)]\tLoss: 0.003127\n", 1050 | "Train Epoch: 8 [48000/60000 (80%)]\tLoss: 0.047396\n", 1051 | "Train Epoch: 8 [48640/60000 (81%)]\tLoss: 0.005713\n", 1052 | "Train Epoch: 8 [49280/60000 (82%)]\tLoss: 0.008856\n", 1053 | "Train Epoch: 8 [49920/60000 (83%)]\tLoss: 0.022905\n", 1054 | "Train Epoch: 8 [50560/60000 (84%)]\tLoss: 0.032476\n", 1055 | "Train Epoch: 8 [51200/60000 (85%)]\tLoss: 0.071746\n", 1056 | "Train Epoch: 8 [51840/60000 (86%)]\tLoss: 0.012006\n", 1057 | "Train Epoch: 8 [52480/60000 (87%)]\tLoss: 0.022184\n", 1058 | "Train Epoch: 8 [53120/60000 (88%)]\tLoss: 0.004263\n", 1059 | "Train Epoch: 8 [53760/60000 (90%)]\tLoss: 0.013653\n", 1060 | "Train Epoch: 8 [54400/60000 (91%)]\tLoss: 0.014716\n", 1061 | "Train Epoch: 8 [55040/60000 (92%)]\tLoss: 0.013155\n", 1062 | "Train Epoch: 8 [55680/60000 (93%)]\tLoss: 0.007262\n", 1063 | "Train Epoch: 8 [56320/60000 (94%)]\tLoss: 0.005386\n", 1064 | "Train Epoch: 8 [56960/60000 (95%)]\tLoss: 0.035712\n", 1065 | "Train Epoch: 8 [57600/60000 (96%)]\tLoss: 0.009545\n", 1066 | "Train Epoch: 8 [58240/60000 (97%)]\tLoss: 0.040252\n", 1067 | "Train Epoch: 8 [58880/60000 (98%)]\tLoss: 0.002436\n", 1068 | "Train Epoch: 8 [59520/60000 (99%)]\tLoss: 0.121468\n", 1069 | "\n", 1070 | "Test set: Average loss: 0.0383, Accuracy: 9876/10000 (99%)\n", 1071 | "\n", 1072 | "Train Epoch: 9 [0/60000 (0%)]\tLoss: 0.088630\n", 1073 | "Train Epoch: 9 [640/60000 (1%)]\tLoss: 0.034617\n", 1074 | "Train Epoch: 9 [1280/60000 (2%)]\tLoss: 0.008059\n", 1075 | "Train Epoch: 9 [1920/60000 (3%)]\tLoss: 0.023889\n", 1076 | "Train Epoch: 9 [2560/60000 (4%)]\tLoss: 0.002230\n", 1077 | "Train Epoch: 9 [3200/60000 (5%)]\tLoss: 0.031798\n", 1078 | "Train Epoch: 9 [3840/60000 (6%)]\tLoss: 0.055308\n", 1079 | "Train Epoch: 9 [4480/60000 (7%)]\tLoss: 0.002510\n", 1080 | "Train Epoch: 9 [5120/60000 (9%)]\tLoss: 0.017774\n", 1081 | "Train Epoch: 9 [5760/60000 (10%)]\tLoss: 0.004589\n", 1082 | "Train Epoch: 9 [6400/60000 (11%)]\tLoss: 0.007489\n", 1083 | "Train Epoch: 9 [7040/60000 (12%)]\tLoss: 0.011340\n", 1084 | "Train Epoch: 9 [7680/60000 (13%)]\tLoss: 0.021789\n", 1085 | "Train Epoch: 9 [8320/60000 (14%)]\tLoss: 0.031297\n", 1086 | "Train Epoch: 9 [8960/60000 (15%)]\tLoss: 0.004788\n", 1087 | "Train Epoch: 9 [9600/60000 (16%)]\tLoss: 0.019253\n", 1088 | "Train Epoch: 9 [10240/60000 (17%)]\tLoss: 0.028170\n", 1089 | "Train Epoch: 9 [10880/60000 (18%)]\tLoss: 0.000594\n", 1090 | "Train Epoch: 9 [11520/60000 (19%)]\tLoss: 0.009160\n", 1091 | "Train Epoch: 9 [12160/60000 (20%)]\tLoss: 0.003196\n", 1092 | "Train Epoch: 9 [12800/60000 (21%)]\tLoss: 0.011406\n", 1093 | "Train Epoch: 9 [13440/60000 (22%)]\tLoss: 0.014823\n", 1094 | "Train Epoch: 9 [14080/60000 (23%)]\tLoss: 0.006777\n", 1095 | "Train Epoch: 9 [14720/60000 (25%)]\tLoss: 0.010902\n", 1096 | "Train Epoch: 9 [15360/60000 (26%)]\tLoss: 0.002905\n", 1097 | "Train Epoch: 9 [16000/60000 (27%)]\tLoss: 0.008538\n", 1098 | "Train Epoch: 9 [16640/60000 (28%)]\tLoss: 0.004745\n", 1099 | "Train Epoch: 9 [17280/60000 (29%)]\tLoss: 0.029966\n", 1100 | "Train Epoch: 9 [17920/60000 (30%)]\tLoss: 0.012189\n", 1101 | "Train Epoch: 9 [18560/60000 (31%)]\tLoss: 0.040613\n", 1102 | "Train Epoch: 9 [19200/60000 (32%)]\tLoss: 0.005084\n", 1103 | "Train Epoch: 9 [19840/60000 (33%)]\tLoss: 0.015768\n", 1104 | "Train Epoch: 9 [20480/60000 (34%)]\tLoss: 0.011173\n", 1105 | "Train Epoch: 9 [21120/60000 (35%)]\tLoss: 0.016716\n", 1106 | "Train Epoch: 9 [21760/60000 (36%)]\tLoss: 0.013562\n", 1107 | "Train Epoch: 9 [22400/60000 (37%)]\tLoss: 0.011166\n", 1108 | "Train Epoch: 9 [23040/60000 (38%)]\tLoss: 0.016691\n", 1109 | "Train Epoch: 9 [23680/60000 (39%)]\tLoss: 0.052693\n", 1110 | "Train Epoch: 9 [24320/60000 (41%)]\tLoss: 0.033988\n", 1111 | "Train Epoch: 9 [24960/60000 (42%)]\tLoss: 0.014907\n", 1112 | "Train Epoch: 9 [25600/60000 (43%)]\tLoss: 0.006444\n", 1113 | "Train Epoch: 9 [26240/60000 (44%)]\tLoss: 0.031594\n", 1114 | "Train Epoch: 9 [26880/60000 (45%)]\tLoss: 0.030838\n", 1115 | "Train Epoch: 9 [27520/60000 (46%)]\tLoss: 0.017357\n", 1116 | "Train Epoch: 9 [28160/60000 (47%)]\tLoss: 0.011702\n", 1117 | "Train Epoch: 9 [28800/60000 (48%)]\tLoss: 0.007437\n", 1118 | "Train Epoch: 9 [29440/60000 (49%)]\tLoss: 0.005403\n", 1119 | "Train Epoch: 9 [30080/60000 (50%)]\tLoss: 0.009932\n", 1120 | "Train Epoch: 9 [30720/60000 (51%)]\tLoss: 0.005162\n", 1121 | "Train Epoch: 9 [31360/60000 (52%)]\tLoss: 0.040360\n", 1122 | "Train Epoch: 9 [32000/60000 (53%)]\tLoss: 0.005927\n", 1123 | "Train Epoch: 9 [32640/60000 (54%)]\tLoss: 0.028381\n", 1124 | "Train Epoch: 9 [33280/60000 (55%)]\tLoss: 0.092976\n", 1125 | "Train Epoch: 9 [33920/60000 (57%)]\tLoss: 0.003711\n", 1126 | "Train Epoch: 9 [34560/60000 (58%)]\tLoss: 0.034436\n", 1127 | "Train Epoch: 9 [35200/60000 (59%)]\tLoss: 0.016767\n", 1128 | "Train Epoch: 9 [35840/60000 (60%)]\tLoss: 0.009220\n", 1129 | "Train Epoch: 9 [36480/60000 (61%)]\tLoss: 0.002478\n", 1130 | "Train Epoch: 9 [37120/60000 (62%)]\tLoss: 0.002932\n", 1131 | "Train Epoch: 9 [37760/60000 (63%)]\tLoss: 0.014523\n", 1132 | "Train Epoch: 9 [38400/60000 (64%)]\tLoss: 0.005226\n", 1133 | "Train Epoch: 9 [39040/60000 (65%)]\tLoss: 0.035034\n", 1134 | "Train Epoch: 9 [39680/60000 (66%)]\tLoss: 0.035984\n", 1135 | "Train Epoch: 9 [40320/60000 (67%)]\tLoss: 0.019027\n", 1136 | "Train Epoch: 9 [40960/60000 (68%)]\tLoss: 0.003227\n", 1137 | "Train Epoch: 9 [41600/60000 (69%)]\tLoss: 0.006627\n", 1138 | "Train Epoch: 9 [42240/60000 (70%)]\tLoss: 0.004965\n", 1139 | "Train Epoch: 9 [42880/60000 (71%)]\tLoss: 0.025554\n", 1140 | "Train Epoch: 9 [43520/60000 (72%)]\tLoss: 0.004482\n", 1141 | "Train Epoch: 9 [44160/60000 (74%)]\tLoss: 0.002664\n", 1142 | "Train Epoch: 9 [44800/60000 (75%)]\tLoss: 0.035892\n", 1143 | "Train Epoch: 9 [45440/60000 (76%)]\tLoss: 0.022228\n", 1144 | "Train Epoch: 9 [46080/60000 (77%)]\tLoss: 0.066382\n", 1145 | "Train Epoch: 9 [46720/60000 (78%)]\tLoss: 0.049771\n", 1146 | "Train Epoch: 9 [47360/60000 (79%)]\tLoss: 0.001005\n", 1147 | "Train Epoch: 9 [48000/60000 (80%)]\tLoss: 0.018655\n", 1148 | "Train Epoch: 9 [48640/60000 (81%)]\tLoss: 0.000931\n", 1149 | "Train Epoch: 9 [49280/60000 (82%)]\tLoss: 0.019791\n", 1150 | "Train Epoch: 9 [49920/60000 (83%)]\tLoss: 0.002442\n", 1151 | "Train Epoch: 9 [50560/60000 (84%)]\tLoss: 0.008955\n", 1152 | "Train Epoch: 9 [51200/60000 (85%)]\tLoss: 0.004925\n", 1153 | "Train Epoch: 9 [51840/60000 (86%)]\tLoss: 0.005218\n", 1154 | "Train Epoch: 9 [52480/60000 (87%)]\tLoss: 0.029394\n", 1155 | "Train Epoch: 9 [53120/60000 (88%)]\tLoss: 0.004908\n", 1156 | "Train Epoch: 9 [53760/60000 (90%)]\tLoss: 0.037266\n", 1157 | "Train Epoch: 9 [54400/60000 (91%)]\tLoss: 0.002073\n", 1158 | "Train Epoch: 9 [55040/60000 (92%)]\tLoss: 0.059806\n", 1159 | "Train Epoch: 9 [55680/60000 (93%)]\tLoss: 0.013109\n", 1160 | "Train Epoch: 9 [56320/60000 (94%)]\tLoss: 0.008051\n", 1161 | "Train Epoch: 9 [56960/60000 (95%)]\tLoss: 0.070158\n", 1162 | "Train Epoch: 9 [57600/60000 (96%)]\tLoss: 0.086570\n", 1163 | "Train Epoch: 9 [58240/60000 (97%)]\tLoss: 0.011028\n", 1164 | "Train Epoch: 9 [58880/60000 (98%)]\tLoss: 0.078466\n", 1165 | "Train Epoch: 9 [59520/60000 (99%)]\tLoss: 0.010095\n", 1166 | "\n", 1167 | "Test set: Average loss: 0.0290, Accuracy: 9911/10000 (99%)\n", 1168 | "\n", 1169 | "Train Epoch: 10 [0/60000 (0%)]\tLoss: 0.121816\n", 1170 | "Train Epoch: 10 [640/60000 (1%)]\tLoss: 0.011049\n", 1171 | "Train Epoch: 10 [1280/60000 (2%)]\tLoss: 0.007778\n", 1172 | "Train Epoch: 10 [1920/60000 (3%)]\tLoss: 0.030580\n", 1173 | "Train Epoch: 10 [2560/60000 (4%)]\tLoss: 0.006801\n", 1174 | "Train Epoch: 10 [3200/60000 (5%)]\tLoss: 0.015644\n", 1175 | "Train Epoch: 10 [3840/60000 (6%)]\tLoss: 0.009172\n", 1176 | "Train Epoch: 10 [4480/60000 (7%)]\tLoss: 0.002725\n", 1177 | "Train Epoch: 10 [5120/60000 (9%)]\tLoss: 0.004766\n", 1178 | "Train Epoch: 10 [5760/60000 (10%)]\tLoss: 0.013712\n", 1179 | "Train Epoch: 10 [6400/60000 (11%)]\tLoss: 0.029625\n", 1180 | "Train Epoch: 10 [7040/60000 (12%)]\tLoss: 0.019506\n", 1181 | "Train Epoch: 10 [7680/60000 (13%)]\tLoss: 0.005904\n", 1182 | "Train Epoch: 10 [8320/60000 (14%)]\tLoss: 0.003161\n", 1183 | "Train Epoch: 10 [8960/60000 (15%)]\tLoss: 0.030412\n", 1184 | "Train Epoch: 10 [9600/60000 (16%)]\tLoss: 0.007489\n", 1185 | "Train Epoch: 10 [10240/60000 (17%)]\tLoss: 0.002637\n", 1186 | "Train Epoch: 10 [10880/60000 (18%)]\tLoss: 0.007939\n", 1187 | "Train Epoch: 10 [11520/60000 (19%)]\tLoss: 0.007080\n", 1188 | "Train Epoch: 10 [12160/60000 (20%)]\tLoss: 0.001840\n", 1189 | "Train Epoch: 10 [12800/60000 (21%)]\tLoss: 0.001410\n", 1190 | "Train Epoch: 10 [13440/60000 (22%)]\tLoss: 0.001227\n", 1191 | "Train Epoch: 10 [14080/60000 (23%)]\tLoss: 0.012916\n", 1192 | "Train Epoch: 10 [14720/60000 (25%)]\tLoss: 0.009867\n", 1193 | "Train Epoch: 10 [15360/60000 (26%)]\tLoss: 0.009356\n", 1194 | "Train Epoch: 10 [16000/60000 (27%)]\tLoss: 0.016766\n", 1195 | "Train Epoch: 10 [16640/60000 (28%)]\tLoss: 0.018208\n", 1196 | "Train Epoch: 10 [17280/60000 (29%)]\tLoss: 0.043678\n", 1197 | "Train Epoch: 10 [17920/60000 (30%)]\tLoss: 0.004466\n", 1198 | "Train Epoch: 10 [18560/60000 (31%)]\tLoss: 0.052181\n", 1199 | "Train Epoch: 10 [19200/60000 (32%)]\tLoss: 0.033881\n", 1200 | "Train Epoch: 10 [19840/60000 (33%)]\tLoss: 0.005014\n", 1201 | "Train Epoch: 10 [20480/60000 (34%)]\tLoss: 0.010579\n", 1202 | "Train Epoch: 10 [21120/60000 (35%)]\tLoss: 0.046698\n", 1203 | "Train Epoch: 10 [21760/60000 (36%)]\tLoss: 0.001839\n", 1204 | "Train Epoch: 10 [22400/60000 (37%)]\tLoss: 0.001649\n", 1205 | "Train Epoch: 10 [23040/60000 (38%)]\tLoss: 0.012382\n", 1206 | "Train Epoch: 10 [23680/60000 (39%)]\tLoss: 0.007124\n", 1207 | "Train Epoch: 10 [24320/60000 (41%)]\tLoss: 0.010925\n", 1208 | "Train Epoch: 10 [24960/60000 (42%)]\tLoss: 0.010823\n", 1209 | "Train Epoch: 10 [25600/60000 (43%)]\tLoss: 0.002183\n", 1210 | "Train Epoch: 10 [26240/60000 (44%)]\tLoss: 0.004735\n", 1211 | "Train Epoch: 10 [26880/60000 (45%)]\tLoss: 0.037791\n", 1212 | "Train Epoch: 10 [27520/60000 (46%)]\tLoss: 0.005325\n", 1213 | "Train Epoch: 10 [28160/60000 (47%)]\tLoss: 0.005741\n", 1214 | "Train Epoch: 10 [28800/60000 (48%)]\tLoss: 0.006020\n", 1215 | "Train Epoch: 10 [29440/60000 (49%)]\tLoss: 0.015782\n", 1216 | "Train Epoch: 10 [30080/60000 (50%)]\tLoss: 0.025564\n", 1217 | "Train Epoch: 10 [30720/60000 (51%)]\tLoss: 0.109878\n", 1218 | "Train Epoch: 10 [31360/60000 (52%)]\tLoss: 0.061099\n", 1219 | "Train Epoch: 10 [32000/60000 (53%)]\tLoss: 0.013350\n", 1220 | "Train Epoch: 10 [32640/60000 (54%)]\tLoss: 0.031633\n", 1221 | "Train Epoch: 10 [33280/60000 (55%)]\tLoss: 0.061641\n", 1222 | "Train Epoch: 10 [33920/60000 (57%)]\tLoss: 0.053409\n", 1223 | "Train Epoch: 10 [34560/60000 (58%)]\tLoss: 0.007698\n", 1224 | "Train Epoch: 10 [35200/60000 (59%)]\tLoss: 0.008173\n", 1225 | "Train Epoch: 10 [35840/60000 (60%)]\tLoss: 0.009077\n", 1226 | "Train Epoch: 10 [36480/60000 (61%)]\tLoss: 0.016080\n", 1227 | "Train Epoch: 10 [37120/60000 (62%)]\tLoss: 0.093979\n", 1228 | "Train Epoch: 10 [37760/60000 (63%)]\tLoss: 0.019608\n", 1229 | "Train Epoch: 10 [38400/60000 (64%)]\tLoss: 0.016654\n", 1230 | "Train Epoch: 10 [39040/60000 (65%)]\tLoss: 0.014212\n", 1231 | "Train Epoch: 10 [39680/60000 (66%)]\tLoss: 0.007410\n", 1232 | "Train Epoch: 10 [40320/60000 (67%)]\tLoss: 0.004689\n", 1233 | "Train Epoch: 10 [40960/60000 (68%)]\tLoss: 0.003324\n", 1234 | "Train Epoch: 10 [41600/60000 (69%)]\tLoss: 0.011062\n", 1235 | "Train Epoch: 10 [42240/60000 (70%)]\tLoss: 0.022511\n", 1236 | "Train Epoch: 10 [42880/60000 (71%)]\tLoss: 0.007078\n", 1237 | "Train Epoch: 10 [43520/60000 (72%)]\tLoss: 0.010904\n", 1238 | "Train Epoch: 10 [44160/60000 (74%)]\tLoss: 0.003113\n", 1239 | "Train Epoch: 10 [44800/60000 (75%)]\tLoss: 0.016834\n", 1240 | "Train Epoch: 10 [45440/60000 (76%)]\tLoss: 0.009184\n", 1241 | "Train Epoch: 10 [46080/60000 (77%)]\tLoss: 0.013544\n", 1242 | "Train Epoch: 10 [46720/60000 (78%)]\tLoss: 0.032994\n", 1243 | "Train Epoch: 10 [47360/60000 (79%)]\tLoss: 0.008709\n", 1244 | "Train Epoch: 10 [48000/60000 (80%)]\tLoss: 0.006710\n", 1245 | "Train Epoch: 10 [48640/60000 (81%)]\tLoss: 0.011893\n", 1246 | "Train Epoch: 10 [49280/60000 (82%)]\tLoss: 0.003608\n", 1247 | "Train Epoch: 10 [49920/60000 (83%)]\tLoss: 0.005951\n", 1248 | "Train Epoch: 10 [50560/60000 (84%)]\tLoss: 0.045160\n", 1249 | "Train Epoch: 10 [51200/60000 (85%)]\tLoss: 0.023386\n", 1250 | "Train Epoch: 10 [51840/60000 (86%)]\tLoss: 0.034269\n", 1251 | "Train Epoch: 10 [52480/60000 (87%)]\tLoss: 0.006814\n", 1252 | "Train Epoch: 10 [53120/60000 (88%)]\tLoss: 0.005255\n", 1253 | "Train Epoch: 10 [53760/60000 (90%)]\tLoss: 0.034382\n", 1254 | "Train Epoch: 10 [54400/60000 (91%)]\tLoss: 0.010701\n", 1255 | "Train Epoch: 10 [55040/60000 (92%)]\tLoss: 0.005467\n", 1256 | "Train Epoch: 10 [55680/60000 (93%)]\tLoss: 0.000586\n", 1257 | "Train Epoch: 10 [56320/60000 (94%)]\tLoss: 0.006687\n", 1258 | "Train Epoch: 10 [56960/60000 (95%)]\tLoss: 0.090429\n", 1259 | "Train Epoch: 10 [57600/60000 (96%)]\tLoss: 0.005081\n", 1260 | "Train Epoch: 10 [58240/60000 (97%)]\tLoss: 0.172134\n", 1261 | "Train Epoch: 10 [58880/60000 (98%)]\tLoss: 0.006218\n", 1262 | "Train Epoch: 10 [59520/60000 (99%)]\tLoss: 0.006959\n", 1263 | "\n", 1264 | "Test set: Average loss: 0.0321, Accuracy: 9895/10000 (99%)\n", 1265 | "\n" 1266 | ] 1267 | } 1268 | ], 1269 | "source": [ 1270 | "model = Net().to(device)\n", 1271 | "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n", 1272 | "\n", 1273 | "for epoch in range(1, EPOCHS + 1):\n", 1274 | " train(model, device, train_loader, optimizer, epoch)\n", 1275 | " test(model, device, test_loader)\n", 1276 | "\n", 1277 | "torch.save(model.state_dict(),\"mnist_cnn.pt\")\n" 1278 | ] 1279 | }, 1280 | { 1281 | "cell_type": "code", 1282 | "execution_count": 0, 1283 | "metadata": { 1284 | "colab": {}, 1285 | "colab_type": "code", 1286 | "id": "S-f_NZgj4uZy" 1287 | }, 1288 | "outputs": [], 1289 | "source": [] 1290 | }, 1291 | { 1292 | "cell_type": "code", 1293 | "execution_count": 0, 1294 | "metadata": { 1295 | "colab": {}, 1296 | "colab_type": "code", 1297 | "id": "LTBjTbqF59Cl" 1298 | }, 1299 | "outputs": [], 1300 | "source": [] 1301 | }, 1302 | { 1303 | "cell_type": "code", 1304 | "execution_count": 0, 1305 | "metadata": { 1306 | "colab": {}, 1307 | "colab_type": "code", 1308 | "id": "AXpn-upI7SdT" 1309 | }, 1310 | "outputs": [], 1311 | "source": [ 1312 | "# # Download from COLAB\n", 1313 | "# from google.colab import files\n", 1314 | "# files.download('mnist_cnn.pt') \n" 1315 | ] 1316 | }, 1317 | { 1318 | "cell_type": "code", 1319 | "execution_count": null, 1320 | "metadata": {}, 1321 | "outputs": [], 1322 | "source": [] 1323 | }, 1324 | { 1325 | "cell_type": "markdown", 1326 | "metadata": {}, 1327 | "source": [ 1328 | "## Load Model" 1329 | ] 1330 | }, 1331 | { 1332 | "cell_type": "code", 1333 | "execution_count": 118, 1334 | "metadata": {}, 1335 | "outputs": [ 1336 | { 1337 | "data": { 1338 | "text/plain": [ 1339 | "Net(\n", 1340 | " (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", 1341 | " (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n", 1342 | " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", 1343 | " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", 1344 | ")" 1345 | ] 1346 | }, 1347 | "execution_count": 118, 1348 | "metadata": {}, 1349 | "output_type": "execute_result" 1350 | } 1351 | ], 1352 | "source": [ 1353 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 1354 | "model = Net()\n", 1355 | "model.eval()\n", 1356 | "model.load_state_dict(torch.load(\"mnist_cnn.pt\", map_location='cpu'))\n", 1357 | "model.to(device)" 1358 | ] 1359 | }, 1360 | { 1361 | "cell_type": "code", 1362 | "execution_count": 119, 1363 | "metadata": {}, 1364 | "outputs": [], 1365 | "source": [ 1366 | "import matplotlib.pyplot as plt\n", 1367 | "import numpy as np\n", 1368 | "%matplotlib inline" 1369 | ] 1370 | }, 1371 | { 1372 | "cell_type": "code", 1373 | "execution_count": 125, 1374 | "metadata": {}, 1375 | "outputs": [ 1376 | { 1377 | "name": "stdout", 1378 | "output_type": "stream", 1379 | "text": [ 1380 | "\u001b[92mPREDICTION : 9\u001b[0m\n" 1381 | ] 1382 | }, 1383 | { 1384 | "data": { 1385 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEGBJREFUeJzt3X+s1fV9x/HnS0tbq2wqRAaWelt0P4pkwAgxKSJqayhZgy2JkThh0Y0aS7fGLps/xoqStEZblRjbBecPxF8ltAaMitq7Tdc/7EBEBZzCWqyyC+gc0ysaFd7743zpbuGezzmcX99z+bweyc095/v+/nhzwut+v+d8zjkfRQRmlp+jym7AzMrh8JtlyuE3y5TDb5Yph98sUw6/WaYc/i4hqUdSSPpYlfrVkv6pgf1uljSj6QbtiCOP85dD0nbgLyLiZ8X9HuBXwLCI+Ki8zrqbpOOBpcCXi0U/jIjF5XU0dA16ljFrlqRPAJ+MiP9t8a5vBj4F9AAnAb2SXo2Iu1p8nCOeL/tLIGkF8BngYUn9kv52QPkiSb+W9KakawZss1jSvcXtT0q6V9J/S9ojaZ2kUVWOtV3SF4vbUyWtl/S2pF2SbmrjP3Mk8Jqk+yR9UVKr/q99BbghIvZGxHbgDuCSFu07Kw5/CSLiYuDXwFci4riIuGFAeRrwB8C5wD9I+qNBdjEf+F1gLDACuAx4r45DLwWWRsTvAOOAlY3/K9IiYgfw+8BzVM7Wv5J0naTPtWD3Ouj26S3YZ3Yc/u5zbUS8FxHPA88DfzzIOh9SCf2pEbEvIp6NiLfr2PeHwKmSRkZEf0Q808K+DxEROyPi+xExAfgacDzwjKR/lTTYv6sea4ErJQ2XdCqVs/6nWtRyVhz+7rNzwO29wHGDrLMCeBx4UNJ/SbpB0rA69n0plbPxfxRPFf60noaKEYP+4ufMYuThwP1/lPSZAff7q+xmK5U/ZtuAP6Tyh6ARf0XlKmcrsBp4AHi9wX1lzS/4lafhYZaI+BC4Fri2GCV4FHiZyvPf1HZbgbnF8++vAaskjYiId2tsN/6gRf8GfPegZYf8kZJ0NHAeMI/Kq/P/DHwPeKzREY2IeAu4aMAxvgv8eyP7yp3DX55dQEPPfyWdDbwJbAHepnI5v7+O7f4MeDwi3pC0p1hcc7sGezyJypl+J3A38M2IeLMF+x0H7Cl+zgMWAGc1u98c+bK/PN8D/r54tf5vDnPb3wNWUQn+S8BTVJ4K1DIT2Fxcmi8FLoyIel4obMReYGZETIqIpa0IfuFPgBeBd6g8hhdFxOYW7TsrfpOPWaZ85jfLlMNvlimH3yxTDr9Zpjo61CfJry6atVlEqPZaTZ75Jc2U9LKkbZKubGZfZtZZDQ/1Fe/eegX4EpW3V64D5kbElsQ2PvObtVknzvxTgW0R8cuI+AB4EJjdxP7MrIOaCf/JwGsD7r9eLPstkhYUnyFf38SxzKzF2v6CX0QsA5aBL/vNukkzZ/4dVL5M4oBPF8vMbAhoJvzrgNMkfVbSx4ELgTWtacvM2q3hy/6I+EjSQipfKnE0cKc/XWU2dHT0U31+zm/Wfh15k4+ZDV0Ov1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZepjzWwsaTvwDrAP+CgiprSiKTNrv6bCXzg7It5swX7MrIN82W+WqWbDH8ATkp6VtGCwFSQtkLRe0vomj2VmLaSIaHxj6eSI2CHpJOBJ4JsR8XRi/cYPZmZ1iQjVs15TZ/6I2FH83g08BExtZn9m1jkNh1/SsZKGH7gNnAdsalVjZtZezbzaPwp4SNKB/dwfEWtb0pWZtV1Tz/kP+2B+zm/Wdh15zm9mQ5fDb5Yph98sUw6/WaYcfrNMteKDPdbFJk6cmKwvWbIkWZ81a1ayftRR6fPH/v37q9ZWrVqV3Paaa65J1vv6+pL1s88+u2qtt7c3ue17772XrB8JfOY3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLlcf4hYNiwYcn6WWedVbV21113JbcdPXp0sl7rU5+pcfxa28+ZMye5ba2x9rFjxybrM2bMqFqbP39+ctt77703WT8S+MxvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2XK4/xDwOTJk5P1tWsb/8b0Wp+JX7hwYbK+d+/eho99yimnJOvvvvtusn7rrbcm6x988EHVWq1/dw585jfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuVx/i4wfvz4ZH3NmjUN77vW99NfddVVyfqGDRsaPnYtY8aMSdZXr16drB9//PHJ+o033li1VutxyUHNM7+kOyXtlrRpwLITJT0paWvx+4T2tmlmrVbPZf/dwMyDll0J9EbEaUBvcd/MhpCa4Y+Ip4G3Dlo8G1he3F4OnN/ivsyszRp9zj8qIg68OXonMKraipIWAAsaPI6ZtUnTL/hFREiq+i2NEbEMWAaQWs/MOqvRob5dkkYDFL93t64lM+uERsO/Bjjw3cfzgfSYjJl1nZqX/ZIeAGYAIyW9DnwHuB5YKelS4FXggnY2eaRbtGhRsj5y5Mhk/ZFHHqlau+KKK5Lbbtu2LVlvp9NPPz1ZnzRpUlP7b+Z7DnJQM/wRMbdK6dwW92JmHeS395plyuE3y5TDb5Yph98sUw6/WaZUawrmlh4s03f43X777cn6JZdckqzX+grrM844o2pty5YtyW3bLTW9+BNPPJHcdvr06cn6U089layfc845yfqRKiJUz3o+85tlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmfJXd3fAlClTkvVa77Xo7+9P1sscy0+N4wMsWbKkau3MM89MblvrcbnuuuuSdUvzmd8sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TH+S2pp6cnWb/88suT9VpfHZ7S19eXrG/cuLHhfZvP/GbZcvjNMuXwm2XK4TfLlMNvlimH3yxTDr9ZpjzO3wG1Pm8/YcKEZH3EiBHJ+nPPPXfYPdWr1vTgY8aMSdabmReit7c3Wd+zZ0/D+7Y6zvyS7pS0W9KmAcsWS9ohaWPxM6u9bZpZq9Vz2X83MHOQ5TdHxMTi59HWtmVm7VYz/BHxNPBWB3oxsw5q5gW/hZJeKJ4WnFBtJUkLJK2XtL6JY5lZizUa/h8B44CJQB/wg2orRsSyiJgSEelvsTSzjmoo/BGxKyL2RcR+4HZgamvbMrN2ayj8kkYPuPtVYFO1dc2sO6nWOKykB4AZwEhgF/Cd4v5EIIDtwNcjIv3h68q+Gh/0HcKOOeaYZH3lypXJ+qxZ6ZHUZsbSmzV79uxkfd68eVVrc+bMSW47bdq0ZP2ZZ55J1nMVEapnvZpv8omIuYMsvuOwOzKzruK395plyuE3y5TDb5Yph98sUw6/WaZqDvW19GCZDvU1a8aMGcl6rSnAUzZv3pysP/bYY8n6bbfdlqxfdtllVWuvvPJKctvp06cn62+88Uaynqt6h/p85jfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuVxfmvKvn37kvXU/6/7778/uW3q48BWncf5zSzJ4TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZ8hTdltTT09PU9v39/VVrt9xyS1P7tub4zG+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZarmOL+kscA9wCgqU3Ivi4ilkk4Efgz0UJmm+4KI+J/2tWplWLRoUVPbP/zww1VrGzZsaGrf1px6zvwfAd+OiM8DZwDfkPR54EqgNyJOA3qL+2Y2RNQMf0T0RcSG4vY7wEvAycBsYHmx2nLg/HY1aWatd1jP+SX1AJOAXwCjIqKvKO2k8rTAzIaIut/bL+k44CfAtyLiben/vyYsIqLa9/NJWgAsaLZRM2utus78koZRCf59EfHTYvEuSaOL+mhg92DbRsSyiJgSEY3PJmlmLVcz/Kqc4u8AXoqImwaU1gDzi9vzgdWtb8/M2qWey/4vABcDL0raWCy7GrgeWCnpUuBV4IL2tGjtNH78+GR9zpw5Te3/8ccfb2p7a5+a4Y+InwPVvgf83Na2Y2ad4nf4mWXK4TfLlMNvlimH3yxTDr9Zphx+s0z5q7szN3ny5GR9+PDhyXqtKd7ff//9w+7JOsNnfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUx7nz9zIkSOT9Vrj+Js3b07WV61addg9WWf4zG+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrj/JmbN29eU9uvWLGiRZ1Yp/nMb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlquY4v6SxwD3AKCCAZRGxVNJi4C+BN4pVr46IR9vVqLXHli1bkvUJEyZ0qBPrtHre5PMR8O2I2CBpOPCspCeL2s0R8f32tWdm7VIz/BHRB/QVt9+R9BJwcrsbM7P2Oqzn/JJ6gEnAL4pFCyW9IOlOSSdU2WaBpPWS1jfVqZm1VN3hl3Qc8BPgWxHxNvAjYBwwkcqVwQ8G2y4ilkXElIiY0oJ+zaxF6gq/pGFUgn9fRPwUICJ2RcS+iNgP3A5MbV+bZtZqNcMvScAdwEsRcdOA5aMHrPZVYFPr2zOzdqnn1f4vABcDL0raWCy7GpgraSKV4b/twNfb0qG11dq1a5P1cePGJevr1q1rZTvWQfW82v9zQIOUPKZvNoT5HX5mmXL4zTLl8JtlyuE3y5TDb5Yph98sU6o1BXNLDyZ17mBmmYqIwYbmD+Ezv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WqU5P0f0m8OqA+yOLZd2oW3vr1r7AvTWqlb2dUu+KHX2TzyEHl9Z363f7dWtv3doXuLdGldWbL/vNMuXwm2Wq7PAvK/n4Kd3aW7f2Be6tUaX0VupzfjMrT9lnfjMricNvlqlSwi9ppqSXJW2TdGUZPVQjabukFyVtLHt+wWIOxN2SNg1YdqKkJyVtLX4POkdiSb0tlrSjeOw2SppVUm9jJf2LpC2SNkv662J5qY9doq9SHreOP+eXdDTwCvAl4HVgHTA3ItITxXeIpO3AlIgo/Q0hkqYD/cA9EXF6sewG4K2IuL74w3lCRPxdl/S2GOgve9r2Yjap0QOnlQfOB/6cEh+7RF8XUMLjVsaZfyqwLSJ+GREfAA8Cs0voo+tFxNPAWwctng0sL24vp/Kfp+Oq9NYVIqIvIjYUt98BDkwrX+pjl+irFGWE/2TgtQH3X6fEB2AQATwh6VlJC8puZhCjIqKvuL0TGFVmM4OoOW17Jx00rXzXPHaNTHffan7B71DTImIy8GXgG8XlbVeKynO2bhqrrWva9k4ZZFr53yjzsWt0uvtWKyP8O4CxA+5/uljWFSJiR/F7N/AQ3Tf1+K4DMyQXv3eX3M9vdNO07YNNK08XPHbdNN19GeFfB5wm6bOSPg5cCKwpoY9DSDq2eCEGSccC59F9U4+vAeYXt+cDq0vs5bd0y7Tt1aaVp+THruumu4+Ijv8As6i84v+fwDVl9FClr88Bzxc/m8vuDXiAymXgh1ReG7kUGAH0AluBnwEndlFvK4AXgReoBG10Sb1No3JJ/wKwsfiZVfZjl+irlMfNb+81y5Rf8DPLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMvV/JCo6RFyYRPwAAAAASUVORK5CYII=\n", 1386 | "text/plain": [ 1387 | "
" 1388 | ] 1389 | }, 1390 | "metadata": {}, 1391 | "output_type": "display_data" 1392 | } 1393 | ], 1394 | "source": [ 1395 | "ind = 12\n", 1396 | "\n", 1397 | "image = test_loader.dataset[ind][0].numpy().reshape(28,28)\n", 1398 | "lbl = test_loader.dataset[ind][1].numpy()\n", 1399 | "plt.title('this is ---> ' + str(lbl))\n", 1400 | "plt.imshow(image, cmap='gray')\n", 1401 | "\n", 1402 | "\n", 1403 | "image_tensor, label_tensor = test_loader.dataset[ind]\n", 1404 | "image_tensor = image_tensor.reshape(1,1,28,28)\n", 1405 | "image_tensor, label_tensor = image_tensor.to(device), label_tensor.to(device)\n", 1406 | "\n", 1407 | "prediction = model(image_tensor)\n", 1408 | "prediction = np.argmax(prediction.detach().numpy())\n", 1409 | "print (\"\\033[92m\" + \"PREDICTION : \" + str(prediction) + \"\\033[0m\")" 1410 | ] 1411 | }, 1412 | { 1413 | "cell_type": "markdown", 1414 | "metadata": {}, 1415 | "source": [ 1416 | "-----------" 1417 | ] 1418 | }, 1419 | { 1420 | "cell_type": "code", 1421 | "execution_count": null, 1422 | "metadata": {}, 1423 | "outputs": [], 1424 | "source": [] 1425 | }, 1426 | { 1427 | "cell_type": "code", 1428 | "execution_count": null, 1429 | "metadata": {}, 1430 | "outputs": [], 1431 | "source": [] 1432 | } 1433 | ], 1434 | "metadata": { 1435 | "accelerator": "GPU", 1436 | "colab": { 1437 | "collapsed_sections": [], 1438 | "name": "mnist.ipynb", 1439 | "provenance": [], 1440 | "version": "0.3.2" 1441 | }, 1442 | "kernelspec": { 1443 | "display_name": "Python 3", 1444 | "language": "python", 1445 | "name": "python3" 1446 | }, 1447 | "language_info": { 1448 | "codemirror_mode": { 1449 | "name": "ipython", 1450 | "version": 3 1451 | }, 1452 | "file_extension": ".py", 1453 | "mimetype": "text/x-python", 1454 | "name": "python", 1455 | "nbconvert_exporter": "python", 1456 | "pygments_lexer": "ipython3", 1457 | "version": "3.6.5" 1458 | } 1459 | }, 1460 | "nbformat": 4, 1461 | "nbformat_minor": 1 1462 | } 1463 | -------------------------------------------------------------------------------- /mnist_arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | import math 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | NUM_CLASSES = 10 24 | 25 | BATCH_SIZE_TEST = 1000 26 | EPOCHS = 20 27 | LOG_INTERVAL = 10 28 | 29 | class Net(nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | krnl_sz=3 33 | strd = 1 34 | 35 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 37 | self.prelu1_1 = nn.PReLU() 38 | self.prelu1_2 = nn.PReLU() 39 | 40 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 42 | self.prelu2_1 = nn.PReLU() 43 | self.prelu2_2 = nn.PReLU() 44 | 45 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 47 | self.prelu3_1 = nn.PReLU() 48 | self.prelu3_2 = nn.PReLU() 49 | 50 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 51 | 52 | self.fc1 = nn.Linear(3*3*512, 3) 53 | # self.fc2 = nn.Linear(3, 2) 54 | self.fc3 = nn.Linear(3, 10) 55 | 56 | def forward(self, x): 57 | mp_ks=2 58 | mp_strd=2 59 | 60 | x = self.prelu1_1(self.conv1(x)) 61 | x = self.prelu1_2(self.conv2(x)) 62 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 63 | 64 | x = self.prelu2_1(self.conv3(x)) 65 | x = self.prelu2_2(self.conv4(x)) 66 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 67 | 68 | x = self.prelu3_1(self.conv5(x)) 69 | x = self.prelu3_2(self.conv6(x)) 70 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 71 | 72 | x = x.view(-1, 3*3*512) # Flatten 73 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 74 | x = self.fc3(features3d) 75 | 76 | return features3d, x 77 | 78 | 79 | # s >= C-1/C * log( (C-1)*P) / (1-P) ) 80 | # s >= 10-1/10 * log( (10-1)*0.8) / (1-0.8) )= 0.9 * log(36) = 0.9 * 3.583 = 3.22516 81 | class Arcface_loss(nn.Module): 82 | r"""Implement of large margin arc distance: : 83 | Args: 84 | in_features: size of each input sample 85 | out_features: size of each output sample 86 | s: norm of input feature 87 | m: margin 88 | 89 | cos(theta + m) 90 | """ 91 | # def __init__(self, num_classes=NUM_CLASSES, feat_dim=FEATURES_DIM, s=7.00, m=0.2): 92 | def __init__(self, feat_dim, num_classes, device, s=3.3, m=0.2, easy_margin=False): 93 | super(Arcface_loss, self).__init__() 94 | self.feat_dim = feat_dim 95 | self.num_classes = num_classes 96 | self.s = s 97 | self.m = m 98 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, feat_dim)) 99 | nn.init.xavier_uniform_(self.weight) 100 | 101 | self.easy_margin = easy_margin 102 | self.cos_m = math.cos(m) 103 | self.sin_m = math.sin(m) 104 | self.th = math.cos(math.pi - m) 105 | self.mm = math.sin(math.pi - m) * m 106 | 107 | self.device = device 108 | 109 | def forward(self, input, label): 110 | # --------------------------- cos(theta) & phi(theta) --------------------------- 111 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 112 | cosine = cosine.clamp(-1,1) # for numerical stability 113 | 114 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 115 | phi = cosine * self.cos_m - sine * self.sin_m 116 | if self.easy_margin: 117 | phi = torch.where(cosine > 0, phi, cosine) 118 | else: 119 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 120 | # --------------------------- convert label to one-hot --------------------------- 121 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 122 | one_hot = torch.zeros(cosine.size()).to(device) 123 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 124 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 125 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 126 | output *= self.s 127 | # print(output) 128 | 129 | return output 130 | 131 | 132 | # class Arcface_loss(nn.Module): 133 | # def __init__(self, num_classes=NUM_CLASSES, feat_dim=FEATURES_DIM, s=7.00, m=0.2): 134 | 135 | # super(Arcface_loss, self).__init__() 136 | # self.num_classes = num_classes 137 | # self.kernel = nn.Parameter(torch.Tensor(feat_dim, num_classes)) 138 | # # initial kernel 139 | # self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) 140 | # self.m = m # the margin value, default is 0.5 141 | # self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369 142 | # self.cos_m = math.cos(m) 143 | # self.sin_m = math.sin(m) 144 | # self.mm = self.sin_m * m # issue 1 145 | # self.threshold = math.cos(math.pi - m) 146 | # def forward(self, embbedings, label): 147 | # # weights norm 148 | # nB = len(embbedings) 149 | # kernel_norm = l2_norm(self.kernel,axis=0) 150 | # # cos(theta+m) 151 | # cos_theta = torch.mm(embbedings,kernel_norm) 152 | # # output = torch.mm(embbedings,kernel_norm) 153 | # cos_theta = cos_theta.clamp(-1,1) # for numerical stability 154 | # cos_theta_2 = torch.pow(cos_theta, 2) 155 | # sin_theta_2 = 1 - cos_theta_2 156 | # sin_theta = torch.sqrt(sin_theta_2) 157 | # cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m) 158 | # # this condition controls the theta+m should in range [0, pi] 159 | # # 0<=theta+m<=pi 160 | # # -m<=theta<=pi-m 161 | # cond_v = cos_theta - self.threshold 162 | # cond_mask = cond_v <= 0 163 | # keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead 164 | # cos_theta_m[cond_mask] = keep_val[cond_mask] 165 | # output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta 166 | # idx_ = torch.arange(0, nB, dtype=torch.long) 167 | # output[idx_, label] = cos_theta_m[idx_, label] 168 | # output *= self.s # scale up in order to make softmax work, first introduced in normface 169 | # return output 170 | # def l2_norm(input,axis=1): 171 | # norm = torch.norm(input,2,axis,True) 172 | # output = torch.div(input, norm) 173 | # return output 174 | 175 | 176 | 177 | def train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch): 178 | model.train() 179 | for batch_idx, (data, target) in enumerate(train_loader): 180 | data, target = data.to(device), target.to(device) 181 | 182 | # optimizer.zero_grad() 183 | # output,_,_ = model(data) 184 | 185 | # loss = loss_function(output, target) 186 | 187 | # loss.backward() 188 | # optimizer.step() 189 | 190 | features, _ = model(data) 191 | logits = loss_arcface(features, target) 192 | loss = loss_softmax(logits, target) 193 | 194 | _, predicted = torch.max(logits.data, 1) 195 | accuracy = (target.data == predicted).float().mean() 196 | 197 | optimizer_nn.zero_grad() 198 | optimzer_arcface.zero_grad() 199 | 200 | loss.backward() 201 | 202 | optimizer_nn.step() 203 | optimzer_arcface.step() 204 | 205 | if batch_idx % LOG_INTERVAL == 0: 206 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 207 | epoch, batch_idx * len(data), len(train_loader.dataset), 208 | 100. * batch_idx / len(train_loader), loss.item())) 209 | 210 | 211 | def test(model, device, test_loader, loss_softmax, loss_arcface): 212 | model.eval() 213 | # test_loss = 0 214 | correct = 0 215 | total = 0 216 | with torch.no_grad(): 217 | for data, target in test_loader: 218 | data, target = data.to(device), target.to(device) 219 | 220 | feats, _ = model(data) 221 | logits = loss_arcface(feats, target) 222 | _, predicted = torch.max(logits.data, 1) 223 | total += target.size(0) 224 | correct += (predicted == target.data).sum() 225 | 226 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 227 | 228 | 229 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 230 | correct, len(test_loader.dataset), 231 | 100. * correct / len(test_loader.dataset))) 232 | 233 | # output,_,_ = model(data) 234 | 235 | # test_loss += loss_function(output, target).item() # sum up batch loss 236 | 237 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 238 | # correct += pred.eq(target.view_as(pred)).sum().item() 239 | 240 | # test_loss /= len(test_loader.dataset) 241 | 242 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 243 | # test_loss, correct, len(test_loader.dataset), 244 | # 100. * correct / len(test_loader.dataset))) 245 | 246 | ################################################################### 247 | 248 | torch.manual_seed(1) 249 | device = torch.device("cuda" if use_cuda else "cpu") 250 | 251 | ####### Data setup 252 | 253 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 254 | train_loader = torch.utils.data.DataLoader( 255 | datasets.MNIST('./data', train=True, download=True, 256 | transform=transforms.Compose([ 257 | transforms.ToTensor(), 258 | transforms.Normalize((0.1307,), (0.3081,)) 259 | ])), 260 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 261 | test_loader = torch.utils.data.DataLoader( 262 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 263 | transforms.ToTensor(), 264 | transforms.Normalize((0.1307,), (0.3081,)) 265 | ])), 266 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 267 | 268 | ####### Model setup 269 | 270 | model = Net().to(device) 271 | loss_softmax = nn.CrossEntropyLoss().to(device) 272 | loss_arcface = Arcface_loss(num_classes=NUM_CLASSES, feat_dim=FEATURES_DIM, device=device).to(device) 273 | 274 | # optimzer nn 275 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 276 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.5) 277 | 278 | # optimzer cosface or arcface 279 | optimzer_arcface = optim.SGD(loss_arcface.parameters(), lr=0.01) 280 | sheduler_arcface = lr_scheduler.StepLR(optimzer_arcface, 20, gamma=0.5) 281 | 282 | 283 | for epoch in range(1, EPOCHS + 1): 284 | sheduler_nn.step() 285 | sheduler_arcface.step() 286 | 287 | train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch) 288 | test(model, device, test_loader, loss_softmax, loss_arcface) 289 | 290 | torch.save(model.state_dict(),"mnist_cnn-arcface.pt") 291 | torch.save(loss_arcface.state_dict(),"mnist_loss-arcface.pt") 292 | -------------------------------------------------------------------------------- /mnist_arcface2_fc7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | import math 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | NUM_OF_CLASSES = 10 24 | 25 | BATCH_SIZE_TEST = 1000 26 | EPOCHS = 20 27 | LOG_INTERVAL = 10 28 | 29 | class Net(nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | krnl_sz=3 33 | strd = 1 34 | 35 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 37 | self.prelu1_1 = nn.PReLU() 38 | self.prelu1_2 = nn.PReLU() 39 | 40 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 42 | self.prelu2_1 = nn.PReLU() 43 | self.prelu2_2 = nn.PReLU() 44 | 45 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 47 | self.prelu3_1 = nn.PReLU() 48 | self.prelu3_2 = nn.PReLU() 49 | 50 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 51 | 52 | self.fc1 = nn.Linear(3*3*512, 3) 53 | # self.fc2 = nn.Linear(3, 2) 54 | self.fc3 = nn.Linear(3, 10) 55 | 56 | def forward(self, x): 57 | mp_ks=2 58 | mp_strd=2 59 | 60 | x = self.prelu1_1(self.conv1(x)) 61 | x = self.prelu1_2(self.conv2(x)) 62 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 63 | 64 | x = self.prelu2_1(self.conv3(x)) 65 | x = self.prelu2_2(self.conv4(x)) 66 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 67 | 68 | x = self.prelu3_1(self.conv5(x)) 69 | x = self.prelu3_2(self.conv6(x)) 70 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 71 | 72 | x = x.view(-1, 3*3*512) # Flatten 73 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 74 | x = self.fc3(features3d) 75 | 76 | return features3d, x 77 | 78 | class LMCL_loss(nn.Module): 79 | 80 | def __init__(self, num_classes, feat_dim, device, s=7.0, m=0.2): 81 | super(LMCL_loss, self).__init__() 82 | self.feat_dim = feat_dim 83 | self.num_classes = num_classes 84 | self.s = s 85 | self.m = m 86 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 87 | self.device = device 88 | 89 | self.cos_m = math.cos(m) 90 | self.sin_m = math.sin(m) 91 | self.mm = math.sin(math.pi-m)*m 92 | self.threshold = math.cos(math.pi-m) 93 | 94 | def forward(self, feat, label, easy_margin=False): 95 | batch_size = feat.shape[0] 96 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 97 | feat_l2norm = torch.div(feat, norms) 98 | feat_l2norm = feat_l2norm * self.s 99 | 100 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 101 | weights_l2norm = torch.div(self.weights, norms_w) 102 | 103 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 104 | 105 | # y_onehot = torch.FloatTensor(batch_size, self.num_classes).to(self.device) 106 | # y_onehot.zero_() 107 | # y_onehot = Variable(y_onehot) 108 | # y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.s_m) 109 | # output = fc7 - y_onehot 110 | 111 | 112 | # zy = mx.sym.pick(fc7, gt_label, axis=1) 113 | label = label.cpu() 114 | fc7 = fc7.cpu() 115 | 116 | target_one_hot = torch.zeros(len(label), NUM_OF_CLASSES).scatter_(1, label.unsqueeze(1), 1.) 117 | zy = torch.addcmul(torch.zeros(fc7.size()), 1., fc7, target_one_hot) 118 | # bp() 119 | zy = zy.sum(-1) 120 | 121 | cos_t = zy/self.s 122 | # cos_m = math.cos(self.m) 123 | # sin_m = math.sin(m) 124 | # mm = math.sin(math.pi-m)*m 125 | # threshold = math.cos(math.pi-m) 126 | if easy_margin: 127 | cond = F.relu(cos_t) 128 | else: 129 | cond_v = cos_t - self.threshold 130 | cond = F.relu(cond_v) 131 | 132 | 133 | body = cos_t*cos_t 134 | body = 1.0-body 135 | sin_t = torch.sqrt(body) 136 | new_zy = cos_t*self.cos_m 137 | b = sin_t*self.sin_m 138 | new_zy = new_zy - b 139 | new_zy = new_zy*self.s 140 | if easy_margin: 141 | zy_keep = zy 142 | else: 143 | zy_keep = zy - self.s*self.mm 144 | 145 | # bp() 146 | new_zy = torch.where(cond.byte(), new_zy, zy_keep) 147 | 148 | diff = new_zy - zy 149 | # diff = mx.sym.expand_dims(diff, 1) 150 | diff = diff.unsqueeze(1) 151 | 152 | # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) 153 | # body = mx.sym.broadcast_mul(gt_one_hot, diff) 154 | body = torch.addcmul(torch.zeros(diff.size()), 1., diff, target_one_hot) 155 | 156 | output = fc7+body 157 | 158 | 159 | return output.to(self.device) 160 | 161 | 162 | # def loss_function(output, target): 163 | # return F.nll_loss(F.log_softmax(output, dim=1), target) 164 | 165 | 166 | def train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch): 167 | model.train() 168 | for batch_idx, (data, target) in enumerate(train_loader): 169 | data, target = data.to(device), target.to(device) 170 | 171 | # optimizer.zero_grad() 172 | # output,_,_ = model(data) 173 | 174 | # loss = loss_function(output, target) 175 | 176 | # loss.backward() 177 | # optimizer.step() 178 | 179 | features, _ = model(data) 180 | logits = loss_lmcl(features, target) 181 | loss = loss_softmax(logits, target) 182 | 183 | _, predicted = torch.max(logits.data, 1) 184 | accuracy = (target.data == predicted).float().mean() 185 | 186 | optimizer_nn.zero_grad() 187 | optimzer_lmcl.zero_grad() 188 | 189 | loss.backward() 190 | 191 | optimizer_nn.step() 192 | optimzer_lmcl.step() 193 | 194 | if batch_idx % LOG_INTERVAL == 0: 195 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 196 | epoch, batch_idx * len(data), len(train_loader.dataset), 197 | 100. * batch_idx / len(train_loader), loss.item())) 198 | 199 | 200 | def test(model, device, test_loader, loss_softmax, loss_lmcl): 201 | model.eval() 202 | # test_loss = 0 203 | correct = 0 204 | total = 0 205 | with torch.no_grad(): 206 | for data, target in test_loader: 207 | data, target = data.to(device), target.to(device) 208 | 209 | feats, _ = model(data) 210 | logits = loss_lmcl(feats, target) 211 | _, predicted = torch.max(logits.data, 1) 212 | total += target.size(0) 213 | correct += (predicted == target.data).sum() 214 | 215 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 216 | 217 | 218 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 219 | correct, len(test_loader.dataset), 220 | 100. * correct / len(test_loader.dataset))) 221 | 222 | # output,_,_ = model(data) 223 | 224 | # test_loss += loss_function(output, target).item() # sum up batch loss 225 | 226 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 227 | # correct += pred.eq(target.view_as(pred)).sum().item() 228 | 229 | # test_loss /= len(test_loader.dataset) 230 | 231 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 232 | # test_loss, correct, len(test_loader.dataset), 233 | # 100. * correct / len(test_loader.dataset))) 234 | 235 | ################################################################### 236 | 237 | torch.manual_seed(1) 238 | device = torch.device("cuda" if use_cuda else "cpu") 239 | 240 | ####### Data setup 241 | 242 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 243 | train_loader = torch.utils.data.DataLoader( 244 | datasets.MNIST('./data', train=True, download=True, 245 | transform=transforms.Compose([ 246 | transforms.ToTensor(), 247 | transforms.Normalize((0.1307,), (0.3081,)) 248 | ])), 249 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 250 | test_loader = torch.utils.data.DataLoader( 251 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 252 | transforms.ToTensor(), 253 | transforms.Normalize((0.1307,), (0.3081,)) 254 | ])), 255 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 256 | 257 | ####### Model setup 258 | 259 | model = Net().to(device) 260 | loss_softmax = nn.CrossEntropyLoss().to(device) 261 | loss_lmcl = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 262 | 263 | # optimzer nn 264 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 265 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.5) 266 | 267 | # optimzer cosface or lmcl 268 | optimzer_lmcl = optim.SGD(loss_lmcl.parameters(), lr=0.01) 269 | sheduler_lmcl = lr_scheduler.StepLR(optimzer_lmcl, 20, gamma=0.5) 270 | 271 | 272 | for epoch in range(1, EPOCHS + 1): 273 | sheduler_nn.step() 274 | sheduler_lmcl.step() 275 | 276 | train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch) 277 | test(model, device, test_loader, loss_softmax, loss_lmcl) 278 | 279 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 280 | torch.save(loss_lmcl.state_dict(),"mnist_loss-cosface.pt") 281 | -------------------------------------------------------------------------------- /mnist_arcface3_fc7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | import math 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | NUM_OF_CLASSES = 10 24 | 25 | BATCH_SIZE_TEST = 1000 26 | EPOCHS = 20 27 | LOG_INTERVAL = 10 28 | 29 | class Net(nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | krnl_sz=3 33 | strd = 1 34 | 35 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 37 | self.prelu1_1 = nn.PReLU() 38 | self.prelu1_2 = nn.PReLU() 39 | 40 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 42 | self.prelu2_1 = nn.PReLU() 43 | self.prelu2_2 = nn.PReLU() 44 | 45 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 47 | self.prelu3_1 = nn.PReLU() 48 | self.prelu3_2 = nn.PReLU() 49 | 50 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 51 | 52 | self.fc1 = nn.Linear(3*3*512, 3) 53 | # self.fc2 = nn.Linear(3, 2) 54 | self.fc3 = nn.Linear(3, 10) 55 | 56 | def forward(self, x): 57 | mp_ks=2 58 | mp_strd=2 59 | 60 | x = self.prelu1_1(self.conv1(x)) 61 | x = self.prelu1_2(self.conv2(x)) 62 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 63 | 64 | x = self.prelu2_1(self.conv3(x)) 65 | x = self.prelu2_2(self.conv4(x)) 66 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 67 | 68 | x = self.prelu3_1(self.conv5(x)) 69 | x = self.prelu3_2(self.conv6(x)) 70 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 71 | 72 | x = x.view(-1, 3*3*512) # Flatten 73 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 74 | x = self.fc3(features3d) 75 | 76 | return features3d, x 77 | 78 | class LMCL_loss(nn.Module): 79 | 80 | def __init__(self, num_classes, feat_dim, device, s=7.0, m=0.2): 81 | super(LMCL_loss, self).__init__() 82 | self.feat_dim = feat_dim 83 | self.num_classes = num_classes 84 | self.s = s 85 | self.m = m 86 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 87 | self.device = device 88 | 89 | self.cos_m = math.cos(m) 90 | self.sin_m = math.sin(m) 91 | self.mm = math.sin(math.pi-m)*m 92 | self.threshold = math.cos(math.pi-m) 93 | 94 | def forward(self, feat, label, easy_margin=False): 95 | eps = 1e-4 96 | batch_size = feat.shape[0] 97 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 98 | feat_l2norm = torch.div(feat, norms) 99 | feat_l2norm = feat_l2norm.clamp(min=-1+eps, max=1-eps) # for numerical stability 100 | feat_l2norm = feat_l2norm * self.s 101 | 102 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 103 | weights_l2norm = torch.div(self.weights, norms_w) 104 | weights_l2norm = weights_l2norm.clamp(min=-1+eps, max=1-eps) # for numerical stability 105 | 106 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 107 | 108 | # zy = mx.sym.pick(fc7, gt_label, axis=1) 109 | label = label.cpu() 110 | fc7 = fc7.cpu() 111 | 112 | target_one_hot = torch.zeros(len(label), NUM_OF_CLASSES).scatter_(1, label.unsqueeze(1), 1.) 113 | zy = torch.addcmul(torch.zeros(fc7.size()), 1., fc7, target_one_hot) 114 | zy = zy.sum(-1) 115 | 116 | cos_t = zy/self.s 117 | cos_t = cos_t.clamp(min=-1+eps, max=1-eps) # for numerical stability 118 | 119 | t = torch.acos(cos_t) 120 | # t = t.clamp(min=-1+eps, max=1-eps) # for numerical stability 121 | t = t+self.m 122 | 123 | body = torch.cos(t) 124 | # body = body.clamp(min=-1+eps, max=1-eps) # for numerical stability 125 | new_zy = body*self.s 126 | 127 | 128 | diff = new_zy - zy 129 | # diff = mx.sym.expand_dims(diff, 1) 130 | diff = diff.unsqueeze(1) 131 | 132 | # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) 133 | # body = mx.sym.broadcast_mul(gt_one_hot, diff) 134 | body = torch.addcmul(torch.zeros(diff.size()), 1., diff, target_one_hot) 135 | # body = body.clamp(min=-1+eps, max=1-eps) # for numerical stability 136 | 137 | output = fc7+body 138 | 139 | return output.to(self.device) 140 | 141 | 142 | # def loss_function(output, target): 143 | # return F.nll_loss(F.log_softmax(output, dim=1), target) 144 | 145 | 146 | def train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch): 147 | model.train() 148 | for batch_idx, (data, target) in enumerate(train_loader): 149 | data, target = data.to(device), target.to(device) 150 | 151 | # optimizer.zero_grad() 152 | # output,_,_ = model(data) 153 | 154 | # loss = loss_function(output, target) 155 | 156 | # loss.backward() 157 | # optimizer.step() 158 | 159 | features, _ = model(data) 160 | logits = loss_lmcl(features, target) 161 | loss = loss_softmax(logits, target) 162 | 163 | _, predicted = torch.max(logits.data, 1) 164 | accuracy = (target.data == predicted).float().mean() 165 | 166 | optimizer_nn.zero_grad() 167 | optimzer_lmcl.zero_grad() 168 | 169 | loss.backward() 170 | 171 | optimizer_nn.step() 172 | optimzer_lmcl.step() 173 | 174 | if batch_idx % LOG_INTERVAL == 0: 175 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 176 | epoch, batch_idx * len(data), len(train_loader.dataset), 177 | 100. * batch_idx / len(train_loader), loss.item())) 178 | 179 | 180 | def test(model, device, test_loader, loss_softmax, loss_lmcl): 181 | model.eval() 182 | # test_loss = 0 183 | correct = 0 184 | total = 0 185 | with torch.no_grad(): 186 | for data, target in test_loader: 187 | data, target = data.to(device), target.to(device) 188 | 189 | feats, _ = model(data) 190 | logits = loss_lmcl(feats, target) 191 | _, predicted = torch.max(logits.data, 1) 192 | total += target.size(0) 193 | correct += (predicted == target.data).sum() 194 | 195 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 196 | 197 | 198 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 199 | correct, len(test_loader.dataset), 200 | 100. * correct / len(test_loader.dataset))) 201 | 202 | # output,_,_ = model(data) 203 | 204 | # test_loss += loss_function(output, target).item() # sum up batch loss 205 | 206 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 207 | # correct += pred.eq(target.view_as(pred)).sum().item() 208 | 209 | # test_loss /= len(test_loader.dataset) 210 | 211 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 212 | # test_loss, correct, len(test_loader.dataset), 213 | # 100. * correct / len(test_loader.dataset))) 214 | 215 | ################################################################### 216 | 217 | torch.manual_seed(1) 218 | device = torch.device("cuda" if use_cuda else "cpu") 219 | 220 | ####### Data setup 221 | 222 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 223 | train_loader = torch.utils.data.DataLoader( 224 | datasets.MNIST('./data', train=True, download=True, 225 | transform=transforms.Compose([ 226 | transforms.ToTensor(), 227 | transforms.Normalize((0.1307,), (0.3081,)) 228 | ])), 229 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 230 | test_loader = torch.utils.data.DataLoader( 231 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 232 | transforms.ToTensor(), 233 | transforms.Normalize((0.1307,), (0.3081,)) 234 | ])), 235 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 236 | 237 | ####### Model setup 238 | 239 | model = Net().to(device) 240 | loss_softmax = nn.CrossEntropyLoss().to(device) 241 | loss_lmcl = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 242 | 243 | # optimzer nn 244 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 245 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.5) 246 | 247 | # optimzer cosface or lmcl 248 | optimzer_lmcl = optim.SGD(loss_lmcl.parameters(), lr=0.01) 249 | sheduler_lmcl = lr_scheduler.StepLR(optimzer_lmcl, 20, gamma=0.5) 250 | 251 | 252 | for epoch in range(1, EPOCHS + 1): 253 | sheduler_nn.step() 254 | sheduler_lmcl.step() 255 | 256 | train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch) 257 | test(model, device, test_loader, loss_softmax, loss_lmcl) 258 | 259 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 260 | torch.save(loss_lmcl.state_dict(),"mnist_loss-cosface.pt") 261 | -------------------------------------------------------------------------------- /mnist_arcface4_fc7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | import math 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | NUM_OF_CLASSES = 10 24 | 25 | BATCH_SIZE_TEST = 1000 26 | EPOCHS = 20 27 | LOG_INTERVAL = 10 28 | 29 | class Net(nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | krnl_sz=3 33 | strd = 1 34 | 35 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 37 | self.prelu1_1 = nn.PReLU() 38 | self.prelu1_2 = nn.PReLU() 39 | 40 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 42 | self.prelu2_1 = nn.PReLU() 43 | self.prelu2_2 = nn.PReLU() 44 | 45 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 47 | self.prelu3_1 = nn.PReLU() 48 | self.prelu3_2 = nn.PReLU() 49 | 50 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 51 | 52 | self.fc1 = nn.Linear(3*3*512, 3) 53 | # self.fc2 = nn.Linear(3, 2) 54 | self.fc3 = nn.Linear(3, 10) 55 | 56 | def forward(self, x): 57 | mp_ks=2 58 | mp_strd=2 59 | 60 | x = self.prelu1_1(self.conv1(x)) 61 | x = self.prelu1_2(self.conv2(x)) 62 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 63 | 64 | x = self.prelu2_1(self.conv3(x)) 65 | x = self.prelu2_2(self.conv4(x)) 66 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 67 | 68 | x = self.prelu3_1(self.conv5(x)) 69 | x = self.prelu3_2(self.conv6(x)) 70 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 71 | 72 | x = x.view(-1, 3*3*512) # Flatten 73 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 74 | x = self.fc3(features3d) 75 | 76 | return features3d, x 77 | 78 | class Arcface_loss(nn.Module): 79 | 80 | def __init__(self, num_classes, feat_dim, device, s=7.0, m=0.2): 81 | super(Arcface_loss, self).__init__() 82 | self.feat_dim = feat_dim 83 | self.num_classes = num_classes 84 | self.s = s 85 | self.m = m 86 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 87 | self.device = device 88 | 89 | self.cos_m = math.cos(m) 90 | self.sin_m = math.sin(m) 91 | self.mm = math.sin(math.pi-m)*m 92 | self.threshold = math.cos(math.pi-m) 93 | 94 | def forward(self, feat, label, easy_margin=False): 95 | eps = 1e-4 96 | batch_size = feat.shape[0] 97 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 98 | feat_l2norm = torch.div(feat, norms) 99 | feat_l2norm = feat_l2norm.clamp(min=-1+eps, max=1-eps) # for numerical stability 100 | feat_l2norm = feat_l2norm * self.s 101 | 102 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 103 | weights_l2norm = torch.div(self.weights, norms_w) 104 | weights_l2norm = weights_l2norm.clamp(min=-1+eps, max=1-eps) # for numerical stability 105 | 106 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 107 | 108 | # zy = mx.sym.pick(fc7, gt_label, axis=1) 109 | label = label.cpu() 110 | fc7 = fc7.cpu() 111 | 112 | target_one_hot = torch.zeros(len(label), NUM_OF_CLASSES).scatter_(1, label.unsqueeze(1), 1.) 113 | zy = torch.addcmul(torch.zeros(fc7.size()), 1., fc7, target_one_hot) 114 | zy = zy.sum(-1) 115 | 116 | cos_t = zy/self.s 117 | cos_t = cos_t.clamp(min=-1+eps, max=1-eps) # for numerical stability 118 | 119 | t = torch.acos(cos_t) 120 | t = t+self.m 121 | 122 | body = torch.cos(t) 123 | new_zy = body*self.s 124 | 125 | diff = new_zy - zy 126 | # diff = mx.sym.expand_dims(diff, 1) 127 | diff = diff.unsqueeze(1) 128 | 129 | # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) 130 | # body = mx.sym.broadcast_mul(gt_one_hot, diff) 131 | body = torch.addcmul(torch.zeros(diff.size()), 1., diff, target_one_hot) 132 | 133 | output = fc7+body 134 | 135 | return output.to(self.device) 136 | 137 | 138 | # def loss_function(output, target): 139 | # return F.nll_loss(F.log_softmax(output, dim=1), target) 140 | 141 | 142 | def train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch): 143 | model.train() 144 | for batch_idx, (data, target) in enumerate(train_loader): 145 | data, target = data.to(device), target.to(device) 146 | 147 | # optimizer.zero_grad() 148 | # output,_,_ = model(data) 149 | 150 | # loss = loss_function(output, target) 151 | 152 | # loss.backward() 153 | # optimizer.step() 154 | 155 | features, _ = model(data) 156 | logits = loss_arcface(features, target) 157 | loss = loss_softmax(logits, target) 158 | 159 | _, predicted = torch.max(logits.data, 1) 160 | accuracy = (target.data == predicted).float().mean() 161 | 162 | optimizer_nn.zero_grad() 163 | optimzer_arcface.zero_grad() 164 | 165 | loss.backward() 166 | 167 | optimizer_nn.step() 168 | optimzer_arcface.step() 169 | 170 | if batch_idx % LOG_INTERVAL == 0: 171 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 172 | epoch, batch_idx * len(data), len(train_loader.dataset), 173 | 100. * batch_idx / len(train_loader), loss.item())) 174 | 175 | 176 | def test(model, device, test_loader, loss_softmax, loss_arcface): 177 | model.eval() 178 | # test_loss = 0 179 | correct = 0 180 | total = 0 181 | with torch.no_grad(): 182 | for data, target in test_loader: 183 | data, target = data.to(device), target.to(device) 184 | 185 | feats, _ = model(data) 186 | logits = loss_arcface(feats, target) 187 | _, predicted = torch.max(logits.data, 1) 188 | total += target.size(0) 189 | correct += (predicted == target.data).sum() 190 | 191 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 192 | 193 | 194 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 195 | correct, len(test_loader.dataset), 196 | 100. * correct / len(test_loader.dataset))) 197 | 198 | # output,_,_ = model(data) 199 | 200 | # test_loss += loss_function(output, target).item() # sum up batch loss 201 | 202 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 203 | # correct += pred.eq(target.view_as(pred)).sum().item() 204 | 205 | # test_loss /= len(test_loader.dataset) 206 | 207 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 208 | # test_loss, correct, len(test_loader.dataset), 209 | # 100. * correct / len(test_loader.dataset))) 210 | 211 | ################################################################### 212 | 213 | torch.manual_seed(1) 214 | device = torch.device("cuda" if use_cuda else "cpu") 215 | 216 | ####### Data setup 217 | 218 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 219 | train_loader = torch.utils.data.DataLoader( 220 | datasets.MNIST('./data', train=True, download=True, 221 | transform=transforms.Compose([ 222 | transforms.ToTensor(), 223 | transforms.Normalize((0.1307,), (0.3081,)) 224 | ])), 225 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 226 | test_loader = torch.utils.data.DataLoader( 227 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 228 | transforms.ToTensor(), 229 | transforms.Normalize((0.1307,), (0.3081,)) 230 | ])), 231 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 232 | 233 | ####### Model setup 234 | 235 | model = Net().to(device) 236 | loss_softmax = nn.CrossEntropyLoss().to(device) 237 | loss_arcface = Arcface_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 238 | 239 | # optimzer nn 240 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 241 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.1) 242 | 243 | # optimzer cosface or arcface 244 | optimzer_arcface = optim.SGD(loss_arcface.parameters(), lr=0.01) 245 | sheduler_arcface = lr_scheduler.StepLR(optimzer_arcface, 20, gamma=0.1) 246 | 247 | 248 | for epoch in range(1, EPOCHS + 1): 249 | sheduler_nn.step() 250 | sheduler_arcface.step() 251 | 252 | train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch) 253 | test(model, device, test_loader, loss_softmax, loss_arcface) 254 | 255 | torch.save(model.state_dict(),"mnist_cnn-arcface.pt") 256 | torch.save(loss_arcface.state_dict(),"mnist_loss-arcface.pt") 257 | -------------------------------------------------------------------------------- /mnist_arcface5_fc7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from torch.autograd.function import Function 16 | import math 17 | 18 | from pdb import set_trace as bp 19 | 20 | BATCH_SIZE = 100 21 | FEATURES_DIM = 3 22 | NUM_OF_CLASSES = 10 23 | BATCH_SIZE_TEST = 1000 24 | EPOCHS = 20 25 | LOG_INTERVAL = 10 26 | 27 | class Net(nn.Module): 28 | def __init__(self): 29 | super(Net, self).__init__() 30 | krnl_sz=3 31 | strd = 1 32 | 33 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 34 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 35 | self.prelu1_1 = nn.PReLU() 36 | self.prelu1_2 = nn.PReLU() 37 | 38 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 39 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 40 | self.prelu2_1 = nn.PReLU() 41 | self.prelu2_2 = nn.PReLU() 42 | 43 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 44 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 45 | self.prelu3_1 = nn.PReLU() 46 | self.prelu3_2 = nn.PReLU() 47 | 48 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 49 | 50 | self.fc1 = nn.Linear(3*3*512, 3) 51 | self.fc3 = nn.Linear(3, 10) 52 | 53 | def forward(self, x): 54 | mp_ks=2 55 | mp_strd=2 56 | 57 | x = self.prelu1_1(self.conv1(x)) 58 | x = self.prelu1_2(self.conv2(x)) 59 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 60 | 61 | x = self.prelu2_1(self.conv3(x)) 62 | x = self.prelu2_2(self.conv4(x)) 63 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 64 | 65 | x = self.prelu3_1(self.conv5(x)) 66 | x = self.prelu3_2(self.conv6(x)) 67 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 68 | 69 | x = x.view(-1, 3*3*512) # Flatten 70 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 71 | x = self.fc3(features3d) 72 | 73 | return features3d, x 74 | 75 | class Arcface_loss(nn.Module): 76 | 77 | def __init__(self, num_classes, feat_dim, device, s=7.0, m=0.2): 78 | super(Arcface_loss, self).__init__() 79 | self.feat_dim = feat_dim 80 | self.num_classes = num_classes 81 | self.s = s 82 | self.m = m 83 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 84 | self.device = device 85 | 86 | self.cos_m = math.cos(m) 87 | self.sin_m = math.sin(m) 88 | self.mm = math.sin(math.pi-m)*m 89 | self.threshold = math.cos(math.pi-m) 90 | 91 | def forward(self, feat, label): 92 | eps = 1e-4 93 | batch_size = feat.shape[0] 94 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 95 | feat_l2norm = torch.div(feat, norms) 96 | # feat_l2norm = feat_l2norm.clamp(min=-1+eps, max=1-eps) # for numerical stability 97 | feat_l2norm = feat_l2norm * self.s 98 | 99 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 100 | weights_l2norm = torch.div(self.weights, norms_w) 101 | # weights_l2norm = weights_l2norm.clamp(min=-1+eps, max=1-eps) # for numerical stability 102 | 103 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 104 | 105 | if torch.cuda.is_available(): 106 | label = label.cuda() 107 | fc7 = fc7.cuda() 108 | else: 109 | label = label.cpu() 110 | fc7 = fc7.cpu() 111 | 112 | ## zy = mx.sym.pick(fc7, gt_label, axis=1) 113 | # label = label.cpu() 114 | # fc7 = fc7.cpu() 115 | 116 | target_one_hot = torch.zeros(len(label), NUM_OF_CLASSES).to(self.device) 117 | target_one_hot = target_one_hot.scatter_(1, label.unsqueeze(1), 1.) 118 | zy = torch.addcmul(torch.zeros(fc7.size()).to(self.device), 1., fc7, target_one_hot) 119 | zy = zy.sum(-1) 120 | 121 | cos_t = zy/self.s 122 | cos_t = cos_t.clamp(min=-1+eps, max=1-eps) # for numerical stability 123 | 124 | t = torch.acos(cos_t) 125 | t = t+self.m 126 | 127 | body = torch.cos(t) 128 | new_zy = body*self.s 129 | 130 | diff = new_zy - zy 131 | # diff = mx.sym.expand_dims(diff, 1) 132 | diff = diff.unsqueeze(1) 133 | 134 | # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) 135 | # body = mx.sym.broadcast_mul(gt_one_hot, diff) 136 | body = torch.addcmul(torch.zeros(diff.size()).to(self.device), 1., diff, target_one_hot) 137 | 138 | output = fc7+body 139 | 140 | return output.to(self.device) 141 | 142 | 143 | def train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch): 144 | model.train() 145 | for batch_idx, (data, target) in enumerate(train_loader): 146 | data, target = data.to(device), target.to(device) 147 | 148 | features, _ = model(data) 149 | logits = loss_arcface(features, target) 150 | loss = loss_softmax(logits, target) 151 | 152 | _, predicted = torch.max(logits.data, 1) 153 | accuracy = (target.data == predicted).float().mean() 154 | 155 | optimizer_nn.zero_grad() 156 | optimzer_arcface.zero_grad() 157 | 158 | loss.backward() 159 | 160 | optimizer_nn.step() 161 | optimzer_arcface.step() 162 | 163 | if batch_idx % LOG_INTERVAL == 0: 164 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 165 | epoch, batch_idx * len(data), len(train_loader.dataset), 166 | 100. * batch_idx / len(train_loader), loss.item())) 167 | 168 | 169 | def test(model, device, test_loader, loss_softmax, loss_arcface): 170 | model.eval() 171 | correct = 0 172 | total = 0 173 | with torch.no_grad(): 174 | for data, target in test_loader: 175 | data, target = data.to(device), target.to(device) 176 | 177 | feats, _ = model(data) 178 | logits = loss_arcface(feats, target) 179 | _, predicted = torch.max(logits.data, 1) 180 | total += target.size(0) 181 | correct += (predicted == target.data).sum() 182 | 183 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 184 | correct, len(test_loader.dataset), 185 | 100. * correct / len(test_loader.dataset))) 186 | 187 | ################################################################### 188 | 189 | torch.manual_seed(1) 190 | device = torch.device("cuda" if use_cuda else "cpu") 191 | 192 | ####### Data setup 193 | 194 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 195 | train_loader = torch.utils.data.DataLoader( 196 | datasets.MNIST('./data', train=True, download=True, 197 | transform=transforms.Compose([ 198 | transforms.ToTensor(), 199 | transforms.Normalize((0.1307,), (0.3081,)) 200 | ])), 201 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 202 | test_loader = torch.utils.data.DataLoader( 203 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 204 | transforms.ToTensor(), 205 | transforms.Normalize((0.1307,), (0.3081,)) 206 | ])), 207 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 208 | 209 | ####### Model setup 210 | 211 | model = Net().to(device) 212 | loss_softmax = nn.CrossEntropyLoss().to(device) 213 | loss_arcface = Arcface_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 214 | 215 | # optimzer nn 216 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 217 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.1) 218 | 219 | # optimzer cosface or arcface 220 | optimzer_arcface = optim.SGD(loss_arcface.parameters(), lr=0.01) 221 | sheduler_arcface = lr_scheduler.StepLR(optimzer_arcface, 20, gamma=0.1) 222 | 223 | 224 | for epoch in range(1, EPOCHS + 1): 225 | sheduler_nn.step() 226 | sheduler_arcface.step() 227 | 228 | train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch) 229 | test(model, device, test_loader, loss_softmax, loss_arcface) 230 | 231 | torch.save(model.state_dict(),"mnist_cnn-arcface.pt") 232 | torch.save(loss_arcface.state_dict(),"mnist_loss-arcface.pt") 233 | -------------------------------------------------------------------------------- /mnist_arcface6_fc7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from torch.autograd.function import Function 16 | import math 17 | 18 | from pdb import set_trace as bp 19 | 20 | BATCH_SIZE = 100 21 | FEATURES_DIM = 3 22 | NUM_OF_CLASSES = 10 23 | BATCH_SIZE_TEST = 1000 24 | EPOCHS = 20 25 | LOG_INTERVAL = 10 26 | 27 | class Net(nn.Module): 28 | def __init__(self): 29 | super(Net, self).__init__() 30 | krnl_sz=3 31 | strd = 1 32 | 33 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 34 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 35 | self.prelu1_1 = nn.PReLU() 36 | self.prelu1_2 = nn.PReLU() 37 | 38 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 39 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 40 | self.prelu2_1 = nn.PReLU() 41 | self.prelu2_2 = nn.PReLU() 42 | 43 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 44 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 45 | self.prelu3_1 = nn.PReLU() 46 | self.prelu3_2 = nn.PReLU() 47 | 48 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 49 | 50 | self.fc1 = nn.Linear(3*3*512, 3) 51 | self.fc3 = nn.Linear(3, 10) 52 | 53 | def forward(self, x): 54 | mp_ks=2 55 | mp_strd=2 56 | 57 | x = self.prelu1_1(self.conv1(x)) 58 | x = self.prelu1_2(self.conv2(x)) 59 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 60 | 61 | x = self.prelu2_1(self.conv3(x)) 62 | x = self.prelu2_2(self.conv4(x)) 63 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 64 | 65 | x = self.prelu3_1(self.conv5(x)) 66 | x = self.prelu3_2(self.conv6(x)) 67 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 68 | 69 | x = x.view(-1, 3*3*512) # Flatten 70 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 71 | x = self.fc3(features3d) 72 | 73 | return features3d, x 74 | 75 | class Arcface_loss(nn.Module): 76 | def __init__(self, num_classes, feat_dim, device, s=7.0, m=0.2): 77 | super(Arcface_loss, self).__init__() 78 | self.feat_dim = feat_dim 79 | self.num_classes = num_classes 80 | self.s = s 81 | self.m = m 82 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 83 | self.device = device 84 | 85 | self.cos_m = math.cos(m) 86 | self.sin_m = math.sin(m) 87 | self.mm = math.sin(math.pi-m)*m 88 | self.threshold = math.cos(math.pi-m) 89 | 90 | def forward(self, feat, label): 91 | eps = 1e-4 92 | batch_size = feat.shape[0] 93 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 94 | feat_l2norm = torch.div(feat, norms) 95 | feat_l2norm = feat_l2norm * self.s 96 | 97 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 98 | weights_l2norm = torch.div(self.weights, norms_w) 99 | 100 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 101 | 102 | if torch.cuda.is_available(): 103 | label = label.cuda() 104 | fc7 = fc7.cuda() 105 | else: 106 | label = label.cpu() 107 | fc7 = fc7.cpu() 108 | 109 | target_one_hot = torch.zeros(len(label), NUM_OF_CLASSES).to(self.device) 110 | target_one_hot = target_one_hot.scatter_(1, label.unsqueeze(1), 1.) 111 | zy = torch.addcmul(torch.zeros(fc7.size()).to(self.device), 1., fc7, target_one_hot) 112 | zy = zy.sum(-1) 113 | 114 | cos_theta = zy/self.s 115 | cos_theta = cos_theta.clamp(min=-1+eps, max=1-eps) # for numerical stability 116 | 117 | theta = torch.acos(cos_theta) 118 | theta = theta+self.m 119 | 120 | body = torch.cos(theta) 121 | new_zy = body*self.s 122 | 123 | diff = new_zy - zy 124 | diff = diff.unsqueeze(1) 125 | 126 | body = torch.addcmul(torch.zeros(diff.size()).to(self.device), 1., diff, target_one_hot) 127 | output = fc7+body 128 | 129 | return output.to(self.device) 130 | 131 | 132 | def train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch): 133 | model.train() 134 | for batch_idx, (data, target) in enumerate(train_loader): 135 | data, target = data.to(device), target.to(device) 136 | 137 | features, _ = model(data) 138 | logits = loss_arcface(features, target) 139 | loss = loss_softmax(logits, target) 140 | 141 | _, predicted = torch.max(logits.data, 1) 142 | accuracy = (target.data == predicted).float().mean() 143 | 144 | optimizer_nn.zero_grad() 145 | optimzer_arcface.zero_grad() 146 | 147 | loss.backward() 148 | 149 | optimizer_nn.step() 150 | optimzer_arcface.step() 151 | 152 | if batch_idx % LOG_INTERVAL == 0: 153 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 154 | epoch, batch_idx * len(data), len(train_loader.dataset), 155 | 100. * batch_idx / len(train_loader), loss.item())) 156 | 157 | 158 | def test(model, device, test_loader, loss_softmax, loss_arcface): 159 | model.eval() 160 | correct = 0 161 | total = 0 162 | with torch.no_grad(): 163 | for data, target in test_loader: 164 | data, target = data.to(device), target.to(device) 165 | 166 | feats, _ = model(data) 167 | logits = loss_arcface(feats, target) 168 | _, predicted = torch.max(logits.data, 1) 169 | total += target.size(0) 170 | correct += (predicted == target.data).sum() 171 | 172 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 173 | correct, len(test_loader.dataset), 174 | 100. * correct / len(test_loader.dataset))) 175 | 176 | ################################################################### 177 | 178 | torch.manual_seed(1) 179 | device = torch.device("cuda" if use_cuda else "cpu") 180 | 181 | ####### Data setup 182 | 183 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 184 | train_loader = torch.utils.data.DataLoader( 185 | datasets.MNIST('./data', train=True, download=True, 186 | transform=transforms.Compose([ 187 | transforms.ToTensor(), 188 | transforms.Normalize((0.1307,), (0.3081,)) 189 | ])), 190 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 191 | test_loader = torch.utils.data.DataLoader( 192 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 193 | transforms.ToTensor(), 194 | transforms.Normalize((0.1307,), (0.3081,)) 195 | ])), 196 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 197 | 198 | ####### Model setup 199 | 200 | model = Net().to(device) 201 | loss_softmax = nn.CrossEntropyLoss().to(device) 202 | loss_arcface = Arcface_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 203 | 204 | # optimzer nn 205 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 206 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.1) 207 | 208 | # optimzer cosface or arcface 209 | optimzer_arcface = optim.SGD(loss_arcface.parameters(), lr=0.01) 210 | sheduler_arcface = lr_scheduler.StepLR(optimzer_arcface, 20, gamma=0.1) 211 | 212 | 213 | for epoch in range(1, EPOCHS + 1): 214 | sheduler_nn.step() 215 | sheduler_arcface.step() 216 | 217 | train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch) 218 | test(model, device, test_loader, loss_softmax, loss_arcface) 219 | 220 | torch.save(model.state_dict(),"mnist_cnn-arcface.pt") 221 | torch.save(loss_arcface.state_dict(),"mnist_loss-arcface.pt") 222 | -------------------------------------------------------------------------------- /mnist_cnn-cosface.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/egcode/pytorch-losses/03e7cd11900e3678a34fe9d1d32964993ab2a926/mnist_cnn-cosface.pt -------------------------------------------------------------------------------- /mnist_cosface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torchvision import datasets, transforms 5 | from torch.autograd import Variable 6 | from torchvision import datasets 7 | from torch.utils.data import DataLoader 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | 10 | from pdb import set_trace as bp 11 | 12 | from torch.autograd.function import Function 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | BATCH_SIZE = 100 17 | FEATURES_DIM = 3 18 | 19 | # class Net(nn.Module): 20 | # def __init__(self): 21 | # super(Net, self).__init__() 22 | # self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2) 23 | # self.prelu1_1 = nn.PReLU() 24 | # self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2) 25 | # self.prelu1_2 = nn.PReLU() 26 | # self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2) 27 | # self.prelu2_1 = nn.PReLU() 28 | # self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2) 29 | # self.prelu2_2 = nn.PReLU() 30 | # self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2) 31 | # self.prelu3_1 = nn.PReLU() 32 | # self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2) 33 | # self.prelu3_2 = nn.PReLU() 34 | # self.preluip1 = nn.PReLU() 35 | # self.ip1 = nn.Linear(128 * 3 * 3, FEATURES_DIM) 36 | # self.ip2 = nn.Linear(FEATURES_DIM, 10) 37 | 38 | # def forward(self, x): 39 | # x = self.prelu1_1(self.conv1_1(x)) 40 | # x = self.prelu1_2(self.conv1_2(x)) 41 | # x = F.max_pool2d(x, 2) 42 | # x = self.prelu2_1(self.conv2_1(x)) 43 | # x = self.prelu2_2(self.conv2_2(x)) 44 | # x = F.max_pool2d(x, 2) 45 | # x = self.prelu3_1(self.conv3_1(x)) 46 | # x = self.prelu3_2(self.conv3_2(x)) 47 | # x = F.max_pool2d(x, 2) 48 | # x = x.view(-1, 128 * 3 * 3) 49 | # ip1 = self.preluip1(self.ip1(x)) 50 | # ip2 = self.ip2(ip1) 51 | # return ip1, ip2 52 | 53 | class Net(nn.Module): 54 | def __init__(self): 55 | super(Net, self).__init__() 56 | krnl_sz=3 57 | strd = 1 58 | 59 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 60 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 61 | self.prelu1_1 = nn.PReLU() 62 | self.prelu1_2 = nn.PReLU() 63 | 64 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 65 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 66 | self.prelu2_1 = nn.PReLU() 67 | self.prelu2_2 = nn.PReLU() 68 | 69 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 70 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 71 | self.prelu3_1 = nn.PReLU() 72 | self.prelu3_2 = nn.PReLU() 73 | 74 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 75 | 76 | self.fc1 = nn.Linear(3*3*512, 3) 77 | # self.fc2 = nn.Linear(3, 2) 78 | self.fc3 = nn.Linear(3, 10) 79 | 80 | def forward(self, x): 81 | mp_ks=2 82 | mp_strd=2 83 | 84 | x = self.prelu1_1(self.conv1(x)) 85 | x = self.prelu1_2(self.conv2(x)) 86 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 87 | 88 | x = self.prelu2_1(self.conv3(x)) 89 | x = self.prelu2_2(self.conv4(x)) 90 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 91 | 92 | x = self.prelu3_1(self.conv5(x)) 93 | x = self.prelu3_2(self.conv6(x)) 94 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 95 | 96 | x = x.view(-1, 3*3*512) # Flatten 97 | 98 | features3d = self.fc1(x) 99 | # features2d = self.fc2(features3d) 100 | x = F.prelu(features3d, self.prelu_weight) 101 | 102 | x = self.fc3(x) 103 | 104 | return features3d, x 105 | 106 | 107 | 108 | class LMCL_loss(nn.Module): 109 | """ 110 | Refer to paper: 111 | Hao Wang, Yitong Wang, Zheng Zhou, Xing Ji, Dihong Gong, Jingchao Zhou,Zhifeng Li, and Wei Liu 112 | CosFace: Large Margin Cosine Loss for Deep Face Recognition. CVPR2018 113 | re-implement by yirong mao 114 | 2018 07/02 115 | """ 116 | 117 | def __init__(self, num_classes, feat_dim, s=7.00, m=0.2): 118 | super(LMCL_loss, self).__init__() 119 | self.feat_dim = feat_dim 120 | self.num_classes = num_classes 121 | self.s = s 122 | self.m = m 123 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 124 | 125 | def forward(self, feat, label): 126 | batch_size = feat.shape[0] 127 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 128 | nfeat = torch.div(feat, norms) 129 | 130 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 131 | ncenters = torch.div(self.centers, norms_c) 132 | logits = torch.matmul(nfeat.cpu(), torch.transpose(ncenters, 0, 1)) 133 | 134 | y_onehot = torch.FloatTensor(batch_size, self.num_classes) 135 | y_onehot.zero_() 136 | y_onehot = Variable(y_onehot).cpu() 137 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 138 | margin_logits = self.s * (logits - y_onehot) 139 | 140 | return logits, margin_logits 141 | 142 | 143 | 144 | # def visualize(feat, labels, epoch): 145 | # plt.ion() 146 | # c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', 147 | # '#ff00ff', '#990000', '#999900', '#009900', '#009999'] 148 | # plt.clf() 149 | # for i in range(10): 150 | # plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i]) 151 | # plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right') 152 | # # plt.xlim(xmin=-5,xmax=5) 153 | # # plt.ylim(ymin=-5,ymax=5) 154 | # plt.text(-4.8, 4.6, "epoch=%d" % epoch) 155 | # plt.savefig('./images/LMCL_loss_u_epoch=%d.jpg' % epoch) 156 | # # plt.draw() 157 | # # plt.pause(0.001) 158 | # plt.close() 159 | 160 | def test(test_loder, criterion, model, use_cuda): 161 | correct = 0 162 | total = 0 163 | for i, (data, target) in enumerate(test_loder): 164 | if use_cuda: 165 | data = data.cuda() 166 | target = target.cuda() 167 | data, target = Variable(data), Variable(target).cpu() 168 | 169 | feats, _ = model(data) 170 | logits, mlogits = criterion[1](feats, target) 171 | _, predicted = torch.max(logits.data, 1) 172 | total += target.size(0) 173 | correct += (predicted == target.data).sum() 174 | 175 | print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 176 | 177 | 178 | def train(train_loader, model, criterion, optimizer, epoch, loss_weight, use_cuda): 179 | ip1_loader = [] 180 | idx_loader = [] 181 | for i, (data, target) in enumerate(train_loader): 182 | if use_cuda: 183 | data = data.cuda() 184 | target = target.cuda() 185 | data, target = Variable(data), Variable(target).cpu() 186 | 187 | feats, _ = model(data) 188 | logits, mlogits = criterion[1](feats, target) 189 | # cross_entropy = criterion[0](logits, target) 190 | loss = criterion[0](mlogits, target) 191 | 192 | _, predicted = torch.max(logits.data, 1) 193 | accuracy = (target.data == predicted).float().mean() 194 | 195 | optimizer[0].zero_grad() 196 | optimizer[1].zero_grad() 197 | 198 | loss.backward() 199 | 200 | optimizer[0].step() 201 | optimizer[1].step() 202 | 203 | ip1_loader.append(feats) 204 | idx_loader.append((target)) 205 | if (i + 1) % 50 == 0: 206 | print('Epoch [%d], Iter [%d/%d] Loss: %.4f Acc %.4f' 207 | % (epoch, i + 1, len(train_loader), loss.data[0], accuracy)) 208 | 209 | feat = torch.cat(ip1_loader, 0) 210 | labels = torch.cat(idx_loader, 0) 211 | # visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch) 212 | 213 | 214 | if torch.cuda.is_available(): 215 | use_cuda = True 216 | else: 217 | use_cuda = False 218 | # Dataset 219 | trainset = datasets.MNIST('./data/', download=True, train=True, transform=transforms.Compose([ 220 | transforms.ToTensor(), 221 | transforms.Normalize((0.1307,), (0.3081,))])) 222 | train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 223 | 224 | testset = datasets.MNIST('./data/', download=True, train=False, transform=transforms.Compose([ 225 | transforms.ToTensor(), 226 | transforms.Normalize((0.1307,), (0.3081,))])) 227 | test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 228 | 229 | # Model 230 | model = Net() 231 | 232 | # NLLLoss 233 | nllloss = nn.CrossEntropyLoss() 234 | # CenterLoss 235 | loss_weight = 0.1 236 | lmcl_loss = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM) 237 | if use_cuda: 238 | nllloss = nllloss.cuda() 239 | # coco_loss = lmcl_loss.cuda() 240 | model = model.cuda() 241 | criterion = [nllloss, lmcl_loss] 242 | # optimzer4nn 243 | optimizer4nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 244 | sheduler_4nn = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5) 245 | 246 | # optimzer4center 247 | optimzer4center = optim.SGD(lmcl_loss.parameters(), lr=0.01) 248 | sheduler_4center = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5) 249 | for epoch in range(20): 250 | sheduler_4nn.step() 251 | sheduler_4center.step() 252 | # print optimizer4nn.param_groups[0]['lr'] 253 | train(train_loader, model, criterion, [optimizer4nn, optimzer4center], epoch + 1, loss_weight, use_cuda) 254 | test(test_loader, criterion, model, use_cuda) 255 | 256 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 257 | torch.save(lmcl_loss.state_dict(),"mnist_loss-cosface.pt") 258 | 259 | -------------------------------------------------------------------------------- /mnist_cosface2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torchvision import datasets, transforms 5 | from torch.autograd import Variable 6 | from torchvision import datasets 7 | from torch.utils.data import DataLoader 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | 10 | from pdb import set_trace as bp 11 | 12 | from torch.autograd.function import Function 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | BATCH_SIZE = 100 17 | FEATURES_DIM = 3 18 | 19 | # class Net(nn.Module): 20 | # def __init__(self): 21 | # super(Net, self).__init__() 22 | # self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2) 23 | # self.prelu1_1 = nn.PReLU() 24 | # self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2) 25 | # self.prelu1_2 = nn.PReLU() 26 | # self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2) 27 | # self.prelu2_1 = nn.PReLU() 28 | # self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2) 29 | # self.prelu2_2 = nn.PReLU() 30 | # self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2) 31 | # self.prelu3_1 = nn.PReLU() 32 | # self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2) 33 | # self.prelu3_2 = nn.PReLU() 34 | # self.preluip1 = nn.PReLU() 35 | # self.ip1 = nn.Linear(128 * 3 * 3, FEATURES_DIM) 36 | # self.ip2 = nn.Linear(FEATURES_DIM, 10) 37 | 38 | # def forward(self, x): 39 | # x = self.prelu1_1(self.conv1_1(x)) 40 | # x = self.prelu1_2(self.conv1_2(x)) 41 | # x = F.max_pool2d(x, 2) 42 | # x = self.prelu2_1(self.conv2_1(x)) 43 | # x = self.prelu2_2(self.conv2_2(x)) 44 | # x = F.max_pool2d(x, 2) 45 | # x = self.prelu3_1(self.conv3_1(x)) 46 | # x = self.prelu3_2(self.conv3_2(x)) 47 | # x = F.max_pool2d(x, 2) 48 | # x = x.view(-1, 128 * 3 * 3) 49 | # ip1 = self.preluip1(self.ip1(x)) 50 | # ip2 = self.ip2(ip1) 51 | # return ip1, ip2 52 | 53 | class Net(nn.Module): 54 | def __init__(self): 55 | super(Net, self).__init__() 56 | krnl_sz=3 57 | strd = 1 58 | 59 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 60 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 61 | self.prelu1_1 = nn.PReLU() 62 | self.prelu1_2 = nn.PReLU() 63 | 64 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 65 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 66 | self.prelu2_1 = nn.PReLU() 67 | self.prelu2_2 = nn.PReLU() 68 | 69 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 70 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 71 | self.prelu3_1 = nn.PReLU() 72 | self.prelu3_2 = nn.PReLU() 73 | 74 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 75 | 76 | self.fc1 = nn.Linear(3*3*512, 3) 77 | # self.fc2 = nn.Linear(3, 2) 78 | self.fc3 = nn.Linear(3, 10) 79 | 80 | def forward(self, x): 81 | mp_ks=2 82 | mp_strd=2 83 | 84 | x = self.prelu1_1(self.conv1(x)) 85 | x = self.prelu1_2(self.conv2(x)) 86 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 87 | 88 | x = self.prelu2_1(self.conv3(x)) 89 | x = self.prelu2_2(self.conv4(x)) 90 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 91 | 92 | x = self.prelu3_1(self.conv5(x)) 93 | x = self.prelu3_2(self.conv6(x)) 94 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 95 | 96 | x = x.view(-1, 3*3*512) # Flatten 97 | 98 | features3d = self.fc1(x) 99 | # features2d = self.fc2(features3d) 100 | x = F.prelu(features3d, self.prelu_weight) 101 | 102 | x = self.fc3(x) 103 | 104 | return features3d, x 105 | 106 | 107 | 108 | class LMCL_loss(nn.Module): 109 | """ 110 | Refer to paper: 111 | Hao Wang, Yitong Wang, Zheng Zhou, Xing Ji, Dihong Gong, Jingchao Zhou,Zhifeng Li, and Wei Liu 112 | CosFace: Large Margin Cosine Loss for Deep Face Recognition. CVPR2018 113 | re-implement by yirong mao 114 | 2018 07/02 115 | """ 116 | 117 | def __init__(self, num_classes, feat_dim, s=7.00, m=0.2): 118 | super(LMCL_loss, self).__init__() 119 | self.feat_dim = feat_dim 120 | self.num_classes = num_classes 121 | self.s = s 122 | self.m = m 123 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 124 | 125 | def forward(self, feat, label): 126 | batch_size = feat.shape[0] 127 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 128 | nfeat = torch.div(feat, norms) 129 | 130 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 131 | ncenters = torch.div(self.centers, norms_c) 132 | logits = torch.matmul(nfeat.cpu(), torch.transpose(ncenters, 0, 1)) 133 | 134 | y_onehot = torch.FloatTensor(batch_size, self.num_classes) 135 | y_onehot.zero_() 136 | y_onehot = Variable(y_onehot).cpu() 137 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 138 | margin_logits = self.s * (logits - y_onehot) 139 | 140 | return logits, margin_logits 141 | 142 | 143 | 144 | # def visualize(feat, labels, epoch): 145 | # plt.ion() 146 | # c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', 147 | # '#ff00ff', '#990000', '#999900', '#009900', '#009999'] 148 | # plt.clf() 149 | # for i in range(10): 150 | # plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i]) 151 | # plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right') 152 | # # plt.xlim(xmin=-5,xmax=5) 153 | # # plt.ylim(ymin=-5,ymax=5) 154 | # plt.text(-4.8, 4.6, "epoch=%d" % epoch) 155 | # plt.savefig('./images/LMCL_loss_u_epoch=%d.jpg' % epoch) 156 | # # plt.draw() 157 | # # plt.pause(0.001) 158 | # plt.close() 159 | 160 | def test(test_loder, criterion, model, use_cuda): 161 | correct = 0 162 | total = 0 163 | for i, (data, target) in enumerate(test_loder): 164 | if use_cuda: 165 | data = data.cuda() 166 | target = target.cuda() 167 | data, target = Variable(data), Variable(target).cpu() 168 | 169 | feats, _ = model(data) 170 | logits, mlogits = criterion[1](feats, target) 171 | _, predicted = torch.max(logits.data, 1) 172 | total += target.size(0) 173 | correct += (predicted == target.data).sum() 174 | 175 | print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 176 | 177 | 178 | def train(train_loader, model, criterion, optimizer, epoch, loss_weight, use_cuda): 179 | ip1_loader = [] 180 | idx_loader = [] 181 | for i, (data, target) in enumerate(train_loader): 182 | if use_cuda: 183 | data = data.cuda() 184 | target = target.cuda() 185 | data, target = Variable(data), Variable(target).cpu() 186 | 187 | feats, _ = model(data) 188 | logits, mlogits = criterion[1](feats, target) 189 | # cross_entropy = criterion[0](logits, target) 190 | loss = criterion[0](mlogits, target) 191 | 192 | _, predicted = torch.max(logits.data, 1) 193 | accuracy = (target.data == predicted).float().mean() 194 | 195 | optimizer[0].zero_grad() 196 | optimizer[1].zero_grad() 197 | 198 | loss.backward() 199 | 200 | optimizer[0].step() 201 | optimizer[1].step() 202 | 203 | ip1_loader.append(feats) 204 | idx_loader.append((target)) 205 | if (i + 1) % 50 == 0: 206 | print('Epoch [%d], Iter [%d/%d] Loss: %.4f Acc %.4f' 207 | % (epoch, i + 1, len(train_loader), loss.data[0], accuracy)) 208 | 209 | feat = torch.cat(ip1_loader, 0) 210 | labels = torch.cat(idx_loader, 0) 211 | # visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch) 212 | 213 | 214 | if torch.cuda.is_available(): 215 | use_cuda = True 216 | else: 217 | use_cuda = False 218 | # Dataset 219 | trainset = datasets.MNIST('./data/', download=True, train=True, transform=transforms.Compose([ 220 | transforms.ToTensor(), 221 | transforms.Normalize((0.1307,), (0.3081,))])) 222 | train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 223 | 224 | testset = datasets.MNIST('./data/', download=True, train=False, transform=transforms.Compose([ 225 | transforms.ToTensor(), 226 | transforms.Normalize((0.1307,), (0.3081,))])) 227 | test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 228 | 229 | # Model 230 | model = Net() 231 | 232 | # NLLLoss 233 | nllloss = nn.CrossEntropyLoss() 234 | # CenterLoss 235 | loss_weight = 0.1 236 | lmcl_loss = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM) 237 | if use_cuda: 238 | nllloss = nllloss.cuda() 239 | # coco_loss = lmcl_loss.cuda() 240 | model = model.cuda() 241 | criterion = [nllloss, lmcl_loss] 242 | # optimzer4nn 243 | optimizer4nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 244 | sheduler_4nn = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5) 245 | 246 | # optimzer4center 247 | optimzer4center = optim.SGD(lmcl_loss.parameters(), lr=0.01) 248 | sheduler_4center = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5) 249 | for epoch in range(20): 250 | sheduler_4nn.step() 251 | sheduler_4center.step() 252 | # print optimizer4nn.param_groups[0]['lr'] 253 | train(train_loader, model, criterion, [optimizer4nn, optimzer4center], epoch + 1, loss_weight, use_cuda) 254 | test(test_loader, criterion, model, use_cuda) 255 | 256 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 257 | torch.save(lmcl_loss.state_dict(),"mnist_loss-cosface.pt") 258 | 259 | -------------------------------------------------------------------------------- /mnist_cosface3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | 24 | BATCH_SIZE_TEST = 1000 25 | EPOCHS = 20 26 | LOG_INTERVAL = 10 27 | 28 | class Net(nn.Module): 29 | def __init__(self): 30 | super(Net, self).__init__() 31 | krnl_sz=3 32 | strd = 1 33 | 34 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 35 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.prelu1_1 = nn.PReLU() 37 | self.prelu1_2 = nn.PReLU() 38 | 39 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 40 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.prelu2_1 = nn.PReLU() 42 | self.prelu2_2 = nn.PReLU() 43 | 44 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 45 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.prelu3_1 = nn.PReLU() 47 | self.prelu3_2 = nn.PReLU() 48 | 49 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 50 | 51 | self.fc1 = nn.Linear(3*3*512, 3) 52 | # self.fc2 = nn.Linear(3, 2) 53 | self.fc3 = nn.Linear(3, 10) 54 | 55 | def forward(self, x): 56 | mp_ks=2 57 | mp_strd=2 58 | 59 | x = self.prelu1_1(self.conv1(x)) 60 | x = self.prelu1_2(self.conv2(x)) 61 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 62 | 63 | x = self.prelu2_1(self.conv3(x)) 64 | x = self.prelu2_2(self.conv4(x)) 65 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 66 | 67 | x = self.prelu3_1(self.conv5(x)) 68 | x = self.prelu3_2(self.conv6(x)) 69 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 70 | 71 | x = x.view(-1, 3*3*512) # Flatten 72 | 73 | features3d = self.fc1(x) 74 | # features2d = self.fc2(features3d) 75 | x = F.prelu(features3d, self.prelu_weight) 76 | 77 | x = self.fc3(x) 78 | 79 | return features3d, x 80 | 81 | class LMCL_loss(nn.Module): 82 | """ 83 | Refer to paper: 84 | Hao Wang, Yitong Wang, Zheng Zhou, Xing Ji, Dihong Gong, Jingchao Zhou,Zhifeng Li, and Wei Liu 85 | CosFace: Large Margin Cosine Loss for Deep Face Recognition. CVPR2018 86 | re-implement by yirong mao 87 | 2018 07/02 88 | """ 89 | 90 | def __init__(self, num_classes, feat_dim, device, s=7.00, m=0.2): 91 | super(LMCL_loss, self).__init__() 92 | self.feat_dim = feat_dim 93 | self.num_classes = num_classes 94 | self.s = s 95 | self.m = m 96 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 97 | self.device = device 98 | 99 | def forward(self, feat, label): 100 | batch_size = feat.shape[0] 101 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 102 | nfeat = torch.div(feat, norms) 103 | 104 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 105 | ncenters = torch.div(self.centers, norms_c) 106 | logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) 107 | 108 | y_onehot = torch.FloatTensor(batch_size, self.num_classes).to(self.device) 109 | y_onehot.zero_() 110 | y_onehot = Variable(y_onehot) 111 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 112 | margin_logits = self.s * (logits - y_onehot) 113 | 114 | return logits, margin_logits 115 | 116 | 117 | # def loss_function(output, target): 118 | # return F.nll_loss(F.log_softmax(output, dim=1), target) 119 | 120 | 121 | def train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch): 122 | model.train() 123 | for batch_idx, (data, target) in enumerate(train_loader): 124 | data, target = data.to(device), target.to(device) 125 | 126 | # optimizer.zero_grad() 127 | # output,_,_ = model(data) 128 | 129 | # loss = loss_function(output, target) 130 | 131 | # loss.backward() 132 | # optimizer.step() 133 | 134 | features, _ = model(data) 135 | logits, mlogits = loss_lmcl(features, target) 136 | loss = loss_softmax(mlogits, target) 137 | 138 | _, predicted = torch.max(logits.data, 1) 139 | accuracy = (target.data == predicted).float().mean() 140 | 141 | optimizer_nn.zero_grad() 142 | optimzer_lmcl.zero_grad() 143 | 144 | loss.backward() 145 | 146 | optimizer_nn.step() 147 | optimzer_lmcl.step() 148 | 149 | if batch_idx % LOG_INTERVAL == 0: 150 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 151 | epoch, batch_idx * len(data), len(train_loader.dataset), 152 | 100. * batch_idx / len(train_loader), loss.item())) 153 | 154 | 155 | def test(model, device, test_loader, loss_softmax, loss_lmcl): 156 | model.eval() 157 | # test_loss = 0 158 | correct = 0 159 | total = 0 160 | with torch.no_grad(): 161 | for data, target in test_loader: 162 | data, target = data.to(device), target.to(device) 163 | 164 | feats, _ = model(data) 165 | logits, mlogits = loss_lmcl(feats, target) 166 | _, predicted = torch.max(logits.data, 1) 167 | total += target.size(0) 168 | correct += (predicted == target.data).sum() 169 | 170 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 171 | 172 | 173 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 174 | correct, len(test_loader.dataset), 175 | 100. * correct / len(test_loader.dataset))) 176 | 177 | # output,_,_ = model(data) 178 | 179 | # test_loss += loss_function(output, target).item() # sum up batch loss 180 | 181 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 182 | # correct += pred.eq(target.view_as(pred)).sum().item() 183 | 184 | # test_loss /= len(test_loader.dataset) 185 | 186 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 187 | # test_loss, correct, len(test_loader.dataset), 188 | # 100. * correct / len(test_loader.dataset))) 189 | 190 | ################################################################### 191 | 192 | torch.manual_seed(1) 193 | device = torch.device("cuda" if use_cuda else "cpu") 194 | 195 | ####### Data setup 196 | 197 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 198 | train_loader = torch.utils.data.DataLoader( 199 | datasets.MNIST('./data', train=True, download=True, 200 | transform=transforms.Compose([ 201 | transforms.ToTensor(), 202 | transforms.Normalize((0.1307,), (0.3081,)) 203 | ])), 204 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 205 | test_loader = torch.utils.data.DataLoader( 206 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 207 | transforms.ToTensor(), 208 | transforms.Normalize((0.1307,), (0.3081,)) 209 | ])), 210 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 211 | 212 | ####### Model setup 213 | 214 | model = Net().to(device) 215 | loss_softmax = nn.CrossEntropyLoss().to(device) 216 | loss_lmcl = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 217 | 218 | # optimzer nn 219 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 220 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 10, gamma=0.5) 221 | 222 | # optimzer cosface or lmcl 223 | optimzer_lmcl = optim.SGD(loss_lmcl.parameters(), lr=0.01) 224 | sheduler_lmcl = lr_scheduler.StepLR(optimzer_lmcl, 10, gamma=0.5) 225 | 226 | 227 | for epoch in range(1, EPOCHS + 1): 228 | sheduler_nn.step() 229 | sheduler_lmcl.step() 230 | 231 | train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch) 232 | test(model, device, test_loader, loss_softmax, loss_lmcl) 233 | 234 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 235 | torch.save(loss_lmcl.state_dict(),"mnist_loss-cosface.pt") 236 | -------------------------------------------------------------------------------- /mnist_cosface4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | 24 | BATCH_SIZE_TEST = 1000 25 | EPOCHS = 20 26 | LOG_INTERVAL = 10 27 | 28 | class Net(nn.Module): 29 | def __init__(self): 30 | super(Net, self).__init__() 31 | krnl_sz=3 32 | strd = 1 33 | 34 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 35 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.prelu1_1 = nn.PReLU() 37 | self.prelu1_2 = nn.PReLU() 38 | 39 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 40 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.prelu2_1 = nn.PReLU() 42 | self.prelu2_2 = nn.PReLU() 43 | 44 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 45 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.prelu3_1 = nn.PReLU() 47 | self.prelu3_2 = nn.PReLU() 48 | 49 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 50 | 51 | self.fc1 = nn.Linear(3*3*512, 3) 52 | # self.fc2 = nn.Linear(3, 2) 53 | self.fc3 = nn.Linear(3, 10) 54 | 55 | def forward(self, x): 56 | mp_ks=2 57 | mp_strd=2 58 | 59 | x = self.prelu1_1(self.conv1(x)) 60 | x = self.prelu1_2(self.conv2(x)) 61 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 62 | 63 | x = self.prelu2_1(self.conv3(x)) 64 | x = self.prelu2_2(self.conv4(x)) 65 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 66 | 67 | x = self.prelu3_1(self.conv5(x)) 68 | x = self.prelu3_2(self.conv6(x)) 69 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 70 | 71 | x = x.view(-1, 3*3*512) # Flatten 72 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 73 | x = self.fc3(features3d) 74 | 75 | return features3d, x 76 | 77 | class LMCL_loss(nn.Module): 78 | """ 79 | Refer to paper: 80 | Hao Wang, Yitong Wang, Zheng Zhou, Xing Ji, Dihong Gong, Jingchao Zhou,Zhifeng Li, and Wei Liu 81 | CosFace: Large Margin Cosine Loss for Deep Face Recognition. CVPR2018 82 | re-implement by yirong mao 83 | 2018 07/02 84 | """ 85 | 86 | def __init__(self, num_classes, feat_dim, device, s=7.00, m=0.2): 87 | super(LMCL_loss, self).__init__() 88 | self.feat_dim = feat_dim 89 | self.num_classes = num_classes 90 | self.s = s 91 | self.m = m 92 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 93 | self.device = device 94 | 95 | def forward(self, feat, label): 96 | batch_size = feat.shape[0] 97 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 98 | nfeat = torch.div(feat, norms) 99 | 100 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 101 | ncenters = torch.div(self.centers, norms_c) 102 | logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) 103 | 104 | y_onehot = torch.FloatTensor(batch_size, self.num_classes).to(self.device) 105 | y_onehot.zero_() 106 | y_onehot = Variable(y_onehot) 107 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 108 | margin_logits = self.s * (logits - y_onehot) 109 | 110 | return logits, margin_logits 111 | 112 | 113 | # def loss_function(output, target): 114 | # return F.nll_loss(F.log_softmax(output, dim=1), target) 115 | 116 | 117 | def train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch): 118 | model.train() 119 | for batch_idx, (data, target) in enumerate(train_loader): 120 | data, target = data.to(device), target.to(device) 121 | 122 | # optimizer.zero_grad() 123 | # output,_,_ = model(data) 124 | 125 | # loss = loss_function(output, target) 126 | 127 | # loss.backward() 128 | # optimizer.step() 129 | 130 | features, _ = model(data) 131 | logits, mlogits = loss_lmcl(features, target) 132 | loss = loss_softmax(mlogits, target) 133 | 134 | _, predicted = torch.max(logits.data, 1) 135 | accuracy = (target.data == predicted).float().mean() 136 | 137 | optimizer_nn.zero_grad() 138 | optimzer_lmcl.zero_grad() 139 | 140 | loss.backward() 141 | 142 | optimizer_nn.step() 143 | optimzer_lmcl.step() 144 | 145 | if batch_idx % LOG_INTERVAL == 0: 146 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 147 | epoch, batch_idx * len(data), len(train_loader.dataset), 148 | 100. * batch_idx / len(train_loader), loss.item())) 149 | 150 | 151 | def test(model, device, test_loader, loss_softmax, loss_lmcl): 152 | model.eval() 153 | # test_loss = 0 154 | correct = 0 155 | total = 0 156 | with torch.no_grad(): 157 | for data, target in test_loader: 158 | data, target = data.to(device), target.to(device) 159 | 160 | feats, _ = model(data) 161 | logits, mlogits = loss_lmcl(feats, target) 162 | _, predicted = torch.max(logits.data, 1) 163 | total += target.size(0) 164 | correct += (predicted == target.data).sum() 165 | 166 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 167 | 168 | 169 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 170 | correct, len(test_loader.dataset), 171 | 100. * correct / len(test_loader.dataset))) 172 | 173 | # output,_,_ = model(data) 174 | 175 | # test_loss += loss_function(output, target).item() # sum up batch loss 176 | 177 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 178 | # correct += pred.eq(target.view_as(pred)).sum().item() 179 | 180 | # test_loss /= len(test_loader.dataset) 181 | 182 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 183 | # test_loss, correct, len(test_loader.dataset), 184 | # 100. * correct / len(test_loader.dataset))) 185 | 186 | ################################################################### 187 | 188 | torch.manual_seed(1) 189 | device = torch.device("cuda" if use_cuda else "cpu") 190 | 191 | ####### Data setup 192 | 193 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 194 | train_loader = torch.utils.data.DataLoader( 195 | datasets.MNIST('./data', train=True, download=True, 196 | transform=transforms.Compose([ 197 | transforms.ToTensor(), 198 | transforms.Normalize((0.1307,), (0.3081,)) 199 | ])), 200 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 201 | test_loader = torch.utils.data.DataLoader( 202 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 203 | transforms.ToTensor(), 204 | transforms.Normalize((0.1307,), (0.3081,)) 205 | ])), 206 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 207 | 208 | ####### Model setup 209 | 210 | model = Net().to(device) 211 | loss_softmax = nn.CrossEntropyLoss().to(device) 212 | loss_lmcl = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 213 | 214 | # optimzer nn 215 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 216 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.5) 217 | 218 | # optimzer cosface or lmcl 219 | optimzer_lmcl = optim.SGD(loss_lmcl.parameters(), lr=0.01) 220 | sheduler_lmcl = lr_scheduler.StepLR(optimzer_lmcl, 20, gamma=0.5) 221 | 222 | 223 | for epoch in range(1, EPOCHS + 1): 224 | sheduler_nn.step() 225 | sheduler_lmcl.step() 226 | 227 | train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch) 228 | test(model, device, test_loader, loss_softmax, loss_lmcl) 229 | 230 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 231 | torch.save(loss_lmcl.state_dict(),"mnist_loss-cosface.pt") 232 | -------------------------------------------------------------------------------- /mnist_cosface5_fc7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | 12 | # Cosface 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from torch.autograd.function import Function 17 | 18 | 19 | from pdb import set_trace as bp 20 | 21 | BATCH_SIZE = 100 22 | FEATURES_DIM = 3 23 | 24 | BATCH_SIZE_TEST = 1000 25 | EPOCHS = 20 26 | LOG_INTERVAL = 10 27 | 28 | class Net(nn.Module): 29 | def __init__(self): 30 | super(Net, self).__init__() 31 | krnl_sz=3 32 | strd = 1 33 | 34 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 35 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.prelu1_1 = nn.PReLU() 37 | self.prelu1_2 = nn.PReLU() 38 | 39 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 40 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.prelu2_1 = nn.PReLU() 42 | self.prelu2_2 = nn.PReLU() 43 | 44 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 45 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.prelu3_1 = nn.PReLU() 47 | self.prelu3_2 = nn.PReLU() 48 | 49 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 50 | 51 | self.fc1 = nn.Linear(3*3*512, 3) 52 | # self.fc2 = nn.Linear(3, 2) 53 | self.fc3 = nn.Linear(3, 10) 54 | 55 | def forward(self, x): 56 | mp_ks=2 57 | mp_strd=2 58 | 59 | x = self.prelu1_1(self.conv1(x)) 60 | x = self.prelu1_2(self.conv2(x)) 61 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 62 | 63 | x = self.prelu2_1(self.conv3(x)) 64 | x = self.prelu2_2(self.conv4(x)) 65 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 66 | 67 | x = self.prelu3_1(self.conv5(x)) 68 | x = self.prelu3_2(self.conv6(x)) 69 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 70 | 71 | x = x.view(-1, 3*3*512) # Flatten 72 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 73 | x = self.fc3(features3d) 74 | 75 | return features3d, x 76 | 77 | class LMCL_loss(nn.Module): 78 | 79 | def __init__(self, num_classes, feat_dim, device, s=7.00, m=0.2): 80 | super(LMCL_loss, self).__init__() 81 | self.feat_dim = feat_dim 82 | self.num_classes = num_classes 83 | self.s = s 84 | self.m = m 85 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 86 | self.device = device 87 | self.s_m = s*m 88 | 89 | def forward(self, feat, label): 90 | batch_size = feat.shape[0] 91 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 92 | feat_l2norm = torch.div(feat, norms) 93 | feat_l2norm = feat_l2norm * self.s 94 | 95 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 96 | weights_l2norm = torch.div(self.weights, norms_w) 97 | 98 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 99 | 100 | y_onehot = torch.FloatTensor(batch_size, self.num_classes).to(self.device) 101 | y_onehot.zero_() 102 | y_onehot = Variable(y_onehot) 103 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.s_m) 104 | output = fc7 - y_onehot 105 | 106 | return output 107 | 108 | 109 | # def loss_function(output, target): 110 | # return F.nll_loss(F.log_softmax(output, dim=1), target) 111 | 112 | 113 | def train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch): 114 | model.train() 115 | for batch_idx, (data, target) in enumerate(train_loader): 116 | data, target = data.to(device), target.to(device) 117 | 118 | # optimizer.zero_grad() 119 | # output,_,_ = model(data) 120 | 121 | # loss = loss_function(output, target) 122 | 123 | # loss.backward() 124 | # optimizer.step() 125 | 126 | features, _ = model(data) 127 | logits = loss_lmcl(features, target) 128 | loss = loss_softmax(logits, target) 129 | 130 | _, predicted = torch.max(logits.data, 1) 131 | accuracy = (target.data == predicted).float().mean() 132 | 133 | optimizer_nn.zero_grad() 134 | optimzer_lmcl.zero_grad() 135 | 136 | loss.backward() 137 | 138 | optimizer_nn.step() 139 | optimzer_lmcl.step() 140 | 141 | if batch_idx % LOG_INTERVAL == 0: 142 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 143 | epoch, batch_idx * len(data), len(train_loader.dataset), 144 | 100. * batch_idx / len(train_loader), loss.item())) 145 | 146 | 147 | def test(model, device, test_loader, loss_softmax, loss_lmcl): 148 | model.eval() 149 | # test_loss = 0 150 | correct = 0 151 | total = 0 152 | with torch.no_grad(): 153 | for data, target in test_loader: 154 | data, target = data.to(device), target.to(device) 155 | 156 | feats, _ = model(data) 157 | logits = loss_lmcl(feats, target) 158 | _, predicted = torch.max(logits.data, 1) 159 | total += target.size(0) 160 | correct += (predicted == target.data).sum() 161 | 162 | # print('Test Accuracy of the model on the 10000 test images: %f %%' % (100 * correct / total)) 163 | 164 | 165 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 166 | correct, len(test_loader.dataset), 167 | 100. * correct / len(test_loader.dataset))) 168 | 169 | # output,_,_ = model(data) 170 | 171 | # test_loss += loss_function(output, target).item() # sum up batch loss 172 | 173 | # pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 174 | # correct += pred.eq(target.view_as(pred)).sum().item() 175 | 176 | # test_loss /= len(test_loader.dataset) 177 | 178 | # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 179 | # test_loss, correct, len(test_loader.dataset), 180 | # 100. * correct / len(test_loader.dataset))) 181 | 182 | ################################################################### 183 | 184 | torch.manual_seed(1) 185 | device = torch.device("cuda" if use_cuda else "cpu") 186 | 187 | ####### Data setup 188 | 189 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 190 | train_loader = torch.utils.data.DataLoader( 191 | datasets.MNIST('./data', train=True, download=True, 192 | transform=transforms.Compose([ 193 | transforms.ToTensor(), 194 | transforms.Normalize((0.1307,), (0.3081,)) 195 | ])), 196 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 197 | test_loader = torch.utils.data.DataLoader( 198 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 199 | transforms.ToTensor(), 200 | transforms.Normalize((0.1307,), (0.3081,)) 201 | ])), 202 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 203 | 204 | ####### Model setup 205 | 206 | model = Net().to(device) 207 | loss_softmax = nn.CrossEntropyLoss().to(device) 208 | loss_lmcl = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 209 | 210 | # optimzer nn 211 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 212 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.5) 213 | 214 | # optimzer cosface or lmcl 215 | optimzer_lmcl = optim.SGD(loss_lmcl.parameters(), lr=0.01) 216 | sheduler_lmcl = lr_scheduler.StepLR(optimzer_lmcl, 20, gamma=0.5) 217 | 218 | 219 | for epoch in range(1, EPOCHS + 1): 220 | sheduler_nn.step() 221 | sheduler_lmcl.step() 222 | 223 | train(model, device, train_loader, loss_softmax, loss_lmcl, optimizer_nn, optimzer_lmcl, epoch) 224 | test(model, device, test_loader, loss_softmax, loss_lmcl) 225 | 226 | torch.save(model.state_dict(),"mnist_cnn-cosface.pt") 227 | torch.save(loss_lmcl.state_dict(),"mnist_loss-cosface.pt") 228 | -------------------------------------------------------------------------------- /mnist_loss-cosface.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/egcode/pytorch-losses/03e7cd11900e3678a34fe9d1d32964993ab2a926/mnist_loss-cosface.pt -------------------------------------------------------------------------------- /mnist_softmax.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | import numpy as np 8 | 9 | print("Pytorch version: " + str(torch.__version__)) 10 | 11 | use_cuda = torch.cuda.is_available() 12 | 13 | print("Use CUDA: " + str(use_cuda)) 14 | 15 | 16 | BATCH_SIZE = 64 17 | BATCH_SIZE_TEST = 1000 18 | EPOCHS = 20 19 | LOG_INTERVAL = 10 20 | 21 | 22 | class Net(nn.Module): 23 | def __init__(self): 24 | super(Net, self).__init__() 25 | krnl_sz=3 26 | strd = 1 27 | 28 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 29 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 30 | 31 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 32 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 33 | 34 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 35 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 36 | 37 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 38 | 39 | self.fc1 = nn.Linear(3*3*512, 3) 40 | self.fc2 = nn.Linear(3, 2) 41 | self.fc3 = nn.Linear(2, 10) 42 | 43 | def forward(self, x): 44 | mp_ks=2 45 | mp_strd=2 46 | 47 | x = F.relu(self.conv1(x)) 48 | x = F.relu(self.conv2(x)) 49 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 50 | 51 | x = F.relu(self.conv3(x)) 52 | x = F.relu(self.conv4(x)) 53 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 54 | 55 | x = F.relu(self.conv5(x)) 56 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 57 | 58 | x = x.view(-1, 3*3*512) # Flatten 59 | 60 | features3d = self.fc1(x) 61 | features2d = self.fc2(features3d) 62 | x = F.prelu(features2d, self.prelu_weight) 63 | 64 | x = self.fc3(x) 65 | 66 | return x, features3d, features2d 67 | 68 | 69 | 70 | def loss_function(output, target): 71 | return F.nll_loss(F.log_softmax(output, dim=1), target) 72 | 73 | 74 | def train(model, device, train_loader, optimizer, epoch): 75 | model.train() 76 | for batch_idx, (data, target) in enumerate(train_loader): 77 | data, target = data.to(device), target.to(device) 78 | optimizer.zero_grad() 79 | output,_,_ = model(data) 80 | 81 | # loss = F.nll_loss(F.log_softmax(output, dim=1), target) 82 | loss = loss_function(output, target) 83 | 84 | loss.backward() 85 | optimizer.step() 86 | if batch_idx % LOG_INTERVAL == 0: 87 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 88 | epoch, batch_idx * len(data), len(train_loader.dataset), 89 | 100. * batch_idx / len(train_loader), loss.item())) 90 | 91 | def test(model, device, test_loader): 92 | model.eval() 93 | test_loss = 0 94 | correct = 0 95 | with torch.no_grad(): 96 | for data, target in test_loader: 97 | data, target = data.to(device), target.to(device) 98 | output,_,_ = model(data) 99 | 100 | # test_loss += F.nll_loss(F.log_softmax(output, dim=1), target).item() # sum up batch loss 101 | 102 | test_loss += loss_function(output, target).item() # sum up batch loss 103 | 104 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 105 | correct += pred.eq(target.view_as(pred)).sum().item() 106 | 107 | test_loss /= len(test_loader.dataset) 108 | 109 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 110 | test_loss, correct, len(test_loader.dataset), 111 | 100. * correct / len(test_loader.dataset))) 112 | 113 | 114 | 115 | 116 | torch.manual_seed(1) 117 | 118 | device = torch.device("cuda" if use_cuda else "cpu") 119 | 120 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 121 | train_loader = torch.utils.data.DataLoader( 122 | datasets.MNIST('./data', train=True, download=True, 123 | transform=transforms.Compose([ 124 | transforms.ToTensor(), 125 | transforms.Normalize((0.1307,), (0.3081,)) 126 | ])), 127 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 128 | test_loader = torch.utils.data.DataLoader( 129 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 130 | transforms.ToTensor(), 131 | transforms.Normalize((0.1307,), (0.3081,)) 132 | ])), 133 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 134 | 135 | 136 | 137 | 138 | model = Net().to(device) 139 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 140 | 141 | for epoch in range(1, EPOCHS + 1): 142 | train(model, device, train_loader, optimizer, epoch) 143 | test(model, device, test_loader) 144 | 145 | torch.save(model.state_dict(),"mnist_cnn-softmax.pt") 146 | -------------------------------------------------------------------------------- /mnist_softmax_custom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | import numpy as np 8 | print("Pytorch version: " + str(torch.__version__)) 9 | use_cuda = torch.cuda.is_available() 10 | print("Use CUDA: " + str(use_cuda)) 11 | from pdb import set_trace as bp 12 | torch.set_printoptions(threshold=1000000) 13 | 14 | BATCH_SIZE = 64 15 | BATCH_SIZE_TEST = 1000 16 | EPOCHS = 20 17 | LOG_INTERVAL = 10 18 | NUM_OF_CLASSES = 10 19 | 20 | 21 | torch.manual_seed(1) 22 | 23 | device = torch.device("cuda" if use_cuda else "cpu") 24 | 25 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 26 | train_loader = torch.utils.data.DataLoader( 27 | datasets.MNIST('./data', train=True, download=True, 28 | transform=transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.1307,), (0.3081,)) 31 | ])), 32 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 33 | test_loader = torch.utils.data.DataLoader( 34 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.1307,), (0.3081,)) 37 | ])), 38 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 39 | 40 | 41 | class Net(nn.Module): 42 | def __init__(self): 43 | super(Net, self).__init__() 44 | krnl_sz=3 45 | strd = 1 46 | 47 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 48 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 49 | 50 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 51 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 52 | 53 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 54 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 55 | 56 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 57 | 58 | self.fc1 = nn.Linear(3*3*512, 3) 59 | self.fc2 = nn.Linear(3, 2) 60 | self.fc3 = nn.Linear(2, 10) 61 | 62 | def forward(self, x): 63 | mp_ks=2 64 | mp_strd=2 65 | 66 | x = F.relu(self.conv1(x)) 67 | x = F.relu(self.conv2(x)) 68 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 69 | 70 | x = F.relu(self.conv3(x)) 71 | x = F.relu(self.conv4(x)) 72 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 73 | 74 | x = F.relu(self.conv5(x)) 75 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 76 | 77 | x = x.view(-1, 3*3*512) # Flatten 78 | 79 | features3d = self.fc1(x) 80 | features2d = self.fc2(features3d) 81 | x = F.prelu(features2d, self.prelu_weight) 82 | 83 | x = self.fc3(x) 84 | 85 | return x, features3d, features2d 86 | 87 | 88 | class CrossEntropyCustom(nn.Module): 89 | 90 | def __init__(self): 91 | super(CrossEntropyCustom, self).__init__() 92 | self.sm = nn.Softmax() 93 | self.lsm = nn.LogSoftmax() 94 | 95 | self.nll = nn.NLLLoss() 96 | 97 | def forward(self, input, target): 98 | 99 | ###################################################### 100 | #################### Log Softmax ######################### 101 | 102 | # Stable Logsoftmax - 103 | b = torch.max(input) 104 | presum = torch.exp(input - b) 105 | prelog = presum.sum(-1).unsqueeze(-1) 106 | prelog = prelog.clamp(min=1e-33, max=1e+33) # for numerical stability 107 | log = torch.log(prelog) 108 | log_probabilities = (input - b) - log 109 | 110 | # --- Pytorch WORKING Logsoftmax 111 | # log_probabilities = self.lsm(input).cpu() 112 | 113 | 114 | ###################################################### 115 | #################### NLLLoss ######################### 116 | 117 | target = target.cpu() ## To remove error on gpu 118 | log_probabilities = log_probabilities.cpu() ## To remove error on gpu 119 | 120 | m = target.shape[0] 121 | 122 | ## NLLLoss V1 123 | cross_entropy = torch.zeros(log_probabilities.size()) 124 | for i in range(m): 125 | value = log_probabilities[i,target[i].long()] 126 | cross_entropy[i,target[i].long()] = value 127 | 128 | ## NLLLoss V2 129 | # target_one_hot = torch.zeros(len(target), NUM_OF_CLASSES).scatter_(1, target.unsqueeze(1), 1.) 130 | # cross_entropy = torch.addcmul(torch.zeros(log_probabilities.size()), 1., log_probabilities, target_one_hot) 131 | 132 | loss = -(1./m) * torch.sum(cross_entropy) 133 | return loss 134 | 135 | 136 | def train(model, loss_custom, device, train_loader, optimizer, epoch): 137 | model.train() 138 | for batch_idx, (data, target) in enumerate(train_loader): 139 | data, target = data.to(device), target.to(device) 140 | optimizer.zero_grad() 141 | output,_,_ = model(data) 142 | 143 | loss = loss_custom(output, target) 144 | 145 | optimizer.zero_grad() # clear previous gradients 146 | loss.backward() # compute gradients of all variables wrt loss 147 | 148 | df = nn.CrossEntropyLoss() 149 | optimizer.step() # perform updates using calculated gradients 150 | 151 | if batch_idx % LOG_INTERVAL == 0: 152 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 153 | epoch, batch_idx * len(data), len(train_loader.dataset), 154 | 100. * batch_idx / len(train_loader), loss.item())) 155 | 156 | def test(model, loss_custom, device, test_loader): 157 | model.eval() 158 | test_loss = 0 159 | correct = 0 160 | with torch.no_grad(): 161 | for data, target in test_loader: 162 | data, target = data.to(device), target.to(device) 163 | output,_,_ = model(data) 164 | 165 | # test_loss += loss_function(output, target).item() # sum up batch loss 166 | test_loss += loss_custom(output, target).item() # sum up batch loss 167 | 168 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 169 | correct += pred.eq(target.view_as(pred)).sum().item() 170 | 171 | test_loss /= len(test_loader.dataset) 172 | 173 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 174 | test_loss, correct, len(test_loader.dataset), 175 | 100. * correct / len(test_loader.dataset))) 176 | 177 | 178 | 179 | model = Net() 180 | # model.load_state_dict(torch.load("mnist_cnn-softmax2.pt")) 181 | model = model.to(device) 182 | 183 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 184 | 185 | loss_custom = CrossEntropyCustom().to(device) 186 | 187 | for epoch in range(1, EPOCHS + 1): 188 | train(model, loss_custom, device, train_loader, optimizer, epoch) 189 | test(model, loss_custom, device, test_loader) 190 | torch.save(model.state_dict(),"mnist_cnn-softmax2.pt") 191 | -------------------------------------------------------------------------------- /playground.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from pdb import set_trace as bp 8 | 9 | # a = torch.Tensor([float('NaN'), 1, float('NaN'), 2, 3]) 10 | # print(a) 11 | # a[a != a] = 0 12 | # print(a) 13 | 14 | 15 | 16 | 17 | 18 | x = torch.Tensor([[1., 2., 2., 3., 4., 5., 6., 7., 8., 9.], [-8.2592, -7.9788, -5.2605, -4.8818, -3.7099, -2.5116, -1.2812, -0.7652, -0.1487, -0.8805]]) 19 | indeces = torch.Tensor([6, 2]) 20 | 21 | print(x) 22 | print("x shape: " + str(x.shape)) 23 | print(indeces) 24 | print("indeces shape: " + str(indeces.shape)) 25 | 26 | result = torch.zeros(x.size()) 27 | for i in range(x.shape[0]): 28 | # print("\n") 29 | # print("iter: " + str(i)) 30 | value = x[i,indeces[i].long()] 31 | # print(value) 32 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 33 | result[i,indeces[i].long()] = value 34 | # print("\n") 35 | 36 | # print("result: ") 37 | # print(result) 38 | -------------------------------------------------------------------------------- /plot_to_gif.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib import cm 3 | from mpl_toolkits.mplot3d import axes3d 4 | import os, sys 5 | import numpy as np 6 | 7 | from mpl_toolkits import mplot3d 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torchvision import datasets, transforms 14 | 15 | print("Pytorch version: " + str(torch.__version__)) 16 | use_cuda = torch.cuda.is_available() 17 | print("Use CUDA: " + str(use_cuda)) 18 | 19 | # Cosface 20 | from torch.autograd import Variable 21 | from torch.utils.data import DataLoader 22 | import torch.optim.lr_scheduler as lr_scheduler 23 | from torch.autograd.function import Function 24 | 25 | 26 | from pdb import set_trace as bp 27 | 28 | BATCH_SIZE = 100 29 | FEATURES_DIM = 3 30 | 31 | BATCH_SIZE_TEST = 1000 32 | EPOCHS = 20 33 | LOG_INTERVAL = 10 34 | 35 | ##### TO CREATE A SERIES OF PICTURES 36 | 37 | def make_views(ax,angles,elevation=None, width=16, height = 9, 38 | prefix='tmprot_',**kwargs): 39 | """ 40 | Makes jpeg pictures of the given 3d ax, with different angles. 41 | Args: 42 | ax (3D axis): te ax 43 | angles (list): the list of angles (in degree) under which to 44 | take the picture. 45 | width,height (float): size, in inches, of the output images. 46 | prefix (str): prefix for the files created. 47 | 48 | Returns: the list of files created (for later removal) 49 | """ 50 | 51 | files = [] 52 | ax.figure.set_size_inches(width,height) 53 | 54 | for i,angle in enumerate(angles): 55 | 56 | ax.view_init(elev = elevation, azim=angle) 57 | fname = '%s%03d.jpeg'%(prefix,i) 58 | ax.figure.savefig(fname) 59 | files.append(fname) 60 | 61 | return files 62 | 63 | 64 | 65 | ##### TO TRANSFORM THE SERIES OF PICTURE INTO AN ANIMATION 66 | 67 | def make_movie(files,output, fps=10,bitrate=1800,**kwargs): 68 | """ 69 | Uses mencoder, produces a .mp4/.ogv/... movie from a list of 70 | picture files. 71 | """ 72 | 73 | output_name, output_ext = os.path.splitext(output) 74 | command = { '.mp4' : 'mencoder "mf://%s" -mf fps=%d -o %s.mp4 -ovc lavc\ 75 | -lavcopts vcodec=msmpeg4v2:vbitrate=%d' 76 | %(",".join(files),fps,output_name,bitrate)} 77 | 78 | command['.ogv'] = command['.mp4'] + '; ffmpeg -i %s.mp4 -r %d %s'%(output_name,fps,output) 79 | 80 | print(command[output_ext]) 81 | output_ext = os.path.splitext(output)[1] 82 | os.system(command[output_ext]) 83 | 84 | 85 | 86 | def make_gif(files,output,delay=100, repeat=True,**kwargs): 87 | """ 88 | Uses imageMagick to produce an animated .gif from a list of 89 | picture files. 90 | """ 91 | 92 | loop = -1 if repeat else 0 93 | os.system('convert -delay %d -loop %d %s %s' 94 | %(delay,loop," ".join(files),output)) 95 | 96 | 97 | 98 | 99 | def make_strip(files,output,**kwargs): 100 | """ 101 | Uses imageMagick to produce a .jpeg strip from a list of 102 | picture files. 103 | """ 104 | 105 | os.system('montage -tile 1x -geometry +0+0 %s %s'%(" ".join(files),output)) 106 | 107 | 108 | 109 | ##### MAIN FUNCTION 110 | 111 | def rotanimate(ax, angles, output, **kwargs): 112 | """ 113 | Produces an animation (.mp4,.ogv,.gif,.jpeg,.png) from a 3D plot on 114 | a 3D ax 115 | 116 | Args: 117 | ax (3D axis): the ax containing the plot of interest 118 | angles (list): the list of angles (in degree) under which to 119 | show the plot. 120 | output : name of the output file. The extension determines the 121 | kind of animation used. 122 | **kwargs: 123 | - width : in inches 124 | - heigth: in inches 125 | - framerate : frames per second 126 | - delay : delay between frames in milliseconds 127 | - repeat : True or False (.gif only) 128 | """ 129 | 130 | output_ext = os.path.splitext(output)[1] 131 | 132 | files = make_views(ax,angles, **kwargs) 133 | 134 | D = { '.mp4' : make_movie, 135 | '.ogv' : make_movie, 136 | '.gif': make_gif , 137 | '.jpeg': make_strip, 138 | '.png':make_strip} 139 | 140 | D[output_ext](files,output,**kwargs) 141 | 142 | for f in files: 143 | os.remove(f) 144 | 145 | 146 | class Net(nn.Module): 147 | def __init__(self): 148 | super(Net, self).__init__() 149 | krnl_sz=3 150 | strd = 1 151 | 152 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 153 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 154 | self.prelu1_1 = nn.PReLU() 155 | self.prelu1_2 = nn.PReLU() 156 | 157 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 158 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 159 | self.prelu2_1 = nn.PReLU() 160 | self.prelu2_2 = nn.PReLU() 161 | 162 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 163 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 164 | self.prelu3_1 = nn.PReLU() 165 | self.prelu3_2 = nn.PReLU() 166 | 167 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 168 | 169 | self.fc1 = nn.Linear(3*3*512, 3) 170 | # self.fc2 = nn.Linear(3, 2) 171 | self.fc3 = nn.Linear(3, 10) 172 | 173 | def forward(self, x): 174 | mp_ks=2 175 | mp_strd=2 176 | 177 | x = self.prelu1_1(self.conv1(x)) 178 | x = self.prelu1_2(self.conv2(x)) 179 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 180 | 181 | x = self.prelu2_1(self.conv3(x)) 182 | x = self.prelu2_2(self.conv4(x)) 183 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 184 | 185 | x = self.prelu3_1(self.conv5(x)) 186 | x = self.prelu3_2(self.conv6(x)) 187 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 188 | 189 | x = x.view(-1, 3*3*512) # Flatten 190 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 191 | x = self.fc3(features3d) 192 | 193 | return features3d, x 194 | 195 | class LMCL_loss(nn.Module): 196 | 197 | def __init__(self, num_classes, feat_dim, device, s=7.00, m=0.2): 198 | super(LMCL_loss, self).__init__() 199 | self.feat_dim = feat_dim 200 | self.num_classes = num_classes 201 | self.s = s 202 | self.m = m 203 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 204 | self.device = device 205 | self.s_m = s*m 206 | 207 | def forward(self, feat, label): 208 | batch_size = feat.shape[0] 209 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 210 | feat_l2norm = torch.div(feat, norms) 211 | feat_l2norm = feat_l2norm * self.s 212 | 213 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 214 | weights_l2norm = torch.div(self.weights, norms_w) 215 | 216 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 217 | 218 | y_onehot = torch.FloatTensor(batch_size, self.num_classes).to(self.device) 219 | y_onehot.zero_() 220 | y_onehot = Variable(y_onehot) 221 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.s_m) 222 | output = fc7 - y_onehot 223 | 224 | return output 225 | 226 | 227 | 228 | ##### EXAMPLE 229 | 230 | if __name__ == '__main__': 231 | 232 | # fig = plt.figure() 233 | # ax = fig.add_subplot(111, projection='3d') 234 | # X, Y, Z = axes3d.get_test_data(0.05) 235 | # s = ax.plot_surface(X, Y, Z, cmap=cm.jet) 236 | # plt.axis('off') # remove axes for visual appeal 237 | 238 | # angles = np.linspace(0,360,21)[:-1] # Take 20 angles between 0 and 360 239 | 240 | # # create an animated gif (20ms between frames) 241 | # rotanimate(ax, angles,'movie.gif',delay=20) 242 | 243 | torch.manual_seed(1) 244 | 245 | ####### Data setup 246 | 247 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 248 | train_loader = torch.utils.data.DataLoader( 249 | datasets.MNIST('./data', train=True, download=True, 250 | transform=transforms.Compose([ 251 | transforms.ToTensor(), 252 | transforms.Normalize((0.1307,), (0.3081,)) 253 | ])), 254 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 255 | test_loader = torch.utils.data.DataLoader( 256 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 257 | transforms.ToTensor(), 258 | transforms.Normalize((0.1307,), (0.3081,)) 259 | ])), 260 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 261 | 262 | 263 | 264 | 265 | device = torch.device("cuda" if use_cuda else "cpu") 266 | model = Net() 267 | model.eval() 268 | model.load_state_dict(torch.load("mnist_cnn-cosface.pt", map_location='cpu')) 269 | model.to(device) 270 | 271 | ind = 142 272 | 273 | image = test_loader.dataset[ind][0].numpy().reshape(28,28) 274 | lbl = test_loader.dataset[ind][1].numpy() 275 | 276 | 277 | image_tensor, label_tensor = test_loader.dataset[ind] 278 | image_tensor = image_tensor.reshape(1,1,28,28) 279 | image_tensor, label_tensor = image_tensor.to(device), label_tensor.to(device) 280 | 281 | lmcl_loss = LMCL_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device) 282 | lmcl_loss.eval() 283 | lmcl_loss.load_state_dict(torch.load("mnist_loss-cosface.pt", map_location='cpu')) 284 | lmcl_loss.to(device) 285 | 286 | features3d, pr = model(image_tensor) 287 | logits = lmcl_loss(features3d, torch.unsqueeze(label_tensor, dim=-1)) 288 | _, prediction = torch.max(logits.data, 1) 289 | prediction = prediction.cpu().detach().numpy()[0] 290 | 291 | # print ("PREDICTION : " + str(prediction) ) 292 | 293 | f3d = [] 294 | # f2d = [] 295 | lbls = [] 296 | for i in range(10000): 297 | image_tensor, label_tensor = test_loader.dataset[i] 298 | image_tensor = image_tensor.reshape(1,1,28,28) 299 | image_tensor, label_tensor = image_tensor.to(device), label_tensor.to(device) 300 | 301 | features3d, pr = model(image_tensor) 302 | logits = lmcl_loss(features3d, torch.unsqueeze(label_tensor, dim=-1)) 303 | _, prediction = torch.max(logits.data, 1) 304 | 305 | f3d.append(features3d[0].cpu().detach().numpy()) 306 | # f2d.append(features2d[0].cpu().detach().numpy()) 307 | 308 | prediction = prediction.cpu().detach().numpy()[0] 309 | lbls.append(prediction) 310 | 311 | # print("features3d: " + str(features3d[0].detach().numpy())) 312 | # print("features2d: " + str(features2d[0].detach().numpy())) 313 | 314 | # feat3d = np.array(f3d) 315 | # print("3d features shape" + str(feat3d.shape)) 316 | 317 | feat3d = np.array(f3d) 318 | print("3d features shape" + str(feat3d.shape)) 319 | 320 | lbls = np.array(lbls) 321 | print("labels shape" + str(lbls.shape)) 322 | 323 | 324 | 325 | fig = plt.figure(figsize=(16,9)) 326 | ax = plt.axes(projection='3d') 327 | 328 | for i in range(10): 329 | # Data for three-dimensional scattered points 330 | xdata = feat3d[lbls==i,2].flatten() 331 | ydata = feat3d[lbls==i,0].flatten() 332 | zdata = feat3d[lbls==i,1].flatten() 333 | ax.scatter3D(xdata, ydata, zdata); 334 | ax.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],loc='center left', bbox_to_anchor=(1, 0.5)) 335 | 336 | # plt.show() 337 | 338 | # angles = np.linspace(0,360,21)[:-1] # Take 20 angles between 0 and 360 339 | angles = np.linspace(0,360,181)[:-1] # Take 20 angles between 0 and 360 340 | 341 | # create an animated gif (30ms between frames) 342 | rotanimate(ax, angles,'movie.gif',delay=10) 343 | 344 | -------------------------------------------------------------------------------- /test_arcface_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | from datetime import datetime, timedelta 8 | import time 9 | 10 | print("Pytorch version: " + str(torch.__version__)) 11 | use_cuda = torch.cuda.is_available() 12 | print("Use CUDA: " + str(use_cuda)) 13 | 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | import torch.optim.lr_scheduler as lr_scheduler 17 | from torch.autograd.function import Function 18 | import math 19 | 20 | from pdb import set_trace as bp 21 | 22 | BATCH_SIZE = 100 23 | FEATURES_DIM = 3 24 | NUM_OF_CLASSES = 10 25 | BATCH_SIZE_TEST = 1000 26 | EPOCHS = 20 27 | LOG_INTERVAL = 10 28 | 29 | class Net(nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | krnl_sz=3 33 | strd = 1 34 | 35 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=krnl_sz, stride=strd, padding=1) 36 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=krnl_sz, stride=strd, padding=1) 37 | self.prelu1_1 = nn.PReLU() 38 | self.prelu1_2 = nn.PReLU() 39 | 40 | self.conv3 = nn.Conv2d(in_channels=50, out_channels=64, kernel_size=krnl_sz, stride=strd, padding=1) 41 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=krnl_sz, stride=strd, padding=1) 42 | self.prelu2_1 = nn.PReLU() 43 | self.prelu2_2 = nn.PReLU() 44 | 45 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 46 | self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=krnl_sz, stride=strd, padding=1) 47 | self.prelu3_1 = nn.PReLU() 48 | self.prelu3_2 = nn.PReLU() 49 | 50 | self.prelu_weight = nn.Parameter(torch.Tensor(1).fill_(0.25)) 51 | 52 | self.fc1 = nn.Linear(3*3*512, 3) 53 | self.fc3 = nn.Linear(3, 10) 54 | 55 | def forward(self, x): 56 | mp_ks=2 57 | mp_strd=2 58 | 59 | x = self.prelu1_1(self.conv1(x)) 60 | x = self.prelu1_2(self.conv2(x)) 61 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 62 | 63 | x = self.prelu2_1(self.conv3(x)) 64 | x = self.prelu2_2(self.conv4(x)) 65 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 66 | 67 | x = self.prelu3_1(self.conv5(x)) 68 | x = self.prelu3_2(self.conv6(x)) 69 | x = F.max_pool2d(x, kernel_size=mp_ks, stride=mp_strd) 70 | 71 | x = x.view(-1, 3*3*512) # Flatten 72 | features3d = F.prelu(self.fc1(x), self.prelu_weight) 73 | x = self.fc3(features3d) 74 | 75 | return features3d, x 76 | 77 | class Arcface_loss(nn.Module): 78 | def __init__(self, num_classes, feat_dim, device, s=7.0, m=0.2): 79 | super(Arcface_loss, self).__init__() 80 | self.feat_dim = feat_dim 81 | self.num_classes = num_classes 82 | self.s = s 83 | self.m = m 84 | self.weights = nn.Parameter(torch.randn(num_classes, feat_dim)) 85 | self.device = device 86 | 87 | self.cos_m = math.cos(m) 88 | self.sin_m = math.sin(m) 89 | self.mm = math.sin(math.pi-m)*m 90 | self.threshold = math.cos(math.pi-m) 91 | 92 | def forward(self, feat, label): 93 | eps = 1e-4 94 | batch_size = feat.shape[0] 95 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 96 | feat_l2norm = torch.div(feat, norms) 97 | feat_l2norm = feat_l2norm * self.s 98 | 99 | norms_w = torch.norm(self.weights, p=2, dim=-1, keepdim=True) 100 | weights_l2norm = torch.div(self.weights, norms_w) 101 | 102 | fc7 = torch.matmul(feat_l2norm, torch.transpose(weights_l2norm, 0, 1)) 103 | 104 | if torch.cuda.is_available(): 105 | label = label.cuda() 106 | fc7 = fc7.cuda() 107 | else: 108 | label = label.cpu() 109 | fc7 = fc7.cpu() 110 | 111 | target_one_hot = torch.zeros(len(label), NUM_OF_CLASSES).to(self.device) 112 | target_one_hot = target_one_hot.scatter_(1, label.unsqueeze(1), 1.) 113 | zy = torch.addcmul(torch.zeros(fc7.size()).to(self.device), 1., fc7, target_one_hot) 114 | zy = zy.sum(-1) 115 | 116 | cos_theta = zy/self.s 117 | cos_theta = cos_theta.clamp(min=-1+eps, max=1-eps) # for numerical stability 118 | 119 | theta = torch.acos(cos_theta) 120 | theta = theta+self.m 121 | 122 | body = torch.cos(theta) 123 | new_zy = body*self.s 124 | 125 | diff = new_zy - zy 126 | diff = diff.unsqueeze(1) 127 | 128 | body = torch.addcmul(torch.zeros(diff.size()).to(self.device), 1., diff, target_one_hot) 129 | output = fc7+body 130 | 131 | return output.to(self.device) 132 | 133 | 134 | def train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch): 135 | model.train() 136 | for batch_idx, (data, target) in enumerate(train_loader): 137 | data, target = data.to(device), target.to(device) 138 | 139 | features, _ = model(data) 140 | logits = loss_arcface(features, target) 141 | loss = loss_softmax(logits, target) 142 | 143 | _, predicted = torch.max(logits.data, 1) 144 | accuracy = (target.data == predicted).float().mean() 145 | 146 | optimizer_nn.zero_grad() 147 | optimzer_arcface.zero_grad() 148 | 149 | loss.backward() 150 | 151 | optimizer_nn.step() 152 | optimzer_arcface.step() 153 | 154 | if batch_idx % LOG_INTERVAL == 0: 155 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 156 | epoch, batch_idx * len(data), len(train_loader.dataset), 157 | 100. * batch_idx / len(train_loader), loss.item())) 158 | 159 | 160 | def test(model, device, test_loader, loss_softmax, loss_arcface): 161 | model.eval() 162 | correct = 0 163 | total = 0 164 | with torch.no_grad(): 165 | for data, target in test_loader: 166 | data, target = data.to(device), target.to(device) 167 | 168 | feats, _ = model(data) 169 | logits = loss_arcface(feats, target) 170 | _, predicted = torch.max(logits.data, 1) 171 | total += target.size(0) 172 | correct += (predicted == target.data).sum() 173 | 174 | print('\nTest set:, Accuracy: {}/{} ({:.0f}%)\n'.format( 175 | correct, len(test_loader.dataset), 176 | 100. * correct / len(test_loader.dataset))) 177 | 178 | ################################################################### 179 | 180 | torch.manual_seed(1) 181 | device = torch.device("cuda" if use_cuda else "cpu") 182 | 183 | ####### Data setup 184 | 185 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 186 | train_loader = torch.utils.data.DataLoader( 187 | datasets.MNIST('./data', train=True, download=True, 188 | transform=transforms.Compose([ 189 | transforms.ToTensor(), 190 | transforms.Normalize((0.1307,), (0.3081,)) 191 | ])), 192 | batch_size=BATCH_SIZE, shuffle=True, **kwargs) 193 | test_loader = torch.utils.data.DataLoader( 194 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 195 | transforms.ToTensor(), 196 | transforms.Normalize((0.1307,), (0.3081,)) 197 | ])), 198 | batch_size=BATCH_SIZE_TEST, shuffle=True, **kwargs) 199 | 200 | ####### Model setup 201 | 202 | model = Net().to(device) 203 | loss_softmax = nn.CrossEntropyLoss().to(device) 204 | loss_arcface = Arcface_loss(num_classes=10, feat_dim=FEATURES_DIM, device=device).to(device) 205 | 206 | # optimzer nn 207 | optimizer_nn = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 208 | sheduler_nn = lr_scheduler.StepLR(optimizer_nn, 20, gamma=0.1) 209 | 210 | # optimzer cosface or arcface 211 | optimzer_arcface = optim.SGD(loss_arcface.parameters(), lr=0.01) 212 | sheduler_arcface = lr_scheduler.StepLR(optimzer_arcface, 20, gamma=0.1) 213 | 214 | t = time.time() 215 | 216 | for epoch in range(1, EPOCHS + 1): 217 | sheduler_nn.step() 218 | sheduler_arcface.step() 219 | 220 | train(model, device, train_loader, loss_softmax, loss_arcface, optimizer_nn, optimzer_arcface, epoch) 221 | test(model, device, test_loader, loss_softmax, loss_arcface) 222 | 223 | tototal_time = int(time.time() - t) 224 | print('Total time: {}'.format(timedelta(seconds=tototal_time))) 225 | 226 | torch.save(model.state_dict(),"mnist_cnn-arcface.pt") 227 | torch.save(loss_arcface.state_dict(),"mnist_loss-arcface.pt") 228 | --------------------------------------------------------------------------------