├── .ipynb_checkpoints ├── CNN-checkpoint.ipynb ├── VGG-Sigmoid-checkpoint.ipynb ├── VGG_LeakyReLU-checkpoint.ipynb ├── VGG_NonTrainable_ProbAct-checkpoint.ipynb ├── VGG_PReLU-checkpoint.ipynb ├── VGG_ProActiv-NonTrainableSigma-Copy1-checkpoint.ipynb ├── VGG_ProActiv-NonTrainableSigma-Copy2-checkpoint.ipynb ├── VGG_ProActiv-NonTrainableSigma-Copy3-checkpoint.ipynb ├── VGG_ProActiv-NonTrainableSigma-checkpoint.ipynb ├── VGG_ProActiv-OneTrainableSigma-checkpoint.ipynb ├── VGG_ProActiv-ReLU-Testing-checkpoint.ipynb ├── VGG_ReLU-checkpoint.ipynb ├── VGG_ReLU-with dropout-checkpoint.ipynb ├── VGG_Swish-checkpoint.ipynb ├── VGG_TrainableProbAct-checkpoint.ipynb ├── VGG_main-checkpoint.ipynb └── distributions_trial-checkpoint.ipynb ├── Others.ipynb ├── README.md ├── VGG_LeakyReLU.ipynb ├── VGG_NonTrainable_ProbAct.ipynb ├── VGG_PReLU.ipynb ├── VGG_ProActiv-ReLU-Testing.ipynb ├── VGG_ReLU.ipynb ├── VGG_Sigmoid.ipynb ├── VGG_Swish.ipynb ├── VGG_TrainableProbAct.ipynb ├── Visualization ├── ComparisonResultsProbAct.png ├── LayerWiseSigma.pdf ├── OneTrainableSigma.pdf ├── OverfittingCIFAR100.png ├── ProbAct.png ├── TestAcc.pdf ├── TestAccCIFAR10.pdf ├── TrainAcc.pdf └── TrainAccCIFAR10.pdf └── models ├── ProbAct.py ├── swish.py ├── vgg.py ├── vgg_probact_non-trainable.py ├── vgg_probact_trainable.py └── vgg_swish.py /.ipynb_checkpoints/CNN-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function\n", 10 | "\n", 11 | "import os\n", 12 | "import sys\n", 13 | "import time\n", 14 | "import argparse\n", 15 | "import datetime\n", 16 | "import math\n", 17 | "import pickle\n", 18 | "\n", 19 | "\n", 20 | "import torchvision\n", 21 | "import torchvision.transforms as transforms\n", 22 | "from utils.autoaugment import CIFAR10Policy\n", 23 | "\n", 24 | "import torch\n", 25 | "import torch.utils.data as data\n", 26 | "import numpy as np\n", 27 | "import torch.nn as nn\n", 28 | "import torch.optim as optim\n", 29 | "import torch.nn.functional as F\n", 30 | "import torch.backends.cudnn as cudnn\n", 31 | "from torch.autograd import Variable\n", 32 | "from torchvision import datasets\n", 33 | "from torch.utils.data.sampler import SubsetRandomSampler" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from models import conv_init\n", 43 | "from models.vgg import VGG" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "net_type = 'VGG'\n", 53 | "dataset = 'CIFAR10'\n", 54 | "outputs = 10\n", 55 | "inputs = 3\n", 56 | "n_epochs = 50\n", 57 | "lr = 0.001\n", 58 | "resize=32" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Hyper Parameter settings\n", 68 | "use_cuda = torch.cuda.is_available()\n", 69 | "torch.cuda.set_device(0)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# number of subprocesses to use for data loading\n", 79 | "num_workers = 0\n", 80 | "# how many samples per batch to load\n", 81 | "batch_size = 64\n", 82 | "# percentage of training set to use as validation\n", 83 | "valid_size = 0.2" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Files already downloaded and verified\n", 96 | "Files already downloaded and verified\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "# convert data to a normalized torch.FloatTensor\n", 102 | "transform = transforms.Compose([\n", 103 | " transforms.ToTensor(),\n", 104 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", 105 | " ])\n", 106 | "\n", 107 | "# choose the training and test datasets\n", 108 | "train_data = datasets.CIFAR10('data', train=True,\n", 109 | " download=True, transform=transform)\n", 110 | "test_data = datasets.CIFAR10('data', train=False,\n", 111 | " download=True, transform=transform)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 7, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# obtain training indices that will be used for validation\n", 121 | "num_train = len(train_data)\n", 122 | "indices = list(range(num_train))\n", 123 | "np.random.shuffle(indices)\n", 124 | "split = int(np.floor(valid_size * num_train))\n", 125 | "train_idx, valid_idx = indices[split:], indices[:split]" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 8, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# define samplers for obtaining training and validation batches\n", 135 | "train_sampler = SubsetRandomSampler(train_idx)\n", 136 | "valid_sampler = SubsetRandomSampler(valid_idx)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 9, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# prepare data loaders (combine dataset and sampler)\n", 146 | "train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,\n", 147 | " sampler=train_sampler, num_workers=num_workers)\n", 148 | "valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, \n", 149 | " sampler=valid_sampler, num_workers=num_workers)\n", 150 | "test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, \n", 151 | " num_workers=num_workers)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 10, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# specify the image classes\n", 161 | "classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',\n", 162 | " 'dog', 'frog', 'horse', 'ship', 'truck']" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 11, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "AlexNet(\n", 175 | " (features): Sequential(\n", 176 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(5, 5))\n", 177 | " (1): ReLU(inplace)\n", 178 | " (2): Dropout(p=0.5)\n", 179 | " (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 180 | " (4): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 181 | " (5): ReLU(inplace)\n", 182 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 183 | " (7): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 184 | " (8): ReLU(inplace)\n", 185 | " (9): Dropout(p=0.5)\n", 186 | " (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 187 | " (11): ReLU(inplace)\n", 188 | " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 189 | " (13): ReLU(inplace)\n", 190 | " (14): Dropout(p=0.5)\n", 191 | " (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 192 | " )\n", 193 | " (classifier): Linear(in_features=256, out_features=10, bias=True)\n", 194 | ")\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "# Architecture\n", 200 | "if (net_type == 'lenet'):\n", 201 | " net = LeNet(outputs,inputs)\n", 202 | "elif (net_type == 'alexnet'):\n", 203 | " net = AlexNet(outputs,inputs)\n", 204 | "elif (net_type == '3conv3fc'):\n", 205 | " net = ThreeConvThreeFC(outputs,inputs)\n", 206 | "else:\n", 207 | " print('Error : Network should be either [LeNet / AlexNet / 3Conv3FC')\n", 208 | "\n", 209 | "print(net)\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 12, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "# move tensors to GPU if CUDA is available\n", 219 | "if use_cuda:\n", 220 | " net.cuda()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 13, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "# specify loss function (categorical cross-entropy)\n", 230 | "criterion = nn.CrossEntropyLoss()\n", 231 | "\n", 232 | "# specify optimizer\n", 233 | "optimizer = optim.Adam(net.parameters(), lr=lr)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 14, 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "data": { 243 | "text/plain": [ 244 | "'model_alexnet_CIFAR10_frequentist.pt'" 245 | ] 246 | }, 247 | "execution_count": 14, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "ckpt_name = f'model_{net_type}_{dataset}_frequentist.pt'\n", 254 | "ckpt_name" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 15, 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "Epoch: 1 \tTraining Loss: 1.431397 \tValidation Loss: 0.327709\n", 267 | "Validation loss decreased (inf --> 0.327709). Saving model ...\n", 268 | "Epoch: 2 \tTraining Loss: 1.196846 \tValidation Loss: 0.303414\n", 269 | "Validation loss decreased (0.327709 --> 0.303414). Saving model ...\n", 270 | "Epoch: 3 \tTraining Loss: 1.109582 \tValidation Loss: 0.284857\n", 271 | "Validation loss decreased (0.303414 --> 0.284857). Saving model ...\n", 272 | "Epoch: 4 \tTraining Loss: 1.058778 \tValidation Loss: 0.282861\n", 273 | "Validation loss decreased (0.284857 --> 0.282861). Saving model ...\n", 274 | "Epoch: 5 \tTraining Loss: 1.015377 \tValidation Loss: 0.272893\n", 275 | "Validation loss decreased (0.282861 --> 0.272893). Saving model ...\n", 276 | "Epoch: 6 \tTraining Loss: 0.985682 \tValidation Loss: 0.264794\n", 277 | "Validation loss decreased (0.272893 --> 0.264794). Saving model ...\n", 278 | "Epoch: 7 \tTraining Loss: 0.961130 \tValidation Loss: 0.257850\n", 279 | "Validation loss decreased (0.264794 --> 0.257850). Saving model ...\n", 280 | "Epoch: 8 \tTraining Loss: 0.942639 \tValidation Loss: 0.253482\n", 281 | "Validation loss decreased (0.257850 --> 0.253482). Saving model ...\n", 282 | "Epoch: 9 \tTraining Loss: 0.921052 \tValidation Loss: 0.251446\n", 283 | "Validation loss decreased (0.253482 --> 0.251446). Saving model ...\n", 284 | "Epoch: 10 \tTraining Loss: 0.892646 \tValidation Loss: 0.244815\n", 285 | "Validation loss decreased (0.251446 --> 0.244815). Saving model ...\n", 286 | "Epoch: 11 \tTraining Loss: 0.884392 \tValidation Loss: 0.250876\n", 287 | "Epoch: 12 \tTraining Loss: 0.864065 \tValidation Loss: 0.239296\n", 288 | "Validation loss decreased (0.244815 --> 0.239296). Saving model ...\n", 289 | "Epoch: 13 \tTraining Loss: 0.844942 \tValidation Loss: 0.239256\n", 290 | "Validation loss decreased (0.239296 --> 0.239256). Saving model ...\n", 291 | "Epoch: 14 \tTraining Loss: 0.840974 \tValidation Loss: 0.238582\n", 292 | "Validation loss decreased (0.239256 --> 0.238582). Saving model ...\n", 293 | "Epoch: 15 \tTraining Loss: 0.829902 \tValidation Loss: 0.234766\n", 294 | "Validation loss decreased (0.238582 --> 0.234766). Saving model ...\n", 295 | "Epoch: 16 \tTraining Loss: 0.814025 \tValidation Loss: 0.230650\n", 296 | "Validation loss decreased (0.234766 --> 0.230650). Saving model ...\n", 297 | "Epoch: 17 \tTraining Loss: 0.806663 \tValidation Loss: 0.227011\n", 298 | "Validation loss decreased (0.230650 --> 0.227011). Saving model ...\n", 299 | "Epoch: 18 \tTraining Loss: 0.792850 \tValidation Loss: 0.224854\n", 300 | "Validation loss decreased (0.227011 --> 0.224854). Saving model ...\n", 301 | "Epoch: 19 \tTraining Loss: 0.787480 \tValidation Loss: 0.229408\n", 302 | "Epoch: 20 \tTraining Loss: 0.782817 \tValidation Loss: 0.232852\n", 303 | "CPU times: user 2min 49s, sys: 10.9 s, total: 3min\n", 304 | "Wall time: 3min\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "%%time\n", 310 | "\n", 311 | "valid_loss_min = np.Inf # track change in validation loss\n", 312 | "\n", 313 | "for epoch in range(1, n_epochs+1):\n", 314 | "\n", 315 | " # keep track of training and validation loss\n", 316 | " train_loss = 0.0\n", 317 | " valid_loss = 0.0\n", 318 | " \n", 319 | " ###################\n", 320 | " # train the model #\n", 321 | " ###################\n", 322 | " net.train()\n", 323 | " for data, target in train_loader:\n", 324 | " # move tensors to GPU if CUDA is available\n", 325 | " if use_cuda:\n", 326 | " data, target = data.cuda(), target.cuda()\n", 327 | " # clear the gradients of all optimized variables\n", 328 | " optimizer.zero_grad()\n", 329 | " # forward pass: compute predicted outputs by passing inputs to the model\n", 330 | " output = net(data)\n", 331 | " # calculate the batch loss\n", 332 | " loss = criterion(output, target)\n", 333 | " # backward pass: compute gradient of the loss with respect to model parameters\n", 334 | " loss.backward()\n", 335 | " # perform a single optimization step (parameter update)\n", 336 | " optimizer.step()\n", 337 | " # update training loss\n", 338 | " train_loss += loss.item()*data.size(0)\n", 339 | " \n", 340 | " ###################### \n", 341 | " # validate the model #\n", 342 | " ######################\n", 343 | " net.eval()\n", 344 | " for data, target in valid_loader:\n", 345 | " # move tensors to GPU if CUDA is available\n", 346 | " if use_cuda:\n", 347 | " data, target = data.cuda(), target.cuda()\n", 348 | " # forward pass: compute predicted outputs by passing inputs to the model\n", 349 | " output = net(data)\n", 350 | " # calculate the batch loss\n", 351 | " loss = criterion(output, target)\n", 352 | " # update average validation loss \n", 353 | " valid_loss += loss.item()*data.size(0)\n", 354 | " \n", 355 | " # calculate average losses\n", 356 | " train_loss = train_loss/len(train_loader.dataset)\n", 357 | " valid_loss = valid_loss/len(valid_loader.dataset)\n", 358 | " \n", 359 | " # print training/validation statistics \n", 360 | " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n", 361 | " epoch, train_loss, valid_loss))\n", 362 | " \n", 363 | " # save model if validation loss has decreased\n", 364 | " if valid_loss <= valid_loss_min:\n", 365 | " print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(\n", 366 | " valid_loss_min,\n", 367 | " valid_loss))\n", 368 | " torch.save(net.state_dict(), ckpt_name)\n", 369 | " valid_loss_min = valid_loss\n" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 16, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "net.load_state_dict(torch.load(ckpt_name))" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 17, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "Test Loss: 1.126139\n", 391 | "\n", 392 | "Test Accuracy of airplane: 67% (672/1000)\n", 393 | "Test Accuracy of automobile: 71% (717/1000)\n", 394 | "Test Accuracy of bird: 47% (476/1000)\n", 395 | "Test Accuracy of cat: 47% (478/1000)\n", 396 | "Test Accuracy of deer: 55% (550/1000)\n", 397 | "Test Accuracy of dog: 49% (497/1000)\n", 398 | "Test Accuracy of frog: 82% (822/1000)\n", 399 | "Test Accuracy of horse: 66% (667/1000)\n", 400 | "Test Accuracy of ship: 79% (796/1000)\n", 401 | "Test Accuracy of truck: 72% (723/1000)\n", 402 | "\n", 403 | "Test Accuracy (Overall): 63% (6398/10000)\n", 404 | "CPU times: user 1.49 s, sys: 79.8 ms, total: 1.57 s\n", 405 | "Wall time: 1.57 s\n" 406 | ] 407 | } 408 | ], 409 | "source": [ 410 | "%%time\n", 411 | "\n", 412 | "# track test loss# track \n", 413 | "test_loss = 0.0\n", 414 | "class_correct = list(0. for i in range(10))\n", 415 | "class_total = list(0. for i in range(10))\n", 416 | "\n", 417 | "net.eval()\n", 418 | "# iterate over test data\n", 419 | "for data, target in test_loader:\n", 420 | " # move tensors to GPU if CUDA is available\n", 421 | " if use_cuda:\n", 422 | " data, target = data.cuda(), target.cuda()\n", 423 | " # forward pass: compute predicted outputs by passing inputs to the model\n", 424 | " output = net(data)\n", 425 | " # calculate the batch loss\n", 426 | " loss = criterion(output, target)\n", 427 | " # update test loss \n", 428 | " test_loss += loss.item()*data.size(0)\n", 429 | " # convert output probabilities to predicted class\n", 430 | " _, pred = torch.max(output, 1) \n", 431 | " # compare predictions to true label\n", 432 | " correct_tensor = pred.eq(target.data.view_as(pred))\n", 433 | " correct = np.squeeze(correct_tensor.numpy()) if not use_cuda else np.squeeze(correct_tensor.cpu().numpy())\n", 434 | " # calculate test accuracy for each object class\n", 435 | " for i in range(batch_size):\n", 436 | " if i >= target.data.shape[0]: # batch_size could be greater than left number of images\n", 437 | " break\n", 438 | " label = target.data[i]\n", 439 | " class_correct[label] += correct[i].item()\n", 440 | " class_total[label] += 1\n", 441 | "\n", 442 | "# average test loss\n", 443 | "test_loss = test_loss/len(test_loader.dataset)\n", 444 | "print('Test Loss: {:.6f}\\n'.format(test_loss))\n", 445 | "\n", 446 | "for i in range(10):\n", 447 | " if class_total[i] > 0:\n", 448 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n", 449 | " classes[i], 100 * class_correct[i] / class_total[i],\n", 450 | " np.sum(class_correct[i]), np.sum(class_total[i])))\n", 451 | " else:\n", 452 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", 453 | "\n", 454 | "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n", 455 | " 100. * np.sum(class_correct) / np.sum(class_total),\n", 456 | " np.sum(class_correct), np.sum(class_total)))" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [] 465 | } 466 | ], 467 | "metadata": { 468 | "kernelspec": { 469 | "display_name": "Python 3", 470 | "language": "python", 471 | "name": "python3" 472 | }, 473 | "language_info": { 474 | "codemirror_mode": { 475 | "name": "ipython", 476 | "version": 3 477 | }, 478 | "file_extension": ".py", 479 | "mimetype": "text/x-python", 480 | "name": "python", 481 | "nbconvert_exporter": "python", 482 | "pygments_lexer": "ipython3", 483 | "version": "3.7.3" 484 | } 485 | }, 486 | "nbformat": 4, 487 | "nbformat_minor": 2 488 | } 489 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/VGG-Sigmoid-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.backends.cudnn as cudnn\n", 13 | "from torch.optim import Adam, SGD\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "\n", 17 | "import sys, os, math\n", 18 | "import argparse" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "lr=0.01\n", 28 | "data='cifar10'\n", 29 | "root='./data/'\n", 30 | "model='vgg'\n", 31 | "model_out='./checkpoint/cifar10_vgg_ReLU.pth'\n", 32 | "resume = False" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Files already downloaded and verified\n", 45 | "Files already downloaded and verified\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "if data == 'cifar10':\n", 51 | " nclass = 10\n", 52 | " img_width = 32\n", 53 | " transform_train = transforms.Compose([\n", 54 | " transforms.RandomCrop(32, padding=4),\n", 55 | " transforms.RandomHorizontalFlip(),\n", 56 | " transforms.ToTensor(),\n", 57 | " ])\n", 58 | " transform_test = transforms.Compose([\n", 59 | " transforms.ToTensor(),\n", 60 | " ])\n", 61 | " trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)\n", 62 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n", 63 | " testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)\n", 64 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "DataParallel(\n", 76 | " (module): VGG(\n", 77 | " (features): Sequential(\n", 78 | " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 79 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 80 | " (2): ReLU(inplace)\n", 81 | " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 82 | " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (5): ReLU(inplace)\n", 84 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 85 | " (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 86 | " (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 87 | " (9): ReLU(inplace)\n", 88 | " (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 89 | " (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (12): ReLU(inplace)\n", 91 | " (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 92 | " (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 93 | " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 94 | " (16): ReLU(inplace)\n", 95 | " (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 96 | " (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 97 | " (19): ReLU(inplace)\n", 98 | " (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 99 | " (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 100 | " (22): ReLU(inplace)\n", 101 | " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 102 | " (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 103 | " (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 104 | " (26): ReLU(inplace)\n", 105 | " (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 106 | " (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 107 | " (29): ReLU(inplace)\n", 108 | " (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 109 | " (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 110 | " (32): ReLU(inplace)\n", 111 | " (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 112 | " (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 113 | " (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 114 | " (36): ReLU(inplace)\n", 115 | " (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 116 | " (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 117 | " (39): ReLU(inplace)\n", 118 | " (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 119 | " (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 120 | " (42): ReLU(inplace)\n", 121 | " (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 122 | " (44): AvgPool2d(kernel_size=1, stride=1, padding=0)\n", 123 | " )\n", 124 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n", 125 | " )\n", 126 | ")" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "if model == 'vgg':\n", 136 | " from models.vgg import VGG\n", 137 | " net = nn.DataParallel(VGG('VGG16', nclass, img_width=img_width).cuda())\n", 138 | " \n", 139 | "net" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "if resume:\n", 149 | " print(f'==> Resuming from {model_out}')\n", 150 | " net.load_state_dict(torch.load(model_out))" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 6, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "cudnn.benchmark = True" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "criterion = nn.CrossEntropyLoss()" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 8, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "def train(epoch):\n", 178 | " print('Epoch: %d' % epoch)\n", 179 | " net.train()\n", 180 | " train_loss = 0\n", 181 | " correct = 0\n", 182 | " total = 0\n", 183 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 184 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 185 | " optimizer.zero_grad()\n", 186 | " outputs, _ = net(inputs)\n", 187 | " loss = criterion(outputs, targets)\n", 188 | " loss.backward()\n", 189 | " optimizer.step()\n", 190 | " pred = torch.max(outputs, dim=1)[1]\n", 191 | " correct += torch.sum(pred.eq(targets)).item()\n", 192 | " total += targets.numel()\n", 193 | " print(f'[TRAIN] Acc: {100.*correct/total:.3f}')" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "def test(epoch):\n", 203 | " net.eval()\n", 204 | " test_loss = 0\n", 205 | " correct = 0\n", 206 | " total = 0\n", 207 | " with torch.no_grad():\n", 208 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 209 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 210 | " outputs, _ = net(inputs)\n", 211 | " loss = criterion(outputs, targets)\n", 212 | " test_loss += loss.item()\n", 213 | " _, predicted = outputs.max(1)\n", 214 | " total += targets.size(0)\n", 215 | " correct += predicted.eq(targets).sum().item()\n", 216 | " print(f'[TEST] Acc: {100.*correct/total:.3f}')\n", 217 | "\n", 218 | " # Save checkpoint after each epoch\n", 219 | " torch.save(net.state_dict(), model_out)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "if data == 'cifar10':\n", 229 | " epochs = [30, 20, 10]" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "count = 0" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "Epoch: 0\n", 251 | "[TRAIN] Acc: 19.048\n", 252 | "[TEST] Acc: 23.660\n", 253 | "Epoch: 1\n", 254 | "[TRAIN] Acc: 37.478\n", 255 | "[TEST] Acc: 44.140\n", 256 | "Epoch: 2\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "for epoch in epochs:\n", 262 | " optimizer = Adam(net.parameters(), lr=lr)\n", 263 | " for _ in range(epoch):\n", 264 | " train(count)\n", 265 | " test(count)\n", 266 | " count += 1\n", 267 | " lr /= 10" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "Python (bayesian)", 281 | "language": "python", 282 | "name": "bayesian" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.7.3" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 2 299 | } 300 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/VGG_ProActiv-ReLU-Testing-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.backends.cudnn as cudnn\n", 13 | "from torch.optim import Adam, SGD\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "\n", 17 | "import sys, os, math, csv\n", 18 | "import argparse\n", 19 | "import pandas as pd" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "lr=0.01\n", 29 | "data='cifar10'\n", 30 | "root='./data/'\n", 31 | "model='vgg'\n", 32 | "model_out='./checkpoint/cifar10_vgg_ProActiv_OneTrainableSigma.pth'\n", 33 | "resume = True" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "Files already downloaded and verified\n", 46 | "Files already downloaded and verified\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "if data == 'cifar10':\n", 52 | " nclass = 10\n", 53 | " img_width = 32\n", 54 | " transform_train = transforms.Compose([\n", 55 | " transforms.RandomCrop(32, padding=4),\n", 56 | " transforms.RandomHorizontalFlip(),\n", 57 | " transforms.ToTensor(),\n", 58 | "# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))\n", 59 | " ])\n", 60 | " transform_test = transforms.Compose([\n", 61 | " transforms.ToTensor(),\n", 62 | " ])\n", 63 | " trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)\n", 64 | " testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)\n", 65 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)\n", 66 | " \n", 67 | "elif data == 'cifar100':\n", 68 | " nclass = 100\n", 69 | " img_width = 32\n", 70 | " transform_train = transforms.Compose([\n", 71 | " transforms.RandomCrop(32, padding=4),\n", 72 | " transforms.RandomHorizontalFlip(),\n", 73 | " transforms.ToTensor(),\n", 74 | " ])\n", 75 | " transform_test = transforms.Compose([\n", 76 | " transforms.ToTensor(),\n", 77 | " ])\n", 78 | " trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)\n", 79 | " testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)\n", 80 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)\n", 81 | " \n", 82 | "elif data == 'stl10':\n", 83 | " nclass = 10\n", 84 | " img_width = 32\n", 85 | " transform_train = transforms.Compose([\n", 86 | "# transforms.RandomCrop(32, padding=4),\n", 87 | " transforms.RandomHorizontalFlip(),\n", 88 | " transforms.Resize((img_width,img_width)),\n", 89 | " transforms.ToTensor(),\n", 90 | " ])\n", 91 | " transform_test = transforms.Compose([\n", 92 | " transforms.Resize((img_width,img_width)),\n", 93 | " transforms.ToTensor(),\n", 94 | " ])\n", 95 | " trainset = torchvision.datasets.STL10(root=root, split='train', transform=transform_train, target_transform=None, download=True)\n", 96 | " testset = torchvision.datasets.STL10(root=root, split='test', transform=transform_test, target_transform=None, download=True)\n", 97 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "DataParallel(\n", 109 | " (module): VGG_Dist(\n", 110 | " (conv1): Sequential(\n", 111 | " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 112 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 113 | " )\n", 114 | " (conv2): Sequential(\n", 115 | " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 116 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 117 | " )\n", 118 | " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 119 | " (conv3): Sequential(\n", 120 | " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 121 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 122 | " )\n", 123 | " (conv4): Sequential(\n", 124 | " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 125 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 126 | " )\n", 127 | " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 128 | " (conv5): Sequential(\n", 129 | " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 130 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 131 | " )\n", 132 | " (conv6): Sequential(\n", 133 | " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 134 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 135 | " )\n", 136 | " (conv7): Sequential(\n", 137 | " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 138 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 139 | " )\n", 140 | " (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 141 | " (conv8): Sequential(\n", 142 | " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 143 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 144 | " )\n", 145 | " (conv9): Sequential(\n", 146 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 147 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 148 | " )\n", 149 | " (conv10): Sequential(\n", 150 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 151 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 152 | " )\n", 153 | " (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 154 | " (conv11): Sequential(\n", 155 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 156 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 157 | " )\n", 158 | " (conv12): Sequential(\n", 159 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 160 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 161 | " )\n", 162 | " (conv13): Sequential(\n", 163 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 164 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 165 | " )\n", 166 | " (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 167 | " (dropout): Dropout(p=0.5)\n", 168 | " (ProbActAF): TrainableSigma(num_parameters=1)\n", 169 | " (SwishAF): Swish()\n", 170 | " (globalAvgpool): AvgPool2d(kernel_size=1, stride=1, padding=0)\n", 171 | " (classifier): Sequential(\n", 172 | " (0): Dropout(p=0.5)\n", 173 | " (1): Linear(in_features=512, out_features=100, bias=True)\n", 174 | " )\n", 175 | " )\n", 176 | ")" 177 | ] 178 | }, 179 | "execution_count": 4, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "if model == 'vgg':\n", 186 | " from models.vgg_dist import VGG_Dist\n", 187 | " net = nn.DataParallel(VGG_Dist(nclass, img_width=img_width).cuda())\n", 188 | " \n", 189 | "net" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 5, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "==> Resuming from ./checkpoint/cifar100_vgg_ProActiv_OneTrainableSigma.pth\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "if resume:\n", 207 | " print(f'==> Resuming from {model_out}')\n", 208 | " net.load_state_dict(torch.load(model_out))" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 6, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "cudnn.benchmark = True" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 7, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "criterion = nn.CrossEntropyLoss()" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 8, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "def test(epoch):\n", 236 | " net.eval()\n", 237 | " test_loss = 0\n", 238 | " correct = 0\n", 239 | " total = 0\n", 240 | " with torch.no_grad():\n", 241 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 242 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 243 | " outputs = net(inputs)\n", 244 | " loss = criterion(outputs, targets)\n", 245 | " test_loss += loss.item()\n", 246 | " _, predicted = outputs.max(1)\n", 247 | " total += targets.size(0)\n", 248 | " correct += predicted.eq(targets).sum().item()\n", 249 | " print(f'[TEST] Acc: {100.*correct/total:.3f}')" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 9, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "if data == 'cifar10':\n", 259 | " epochs = [3]\n", 260 | "elif data == 'cifar100':\n", 261 | " epochs = [3]\n", 262 | "elif data == 'stl10':\n", 263 | " epochs = [3]" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 10, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "count = 0" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 11, 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "[TEST] Acc: 87.060\n", 285 | "[TEST] Acc: 87.060\n", 286 | "[TEST] Acc: 87.060\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "for epoch in epochs:\n", 292 | " optimizer = Adam(net.parameters(), lr=lr)\n", 293 | " for _ in range(epoch):\n", 294 | " test(count)\n", 295 | " count += 1\n", 296 | " lr /= 10" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 2, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "lr=0.01\n", 306 | "data='stl10'\n", 307 | "root='./data/'\n", 308 | "model='vgg'\n", 309 | "model_out='./checkpoint/stl10_vgg_ProActiv_OneTrainableSigma.pth'\n", 310 | "resume = True" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 11, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "[TEST] Acc: 62.038\n", 323 | "[TEST] Acc: 62.038\n", 324 | "[TEST] Acc: 62.038\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "for epoch in epochs:\n", 330 | " optimizer = Adam(net.parameters(), lr=lr)\n", 331 | " for _ in range(epoch):\n", 332 | " test(count)\n", 333 | " count += 1\n", 334 | " lr /= 10" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 2, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "lr=0.01\n", 344 | "data='cifar100'\n", 345 | "root='./data/'\n", 346 | "model='vgg'\n", 347 | "model_out='./checkpoint/cifar100_vgg_ProActiv_OneTrainableSigma.pth'\n", 348 | "resume = True" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 11, 354 | "metadata": {}, 355 | "outputs": [ 356 | { 357 | "name": "stdout", 358 | "output_type": "stream", 359 | "text": [ 360 | "[TEST] Acc: 50.750\n", 361 | "[TEST] Acc: 50.750\n", 362 | "[TEST] Acc: 50.750\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "for epoch in epochs:\n", 368 | " optimizer = Adam(net.parameters(), lr=lr)\n", 369 | " for _ in range(epoch):\n", 370 | " test(count)\n", 371 | " count += 1\n", 372 | " lr /= 10" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [] 381 | } 382 | ], 383 | "metadata": { 384 | "kernelspec": { 385 | "display_name": "Python (bayesian)", 386 | "language": "python", 387 | "name": "bayesian" 388 | }, 389 | "language_info": { 390 | "codemirror_mode": { 391 | "name": "ipython", 392 | "version": 3 393 | }, 394 | "file_extension": ".py", 395 | "mimetype": "text/x-python", 396 | "name": "python", 397 | "nbconvert_exporter": "python", 398 | "pygments_lexer": "ipython3", 399 | "version": "3.7.3" 400 | } 401 | }, 402 | "nbformat": 4, 403 | "nbformat_minor": 2 404 | } 405 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/VGG_main-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/distributions_trial-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from torch.distributions import normal\n", 10 | "import torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 16, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "x = torch.Tensor([[1,2],[3,4]])" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 29, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "m = normal.Normal(x, 0.1)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 30, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "Normal(loc: torch.Size([2, 2]), scale: torch.Size([2, 2]))" 40 | ] 41 | }, 42 | "execution_count": 30, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "m" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 31, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "a = m.sample((5,))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 32, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "tensor([[[1.1799, 2.0784],\n", 69 | " [2.8817, 4.0973]],\n", 70 | "\n", 71 | " [[1.0271, 1.9580],\n", 72 | " [3.0553, 4.0144]],\n", 73 | "\n", 74 | " [[1.1359, 1.8070],\n", 75 | " [2.9763, 4.0606]],\n", 76 | "\n", 77 | " [[0.9979, 2.1031],\n", 78 | " [2.9931, 4.1643]],\n", 79 | "\n", 80 | " [[1.0098, 2.0095],\n", 81 | " [2.9654, 4.1030]]])" 82 | ] 83 | }, 84 | "execution_count": 32, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | } 88 | ], 89 | "source": [ 90 | "a" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 33, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "x1 = m.sample((5,)).mean(0)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 34, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "tensor([[0.9699, 2.0559],\n", 111 | " [3.0271, 4.0124]])" 112 | ] 113 | }, 114 | "execution_count": 34, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "x1" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [] 129 | } 130 | ], 131 | "metadata": { 132 | "kernelspec": { 133 | "display_name": "Python (bayesian)", 134 | "language": "python", 135 | "name": "bayesian" 136 | }, 137 | "language_info": { 138 | "codemirror_mode": { 139 | "name": "ipython", 140 | "version": 3 141 | }, 142 | "file_extension": ".py", 143 | "mimetype": "text/x-python", 144 | "name": "python", 145 | "nbconvert_exporter": "python", 146 | "pygments_lexer": "ipython3", 147 | "version": "3.7.3" 148 | } 149 | }, 150 | "nbformat": 4, 151 | "nbformat_minor": 2 152 | } 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProbAct-Probabilistic-Activation-Function 2 | 3 | Official PyTorch implementation of the paper : ![ProbAct: A Probabilistic Activation Function for Deep Neural Networks](https://arxiv.org/abs/1905.10761). 4 | ![ProbAct](Visualization/ProbAct.png) 5 | 6 | ## Why ProbAct 7 | 8 | Most of the activation functions currently used are deterministic in nature, whose input-output relationship is fixed. In this work, we propose a probabilistic activation function, called *ProbAct*. The output value of ProbAct is sampled from a normal distribution, with the mean value same as the output of ReLU and with a fixed or trainable variance for each element. In the trainable ProbAct, the variance of the activation distribution is trained through back-propagation. We also show that the stochastic perturbation through ProbAct is a viable generalization technique that can prevent overfitting. 9 | 10 | ## Accuracy Comparison 11 | 12 | ![Comparison with other activation Functions](Visualization/ComparisonResultsProbAct.png) 13 | 14 | ## Overfitting Comparison 15 | 16 | ![Test-Train Comparison on CIFAR100](Visualization/OverfittingCIFAR100.png) 17 | 18 | 19 | Cite the authors if you find the work useful: 20 | 21 | ``` 22 | @article{lee2019probact, 23 | title={ProbAct: A Probabilistic Activation Function for Deep Neural Networks}, 24 | author={Lee, Joonho and Shridhar, Kumar and Hayashi, Hideaki and Iwana, Brian Kenji and Kang, Seokjun and Uchida, Seiichi}, 25 | journal={arXiv preprint arXiv:1905.10761}, 26 | year={2019} 27 | } 28 | 29 | ``` 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /VGG_PReLU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.backends.cudnn as cudnn\n", 13 | "from torch.optim import Adam, SGD\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "\n", 17 | "import sys, os, math\n", 18 | "import argparse" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "lr=0.01\n", 28 | "data='cifar10'\n", 29 | "root='./data/'\n", 30 | "model='vgg'\n", 31 | "model_out='./checkpoint/cifar10_vgg_PReLU.pth'\n", 32 | "resume = False" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Files already downloaded and verified\n", 45 | "Files already downloaded and verified\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "if data == 'cifar10':\n", 51 | " nclass = 10\n", 52 | " img_width = 32\n", 53 | " transform_train = transforms.Compose([\n", 54 | "# transforms.RandomCrop(32, padding=4),\n", 55 | "# transforms.RandomHorizontalFlip(),\n", 56 | " transforms.ToTensor(),\n", 57 | " ])\n", 58 | " transform_test = transforms.Compose([\n", 59 | " transforms.ToTensor(),\n", 60 | " ])\n", 61 | " trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)\n", 62 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)\n", 63 | " testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)\n", 64 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)\n", 65 | " \n", 66 | "elif data == 'cifar100':\n", 67 | " nclass = 100\n", 68 | " img_width = 32\n", 69 | " transform_train = transforms.Compose([\n", 70 | "# transforms.RandomCrop(32, padding=4),\n", 71 | "# transforms.RandomHorizontalFlip(),\n", 72 | " transforms.ToTensor(),\n", 73 | " ])\n", 74 | " transform_test = transforms.Compose([\n", 75 | " transforms.ToTensor(),\n", 76 | " ])\n", 77 | " trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)\n", 78 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)\n", 79 | " testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)\n", 80 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)\n", 81 | " \n", 82 | "elif data == 'stl10':\n", 83 | " nclass = 10\n", 84 | " img_width = 32\n", 85 | " transform_train = transforms.Compose([\n", 86 | "# transforms.RandomCrop(32, padding=4),\n", 87 | "# transforms.RandomHorizontalFlip(),\n", 88 | " transforms.Resize((img_width,img_width)),\n", 89 | " transforms.ToTensor(),\n", 90 | " ])\n", 91 | " transform_test = transforms.Compose([\n", 92 | " transforms.Resize((img_width,img_width)),\n", 93 | " transforms.ToTensor(),\n", 94 | " ])\n", 95 | " trainset = torchvision.datasets.STL10(root=root, split='train', transform=transform_train, target_transform=None, download=True)\n", 96 | " testset = torchvision.datasets.STL10(root=root, split='test', transform=transform_test, target_transform=None, download=True)\n", 97 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)\n", 98 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "DataParallel(\n", 110 | " (module): VGG_PReLU(\n", 111 | " (features): Sequential(\n", 112 | " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 113 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 114 | " (2): PReLU(num_parameters=1)\n", 115 | " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 116 | " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 117 | " (5): PReLU(num_parameters=1)\n", 118 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 119 | " (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 120 | " (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 121 | " (9): PReLU(num_parameters=1)\n", 122 | " (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 123 | " (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 124 | " (12): PReLU(num_parameters=1)\n", 125 | " (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 126 | " (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 127 | " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 128 | " (16): PReLU(num_parameters=1)\n", 129 | " (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 130 | " (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 131 | " (19): PReLU(num_parameters=1)\n", 132 | " (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 133 | " (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 134 | " (22): PReLU(num_parameters=1)\n", 135 | " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 136 | " (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 137 | " (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 138 | " (26): PReLU(num_parameters=1)\n", 139 | " (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 140 | " (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 141 | " (29): PReLU(num_parameters=1)\n", 142 | " (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 143 | " (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 144 | " (32): PReLU(num_parameters=1)\n", 145 | " (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 146 | " (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 147 | " (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 148 | " (36): PReLU(num_parameters=1)\n", 149 | " (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 150 | " (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 151 | " (39): PReLU(num_parameters=1)\n", 152 | " (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 153 | " (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 154 | " (42): PReLU(num_parameters=1)\n", 155 | " (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 156 | " (44): AvgPool2d(kernel_size=1, stride=1, padding=0)\n", 157 | " )\n", 158 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n", 159 | " )\n", 160 | ")" 161 | ] 162 | }, 163 | "execution_count": 4, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "if model == 'vgg':\n", 170 | " from models.vgg import VGG_PReLU\n", 171 | " net = nn.DataParallel(VGG_PReLU('VGG16', nclass, img_width=img_width).cuda())\n", 172 | " \n", 173 | "net" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 5, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "if resume:\n", 183 | " print(f'==> Resuming from {model_out}')\n", 184 | " net.load_state_dict(torch.load(model_out))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "cudnn.benchmark = True" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "criterion = nn.CrossEntropyLoss()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "def train(epoch):\n", 212 | " print('Epoch: %d' % epoch)\n", 213 | " net.train()\n", 214 | " train_loss = 0\n", 215 | " correct = 0\n", 216 | " total = 0\n", 217 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 218 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 219 | " optimizer.zero_grad()\n", 220 | " outputs, _ = net(inputs)\n", 221 | " loss = criterion(outputs, targets)\n", 222 | " loss.backward()\n", 223 | " optimizer.step()\n", 224 | " pred = torch.max(outputs, dim=1)[1]\n", 225 | " correct += torch.sum(pred.eq(targets)).item()\n", 226 | " total += targets.numel()\n", 227 | " print(f'[TRAIN] Acc: {100.*correct/total:.3f}')" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def test(epoch):\n", 237 | " net.eval()\n", 238 | " test_loss = 0\n", 239 | " correct = 0\n", 240 | " total = 0\n", 241 | " with torch.no_grad():\n", 242 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 243 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 244 | " outputs, _ = net(inputs)\n", 245 | " loss = criterion(outputs, targets)\n", 246 | " test_loss += loss.item()\n", 247 | " _, predicted = outputs.max(1)\n", 248 | " total += targets.size(0)\n", 249 | " correct += predicted.eq(targets).sum().item()\n", 250 | " print(f'[TEST] Acc: {100.*correct/total:.3f}')\n", 251 | "\n", 252 | " # Save checkpoint after each epoch\n", 253 | " torch.save(net.state_dict(), model_out)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 10, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "if data == 'cifar10':\n", 263 | " epochs = [1]\n", 264 | "elif data == 'cifar100':\n", 265 | " epochs = [50, 50, 50, 50]\n", 266 | "elif data == 'stl10':\n", 267 | " epochs = [50, 50, 50, 50]" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 11, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "count = 0" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 13, 282 | "metadata": { 283 | "scrolled": true 284 | }, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "Epoch: 1\n", 291 | "[TRAIN] Acc: 42.098\n", 292 | "CPU times: user 10.1 s, sys: 3.65 s, total: 13.8 s\n", 293 | "Wall time: 13.9 s\n", 294 | "[TEST] Acc: 46.730\n", 295 | "CPU times: user 727 ms, sys: 490 ms, total: 1.22 s\n", 296 | "Wall time: 1.33 s\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "for epoch in epochs:\n", 302 | " optimizer = Adam(net.parameters(), lr=lr)\n", 303 | " for _ in range(epoch):\n", 304 | " %time train(count)\n", 305 | " %time test(count)\n", 306 | " count += 1\n", 307 | " lr /= 10" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 2, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "lr=0.01\n", 317 | "data='stl10'\n", 318 | "root='./data/'\n", 319 | "model='vgg'\n", 320 | "model_out='./checkpoint/stl10_vgg_PReLU.pth'\n", 321 | "resume = False" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 12, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "name": "stdout", 331 | "output_type": "stream", 332 | "text": [ 333 | "Epoch: 0\n", 334 | "[TRAIN] Acc: 11.380\n", 335 | "[TEST] Acc: 13.988\n", 336 | "Epoch: 1\n", 337 | "[TRAIN] Acc: 13.000\n", 338 | "[TEST] Acc: 16.913\n", 339 | "Epoch: 2\n", 340 | "[TRAIN] Acc: 15.300\n", 341 | "[TEST] Acc: 16.788\n", 342 | "Epoch: 3\n", 343 | "[TRAIN] Acc: 17.840\n", 344 | "[TEST] Acc: 11.162\n", 345 | "Epoch: 4\n", 346 | "[TRAIN] Acc: 20.760\n", 347 | "[TEST] Acc: 19.288\n", 348 | "Epoch: 5\n", 349 | "[TRAIN] Acc: 24.000\n", 350 | "[TEST] Acc: 18.587\n", 351 | "Epoch: 6\n", 352 | "[TRAIN] Acc: 27.640\n", 353 | "[TEST] Acc: 24.637\n", 354 | "Epoch: 7\n", 355 | "[TRAIN] Acc: 30.240\n", 356 | "[TEST] Acc: 32.062\n", 357 | "Epoch: 8\n", 358 | "[TRAIN] Acc: 34.420\n", 359 | "[TEST] Acc: 32.138\n", 360 | "Epoch: 9\n", 361 | "[TRAIN] Acc: 34.820\n", 362 | "[TEST] Acc: 28.400\n", 363 | "Epoch: 10\n", 364 | "[TRAIN] Acc: 36.380\n", 365 | "[TEST] Acc: 34.675\n", 366 | "Epoch: 11\n", 367 | "[TRAIN] Acc: 38.660\n", 368 | "[TEST] Acc: 33.425\n", 369 | "Epoch: 12\n", 370 | "[TRAIN] Acc: 40.900\n", 371 | "[TEST] Acc: 40.638\n", 372 | "Epoch: 13\n", 373 | "[TRAIN] Acc: 43.620\n", 374 | "[TEST] Acc: 34.862\n", 375 | "Epoch: 14\n", 376 | "[TRAIN] Acc: 47.300\n", 377 | "[TEST] Acc: 35.312\n", 378 | "Epoch: 15\n", 379 | "[TRAIN] Acc: 48.200\n", 380 | "[TEST] Acc: 30.962\n", 381 | "Epoch: 16\n", 382 | "[TRAIN] Acc: 51.420\n", 383 | "[TEST] Acc: 33.225\n", 384 | "Epoch: 17\n", 385 | "[TRAIN] Acc: 54.340\n", 386 | "[TEST] Acc: 31.212\n", 387 | "Epoch: 18\n", 388 | "[TRAIN] Acc: 57.280\n", 389 | "[TEST] Acc: 44.125\n", 390 | "Epoch: 19\n", 391 | "[TRAIN] Acc: 61.300\n", 392 | "[TEST] Acc: 45.462\n", 393 | "Epoch: 20\n", 394 | "[TRAIN] Acc: 64.320\n", 395 | "[TEST] Acc: 44.212\n", 396 | "Epoch: 21\n", 397 | "[TRAIN] Acc: 67.720\n", 398 | "[TEST] Acc: 38.525\n", 399 | "Epoch: 22\n", 400 | "[TRAIN] Acc: 70.620\n", 401 | "[TEST] Acc: 40.188\n", 402 | "Epoch: 23\n", 403 | "[TRAIN] Acc: 71.520\n", 404 | "[TEST] Acc: 46.038\n", 405 | "Epoch: 24\n", 406 | "[TRAIN] Acc: 76.200\n", 407 | "[TEST] Acc: 40.325\n", 408 | "Epoch: 25\n", 409 | "[TRAIN] Acc: 77.980\n", 410 | "[TEST] Acc: 34.525\n", 411 | "Epoch: 26\n", 412 | "[TRAIN] Acc: 81.640\n", 413 | "[TEST] Acc: 38.062\n", 414 | "Epoch: 27\n", 415 | "[TRAIN] Acc: 83.620\n", 416 | "[TEST] Acc: 41.462\n", 417 | "Epoch: 28\n", 418 | "[TRAIN] Acc: 86.100\n", 419 | "[TEST] Acc: 46.300\n", 420 | "Epoch: 29\n", 421 | "[TRAIN] Acc: 87.280\n", 422 | "[TEST] Acc: 40.175\n", 423 | "Epoch: 30\n", 424 | "[TRAIN] Acc: 87.500\n", 425 | "[TEST] Acc: 43.000\n", 426 | "Epoch: 31\n", 427 | "[TRAIN] Acc: 90.080\n", 428 | "[TEST] Acc: 48.250\n", 429 | "Epoch: 32\n", 430 | "[TRAIN] Acc: 91.620\n", 431 | "[TEST] Acc: 35.550\n", 432 | "Epoch: 33\n", 433 | "[TRAIN] Acc: 91.680\n", 434 | "[TEST] Acc: 50.925\n", 435 | "Epoch: 34\n", 436 | "[TRAIN] Acc: 92.880\n", 437 | "[TEST] Acc: 52.812\n", 438 | "Epoch: 35\n", 439 | "[TRAIN] Acc: 95.620\n", 440 | "[TEST] Acc: 44.737\n", 441 | "Epoch: 36\n", 442 | "[TRAIN] Acc: 94.460\n", 443 | "[TEST] Acc: 46.575\n", 444 | "Epoch: 37\n", 445 | "[TRAIN] Acc: 94.480\n", 446 | "[TEST] Acc: 46.125\n", 447 | "Epoch: 38\n", 448 | "[TRAIN] Acc: 95.700\n", 449 | "[TEST] Acc: 46.000\n", 450 | "Epoch: 39\n", 451 | "[TRAIN] Acc: 95.600\n", 452 | "[TEST] Acc: 41.700\n", 453 | "Epoch: 40\n", 454 | "[TRAIN] Acc: 95.900\n", 455 | "[TEST] Acc: 38.600\n", 456 | "Epoch: 41\n", 457 | "[TRAIN] Acc: 93.780\n", 458 | "[TEST] Acc: 35.825\n", 459 | "Epoch: 42\n", 460 | "[TRAIN] Acc: 95.080\n", 461 | "[TEST] Acc: 48.775\n", 462 | "Epoch: 43\n", 463 | "[TRAIN] Acc: 97.460\n", 464 | "[TEST] Acc: 47.650\n", 465 | "Epoch: 44\n", 466 | "[TRAIN] Acc: 97.600\n", 467 | "[TEST] Acc: 42.500\n", 468 | "Epoch: 45\n", 469 | "[TRAIN] Acc: 97.220\n", 470 | "[TEST] Acc: 50.987\n", 471 | "Epoch: 46\n", 472 | "[TRAIN] Acc: 97.240\n", 473 | "[TEST] Acc: 51.788\n", 474 | "Epoch: 47\n", 475 | "[TRAIN] Acc: 97.180\n", 476 | "[TEST] Acc: 46.100\n", 477 | "Epoch: 48\n", 478 | "[TRAIN] Acc: 97.660\n", 479 | "[TEST] Acc: 50.700\n", 480 | "Epoch: 49\n", 481 | "[TRAIN] Acc: 97.760\n", 482 | "[TEST] Acc: 50.413\n", 483 | "Epoch: 50\n", 484 | "[TRAIN] Acc: 99.080\n", 485 | "[TEST] Acc: 57.138\n", 486 | "Epoch: 51\n", 487 | "[TRAIN] Acc: 99.780\n", 488 | "[TEST] Acc: 59.750\n", 489 | "Epoch: 52\n", 490 | "[TRAIN] Acc: 99.960\n", 491 | "[TEST] Acc: 59.350\n", 492 | "Epoch: 53\n", 493 | "[TRAIN] Acc: 99.940\n", 494 | "[TEST] Acc: 59.612\n", 495 | "Epoch: 54\n", 496 | "[TRAIN] Acc: 100.000\n", 497 | "[TEST] Acc: 59.562\n", 498 | "Epoch: 55\n", 499 | "[TRAIN] Acc: 100.000\n", 500 | "[TEST] Acc: 59.750\n", 501 | "Epoch: 56\n", 502 | "[TRAIN] Acc: 100.000\n", 503 | "[TEST] Acc: 59.513\n", 504 | "Epoch: 57\n", 505 | "[TRAIN] Acc: 100.000\n", 506 | "[TEST] Acc: 59.850\n", 507 | "Epoch: 58\n", 508 | "[TRAIN] Acc: 100.000\n", 509 | "[TEST] Acc: 59.750\n", 510 | "Epoch: 59\n", 511 | "[TRAIN] Acc: 99.980\n", 512 | "[TEST] Acc: 60.013\n", 513 | "Epoch: 60\n", 514 | "[TRAIN] Acc: 99.940\n", 515 | "[TEST] Acc: 59.125\n", 516 | "Epoch: 61\n", 517 | "[TRAIN] Acc: 99.960\n", 518 | "[TEST] Acc: 58.700\n", 519 | "Epoch: 62\n", 520 | "[TRAIN] Acc: 99.980\n", 521 | "[TEST] Acc: 59.475\n", 522 | "Epoch: 63\n", 523 | "[TRAIN] Acc: 99.960\n", 524 | "[TEST] Acc: 58.700\n", 525 | "Epoch: 64\n", 526 | "[TRAIN] Acc: 99.900\n", 527 | "[TEST] Acc: 58.562\n", 528 | "Epoch: 65\n", 529 | "[TRAIN] Acc: 99.780\n", 530 | "[TEST] Acc: 58.638\n", 531 | "Epoch: 66\n", 532 | "[TRAIN] Acc: 99.920\n", 533 | "[TEST] Acc: 59.050\n", 534 | "Epoch: 67\n", 535 | "[TRAIN] Acc: 99.920\n", 536 | "[TEST] Acc: 58.700\n", 537 | "Epoch: 68\n", 538 | "[TRAIN] Acc: 99.840\n", 539 | "[TEST] Acc: 59.600\n", 540 | "Epoch: 69\n", 541 | "[TRAIN] Acc: 99.980\n", 542 | "[TEST] Acc: 58.725\n", 543 | "Epoch: 70\n", 544 | "[TRAIN] Acc: 99.880\n", 545 | "[TEST] Acc: 58.612\n", 546 | "Epoch: 71\n", 547 | "[TRAIN] Acc: 99.940\n", 548 | "[TEST] Acc: 58.462\n", 549 | "Epoch: 72\n", 550 | "[TRAIN] Acc: 100.000\n", 551 | "[TEST] Acc: 58.975\n", 552 | "Epoch: 73\n", 553 | "[TRAIN] Acc: 99.960\n", 554 | "[TEST] Acc: 58.500\n", 555 | "Epoch: 74\n", 556 | "[TRAIN] Acc: 99.960\n", 557 | "[TEST] Acc: 59.275\n", 558 | "Epoch: 75\n", 559 | "[TRAIN] Acc: 99.980\n", 560 | "[TEST] Acc: 58.850\n", 561 | "Epoch: 76\n", 562 | "[TRAIN] Acc: 99.980\n", 563 | "[TEST] Acc: 58.538\n", 564 | "Epoch: 77\n", 565 | "[TRAIN] Acc: 99.980\n", 566 | "[TEST] Acc: 59.050\n", 567 | "Epoch: 78\n", 568 | "[TRAIN] Acc: 99.960\n", 569 | "[TEST] Acc: 58.775\n", 570 | "Epoch: 79\n", 571 | "[TRAIN] Acc: 100.000\n", 572 | "[TEST] Acc: 58.750\n", 573 | "Epoch: 80\n", 574 | "[TRAIN] Acc: 100.000\n", 575 | "[TEST] Acc: 59.388\n", 576 | "Epoch: 81\n", 577 | "[TRAIN] Acc: 100.000\n", 578 | "[TEST] Acc: 59.362\n", 579 | "Epoch: 82\n", 580 | "[TRAIN] Acc: 100.000\n", 581 | "[TEST] Acc: 59.337\n", 582 | "Epoch: 83\n", 583 | "[TRAIN] Acc: 100.000\n", 584 | "[TEST] Acc: 59.263\n", 585 | "Epoch: 84\n", 586 | "[TRAIN] Acc: 100.000\n", 587 | "[TEST] Acc: 59.188\n", 588 | "Epoch: 85\n", 589 | "[TRAIN] Acc: 100.000\n", 590 | "[TEST] Acc: 59.325\n", 591 | "Epoch: 86\n", 592 | "[TRAIN] Acc: 100.000\n", 593 | "[TEST] Acc: 59.337\n", 594 | "Epoch: 87\n", 595 | "[TRAIN] Acc: 99.960\n", 596 | "[TEST] Acc: 55.300\n", 597 | "Epoch: 88\n", 598 | "[TRAIN] Acc: 99.960\n", 599 | "[TEST] Acc: 58.487\n", 600 | "Epoch: 89\n", 601 | "[TRAIN] Acc: 99.980\n", 602 | "[TEST] Acc: 57.862\n", 603 | "Epoch: 90\n", 604 | "[TRAIN] Acc: 99.740\n", 605 | "[TEST] Acc: 56.938\n", 606 | "Epoch: 91\n", 607 | "[TRAIN] Acc: 99.620\n", 608 | "[TEST] Acc: 56.275\n", 609 | "Epoch: 92\n", 610 | "[TRAIN] Acc: 99.780\n", 611 | "[TEST] Acc: 55.587\n", 612 | "Epoch: 93\n", 613 | "[TRAIN] Acc: 99.860\n", 614 | "[TEST] Acc: 58.562\n", 615 | "Epoch: 94\n", 616 | "[TRAIN] Acc: 99.920\n", 617 | "[TEST] Acc: 58.362\n", 618 | "Epoch: 95\n", 619 | "[TRAIN] Acc: 99.920\n", 620 | "[TEST] Acc: 58.788\n", 621 | "Epoch: 96\n", 622 | "[TRAIN] Acc: 99.980\n", 623 | "[TEST] Acc: 58.513\n", 624 | "Epoch: 97\n", 625 | "[TRAIN] Acc: 100.000\n", 626 | "[TEST] Acc: 59.525\n", 627 | "Epoch: 98\n", 628 | "[TRAIN] Acc: 100.000\n", 629 | "[TEST] Acc: 59.375\n", 630 | "Epoch: 99\n", 631 | "[TRAIN] Acc: 100.000\n", 632 | "[TEST] Acc: 59.650\n", 633 | "Epoch: 100\n", 634 | "[TRAIN] Acc: 100.000\n", 635 | "[TEST] Acc: 59.575\n", 636 | "Epoch: 101\n", 637 | "[TRAIN] Acc: 100.000\n", 638 | "[TEST] Acc: 59.425\n", 639 | "Epoch: 102\n", 640 | "[TRAIN] Acc: 100.000\n", 641 | "[TEST] Acc: 59.638\n", 642 | "Epoch: 103\n", 643 | "[TRAIN] Acc: 100.000\n", 644 | "[TEST] Acc: 59.413\n", 645 | "Epoch: 104\n", 646 | "[TRAIN] Acc: 100.000\n", 647 | "[TEST] Acc: 59.587\n", 648 | "Epoch: 105\n", 649 | "[TRAIN] Acc: 100.000\n", 650 | "[TEST] Acc: 59.575\n", 651 | "Epoch: 106\n", 652 | "[TRAIN] Acc: 100.000\n", 653 | "[TEST] Acc: 59.538\n", 654 | "Epoch: 107\n", 655 | "[TRAIN] Acc: 100.000\n", 656 | "[TEST] Acc: 59.312\n", 657 | "Epoch: 108\n", 658 | "[TRAIN] Acc: 100.000\n", 659 | "[TEST] Acc: 59.288\n", 660 | "Epoch: 109\n", 661 | "[TRAIN] Acc: 100.000\n", 662 | "[TEST] Acc: 59.462\n", 663 | "Epoch: 110\n", 664 | "[TRAIN] Acc: 100.000\n", 665 | "[TEST] Acc: 59.550\n", 666 | "Epoch: 111\n", 667 | "[TRAIN] Acc: 100.000\n", 668 | "[TEST] Acc: 59.663\n", 669 | "Epoch: 112\n", 670 | "[TRAIN] Acc: 100.000\n", 671 | "[TEST] Acc: 59.525\n", 672 | "Epoch: 113\n", 673 | "[TRAIN] Acc: 100.000\n", 674 | "[TEST] Acc: 59.600\n", 675 | "Epoch: 114\n", 676 | "[TRAIN] Acc: 100.000\n", 677 | "[TEST] Acc: 59.688\n", 678 | "Epoch: 115\n", 679 | "[TRAIN] Acc: 100.000\n", 680 | "[TEST] Acc: 59.625\n", 681 | "Epoch: 116\n", 682 | "[TRAIN] Acc: 100.000\n", 683 | "[TEST] Acc: 59.712\n", 684 | "Epoch: 117\n", 685 | "[TRAIN] Acc: 100.000\n", 686 | "[TEST] Acc: 59.725\n", 687 | "Epoch: 118\n", 688 | "[TRAIN] Acc: 100.000\n", 689 | "[TEST] Acc: 59.875\n", 690 | "Epoch: 119\n", 691 | "[TRAIN] Acc: 100.000\n", 692 | "[TEST] Acc: 59.600\n", 693 | "Epoch: 120\n", 694 | "[TRAIN] Acc: 100.000\n", 695 | "[TEST] Acc: 59.763\n", 696 | "Epoch: 121\n", 697 | "[TRAIN] Acc: 100.000\n", 698 | "[TEST] Acc: 59.650\n", 699 | "Epoch: 122\n", 700 | "[TRAIN] Acc: 100.000\n", 701 | "[TEST] Acc: 59.625\n", 702 | "Epoch: 123\n", 703 | "[TRAIN] Acc: 100.000\n", 704 | "[TEST] Acc: 59.525\n", 705 | "Epoch: 124\n", 706 | "[TRAIN] Acc: 100.000\n", 707 | "[TEST] Acc: 59.737\n", 708 | "Epoch: 125\n", 709 | "[TRAIN] Acc: 99.980\n", 710 | "[TEST] Acc: 59.788\n", 711 | "Epoch: 126\n", 712 | "[TRAIN] Acc: 100.000\n", 713 | "[TEST] Acc: 59.800\n", 714 | "Epoch: 127\n", 715 | "[TRAIN] Acc: 100.000\n", 716 | "[TEST] Acc: 59.712\n", 717 | "Epoch: 128\n", 718 | "[TRAIN] Acc: 100.000\n", 719 | "[TEST] Acc: 59.550\n", 720 | "Epoch: 129\n", 721 | "[TRAIN] Acc: 100.000\n", 722 | "[TEST] Acc: 59.625\n", 723 | "Epoch: 130\n", 724 | "[TRAIN] Acc: 100.000\n", 725 | "[TEST] Acc: 59.750\n", 726 | "Epoch: 131\n", 727 | "[TRAIN] Acc: 100.000\n", 728 | "[TEST] Acc: 59.850\n", 729 | "Epoch: 132\n", 730 | "[TRAIN] Acc: 100.000\n", 731 | "[TEST] Acc: 59.950\n", 732 | "Epoch: 133\n", 733 | "[TRAIN] Acc: 100.000\n", 734 | "[TEST] Acc: 59.837\n", 735 | "Epoch: 134\n", 736 | "[TRAIN] Acc: 100.000\n", 737 | "[TEST] Acc: 59.900\n", 738 | "Epoch: 135\n", 739 | "[TRAIN] Acc: 100.000\n", 740 | "[TEST] Acc: 59.938\n", 741 | "Epoch: 136\n", 742 | "[TRAIN] Acc: 100.000\n", 743 | "[TEST] Acc: 59.750\n", 744 | "Epoch: 137\n", 745 | "[TRAIN] Acc: 100.000\n", 746 | "[TEST] Acc: 59.862\n", 747 | "Epoch: 138\n", 748 | "[TRAIN] Acc: 100.000\n", 749 | "[TEST] Acc: 59.875\n", 750 | "Epoch: 139\n", 751 | "[TRAIN] Acc: 100.000\n", 752 | "[TEST] Acc: 59.763\n", 753 | "Epoch: 140\n", 754 | "[TRAIN] Acc: 100.000\n", 755 | "[TEST] Acc: 59.737\n", 756 | "Epoch: 141\n", 757 | "[TRAIN] Acc: 100.000\n", 758 | "[TEST] Acc: 59.825\n", 759 | "Epoch: 142\n", 760 | "[TRAIN] Acc: 100.000\n", 761 | "[TEST] Acc: 59.850\n", 762 | "Epoch: 143\n", 763 | "[TRAIN] Acc: 100.000\n", 764 | "[TEST] Acc: 59.775\n", 765 | "Epoch: 144\n", 766 | "[TRAIN] Acc: 100.000\n", 767 | "[TEST] Acc: 59.688\n", 768 | "Epoch: 145\n", 769 | "[TRAIN] Acc: 100.000\n", 770 | "[TEST] Acc: 59.812\n", 771 | "Epoch: 146\n", 772 | "[TRAIN] Acc: 100.000\n", 773 | "[TEST] Acc: 59.800\n", 774 | "Epoch: 147\n", 775 | "[TRAIN] Acc: 100.000\n", 776 | "[TEST] Acc: 59.750\n", 777 | "Epoch: 148\n", 778 | "[TRAIN] Acc: 100.000\n", 779 | "[TEST] Acc: 59.725\n", 780 | "Epoch: 149\n", 781 | "[TRAIN] Acc: 100.000\n", 782 | "[TEST] Acc: 59.725\n", 783 | "Epoch: 150\n", 784 | "[TRAIN] Acc: 100.000\n", 785 | "[TEST] Acc: 59.788\n", 786 | "Epoch: 151\n", 787 | "[TRAIN] Acc: 100.000\n", 788 | "[TEST] Acc: 59.825\n", 789 | "Epoch: 152\n", 790 | "[TRAIN] Acc: 100.000\n", 791 | "[TEST] Acc: 59.725\n", 792 | "Epoch: 153\n", 793 | "[TRAIN] Acc: 100.000\n", 794 | "[TEST] Acc: 59.750\n", 795 | "Epoch: 154\n", 796 | "[TRAIN] Acc: 100.000\n", 797 | "[TEST] Acc: 59.837\n", 798 | "Epoch: 155\n", 799 | "[TRAIN] Acc: 100.000\n", 800 | "[TEST] Acc: 59.675\n", 801 | "Epoch: 156\n", 802 | "[TRAIN] Acc: 100.000\n", 803 | "[TEST] Acc: 59.788\n", 804 | "Epoch: 157\n", 805 | "[TRAIN] Acc: 100.000\n", 806 | "[TEST] Acc: 59.812\n", 807 | "Epoch: 158\n", 808 | "[TRAIN] Acc: 100.000\n", 809 | "[TEST] Acc: 59.775\n", 810 | "Epoch: 159\n", 811 | "[TRAIN] Acc: 100.000\n", 812 | "[TEST] Acc: 59.763\n", 813 | "Epoch: 160\n", 814 | "[TRAIN] Acc: 100.000\n", 815 | "[TEST] Acc: 59.763\n", 816 | "Epoch: 161\n", 817 | "[TRAIN] Acc: 100.000\n", 818 | "[TEST] Acc: 59.750\n", 819 | "Epoch: 162\n", 820 | "[TRAIN] Acc: 100.000\n", 821 | "[TEST] Acc: 59.750\n", 822 | "Epoch: 163\n", 823 | "[TRAIN] Acc: 99.960\n", 824 | "[TEST] Acc: 59.750\n", 825 | "Epoch: 164\n", 826 | "[TRAIN] Acc: 100.000\n" 827 | ] 828 | }, 829 | { 830 | "name": "stdout", 831 | "output_type": "stream", 832 | "text": [ 833 | "[TEST] Acc: 59.837\n", 834 | "Epoch: 165\n", 835 | "[TRAIN] Acc: 100.000\n", 836 | "[TEST] Acc: 59.913\n", 837 | "Epoch: 166\n", 838 | "[TRAIN] Acc: 100.000\n", 839 | "[TEST] Acc: 59.975\n", 840 | "Epoch: 167\n", 841 | "[TRAIN] Acc: 100.000\n", 842 | "[TEST] Acc: 60.050\n", 843 | "Epoch: 168\n", 844 | "[TRAIN] Acc: 100.000\n", 845 | "[TEST] Acc: 59.900\n", 846 | "Epoch: 169\n", 847 | "[TRAIN] Acc: 100.000\n", 848 | "[TEST] Acc: 59.875\n", 849 | "Epoch: 170\n", 850 | "[TRAIN] Acc: 100.000\n", 851 | "[TEST] Acc: 59.712\n", 852 | "Epoch: 171\n", 853 | "[TRAIN] Acc: 100.000\n", 854 | "[TEST] Acc: 59.763\n", 855 | "Epoch: 172\n", 856 | "[TRAIN] Acc: 100.000\n", 857 | "[TEST] Acc: 59.775\n", 858 | "Epoch: 173\n", 859 | "[TRAIN] Acc: 100.000\n", 860 | "[TEST] Acc: 59.825\n", 861 | "Epoch: 174\n", 862 | "[TRAIN] Acc: 100.000\n", 863 | "[TEST] Acc: 59.888\n", 864 | "Epoch: 175\n", 865 | "[TRAIN] Acc: 100.000\n", 866 | "[TEST] Acc: 59.888\n", 867 | "Epoch: 176\n", 868 | "[TRAIN] Acc: 100.000\n", 869 | "[TEST] Acc: 59.862\n", 870 | "Epoch: 177\n", 871 | "[TRAIN] Acc: 100.000\n", 872 | "[TEST] Acc: 60.000\n", 873 | "Epoch: 178\n", 874 | "[TRAIN] Acc: 100.000\n", 875 | "[TEST] Acc: 59.888\n", 876 | "Epoch: 179\n", 877 | "[TRAIN] Acc: 100.000\n", 878 | "[TEST] Acc: 59.788\n", 879 | "Epoch: 180\n", 880 | "[TRAIN] Acc: 100.000\n", 881 | "[TEST] Acc: 59.750\n", 882 | "Epoch: 181\n", 883 | "[TRAIN] Acc: 100.000\n", 884 | "[TEST] Acc: 59.825\n", 885 | "Epoch: 182\n", 886 | "[TRAIN] Acc: 100.000\n", 887 | "[TEST] Acc: 59.950\n", 888 | "Epoch: 183\n", 889 | "[TRAIN] Acc: 100.000\n", 890 | "[TEST] Acc: 59.975\n", 891 | "Epoch: 184\n", 892 | "[TRAIN] Acc: 100.000\n", 893 | "[TEST] Acc: 59.788\n", 894 | "Epoch: 185\n", 895 | "[TRAIN] Acc: 100.000\n", 896 | "[TEST] Acc: 59.675\n", 897 | "Epoch: 186\n", 898 | "[TRAIN] Acc: 100.000\n", 899 | "[TEST] Acc: 59.837\n", 900 | "Epoch: 187\n", 901 | "[TRAIN] Acc: 100.000\n", 902 | "[TEST] Acc: 59.862\n", 903 | "Epoch: 188\n", 904 | "[TRAIN] Acc: 100.000\n", 905 | "[TEST] Acc: 59.763\n", 906 | "Epoch: 189\n", 907 | "[TRAIN] Acc: 100.000\n", 908 | "[TEST] Acc: 59.800\n", 909 | "Epoch: 190\n", 910 | "[TRAIN] Acc: 100.000\n", 911 | "[TEST] Acc: 59.925\n", 912 | "Epoch: 191\n", 913 | "[TRAIN] Acc: 100.000\n", 914 | "[TEST] Acc: 59.900\n", 915 | "Epoch: 192\n", 916 | "[TRAIN] Acc: 100.000\n", 917 | "[TEST] Acc: 59.688\n", 918 | "Epoch: 193\n", 919 | "[TRAIN] Acc: 100.000\n", 920 | "[TEST] Acc: 59.775\n", 921 | "Epoch: 194\n", 922 | "[TRAIN] Acc: 100.000\n", 923 | "[TEST] Acc: 59.850\n", 924 | "Epoch: 195\n", 925 | "[TRAIN] Acc: 100.000\n", 926 | "[TEST] Acc: 59.925\n", 927 | "Epoch: 196\n", 928 | "[TRAIN] Acc: 100.000\n", 929 | "[TEST] Acc: 59.925\n", 930 | "Epoch: 197\n", 931 | "[TRAIN] Acc: 100.000\n", 932 | "[TEST] Acc: 59.862\n", 933 | "Epoch: 198\n", 934 | "[TRAIN] Acc: 100.000\n", 935 | "[TEST] Acc: 59.837\n", 936 | "Epoch: 199\n", 937 | "[TRAIN] Acc: 100.000\n", 938 | "[TEST] Acc: 59.825\n" 939 | ] 940 | } 941 | ], 942 | "source": [ 943 | "for epoch in epochs:\n", 944 | " optimizer = Adam(net.parameters(), lr=lr)\n", 945 | " for _ in range(epoch):\n", 946 | " train(count)\n", 947 | " test(count)\n", 948 | " count += 1\n", 949 | " lr /= 10" 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "execution_count": 2, 955 | "metadata": {}, 956 | "outputs": [], 957 | "source": [ 958 | "lr=0.01\n", 959 | "data='cifar100'\n", 960 | "root='./data/'\n", 961 | "model='vgg'\n", 962 | "model_out='./checkpoint/cifar100_vgg_PReLU.pth'\n", 963 | "resume = False" 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "execution_count": 13, 969 | "metadata": {}, 970 | "outputs": [ 971 | { 972 | "name": "stdout", 973 | "output_type": "stream", 974 | "text": [ 975 | "Epoch: 0\n", 976 | "[TRAIN] Acc: 3.426\n", 977 | "[TEST] Acc: 3.960\n", 978 | "Epoch: 1\n", 979 | "[TRAIN] Acc: 5.782\n", 980 | "[TEST] Acc: 6.310\n", 981 | "Epoch: 2\n", 982 | "[TRAIN] Acc: 9.578\n", 983 | "[TEST] Acc: 6.440\n", 984 | "Epoch: 3\n", 985 | "[TRAIN] Acc: 13.704\n", 986 | "[TEST] Acc: 12.660\n", 987 | "Epoch: 4\n", 988 | "[TRAIN] Acc: 18.472\n", 989 | "[TEST] Acc: 15.680\n", 990 | "Epoch: 5\n", 991 | "[TRAIN] Acc: 22.984\n", 992 | "[TEST] Acc: 21.550\n", 993 | "Epoch: 6\n", 994 | "[TRAIN] Acc: 27.154\n", 995 | "[TEST] Acc: 20.270\n", 996 | "Epoch: 7\n", 997 | "[TRAIN] Acc: 30.962\n", 998 | "[TEST] Acc: 23.910\n", 999 | "Epoch: 8\n", 1000 | "[TRAIN] Acc: 35.440\n", 1001 | "[TEST] Acc: 18.690\n", 1002 | "Epoch: 9\n", 1003 | "[TRAIN] Acc: 39.474\n", 1004 | "[TEST] Acc: 30.100\n", 1005 | "Epoch: 10\n", 1006 | "[TRAIN] Acc: 43.768\n", 1007 | "[TEST] Acc: 27.830\n", 1008 | "Epoch: 11\n", 1009 | "[TRAIN] Acc: 47.416\n", 1010 | "[TEST] Acc: 29.520\n", 1011 | "Epoch: 12\n", 1012 | "[TRAIN] Acc: 51.422\n", 1013 | "[TEST] Acc: 32.580\n", 1014 | "Epoch: 13\n", 1015 | "[TRAIN] Acc: 55.230\n", 1016 | "[TEST] Acc: 32.330\n", 1017 | "Epoch: 14\n", 1018 | "[TRAIN] Acc: 59.860\n", 1019 | "[TEST] Acc: 32.610\n", 1020 | "Epoch: 15\n", 1021 | "[TRAIN] Acc: 62.998\n", 1022 | "[TEST] Acc: 34.090\n", 1023 | "Epoch: 16\n", 1024 | "[TRAIN] Acc: 67.052\n", 1025 | "[TEST] Acc: 35.710\n", 1026 | "Epoch: 17\n", 1027 | "[TRAIN] Acc: 70.554\n", 1028 | "[TEST] Acc: 30.890\n", 1029 | "Epoch: 18\n", 1030 | "[TRAIN] Acc: 73.552\n", 1031 | "[TEST] Acc: 33.040\n", 1032 | "Epoch: 19\n", 1033 | "[TRAIN] Acc: 76.548\n", 1034 | "[TEST] Acc: 34.830\n", 1035 | "Epoch: 20\n", 1036 | "[TRAIN] Acc: 79.182\n", 1037 | "[TEST] Acc: 35.210\n", 1038 | "Epoch: 21\n", 1039 | "[TRAIN] Acc: 81.338\n", 1040 | "[TEST] Acc: 33.840\n", 1041 | "Epoch: 22\n", 1042 | "[TRAIN] Acc: 83.018\n", 1043 | "[TEST] Acc: 34.420\n", 1044 | "Epoch: 23\n", 1045 | "[TRAIN] Acc: 85.204\n", 1046 | "[TEST] Acc: 34.960\n", 1047 | "Epoch: 24\n", 1048 | "[TRAIN] Acc: 86.512\n", 1049 | "[TEST] Acc: 35.720\n", 1050 | "Epoch: 25\n", 1051 | "[TRAIN] Acc: 87.698\n", 1052 | "[TEST] Acc: 35.840\n", 1053 | "Epoch: 26\n", 1054 | "[TRAIN] Acc: 88.442\n", 1055 | "[TEST] Acc: 36.420\n", 1056 | "Epoch: 27\n", 1057 | "[TRAIN] Acc: 89.740\n", 1058 | "[TEST] Acc: 36.600\n", 1059 | "Epoch: 28\n", 1060 | "[TRAIN] Acc: 90.728\n", 1061 | "[TEST] Acc: 36.020\n", 1062 | "Epoch: 29\n", 1063 | "[TRAIN] Acc: 91.216\n", 1064 | "[TEST] Acc: 35.760\n", 1065 | "Epoch: 30\n", 1066 | "[TRAIN] Acc: 91.280\n", 1067 | "[TEST] Acc: 35.850\n", 1068 | "Epoch: 31\n", 1069 | "[TRAIN] Acc: 92.200\n", 1070 | "[TEST] Acc: 35.210\n", 1071 | "Epoch: 32\n", 1072 | "[TRAIN] Acc: 92.666\n", 1073 | "[TEST] Acc: 35.670\n", 1074 | "Epoch: 33\n", 1075 | "[TRAIN] Acc: 93.156\n", 1076 | "[TEST] Acc: 37.090\n", 1077 | "Epoch: 34\n", 1078 | "[TRAIN] Acc: 92.074\n", 1079 | "[TEST] Acc: 36.980\n", 1080 | "Epoch: 35\n", 1081 | "[TRAIN] Acc: 93.802\n", 1082 | "[TEST] Acc: 36.650\n", 1083 | "Epoch: 36\n", 1084 | "[TRAIN] Acc: 93.864\n", 1085 | "[TEST] Acc: 36.960\n", 1086 | "Epoch: 37\n", 1087 | "[TRAIN] Acc: 93.334\n", 1088 | "[TEST] Acc: 36.090\n", 1089 | "Epoch: 38\n", 1090 | "[TRAIN] Acc: 94.326\n", 1091 | "[TEST] Acc: 37.340\n", 1092 | "Epoch: 39\n", 1093 | "[TRAIN] Acc: 93.526\n", 1094 | "[TEST] Acc: 38.060\n", 1095 | "Epoch: 40\n", 1096 | "[TRAIN] Acc: 93.810\n", 1097 | "[TEST] Acc: 38.000\n", 1098 | "Epoch: 41\n", 1099 | "[TRAIN] Acc: 95.286\n", 1100 | "[TEST] Acc: 37.170\n", 1101 | "Epoch: 42\n", 1102 | "[TRAIN] Acc: 95.218\n", 1103 | "[TEST] Acc: 36.570\n", 1104 | "Epoch: 43\n", 1105 | "[TRAIN] Acc: 94.766\n", 1106 | "[TEST] Acc: 34.950\n", 1107 | "Epoch: 44\n", 1108 | "[TRAIN] Acc: 95.224\n", 1109 | "[TEST] Acc: 35.060\n", 1110 | "Epoch: 45\n", 1111 | "[TRAIN] Acc: 94.754\n", 1112 | "[TEST] Acc: 38.200\n", 1113 | "Epoch: 46\n", 1114 | "[TRAIN] Acc: 95.072\n", 1115 | "[TEST] Acc: 37.820\n", 1116 | "Epoch: 47\n", 1117 | "[TRAIN] Acc: 95.924\n", 1118 | "[TEST] Acc: 36.260\n", 1119 | "Epoch: 48\n", 1120 | "[TRAIN] Acc: 95.234\n", 1121 | "[TEST] Acc: 37.580\n", 1122 | "Epoch: 49\n", 1123 | "[TRAIN] Acc: 95.684\n", 1124 | "[TEST] Acc: 37.680\n", 1125 | "Epoch: 50\n", 1126 | "[TRAIN] Acc: 98.530\n", 1127 | "[TEST] Acc: 42.640\n", 1128 | "Epoch: 51\n", 1129 | "[TRAIN] Acc: 99.658\n", 1130 | "[TEST] Acc: 42.830\n", 1131 | "Epoch: 52\n", 1132 | "[TRAIN] Acc: 99.826\n", 1133 | "[TEST] Acc: 42.940\n", 1134 | "Epoch: 53\n", 1135 | "[TRAIN] Acc: 99.878\n", 1136 | "[TEST] Acc: 42.580\n", 1137 | "Epoch: 54\n", 1138 | "[TRAIN] Acc: 99.910\n", 1139 | "[TEST] Acc: 43.020\n", 1140 | "Epoch: 55\n", 1141 | "[TRAIN] Acc: 99.922\n", 1142 | "[TEST] Acc: 43.120\n", 1143 | "Epoch: 56\n", 1144 | "[TRAIN] Acc: 99.936\n", 1145 | "[TEST] Acc: 42.940\n", 1146 | "Epoch: 57\n", 1147 | "[TRAIN] Acc: 99.934\n", 1148 | "[TEST] Acc: 42.950\n", 1149 | "Epoch: 58\n", 1150 | "[TRAIN] Acc: 99.936\n", 1151 | "[TEST] Acc: 43.380\n", 1152 | "Epoch: 59\n", 1153 | "[TRAIN] Acc: 99.920\n", 1154 | "[TEST] Acc: 43.020\n", 1155 | "Epoch: 60\n", 1156 | "[TRAIN] Acc: 99.936\n", 1157 | "[TEST] Acc: 43.000\n", 1158 | "Epoch: 61\n", 1159 | "[TRAIN] Acc: 99.916\n", 1160 | "[TEST] Acc: 42.860\n", 1161 | "Epoch: 62\n", 1162 | "[TRAIN] Acc: 99.952\n", 1163 | "[TEST] Acc: 42.560\n", 1164 | "Epoch: 63\n", 1165 | "[TRAIN] Acc: 99.946\n", 1166 | "[TEST] Acc: 42.580\n", 1167 | "Epoch: 64\n", 1168 | "[TRAIN] Acc: 99.934\n", 1169 | "[TEST] Acc: 42.820\n", 1170 | "Epoch: 65\n", 1171 | "[TRAIN] Acc: 99.952\n", 1172 | "[TEST] Acc: 42.770\n", 1173 | "Epoch: 66\n", 1174 | "[TRAIN] Acc: 99.962\n", 1175 | "[TEST] Acc: 42.880\n", 1176 | "Epoch: 67\n", 1177 | "[TRAIN] Acc: 99.952\n", 1178 | "[TEST] Acc: 42.730\n", 1179 | "Epoch: 68\n", 1180 | "[TRAIN] Acc: 99.946\n", 1181 | "[TEST] Acc: 42.870\n", 1182 | "Epoch: 69\n", 1183 | "[TRAIN] Acc: 99.932\n", 1184 | "[TEST] Acc: 42.910\n", 1185 | "Epoch: 70\n", 1186 | "[TRAIN] Acc: 99.918\n", 1187 | "[TEST] Acc: 42.990\n", 1188 | "Epoch: 71\n", 1189 | "[TRAIN] Acc: 99.904\n", 1190 | "[TEST] Acc: 42.920\n", 1191 | "Epoch: 72\n", 1192 | "[TRAIN] Acc: 99.840\n", 1193 | "[TEST] Acc: 42.830\n", 1194 | "Epoch: 73\n", 1195 | "[TRAIN] Acc: 99.818\n", 1196 | "[TEST] Acc: 42.020\n", 1197 | "Epoch: 74\n", 1198 | "[TRAIN] Acc: 99.944\n", 1199 | "[TEST] Acc: 42.690\n", 1200 | "Epoch: 75\n", 1201 | "[TRAIN] Acc: 99.932\n", 1202 | "[TEST] Acc: 42.500\n", 1203 | "Epoch: 76\n", 1204 | "[TRAIN] Acc: 99.936\n", 1205 | "[TEST] Acc: 42.850\n", 1206 | "Epoch: 77\n", 1207 | "[TRAIN] Acc: 99.886\n", 1208 | "[TEST] Acc: 42.820\n", 1209 | "Epoch: 78\n", 1210 | "[TRAIN] Acc: 99.924\n", 1211 | "[TEST] Acc: 42.910\n", 1212 | "Epoch: 79\n", 1213 | "[TRAIN] Acc: 99.934\n", 1214 | "[TEST] Acc: 42.890\n", 1215 | "Epoch: 80\n", 1216 | "[TRAIN] Acc: 99.926\n", 1217 | "[TEST] Acc: 42.790\n", 1218 | "Epoch: 81\n", 1219 | "[TRAIN] Acc: 99.926\n", 1220 | "[TEST] Acc: 42.980\n", 1221 | "Epoch: 82\n", 1222 | "[TRAIN] Acc: 99.898\n", 1223 | "[TEST] Acc: 43.000\n", 1224 | "Epoch: 83\n", 1225 | "[TRAIN] Acc: 99.928\n", 1226 | "[TEST] Acc: 43.020\n", 1227 | "Epoch: 84\n", 1228 | "[TRAIN] Acc: 99.944\n", 1229 | "[TEST] Acc: 42.730\n", 1230 | "Epoch: 85\n", 1231 | "[TRAIN] Acc: 99.950\n", 1232 | "[TEST] Acc: 43.110\n", 1233 | "Epoch: 86\n", 1234 | "[TRAIN] Acc: 99.896\n", 1235 | "[TEST] Acc: 42.890\n", 1236 | "Epoch: 87\n", 1237 | "[TRAIN] Acc: 99.914\n", 1238 | "[TEST] Acc: 42.610\n", 1239 | "Epoch: 88\n", 1240 | "[TRAIN] Acc: 99.946\n", 1241 | "[TEST] Acc: 42.900\n", 1242 | "Epoch: 89\n", 1243 | "[TRAIN] Acc: 99.942\n", 1244 | "[TEST] Acc: 42.360\n", 1245 | "Epoch: 90\n", 1246 | "[TRAIN] Acc: 99.942\n", 1247 | "[TEST] Acc: 41.990\n", 1248 | "Epoch: 91\n", 1249 | "[TRAIN] Acc: 99.938\n", 1250 | "[TEST] Acc: 42.390\n", 1251 | "Epoch: 92\n", 1252 | "[TRAIN] Acc: 99.946\n", 1253 | "[TEST] Acc: 42.470\n", 1254 | "Epoch: 93\n", 1255 | "[TRAIN] Acc: 99.904\n", 1256 | "[TEST] Acc: 42.540\n", 1257 | "Epoch: 94\n", 1258 | "[TRAIN] Acc: 99.882\n", 1259 | "[TEST] Acc: 42.730\n", 1260 | "Epoch: 95\n", 1261 | "[TRAIN] Acc: 99.930\n", 1262 | "[TEST] Acc: 42.910\n", 1263 | "Epoch: 96\n", 1264 | "[TRAIN] Acc: 99.958\n", 1265 | "[TEST] Acc: 43.120\n", 1266 | "Epoch: 97\n", 1267 | "[TRAIN] Acc: 99.956\n", 1268 | "[TEST] Acc: 42.810\n", 1269 | "Epoch: 98\n", 1270 | "[TRAIN] Acc: 99.938\n", 1271 | "[TEST] Acc: 43.220\n", 1272 | "Epoch: 99\n", 1273 | "[TRAIN] Acc: 99.936\n", 1274 | "[TEST] Acc: 42.630\n", 1275 | "Epoch: 100\n", 1276 | "[TRAIN] Acc: 99.946\n", 1277 | "[TEST] Acc: 42.580\n", 1278 | "Epoch: 101\n", 1279 | "[TRAIN] Acc: 99.960\n", 1280 | "[TEST] Acc: 42.910\n", 1281 | "Epoch: 102\n", 1282 | "[TRAIN] Acc: 99.966\n", 1283 | "[TEST] Acc: 42.840\n", 1284 | "Epoch: 103\n", 1285 | "[TRAIN] Acc: 99.974\n", 1286 | "[TEST] Acc: 42.870\n", 1287 | "Epoch: 104\n", 1288 | "[TRAIN] Acc: 99.974\n", 1289 | "[TEST] Acc: 42.890\n", 1290 | "Epoch: 105\n", 1291 | "[TRAIN] Acc: 99.974\n", 1292 | "[TEST] Acc: 42.770\n", 1293 | "Epoch: 106\n", 1294 | "[TRAIN] Acc: 99.966\n", 1295 | "[TEST] Acc: 43.000\n", 1296 | "Epoch: 107\n", 1297 | "[TRAIN] Acc: 99.974\n", 1298 | "[TEST] Acc: 43.050\n", 1299 | "Epoch: 108\n", 1300 | "[TRAIN] Acc: 99.970\n", 1301 | "[TEST] Acc: 42.880\n", 1302 | "Epoch: 109\n", 1303 | "[TRAIN] Acc: 99.974\n", 1304 | "[TEST] Acc: 43.040\n", 1305 | "Epoch: 110\n", 1306 | "[TRAIN] Acc: 99.978\n", 1307 | "[TEST] Acc: 42.710\n", 1308 | "Epoch: 111\n", 1309 | "[TRAIN] Acc: 99.972\n", 1310 | "[TEST] Acc: 42.940\n", 1311 | "Epoch: 112\n", 1312 | "[TRAIN] Acc: 99.970\n", 1313 | "[TEST] Acc: 43.010\n", 1314 | "Epoch: 113\n", 1315 | "[TRAIN] Acc: 99.976\n", 1316 | "[TEST] Acc: 42.940\n", 1317 | "Epoch: 114\n", 1318 | "[TRAIN] Acc: 99.982\n", 1319 | "[TEST] Acc: 42.890\n", 1320 | "Epoch: 115\n", 1321 | "[TRAIN] Acc: 99.986\n", 1322 | "[TEST] Acc: 42.810\n", 1323 | "Epoch: 116\n", 1324 | "[TRAIN] Acc: 99.978\n", 1325 | "[TEST] Acc: 42.820\n", 1326 | "Epoch: 117\n", 1327 | "[TRAIN] Acc: 99.970\n", 1328 | "[TEST] Acc: 42.870\n", 1329 | "Epoch: 118\n", 1330 | "[TRAIN] Acc: 99.972\n", 1331 | "[TEST] Acc: 42.850\n", 1332 | "Epoch: 119\n", 1333 | "[TRAIN] Acc: 99.974\n", 1334 | "[TEST] Acc: 43.050\n", 1335 | "Epoch: 120\n", 1336 | "[TRAIN] Acc: 99.976\n", 1337 | "[TEST] Acc: 42.850\n", 1338 | "Epoch: 121\n", 1339 | "[TRAIN] Acc: 99.974\n", 1340 | "[TEST] Acc: 42.890\n", 1341 | "Epoch: 122\n", 1342 | "[TRAIN] Acc: 99.980\n", 1343 | "[TEST] Acc: 43.110\n", 1344 | "Epoch: 123\n", 1345 | "[TRAIN] Acc: 99.978\n", 1346 | "[TEST] Acc: 42.910\n", 1347 | "Epoch: 124\n", 1348 | "[TRAIN] Acc: 99.976\n", 1349 | "[TEST] Acc: 42.980\n", 1350 | "Epoch: 125\n", 1351 | "[TRAIN] Acc: 99.980\n", 1352 | "[TEST] Acc: 42.910\n", 1353 | "Epoch: 126\n", 1354 | "[TRAIN] Acc: 99.974\n", 1355 | "[TEST] Acc: 43.150\n", 1356 | "Epoch: 127\n", 1357 | "[TRAIN] Acc: 99.980\n", 1358 | "[TEST] Acc: 42.990\n", 1359 | "Epoch: 128\n", 1360 | "[TRAIN] Acc: 99.974\n", 1361 | "[TEST] Acc: 43.080\n", 1362 | "Epoch: 129\n", 1363 | "[TRAIN] Acc: 99.978\n", 1364 | "[TEST] Acc: 43.020\n", 1365 | "Epoch: 130\n", 1366 | "[TRAIN] Acc: 99.982\n", 1367 | "[TEST] Acc: 43.070\n", 1368 | "Epoch: 131\n", 1369 | "[TRAIN] Acc: 99.970\n", 1370 | "[TEST] Acc: 43.280\n", 1371 | "Epoch: 132\n", 1372 | "[TRAIN] Acc: 99.962\n", 1373 | "[TEST] Acc: 42.800\n", 1374 | "Epoch: 133\n", 1375 | "[TRAIN] Acc: 99.974\n", 1376 | "[TEST] Acc: 42.890\n", 1377 | "Epoch: 134\n", 1378 | "[TRAIN] Acc: 99.976\n", 1379 | "[TEST] Acc: 43.020\n", 1380 | "Epoch: 135\n", 1381 | "[TRAIN] Acc: 99.970\n", 1382 | "[TEST] Acc: 42.920\n", 1383 | "Epoch: 136\n", 1384 | "[TRAIN] Acc: 99.976\n", 1385 | "[TEST] Acc: 43.080\n", 1386 | "Epoch: 137\n", 1387 | "[TRAIN] Acc: 99.966\n", 1388 | "[TEST] Acc: 43.150\n", 1389 | "Epoch: 138\n", 1390 | "[TRAIN] Acc: 99.966\n", 1391 | "[TEST] Acc: 43.160\n", 1392 | "Epoch: 139\n", 1393 | "[TRAIN] Acc: 99.980\n", 1394 | "[TEST] Acc: 43.030\n", 1395 | "Epoch: 140\n", 1396 | "[TRAIN] Acc: 99.974\n", 1397 | "[TEST] Acc: 43.290\n", 1398 | "Epoch: 141\n", 1399 | "[TRAIN] Acc: 99.976\n", 1400 | "[TEST] Acc: 43.120\n", 1401 | "Epoch: 142\n", 1402 | "[TRAIN] Acc: 99.976\n", 1403 | "[TEST] Acc: 43.200\n", 1404 | "Epoch: 143\n", 1405 | "[TRAIN] Acc: 99.978\n", 1406 | "[TEST] Acc: 43.050\n", 1407 | "Epoch: 144\n", 1408 | "[TRAIN] Acc: 99.984\n", 1409 | "[TEST] Acc: 43.260\n", 1410 | "Epoch: 145\n", 1411 | "[TRAIN] Acc: 99.978\n", 1412 | "[TEST] Acc: 43.060\n", 1413 | "Epoch: 146\n", 1414 | "[TRAIN] Acc: 99.976\n", 1415 | "[TEST] Acc: 43.250\n", 1416 | "Epoch: 147\n", 1417 | "[TRAIN] Acc: 99.972\n", 1418 | "[TEST] Acc: 43.190\n", 1419 | "Epoch: 148\n", 1420 | "[TRAIN] Acc: 99.976\n", 1421 | "[TEST] Acc: 43.170\n", 1422 | "Epoch: 149\n", 1423 | "[TRAIN] Acc: 99.976\n", 1424 | "[TEST] Acc: 43.210\n", 1425 | "Epoch: 150\n", 1426 | "[TRAIN] Acc: 99.974\n", 1427 | "[TEST] Acc: 43.130\n", 1428 | "Epoch: 151\n", 1429 | "[TRAIN] Acc: 99.982\n", 1430 | "[TEST] Acc: 43.210\n", 1431 | "Epoch: 152\n", 1432 | "[TRAIN] Acc: 99.982\n", 1433 | "[TEST] Acc: 42.990\n", 1434 | "Epoch: 153\n", 1435 | "[TRAIN] Acc: 99.984\n", 1436 | "[TEST] Acc: 43.070\n", 1437 | "Epoch: 154\n", 1438 | "[TRAIN] Acc: 99.980\n", 1439 | "[TEST] Acc: 42.980\n", 1440 | "Epoch: 155\n", 1441 | "[TRAIN] Acc: 99.972\n", 1442 | "[TEST] Acc: 43.140\n", 1443 | "Epoch: 156\n", 1444 | "[TRAIN] Acc: 99.984\n", 1445 | "[TEST] Acc: 43.140\n", 1446 | "Epoch: 157\n", 1447 | "[TRAIN] Acc: 99.970\n", 1448 | "[TEST] Acc: 43.310\n", 1449 | "Epoch: 158\n", 1450 | "[TRAIN] Acc: 99.978\n", 1451 | "[TEST] Acc: 43.280\n", 1452 | "Epoch: 159\n", 1453 | "[TRAIN] Acc: 99.974\n", 1454 | "[TEST] Acc: 43.110\n", 1455 | "Epoch: 160\n", 1456 | "[TRAIN] Acc: 99.984\n", 1457 | "[TEST] Acc: 43.170\n", 1458 | "Epoch: 161\n", 1459 | "[TRAIN] Acc: 99.976\n", 1460 | "[TEST] Acc: 43.230\n", 1461 | "Epoch: 162\n", 1462 | "[TRAIN] Acc: 99.984\n", 1463 | "[TEST] Acc: 43.300\n", 1464 | "Epoch: 163\n", 1465 | "[TRAIN] Acc: 99.984\n", 1466 | "[TEST] Acc: 43.150\n", 1467 | "Epoch: 164\n", 1468 | "[TRAIN] Acc: 99.986\n", 1469 | "[TEST] Acc: 43.280\n", 1470 | "Epoch: 165\n", 1471 | "[TRAIN] Acc: 99.974\n", 1472 | "[TEST] Acc: 43.140\n", 1473 | "Epoch: 166\n" 1474 | ] 1475 | }, 1476 | { 1477 | "name": "stdout", 1478 | "output_type": "stream", 1479 | "text": [ 1480 | "[TRAIN] Acc: 99.978\n", 1481 | "[TEST] Acc: 43.320\n", 1482 | "Epoch: 167\n", 1483 | "[TRAIN] Acc: 99.978\n", 1484 | "[TEST] Acc: 43.220\n", 1485 | "Epoch: 168\n", 1486 | "[TRAIN] Acc: 99.986\n", 1487 | "[TEST] Acc: 43.010\n", 1488 | "Epoch: 169\n", 1489 | "[TRAIN] Acc: 99.984\n", 1490 | "[TEST] Acc: 43.100\n", 1491 | "Epoch: 170\n", 1492 | "[TRAIN] Acc: 99.982\n", 1493 | "[TEST] Acc: 43.190\n", 1494 | "Epoch: 171\n", 1495 | "[TRAIN] Acc: 99.976\n", 1496 | "[TEST] Acc: 43.040\n", 1497 | "Epoch: 172\n", 1498 | "[TRAIN] Acc: 99.972\n", 1499 | "[TEST] Acc: 43.310\n", 1500 | "Epoch: 173\n", 1501 | "[TRAIN] Acc: 99.994\n", 1502 | "[TEST] Acc: 43.260\n", 1503 | "Epoch: 174\n", 1504 | "[TRAIN] Acc: 99.976\n", 1505 | "[TEST] Acc: 43.130\n", 1506 | "Epoch: 175\n", 1507 | "[TRAIN] Acc: 99.978\n", 1508 | "[TEST] Acc: 43.100\n", 1509 | "Epoch: 176\n", 1510 | "[TRAIN] Acc: 99.972\n", 1511 | "[TEST] Acc: 43.130\n", 1512 | "Epoch: 177\n", 1513 | "[TRAIN] Acc: 99.978\n", 1514 | "[TEST] Acc: 43.180\n", 1515 | "Epoch: 178\n", 1516 | "[TRAIN] Acc: 99.986\n", 1517 | "[TEST] Acc: 42.980\n", 1518 | "Epoch: 179\n", 1519 | "[TRAIN] Acc: 99.978\n", 1520 | "[TEST] Acc: 43.300\n", 1521 | "Epoch: 180\n", 1522 | "[TRAIN] Acc: 99.982\n", 1523 | "[TEST] Acc: 43.100\n", 1524 | "Epoch: 181\n", 1525 | "[TRAIN] Acc: 99.976\n", 1526 | "[TEST] Acc: 43.200\n", 1527 | "Epoch: 182\n", 1528 | "[TRAIN] Acc: 99.984\n", 1529 | "[TEST] Acc: 43.070\n", 1530 | "Epoch: 183\n", 1531 | "[TRAIN] Acc: 99.986\n", 1532 | "[TEST] Acc: 43.140\n", 1533 | "Epoch: 184\n", 1534 | "[TRAIN] Acc: 99.982\n", 1535 | "[TEST] Acc: 43.040\n", 1536 | "Epoch: 185\n", 1537 | "[TRAIN] Acc: 99.986\n", 1538 | "[TEST] Acc: 43.120\n", 1539 | "Epoch: 186\n", 1540 | "[TRAIN] Acc: 99.980\n", 1541 | "[TEST] Acc: 43.240\n", 1542 | "Epoch: 187\n", 1543 | "[TRAIN] Acc: 99.978\n", 1544 | "[TEST] Acc: 43.220\n", 1545 | "Epoch: 188\n", 1546 | "[TRAIN] Acc: 99.980\n", 1547 | "[TEST] Acc: 43.210\n", 1548 | "Epoch: 189\n", 1549 | "[TRAIN] Acc: 99.986\n", 1550 | "[TEST] Acc: 43.110\n", 1551 | "Epoch: 190\n", 1552 | "[TRAIN] Acc: 99.978\n", 1553 | "[TEST] Acc: 43.230\n", 1554 | "Epoch: 191\n", 1555 | "[TRAIN] Acc: 99.980\n", 1556 | "[TEST] Acc: 43.080\n", 1557 | "Epoch: 192\n", 1558 | "[TRAIN] Acc: 99.982\n", 1559 | "[TEST] Acc: 43.270\n", 1560 | "Epoch: 193\n", 1561 | "[TRAIN] Acc: 99.974\n", 1562 | "[TEST] Acc: 43.080\n", 1563 | "Epoch: 194\n", 1564 | "[TRAIN] Acc: 99.986\n", 1565 | "[TEST] Acc: 42.880\n", 1566 | "Epoch: 195\n", 1567 | "[TRAIN] Acc: 99.966\n", 1568 | "[TEST] Acc: 42.970\n", 1569 | "Epoch: 196\n", 1570 | "[TRAIN] Acc: 99.980\n", 1571 | "[TEST] Acc: 42.950\n", 1572 | "Epoch: 197\n", 1573 | "[TRAIN] Acc: 99.980\n", 1574 | "[TEST] Acc: 43.310\n", 1575 | "Epoch: 198\n", 1576 | "[TRAIN] Acc: 99.976\n", 1577 | "[TEST] Acc: 43.110\n", 1578 | "Epoch: 199\n", 1579 | "[TRAIN] Acc: 99.978\n", 1580 | "[TEST] Acc: 43.110\n" 1581 | ] 1582 | } 1583 | ], 1584 | "source": [ 1585 | "for epoch in epochs:\n", 1586 | " optimizer = Adam(net.parameters(), lr=lr)\n", 1587 | " for _ in range(epoch):\n", 1588 | " train(count)\n", 1589 | " test(count)\n", 1590 | " count += 1\n", 1591 | " lr /= 10" 1592 | ] 1593 | }, 1594 | { 1595 | "cell_type": "code", 1596 | "execution_count": null, 1597 | "metadata": {}, 1598 | "outputs": [], 1599 | "source": [] 1600 | } 1601 | ], 1602 | "metadata": { 1603 | "kernelspec": { 1604 | "display_name": "Python (bayesian)", 1605 | "language": "python", 1606 | "name": "bayesian" 1607 | }, 1608 | "language_info": { 1609 | "codemirror_mode": { 1610 | "name": "ipython", 1611 | "version": 3 1612 | }, 1613 | "file_extension": ".py", 1614 | "mimetype": "text/x-python", 1615 | "name": "python", 1616 | "nbconvert_exporter": "python", 1617 | "pygments_lexer": "ipython3", 1618 | "version": "3.7.3" 1619 | } 1620 | }, 1621 | "nbformat": 4, 1622 | "nbformat_minor": 2 1623 | } 1624 | -------------------------------------------------------------------------------- /VGG_ReLU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.backends.cudnn as cudnn\n", 13 | "from torch.optim import Adam, SGD\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "\n", 17 | "import sys, os, math\n", 18 | "import argparse" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "lr=0.01\n", 28 | "data='cifar10'\n", 29 | "root='./data/'\n", 30 | "model='vgg'\n", 31 | "model_out='./checkpoint/cifar10_vgg_ReLU.pth'\n", 32 | "resume = False" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Files already downloaded and verified\n", 45 | "Files already downloaded and verified\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "if data == 'cifar10':\n", 51 | " nclass = 10\n", 52 | " img_width = 32\n", 53 | " transform_train = transforms.Compose([\n", 54 | "# transforms.RandomCrop(32, padding=4),\n", 55 | "# transforms.RandomHorizontalFlip(),\n", 56 | " transforms.ToTensor(),\n", 57 | " ])\n", 58 | " transform_test = transforms.Compose([\n", 59 | " transforms.ToTensor(),\n", 60 | " ])\n", 61 | " trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)\n", 62 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)\n", 63 | " testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)\n", 64 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)\n", 65 | " \n", 66 | "elif data == 'cifar100':\n", 67 | " nclass = 100\n", 68 | " img_width = 32\n", 69 | " transform_train = transforms.Compose([\n", 70 | "# transforms.RandomCrop(32, padding=4),\n", 71 | "# transforms.RandomHorizontalFlip(),\n", 72 | " transforms.ToTensor(),\n", 73 | " ])\n", 74 | " transform_test = transforms.Compose([\n", 75 | " transforms.ToTensor(),\n", 76 | " ])\n", 77 | " trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)\n", 78 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)\n", 79 | " testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)\n", 80 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)\n", 81 | " \n", 82 | "elif data == 'stl10':\n", 83 | " nclass = 10\n", 84 | " img_width = 32\n", 85 | " transform_train = transforms.Compose([\n", 86 | "# transforms.RandomCrop(32, padding=4),\n", 87 | "# transforms.RandomHorizontalFlip(),\n", 88 | " transforms.Resize((img_width,img_width)),\n", 89 | " transforms.ToTensor(),\n", 90 | " ])\n", 91 | " transform_test = transforms.Compose([\n", 92 | " transforms.Resize((img_width,img_width)),\n", 93 | " transforms.ToTensor(),\n", 94 | " ])\n", 95 | " trainset = torchvision.datasets.STL10(root=root, split='train', transform=transform_train, target_transform=None, download=True)\n", 96 | " testset = torchvision.datasets.STL10(root=root, split='test', transform=transform_test, target_transform=None, download=True)\n", 97 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)\n", 98 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 8, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "DataParallel(\n", 110 | " (module): VGG_ReLU(\n", 111 | " (features): Sequential(\n", 112 | " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 113 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 114 | " (2): ReLU(inplace)\n", 115 | " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 116 | " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 117 | " (5): ReLU(inplace)\n", 118 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 119 | " (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 120 | " (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 121 | " (9): ReLU(inplace)\n", 122 | " (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 123 | " (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 124 | " (12): ReLU(inplace)\n", 125 | " (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 126 | " (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 127 | " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 128 | " (16): ReLU(inplace)\n", 129 | " (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 130 | " (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 131 | " (19): ReLU(inplace)\n", 132 | " (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 133 | " (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 134 | " (22): ReLU(inplace)\n", 135 | " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 136 | " (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 137 | " (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 138 | " (26): ReLU(inplace)\n", 139 | " (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 140 | " (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 141 | " (29): ReLU(inplace)\n", 142 | " (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 143 | " (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 144 | " (32): ReLU(inplace)\n", 145 | " (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 146 | " (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 147 | " (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 148 | " (36): ReLU(inplace)\n", 149 | " (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 150 | " (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 151 | " (39): ReLU(inplace)\n", 152 | " (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 153 | " (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 154 | " (42): ReLU(inplace)\n", 155 | " (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 156 | " (44): Dropout(p=0.5)\n", 157 | " )\n", 158 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n", 159 | " )\n", 160 | ")" 161 | ] 162 | }, 163 | "execution_count": 8, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "if model == 'vgg':\n", 170 | " from models.vgg import VGG_ReLU\n", 171 | " net = nn.DataParallel(VGG_ReLU('VGG16', nclass, img_width=img_width).cuda())\n", 172 | " \n", 173 | "net" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 9, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "if resume:\n", 183 | " print(f'==> Resuming from {model_out}')\n", 184 | " net.load_state_dict(torch.load(model_out))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 10, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "cudnn.benchmark = True" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "criterion = nn.CrossEntropyLoss()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 19, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "def train(epoch):\n", 212 | " print('Epoch: %d' % epoch)\n", 213 | " net.train()\n", 214 | " train_loss = 0\n", 215 | " correct = 0\n", 216 | " total = 0\n", 217 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 218 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 219 | " optimizer.zero_grad()\n", 220 | " outputs, _ = net(inputs)\n", 221 | " loss = criterion(outputs, targets)\n", 222 | " loss.backward()\n", 223 | " optimizer.step()\n", 224 | " pred = torch.max(outputs, dim=1)[1]\n", 225 | " correct += torch.sum(pred.eq(targets)).item()\n", 226 | " total += targets.numel()\n", 227 | " print(f'[TRAIN] Acc: {100.*correct/total:.3f}')" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 20, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def test(epoch):\n", 237 | " net.eval()\n", 238 | " test_loss = 0\n", 239 | " correct = 0\n", 240 | " total = 0\n", 241 | " with torch.no_grad():\n", 242 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 243 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 244 | " outputs, _ = net(inputs)\n", 245 | " loss = criterion(outputs, targets)\n", 246 | " test_loss += loss.item()\n", 247 | " _, predicted = outputs.max(1)\n", 248 | " total += targets.size(0)\n", 249 | " correct += predicted.eq(targets).sum().item()\n", 250 | " print(f'[TEST] Acc: {100.*correct/total:.3f}')\n", 251 | "\n", 252 | " # Save checkpoint after each epoch\n", 253 | " torch.save(net.state_dict(), model_out)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 21, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "if data == 'cifar10':\n", 263 | " epochs = [50, 50, 50, 50]\n", 264 | "elif data == 'cifar100':\n", 265 | " epochs = [50, 50, 50, 50]\n", 266 | "elif data == 'stl10':\n", 267 | " epochs = [50, 50, 50, 50]" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 22, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "count = 0" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 23, 282 | "metadata": { 283 | "scrolled": true 284 | }, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "Epoch: 0\n", 291 | "[TRAIN] Acc: 39.858\n", 292 | "CPU times: user 8.15 s, sys: 3.68 s, total: 11.8 s\n", 293 | "Wall time: 13.2 s\n", 294 | "[TEST] Acc: 58.800\n", 295 | "CPU times: user 724 ms, sys: 512 ms, total: 1.24 s\n", 296 | "Wall time: 1.35 s\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "for epoch in epochs:\n", 302 | " optimizer = Adam(net.parameters(), lr=lr)\n", 303 | " for _ in range(epoch):\n", 304 | " %time train(count)\n", 305 | " %time test(count)\n", 306 | " count += 1\n", 307 | " lr /= 10" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "lr=0.01\n", 317 | "data='stl10'\n", 318 | "root='./data/'\n", 319 | "model='vgg'\n", 320 | "model_out='./checkpoint/stl10_vgg_ReLU.pth'\n", 321 | "resume = False" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 12, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "name": "stdout", 331 | "output_type": "stream", 332 | "text": [ 333 | "Epoch: 0\n", 334 | "[TRAIN] Acc: 10.520\n", 335 | "[TEST] Acc: 12.537\n", 336 | "Epoch: 1\n", 337 | "[TRAIN] Acc: 11.360\n", 338 | "[TEST] Acc: 10.725\n", 339 | "Epoch: 2\n", 340 | "[TRAIN] Acc: 11.120\n", 341 | "[TEST] Acc: 12.338\n", 342 | "Epoch: 3\n", 343 | "[TRAIN] Acc: 12.340\n", 344 | "[TEST] Acc: 13.550\n", 345 | "Epoch: 4\n", 346 | "[TRAIN] Acc: 15.680\n", 347 | "[TEST] Acc: 16.250\n", 348 | "Epoch: 5\n", 349 | "[TRAIN] Acc: 18.020\n", 350 | "[TEST] Acc: 16.225\n", 351 | "Epoch: 6\n", 352 | "[TRAIN] Acc: 22.100\n", 353 | "[TEST] Acc: 15.400\n", 354 | "Epoch: 7\n", 355 | "[TRAIN] Acc: 23.760\n", 356 | "[TEST] Acc: 23.137\n", 357 | "Epoch: 8\n", 358 | "[TRAIN] Acc: 25.460\n", 359 | "[TEST] Acc: 17.062\n", 360 | "Epoch: 9\n", 361 | "[TRAIN] Acc: 25.860\n", 362 | "[TEST] Acc: 17.038\n", 363 | "Epoch: 10\n", 364 | "[TRAIN] Acc: 27.720\n", 365 | "[TEST] Acc: 22.488\n", 366 | "Epoch: 11\n", 367 | "[TRAIN] Acc: 28.540\n", 368 | "[TEST] Acc: 19.038\n", 369 | "Epoch: 12\n", 370 | "[TRAIN] Acc: 28.640\n", 371 | "[TEST] Acc: 21.575\n", 372 | "Epoch: 13\n", 373 | "[TRAIN] Acc: 30.200\n", 374 | "[TEST] Acc: 28.875\n", 375 | "Epoch: 14\n", 376 | "[TRAIN] Acc: 31.180\n", 377 | "[TEST] Acc: 31.887\n", 378 | "Epoch: 15\n", 379 | "[TRAIN] Acc: 34.300\n", 380 | "[TEST] Acc: 26.600\n", 381 | "Epoch: 16\n", 382 | "[TRAIN] Acc: 36.260\n", 383 | "[TEST] Acc: 36.513\n", 384 | "Epoch: 17\n", 385 | "[TRAIN] Acc: 38.140\n", 386 | "[TEST] Acc: 35.388\n", 387 | "Epoch: 18\n", 388 | "[TRAIN] Acc: 39.220\n", 389 | "[TEST] Acc: 33.462\n", 390 | "Epoch: 19\n", 391 | "[TRAIN] Acc: 42.460\n", 392 | "[TEST] Acc: 36.725\n", 393 | "Epoch: 20\n", 394 | "[TRAIN] Acc: 44.300\n", 395 | "[TEST] Acc: 39.100\n", 396 | "Epoch: 21\n", 397 | "[TRAIN] Acc: 48.400\n", 398 | "[TEST] Acc: 40.150\n", 399 | "Epoch: 22\n", 400 | "[TRAIN] Acc: 51.240\n", 401 | "[TEST] Acc: 43.062\n", 402 | "Epoch: 23\n", 403 | "[TRAIN] Acc: 53.240\n", 404 | "[TEST] Acc: 44.175\n", 405 | "Epoch: 24\n", 406 | "[TRAIN] Acc: 55.720\n", 407 | "[TEST] Acc: 39.562\n", 408 | "Epoch: 25\n", 409 | "[TRAIN] Acc: 58.700\n", 410 | "[TEST] Acc: 41.650\n", 411 | "Epoch: 26\n", 412 | "[TRAIN] Acc: 61.200\n", 413 | "[TEST] Acc: 46.125\n", 414 | "Epoch: 27\n", 415 | "[TRAIN] Acc: 66.120\n", 416 | "[TEST] Acc: 44.750\n", 417 | "Epoch: 28\n", 418 | "[TRAIN] Acc: 66.560\n", 419 | "[TEST] Acc: 43.150\n", 420 | "Epoch: 29\n", 421 | "[TRAIN] Acc: 69.440\n", 422 | "[TEST] Acc: 42.725\n", 423 | "Epoch: 30\n", 424 | "[TRAIN] Acc: 71.460\n", 425 | "[TEST] Acc: 50.337\n", 426 | "Epoch: 31\n", 427 | "[TRAIN] Acc: 75.540\n", 428 | "[TEST] Acc: 47.850\n", 429 | "Epoch: 32\n", 430 | "[TRAIN] Acc: 77.500\n", 431 | "[TEST] Acc: 49.288\n", 432 | "Epoch: 33\n", 433 | "[TRAIN] Acc: 80.260\n", 434 | "[TEST] Acc: 47.425\n", 435 | "Epoch: 34\n", 436 | "[TRAIN] Acc: 82.940\n", 437 | "[TEST] Acc: 51.987\n", 438 | "Epoch: 35\n", 439 | "[TRAIN] Acc: 85.680\n", 440 | "[TEST] Acc: 49.163\n", 441 | "Epoch: 36\n", 442 | "[TRAIN] Acc: 86.440\n", 443 | "[TEST] Acc: 53.325\n", 444 | "Epoch: 37\n", 445 | "[TRAIN] Acc: 89.260\n", 446 | "[TEST] Acc: 50.450\n", 447 | "Epoch: 38\n", 448 | "[TRAIN] Acc: 90.100\n", 449 | "[TEST] Acc: 48.700\n", 450 | "Epoch: 39\n", 451 | "[TRAIN] Acc: 91.460\n", 452 | "[TEST] Acc: 52.163\n", 453 | "Epoch: 40\n", 454 | "[TRAIN] Acc: 93.160\n", 455 | "[TEST] Acc: 52.612\n", 456 | "Epoch: 41\n", 457 | "[TRAIN] Acc: 91.440\n", 458 | "[TEST] Acc: 47.525\n", 459 | "Epoch: 42\n", 460 | "[TRAIN] Acc: 92.740\n", 461 | "[TEST] Acc: 49.025\n", 462 | "Epoch: 43\n", 463 | "[TRAIN] Acc: 94.740\n", 464 | "[TEST] Acc: 56.675\n", 465 | "Epoch: 44\n", 466 | "[TRAIN] Acc: 94.840\n", 467 | "[TEST] Acc: 57.112\n", 468 | "Epoch: 45\n", 469 | "[TRAIN] Acc: 94.580\n", 470 | "[TEST] Acc: 54.900\n", 471 | "Epoch: 46\n", 472 | "[TRAIN] Acc: 95.580\n", 473 | "[TEST] Acc: 57.763\n", 474 | "Epoch: 47\n", 475 | "[TRAIN] Acc: 97.140\n", 476 | "[TEST] Acc: 54.913\n", 477 | "Epoch: 48\n", 478 | "[TRAIN] Acc: 97.840\n", 479 | "[TEST] Acc: 53.850\n", 480 | "Epoch: 49\n", 481 | "[TRAIN] Acc: 97.600\n", 482 | "[TEST] Acc: 53.175\n", 483 | "Epoch: 50\n", 484 | "[TRAIN] Acc: 98.820\n", 485 | "[TEST] Acc: 59.850\n", 486 | "Epoch: 51\n", 487 | "[TRAIN] Acc: 99.780\n", 488 | "[TEST] Acc: 60.200\n", 489 | "Epoch: 52\n", 490 | "[TRAIN] Acc: 99.980\n", 491 | "[TEST] Acc: 60.038\n", 492 | "Epoch: 53\n", 493 | "[TRAIN] Acc: 100.000\n", 494 | "[TEST] Acc: 60.425\n", 495 | "Epoch: 54\n", 496 | "[TRAIN] Acc: 99.980\n", 497 | "[TEST] Acc: 60.325\n", 498 | "Epoch: 55\n", 499 | "[TRAIN] Acc: 100.000\n", 500 | "[TEST] Acc: 60.462\n", 501 | "Epoch: 56\n", 502 | "[TRAIN] Acc: 99.920\n", 503 | "[TEST] Acc: 60.250\n", 504 | "Epoch: 57\n", 505 | "[TRAIN] Acc: 100.000\n", 506 | "[TEST] Acc: 60.350\n", 507 | "Epoch: 58\n", 508 | "[TRAIN] Acc: 100.000\n", 509 | "[TEST] Acc: 60.438\n", 510 | "Epoch: 59\n", 511 | "[TRAIN] Acc: 100.000\n", 512 | "[TEST] Acc: 60.163\n", 513 | "Epoch: 60\n", 514 | "[TRAIN] Acc: 100.000\n", 515 | "[TEST] Acc: 60.163\n", 516 | "Epoch: 61\n", 517 | "[TRAIN] Acc: 99.980\n", 518 | "[TEST] Acc: 60.275\n", 519 | "Epoch: 62\n", 520 | "[TRAIN] Acc: 100.000\n", 521 | "[TEST] Acc: 60.050\n", 522 | "Epoch: 63\n", 523 | "[TRAIN] Acc: 100.000\n", 524 | "[TEST] Acc: 60.038\n", 525 | "Epoch: 64\n", 526 | "[TRAIN] Acc: 100.000\n", 527 | "[TEST] Acc: 60.200\n", 528 | "Epoch: 65\n", 529 | "[TRAIN] Acc: 100.000\n", 530 | "[TEST] Acc: 60.150\n", 531 | "Epoch: 66\n", 532 | "[TRAIN] Acc: 100.000\n", 533 | "[TEST] Acc: 60.300\n", 534 | "Epoch: 67\n", 535 | "[TRAIN] Acc: 100.000\n", 536 | "[TEST] Acc: 60.112\n", 537 | "Epoch: 68\n", 538 | "[TRAIN] Acc: 100.000\n", 539 | "[TEST] Acc: 60.325\n", 540 | "Epoch: 69\n", 541 | "[TRAIN] Acc: 100.000\n", 542 | "[TEST] Acc: 60.138\n", 543 | "Epoch: 70\n", 544 | "[TRAIN] Acc: 100.000\n", 545 | "[TEST] Acc: 60.050\n", 546 | "Epoch: 71\n", 547 | "[TRAIN] Acc: 100.000\n", 548 | "[TEST] Acc: 60.025\n", 549 | "Epoch: 72\n", 550 | "[TRAIN] Acc: 100.000\n", 551 | "[TEST] Acc: 60.013\n", 552 | "Epoch: 73\n", 553 | "[TRAIN] Acc: 100.000\n", 554 | "[TEST] Acc: 60.200\n", 555 | "Epoch: 74\n", 556 | "[TRAIN] Acc: 99.960\n", 557 | "[TEST] Acc: 59.987\n", 558 | "Epoch: 75\n", 559 | "[TRAIN] Acc: 99.680\n", 560 | "[TEST] Acc: 57.750\n", 561 | "Epoch: 76\n", 562 | "[TRAIN] Acc: 99.340\n", 563 | "[TEST] Acc: 58.200\n", 564 | "Epoch: 77\n", 565 | "[TRAIN] Acc: 99.400\n", 566 | "[TEST] Acc: 58.288\n", 567 | "Epoch: 78\n", 568 | "[TRAIN] Acc: 99.420\n", 569 | "[TEST] Acc: 58.575\n", 570 | "Epoch: 79\n", 571 | "[TRAIN] Acc: 99.780\n", 572 | "[TEST] Acc: 58.888\n", 573 | "Epoch: 80\n", 574 | "[TRAIN] Acc: 99.860\n", 575 | "[TEST] Acc: 59.575\n", 576 | "Epoch: 81\n", 577 | "[TRAIN] Acc: 99.880\n", 578 | "[TEST] Acc: 59.312\n", 579 | "Epoch: 82\n", 580 | "[TRAIN] Acc: 99.820\n", 581 | "[TEST] Acc: 59.450\n", 582 | "Epoch: 83\n", 583 | "[TRAIN] Acc: 99.900\n", 584 | "[TEST] Acc: 60.013\n", 585 | "Epoch: 84\n", 586 | "[TRAIN] Acc: 99.940\n", 587 | "[TEST] Acc: 59.900\n", 588 | "Epoch: 85\n", 589 | "[TRAIN] Acc: 99.960\n", 590 | "[TEST] Acc: 59.925\n", 591 | "Epoch: 86\n", 592 | "[TRAIN] Acc: 99.920\n", 593 | "[TEST] Acc: 60.112\n", 594 | "Epoch: 87\n", 595 | "[TRAIN] Acc: 99.980\n", 596 | "[TEST] Acc: 59.950\n", 597 | "Epoch: 88\n", 598 | "[TRAIN] Acc: 99.920\n", 599 | "[TEST] Acc: 59.538\n", 600 | "Epoch: 89\n", 601 | "[TRAIN] Acc: 99.980\n", 602 | "[TEST] Acc: 60.275\n", 603 | "Epoch: 90\n", 604 | "[TRAIN] Acc: 100.000\n", 605 | "[TEST] Acc: 60.288\n", 606 | "Epoch: 91\n", 607 | "[TRAIN] Acc: 99.940\n", 608 | "[TEST] Acc: 60.250\n", 609 | "Epoch: 92\n", 610 | "[TRAIN] Acc: 99.980\n", 611 | "[TEST] Acc: 59.812\n", 612 | "Epoch: 93\n", 613 | "[TRAIN] Acc: 99.860\n", 614 | "[TEST] Acc: 58.650\n", 615 | "Epoch: 94\n", 616 | "[TRAIN] Acc: 99.840\n", 617 | "[TEST] Acc: 60.300\n", 618 | "Epoch: 95\n", 619 | "[TRAIN] Acc: 99.920\n", 620 | "[TEST] Acc: 60.013\n", 621 | "Epoch: 96\n", 622 | "[TRAIN] Acc: 99.980\n", 623 | "[TEST] Acc: 59.900\n", 624 | "Epoch: 97\n", 625 | "[TRAIN] Acc: 99.960\n", 626 | "[TEST] Acc: 59.700\n", 627 | "Epoch: 98\n", 628 | "[TRAIN] Acc: 99.980\n", 629 | "[TEST] Acc: 60.438\n", 630 | "Epoch: 99\n", 631 | "[TRAIN] Acc: 100.000\n", 632 | "[TEST] Acc: 60.737\n", 633 | "Epoch: 100\n", 634 | "[TRAIN] Acc: 100.000\n", 635 | "[TEST] Acc: 60.513\n", 636 | "Epoch: 101\n", 637 | "[TRAIN] Acc: 100.000\n", 638 | "[TEST] Acc: 60.675\n", 639 | "Epoch: 102\n", 640 | "[TRAIN] Acc: 100.000\n", 641 | "[TEST] Acc: 60.700\n", 642 | "Epoch: 103\n", 643 | "[TRAIN] Acc: 100.000\n", 644 | "[TEST] Acc: 60.312\n", 645 | "Epoch: 104\n", 646 | "[TRAIN] Acc: 100.000\n", 647 | "[TEST] Acc: 60.163\n", 648 | "Epoch: 105\n", 649 | "[TRAIN] Acc: 100.000\n", 650 | "[TEST] Acc: 60.100\n", 651 | "Epoch: 106\n", 652 | "[TRAIN] Acc: 100.000\n", 653 | "[TEST] Acc: 60.500\n", 654 | "Epoch: 107\n", 655 | "[TRAIN] Acc: 100.000\n", 656 | "[TEST] Acc: 60.225\n", 657 | "Epoch: 108\n", 658 | "[TRAIN] Acc: 100.000\n", 659 | "[TEST] Acc: 60.288\n", 660 | "Epoch: 109\n", 661 | "[TRAIN] Acc: 100.000\n", 662 | "[TEST] Acc: 60.612\n", 663 | "Epoch: 110\n", 664 | "[TRAIN] Acc: 100.000\n", 665 | "[TEST] Acc: 60.375\n", 666 | "Epoch: 111\n", 667 | "[TRAIN] Acc: 100.000\n", 668 | "[TEST] Acc: 60.462\n", 669 | "Epoch: 112\n", 670 | "[TRAIN] Acc: 100.000\n", 671 | "[TEST] Acc: 60.388\n", 672 | "Epoch: 113\n", 673 | "[TRAIN] Acc: 100.000\n", 674 | "[TEST] Acc: 60.425\n", 675 | "Epoch: 114\n", 676 | "[TRAIN] Acc: 100.000\n", 677 | "[TEST] Acc: 60.225\n", 678 | "Epoch: 115\n", 679 | "[TRAIN] Acc: 100.000\n", 680 | "[TEST] Acc: 60.462\n", 681 | "Epoch: 116\n", 682 | "[TRAIN] Acc: 100.000\n", 683 | "[TEST] Acc: 60.400\n", 684 | "Epoch: 117\n", 685 | "[TRAIN] Acc: 100.000\n", 686 | "[TEST] Acc: 60.462\n", 687 | "Epoch: 118\n", 688 | "[TRAIN] Acc: 100.000\n", 689 | "[TEST] Acc: 60.462\n", 690 | "Epoch: 119\n", 691 | "[TRAIN] Acc: 100.000\n", 692 | "[TEST] Acc: 60.538\n", 693 | "Epoch: 120\n", 694 | "[TRAIN] Acc: 100.000\n", 695 | "[TEST] Acc: 60.538\n", 696 | "Epoch: 121\n", 697 | "[TRAIN] Acc: 100.000\n", 698 | "[TEST] Acc: 60.487\n", 699 | "Epoch: 122\n", 700 | "[TRAIN] Acc: 100.000\n", 701 | "[TEST] Acc: 60.812\n", 702 | "Epoch: 123\n", 703 | "[TRAIN] Acc: 100.000\n", 704 | "[TEST] Acc: 60.438\n", 705 | "Epoch: 124\n", 706 | "[TRAIN] Acc: 100.000\n", 707 | "[TEST] Acc: 60.250\n", 708 | "Epoch: 125\n", 709 | "[TRAIN] Acc: 100.000\n", 710 | "[TEST] Acc: 60.663\n", 711 | "Epoch: 126\n", 712 | "[TRAIN] Acc: 100.000\n", 713 | "[TEST] Acc: 60.612\n", 714 | "Epoch: 127\n", 715 | "[TRAIN] Acc: 100.000\n", 716 | "[TEST] Acc: 60.475\n", 717 | "Epoch: 128\n", 718 | "[TRAIN] Acc: 100.000\n", 719 | "[TEST] Acc: 60.800\n", 720 | "Epoch: 129\n", 721 | "[TRAIN] Acc: 100.000\n", 722 | "[TEST] Acc: 60.712\n", 723 | "Epoch: 130\n", 724 | "[TRAIN] Acc: 100.000\n", 725 | "[TEST] Acc: 60.712\n", 726 | "Epoch: 131\n", 727 | "[TRAIN] Acc: 100.000\n", 728 | "[TEST] Acc: 60.763\n", 729 | "Epoch: 132\n", 730 | "[TRAIN] Acc: 100.000\n", 731 | "[TEST] Acc: 60.725\n", 732 | "Epoch: 133\n", 733 | "[TRAIN] Acc: 100.000\n", 734 | "[TEST] Acc: 60.600\n", 735 | "Epoch: 134\n", 736 | "[TRAIN] Acc: 100.000\n", 737 | "[TEST] Acc: 60.625\n", 738 | "Epoch: 135\n", 739 | "[TRAIN] Acc: 100.000\n", 740 | "[TEST] Acc: 60.788\n", 741 | "Epoch: 136\n", 742 | "[TRAIN] Acc: 100.000\n", 743 | "[TEST] Acc: 60.700\n", 744 | "Epoch: 137\n", 745 | "[TRAIN] Acc: 100.000\n", 746 | "[TEST] Acc: 60.538\n", 747 | "Epoch: 138\n", 748 | "[TRAIN] Acc: 100.000\n", 749 | "[TEST] Acc: 60.600\n", 750 | "Epoch: 139\n", 751 | "[TRAIN] Acc: 100.000\n", 752 | "[TEST] Acc: 60.737\n", 753 | "Epoch: 140\n", 754 | "[TRAIN] Acc: 100.000\n", 755 | "[TEST] Acc: 60.737\n", 756 | "Epoch: 141\n", 757 | "[TRAIN] Acc: 100.000\n", 758 | "[TEST] Acc: 60.612\n", 759 | "Epoch: 142\n", 760 | "[TRAIN] Acc: 100.000\n", 761 | "[TEST] Acc: 60.388\n", 762 | "Epoch: 143\n", 763 | "[TRAIN] Acc: 100.000\n", 764 | "[TEST] Acc: 60.212\n", 765 | "Epoch: 144\n", 766 | "[TRAIN] Acc: 100.000\n", 767 | "[TEST] Acc: 60.413\n", 768 | "Epoch: 145\n", 769 | "[TRAIN] Acc: 100.000\n", 770 | "[TEST] Acc: 60.325\n", 771 | "Epoch: 146\n", 772 | "[TRAIN] Acc: 100.000\n", 773 | "[TEST] Acc: 60.550\n", 774 | "Epoch: 147\n", 775 | "[TRAIN] Acc: 100.000\n", 776 | "[TEST] Acc: 60.438\n", 777 | "Epoch: 148\n", 778 | "[TRAIN] Acc: 100.000\n", 779 | "[TEST] Acc: 60.525\n", 780 | "Epoch: 149\n", 781 | "[TRAIN] Acc: 100.000\n", 782 | "[TEST] Acc: 60.650\n", 783 | "Epoch: 150\n", 784 | "[TRAIN] Acc: 100.000\n", 785 | "[TEST] Acc: 60.612\n", 786 | "Epoch: 151\n", 787 | "[TRAIN] Acc: 100.000\n", 788 | "[TEST] Acc: 60.650\n", 789 | "Epoch: 152\n", 790 | "[TRAIN] Acc: 100.000\n", 791 | "[TEST] Acc: 60.538\n", 792 | "Epoch: 153\n", 793 | "[TRAIN] Acc: 100.000\n", 794 | "[TEST] Acc: 60.675\n", 795 | "Epoch: 154\n", 796 | "[TRAIN] Acc: 100.000\n", 797 | "[TEST] Acc: 60.575\n", 798 | "Epoch: 155\n", 799 | "[TRAIN] Acc: 100.000\n", 800 | "[TEST] Acc: 60.650\n", 801 | "Epoch: 156\n", 802 | "[TRAIN] Acc: 100.000\n", 803 | "[TEST] Acc: 60.688\n", 804 | "Epoch: 157\n", 805 | "[TRAIN] Acc: 100.000\n", 806 | "[TEST] Acc: 60.562\n", 807 | "Epoch: 158\n", 808 | "[TRAIN] Acc: 100.000\n", 809 | "[TEST] Acc: 60.737\n", 810 | "Epoch: 159\n", 811 | "[TRAIN] Acc: 100.000\n", 812 | "[TEST] Acc: 60.800\n", 813 | "Epoch: 160\n", 814 | "[TRAIN] Acc: 100.000\n", 815 | "[TEST] Acc: 60.625\n", 816 | "Epoch: 161\n", 817 | "[TRAIN] Acc: 100.000\n", 818 | "[TEST] Acc: 60.638\n", 819 | "Epoch: 162\n", 820 | "[TRAIN] Acc: 100.000\n", 821 | "[TEST] Acc: 60.600\n", 822 | "Epoch: 163\n", 823 | "[TRAIN] Acc: 100.000\n", 824 | "[TEST] Acc: 60.688\n", 825 | "Epoch: 164\n", 826 | "[TRAIN] Acc: 100.000\n" 827 | ] 828 | }, 829 | { 830 | "name": "stdout", 831 | "output_type": "stream", 832 | "text": [ 833 | "[TEST] Acc: 60.550\n", 834 | "Epoch: 165\n", 835 | "[TRAIN] Acc: 100.000\n", 836 | "[TEST] Acc: 60.600\n", 837 | "Epoch: 166\n", 838 | "[TRAIN] Acc: 100.000\n", 839 | "[TEST] Acc: 60.675\n", 840 | "Epoch: 167\n", 841 | "[TRAIN] Acc: 100.000\n", 842 | "[TEST] Acc: 60.562\n", 843 | "Epoch: 168\n", 844 | "[TRAIN] Acc: 100.000\n", 845 | "[TEST] Acc: 60.712\n", 846 | "Epoch: 169\n", 847 | "[TRAIN] Acc: 100.000\n", 848 | "[TEST] Acc: 60.638\n", 849 | "Epoch: 170\n", 850 | "[TRAIN] Acc: 100.000\n", 851 | "[TEST] Acc: 60.750\n", 852 | "Epoch: 171\n", 853 | "[TRAIN] Acc: 100.000\n", 854 | "[TEST] Acc: 60.825\n", 855 | "Epoch: 172\n", 856 | "[TRAIN] Acc: 100.000\n", 857 | "[TEST] Acc: 60.850\n", 858 | "Epoch: 173\n", 859 | "[TRAIN] Acc: 100.000\n", 860 | "[TEST] Acc: 60.700\n", 861 | "Epoch: 174\n", 862 | "[TRAIN] Acc: 100.000\n", 863 | "[TEST] Acc: 60.700\n", 864 | "Epoch: 175\n", 865 | "[TRAIN] Acc: 100.000\n", 866 | "[TEST] Acc: 60.663\n", 867 | "Epoch: 176\n", 868 | "[TRAIN] Acc: 100.000\n", 869 | "[TEST] Acc: 60.737\n", 870 | "Epoch: 177\n", 871 | "[TRAIN] Acc: 100.000\n", 872 | "[TEST] Acc: 60.487\n", 873 | "Epoch: 178\n", 874 | "[TRAIN] Acc: 100.000\n", 875 | "[TEST] Acc: 60.737\n", 876 | "Epoch: 179\n", 877 | "[TRAIN] Acc: 100.000\n", 878 | "[TEST] Acc: 60.763\n", 879 | "Epoch: 180\n", 880 | "[TRAIN] Acc: 100.000\n", 881 | "[TEST] Acc: 60.750\n", 882 | "Epoch: 181\n", 883 | "[TRAIN] Acc: 100.000\n", 884 | "[TEST] Acc: 60.700\n", 885 | "Epoch: 182\n", 886 | "[TRAIN] Acc: 100.000\n", 887 | "[TEST] Acc: 60.600\n", 888 | "Epoch: 183\n", 889 | "[TRAIN] Acc: 100.000\n", 890 | "[TEST] Acc: 60.737\n", 891 | "Epoch: 184\n", 892 | "[TRAIN] Acc: 100.000\n", 893 | "[TEST] Acc: 60.737\n", 894 | "Epoch: 185\n", 895 | "[TRAIN] Acc: 100.000\n", 896 | "[TEST] Acc: 60.650\n", 897 | "Epoch: 186\n", 898 | "[TRAIN] Acc: 100.000\n", 899 | "[TEST] Acc: 60.750\n", 900 | "Epoch: 187\n", 901 | "[TRAIN] Acc: 100.000\n", 902 | "[TEST] Acc: 60.788\n", 903 | "Epoch: 188\n", 904 | "[TRAIN] Acc: 100.000\n", 905 | "[TEST] Acc: 60.763\n", 906 | "Epoch: 189\n", 907 | "[TRAIN] Acc: 100.000\n", 908 | "[TEST] Acc: 60.600\n", 909 | "Epoch: 190\n", 910 | "[TRAIN] Acc: 100.000\n", 911 | "[TEST] Acc: 60.650\n", 912 | "Epoch: 191\n", 913 | "[TRAIN] Acc: 100.000\n", 914 | "[TEST] Acc: 60.688\n", 915 | "Epoch: 192\n", 916 | "[TRAIN] Acc: 100.000\n", 917 | "[TEST] Acc: 60.800\n", 918 | "Epoch: 193\n", 919 | "[TRAIN] Acc: 100.000\n", 920 | "[TEST] Acc: 60.675\n", 921 | "Epoch: 194\n", 922 | "[TRAIN] Acc: 100.000\n", 923 | "[TEST] Acc: 60.763\n", 924 | "Epoch: 195\n", 925 | "[TRAIN] Acc: 100.000\n", 926 | "[TEST] Acc: 60.763\n", 927 | "Epoch: 196\n", 928 | "[TRAIN] Acc: 100.000\n", 929 | "[TEST] Acc: 60.800\n", 930 | "Epoch: 197\n", 931 | "[TRAIN] Acc: 100.000\n", 932 | "[TEST] Acc: 60.500\n", 933 | "Epoch: 198\n", 934 | "[TRAIN] Acc: 100.000\n", 935 | "[TEST] Acc: 60.750\n", 936 | "Epoch: 199\n", 937 | "[TRAIN] Acc: 100.000\n", 938 | "[TEST] Acc: 60.913\n" 939 | ] 940 | } 941 | ], 942 | "source": [ 943 | "for epoch in epochs:\n", 944 | " optimizer = Adam(net.parameters(), lr=lr)\n", 945 | " for _ in range(epoch):\n", 946 | " train(count)\n", 947 | " test(count)\n", 948 | " count += 1\n", 949 | " lr /= 10" 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "execution_count": 2, 955 | "metadata": {}, 956 | "outputs": [], 957 | "source": [ 958 | "lr=0.01\n", 959 | "data='cifar100'\n", 960 | "root='./data/'\n", 961 | "model='vgg'\n", 962 | "model_out='./checkpoint/cifar100_vgg_ReLU.pth'\n", 963 | "resume = False" 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "execution_count": 12, 969 | "metadata": {}, 970 | "outputs": [ 971 | { 972 | "name": "stdout", 973 | "output_type": "stream", 974 | "text": [ 975 | "Epoch: 0\n", 976 | "[TRAIN] Acc: 1.368\n", 977 | "[TEST] Acc: 2.140\n", 978 | "Epoch: 1\n", 979 | "[TRAIN] Acc: 2.760\n", 980 | "[TEST] Acc: 4.120\n", 981 | "Epoch: 2\n", 982 | "[TRAIN] Acc: 5.608\n", 983 | "[TEST] Acc: 5.680\n", 984 | "Epoch: 3\n", 985 | "[TRAIN] Acc: 8.578\n", 986 | "[TEST] Acc: 6.810\n", 987 | "Epoch: 4\n", 988 | "[TRAIN] Acc: 12.042\n", 989 | "[TEST] Acc: 12.300\n", 990 | "Epoch: 5\n", 991 | "[TRAIN] Acc: 16.670\n", 992 | "[TEST] Acc: 15.630\n", 993 | "Epoch: 6\n", 994 | "[TRAIN] Acc: 23.162\n", 995 | "[TEST] Acc: 19.630\n", 996 | "Epoch: 7\n", 997 | "[TRAIN] Acc: 28.002\n", 998 | "[TEST] Acc: 27.200\n", 999 | "Epoch: 8\n", 1000 | "[TRAIN] Acc: 33.262\n", 1001 | "[TEST] Acc: 27.840\n", 1002 | "Epoch: 9\n", 1003 | "[TRAIN] Acc: 37.922\n", 1004 | "[TEST] Acc: 30.170\n", 1005 | "Epoch: 10\n", 1006 | "[TRAIN] Acc: 42.210\n", 1007 | "[TEST] Acc: 32.640\n", 1008 | "Epoch: 11\n", 1009 | "[TRAIN] Acc: 46.446\n", 1010 | "[TEST] Acc: 35.670\n", 1011 | "Epoch: 12\n", 1012 | "[TRAIN] Acc: 50.478\n", 1013 | "[TEST] Acc: 38.810\n", 1014 | "Epoch: 13\n", 1015 | "[TRAIN] Acc: 54.634\n", 1016 | "[TEST] Acc: 38.920\n", 1017 | "Epoch: 14\n", 1018 | "[TRAIN] Acc: 58.306\n", 1019 | "[TEST] Acc: 39.060\n", 1020 | "Epoch: 15\n", 1021 | "[TRAIN] Acc: 61.852\n", 1022 | "[TEST] Acc: 41.480\n", 1023 | "Epoch: 16\n", 1024 | "[TRAIN] Acc: 65.646\n", 1025 | "[TEST] Acc: 42.990\n", 1026 | "Epoch: 17\n", 1027 | "[TRAIN] Acc: 69.024\n", 1028 | "[TEST] Acc: 41.180\n", 1029 | "Epoch: 18\n", 1030 | "[TRAIN] Acc: 72.144\n", 1031 | "[TEST] Acc: 44.390\n", 1032 | "Epoch: 19\n", 1033 | "[TRAIN] Acc: 75.098\n", 1034 | "[TEST] Acc: 44.820\n", 1035 | "Epoch: 20\n", 1036 | "[TRAIN] Acc: 78.286\n", 1037 | "[TEST] Acc: 41.610\n", 1038 | "Epoch: 21\n", 1039 | "[TRAIN] Acc: 80.022\n", 1040 | "[TEST] Acc: 45.260\n", 1041 | "Epoch: 22\n", 1042 | "[TRAIN] Acc: 82.516\n", 1043 | "[TEST] Acc: 45.280\n", 1044 | "Epoch: 23\n", 1045 | "[TRAIN] Acc: 84.522\n", 1046 | "[TEST] Acc: 44.260\n", 1047 | "Epoch: 24\n", 1048 | "[TRAIN] Acc: 86.320\n", 1049 | "[TEST] Acc: 45.860\n", 1050 | "Epoch: 25\n", 1051 | "[TRAIN] Acc: 87.398\n", 1052 | "[TEST] Acc: 44.110\n", 1053 | "Epoch: 26\n", 1054 | "[TRAIN] Acc: 88.418\n", 1055 | "[TEST] Acc: 45.950\n", 1056 | "Epoch: 27\n", 1057 | "[TRAIN] Acc: 89.944\n", 1058 | "[TEST] Acc: 44.060\n", 1059 | "Epoch: 28\n", 1060 | "[TRAIN] Acc: 90.550\n", 1061 | "[TEST] Acc: 46.000\n", 1062 | "Epoch: 29\n", 1063 | "[TRAIN] Acc: 91.666\n", 1064 | "[TEST] Acc: 46.130\n", 1065 | "Epoch: 30\n", 1066 | "[TRAIN] Acc: 91.758\n", 1067 | "[TEST] Acc: 47.080\n", 1068 | "Epoch: 31\n", 1069 | "[TRAIN] Acc: 92.146\n", 1070 | "[TEST] Acc: 44.580\n", 1071 | "Epoch: 32\n", 1072 | "[TRAIN] Acc: 93.332\n", 1073 | "[TEST] Acc: 46.770\n", 1074 | "Epoch: 33\n", 1075 | "[TRAIN] Acc: 93.428\n", 1076 | "[TEST] Acc: 45.930\n", 1077 | "Epoch: 34\n", 1078 | "[TRAIN] Acc: 93.750\n", 1079 | "[TEST] Acc: 45.700\n", 1080 | "Epoch: 35\n", 1081 | "[TRAIN] Acc: 93.674\n", 1082 | "[TEST] Acc: 45.920\n", 1083 | "Epoch: 36\n", 1084 | "[TRAIN] Acc: 93.838\n", 1085 | "[TEST] Acc: 46.050\n", 1086 | "Epoch: 37\n", 1087 | "[TRAIN] Acc: 94.356\n", 1088 | "[TEST] Acc: 47.170\n", 1089 | "Epoch: 38\n", 1090 | "[TRAIN] Acc: 94.646\n", 1091 | "[TEST] Acc: 46.360\n", 1092 | "Epoch: 39\n", 1093 | "[TRAIN] Acc: 95.138\n", 1094 | "[TEST] Acc: 46.480\n", 1095 | "Epoch: 40\n", 1096 | "[TRAIN] Acc: 94.976\n", 1097 | "[TEST] Acc: 44.330\n", 1098 | "Epoch: 41\n", 1099 | "[TRAIN] Acc: 95.432\n", 1100 | "[TEST] Acc: 46.040\n", 1101 | "Epoch: 42\n", 1102 | "[TRAIN] Acc: 94.842\n", 1103 | "[TEST] Acc: 45.940\n", 1104 | "Epoch: 43\n", 1105 | "[TRAIN] Acc: 95.722\n", 1106 | "[TEST] Acc: 47.120\n", 1107 | "Epoch: 44\n", 1108 | "[TRAIN] Acc: 95.536\n", 1109 | "[TEST] Acc: 46.750\n", 1110 | "Epoch: 45\n", 1111 | "[TRAIN] Acc: 95.678\n", 1112 | "[TEST] Acc: 45.530\n", 1113 | "Epoch: 46\n", 1114 | "[TRAIN] Acc: 95.576\n", 1115 | "[TEST] Acc: 46.480\n", 1116 | "Epoch: 47\n", 1117 | "[TRAIN] Acc: 96.134\n", 1118 | "[TEST] Acc: 47.770\n", 1119 | "Epoch: 48\n", 1120 | "[TRAIN] Acc: 95.662\n", 1121 | "[TEST] Acc: 47.220\n", 1122 | "Epoch: 49\n", 1123 | "[TRAIN] Acc: 96.006\n", 1124 | "[TEST] Acc: 47.680\n", 1125 | "Epoch: 50\n", 1126 | "[TRAIN] Acc: 98.940\n", 1127 | "[TEST] Acc: 50.570\n", 1128 | "Epoch: 51\n", 1129 | "[TRAIN] Acc: 99.756\n", 1130 | "[TEST] Acc: 50.750\n", 1131 | "Epoch: 52\n", 1132 | "[TRAIN] Acc: 99.910\n", 1133 | "[TEST] Acc: 50.870\n", 1134 | "Epoch: 53\n", 1135 | "[TRAIN] Acc: 99.908\n", 1136 | "[TEST] Acc: 50.460\n", 1137 | "Epoch: 54\n", 1138 | "[TRAIN] Acc: 99.932\n", 1139 | "[TEST] Acc: 50.870\n", 1140 | "Epoch: 55\n", 1141 | "[TRAIN] Acc: 99.942\n", 1142 | "[TEST] Acc: 51.020\n", 1143 | "Epoch: 56\n", 1144 | "[TRAIN] Acc: 99.954\n", 1145 | "[TEST] Acc: 50.830\n", 1146 | "Epoch: 57\n", 1147 | "[TRAIN] Acc: 99.968\n", 1148 | "[TEST] Acc: 50.850\n", 1149 | "Epoch: 58\n", 1150 | "[TRAIN] Acc: 99.960\n", 1151 | "[TEST] Acc: 50.780\n", 1152 | "Epoch: 59\n", 1153 | "[TRAIN] Acc: 99.950\n", 1154 | "[TEST] Acc: 50.890\n", 1155 | "Epoch: 60\n", 1156 | "[TRAIN] Acc: 99.952\n", 1157 | "[TEST] Acc: 50.740\n", 1158 | "Epoch: 61\n", 1159 | "[TRAIN] Acc: 99.950\n", 1160 | "[TEST] Acc: 50.890\n", 1161 | "Epoch: 62\n", 1162 | "[TRAIN] Acc: 99.946\n", 1163 | "[TEST] Acc: 50.320\n", 1164 | "Epoch: 63\n", 1165 | "[TRAIN] Acc: 99.946\n", 1166 | "[TEST] Acc: 50.360\n", 1167 | "Epoch: 64\n", 1168 | "[TRAIN] Acc: 99.952\n", 1169 | "[TEST] Acc: 50.320\n", 1170 | "Epoch: 65\n", 1171 | "[TRAIN] Acc: 99.958\n", 1172 | "[TEST] Acc: 50.520\n", 1173 | "Epoch: 66\n", 1174 | "[TRAIN] Acc: 99.958\n", 1175 | "[TEST] Acc: 50.570\n", 1176 | "Epoch: 67\n", 1177 | "[TRAIN] Acc: 99.916\n", 1178 | "[TEST] Acc: 50.170\n", 1179 | "Epoch: 68\n", 1180 | "[TRAIN] Acc: 99.948\n", 1181 | "[TEST] Acc: 50.550\n", 1182 | "Epoch: 69\n", 1183 | "[TRAIN] Acc: 99.934\n", 1184 | "[TEST] Acc: 50.110\n", 1185 | "Epoch: 70\n", 1186 | "[TRAIN] Acc: 99.906\n", 1187 | "[TEST] Acc: 50.440\n", 1188 | "Epoch: 71\n", 1189 | "[TRAIN] Acc: 99.878\n", 1190 | "[TEST] Acc: 50.380\n", 1191 | "Epoch: 72\n", 1192 | "[TRAIN] Acc: 99.904\n", 1193 | "[TEST] Acc: 50.750\n", 1194 | "Epoch: 73\n", 1195 | "[TRAIN] Acc: 99.940\n", 1196 | "[TEST] Acc: 50.900\n", 1197 | "Epoch: 74\n", 1198 | "[TRAIN] Acc: 99.964\n", 1199 | "[TEST] Acc: 50.530\n", 1200 | "Epoch: 75\n", 1201 | "[TRAIN] Acc: 99.958\n", 1202 | "[TEST] Acc: 50.690\n", 1203 | "Epoch: 76\n", 1204 | "[TRAIN] Acc: 99.940\n", 1205 | "[TEST] Acc: 50.710\n", 1206 | "Epoch: 77\n", 1207 | "[TRAIN] Acc: 99.918\n", 1208 | "[TEST] Acc: 50.690\n", 1209 | "Epoch: 78\n", 1210 | "[TRAIN] Acc: 99.944\n", 1211 | "[TEST] Acc: 50.540\n", 1212 | "Epoch: 79\n", 1213 | "[TRAIN] Acc: 99.932\n", 1214 | "[TEST] Acc: 50.200\n", 1215 | "Epoch: 80\n", 1216 | "[TRAIN] Acc: 99.948\n", 1217 | "[TEST] Acc: 50.370\n", 1218 | "Epoch: 81\n", 1219 | "[TRAIN] Acc: 99.850\n", 1220 | "[TEST] Acc: 50.420\n", 1221 | "Epoch: 82\n", 1222 | "[TRAIN] Acc: 99.902\n", 1223 | "[TEST] Acc: 50.370\n", 1224 | "Epoch: 83\n", 1225 | "[TRAIN] Acc: 99.938\n", 1226 | "[TEST] Acc: 50.600\n", 1227 | "Epoch: 84\n", 1228 | "[TRAIN] Acc: 99.938\n", 1229 | "[TEST] Acc: 50.860\n", 1230 | "Epoch: 85\n", 1231 | "[TRAIN] Acc: 99.954\n", 1232 | "[TEST] Acc: 50.900\n", 1233 | "Epoch: 86\n", 1234 | "[TRAIN] Acc: 99.932\n", 1235 | "[TEST] Acc: 50.670\n", 1236 | "Epoch: 87\n", 1237 | "[TRAIN] Acc: 99.878\n", 1238 | "[TEST] Acc: 50.610\n", 1239 | "Epoch: 88\n", 1240 | "[TRAIN] Acc: 99.944\n", 1241 | "[TEST] Acc: 50.440\n", 1242 | "Epoch: 89\n", 1243 | "[TRAIN] Acc: 99.936\n", 1244 | "[TEST] Acc: 50.450\n", 1245 | "Epoch: 90\n", 1246 | "[TRAIN] Acc: 99.946\n", 1247 | "[TEST] Acc: 50.840\n", 1248 | "Epoch: 91\n", 1249 | "[TRAIN] Acc: 99.952\n", 1250 | "[TEST] Acc: 50.830\n", 1251 | "Epoch: 92\n", 1252 | "[TRAIN] Acc: 99.942\n", 1253 | "[TEST] Acc: 50.440\n", 1254 | "Epoch: 93\n", 1255 | "[TRAIN] Acc: 99.938\n", 1256 | "[TEST] Acc: 50.830\n", 1257 | "Epoch: 94\n", 1258 | "[TRAIN] Acc: 99.958\n", 1259 | "[TEST] Acc: 50.770\n", 1260 | "Epoch: 95\n", 1261 | "[TRAIN] Acc: 99.950\n", 1262 | "[TEST] Acc: 50.950\n", 1263 | "Epoch: 96\n", 1264 | "[TRAIN] Acc: 99.958\n", 1265 | "[TEST] Acc: 50.480\n", 1266 | "Epoch: 97\n", 1267 | "[TRAIN] Acc: 99.936\n", 1268 | "[TEST] Acc: 50.550\n", 1269 | "Epoch: 98\n", 1270 | "[TRAIN] Acc: 99.936\n", 1271 | "[TEST] Acc: 50.750\n", 1272 | "Epoch: 99\n", 1273 | "[TRAIN] Acc: 99.912\n", 1274 | "[TEST] Acc: 50.550\n", 1275 | "Epoch: 100\n", 1276 | "[TRAIN] Acc: 99.948\n", 1277 | "[TEST] Acc: 50.860\n", 1278 | "Epoch: 101\n", 1279 | "[TRAIN] Acc: 99.974\n", 1280 | "[TEST] Acc: 50.920\n", 1281 | "Epoch: 102\n", 1282 | "[TRAIN] Acc: 99.970\n", 1283 | "[TEST] Acc: 51.060\n", 1284 | "Epoch: 103\n", 1285 | "[TRAIN] Acc: 99.978\n", 1286 | "[TEST] Acc: 51.010\n", 1287 | "Epoch: 104\n", 1288 | "[TRAIN] Acc: 99.974\n", 1289 | "[TEST] Acc: 51.160\n", 1290 | "Epoch: 105\n", 1291 | "[TRAIN] Acc: 99.976\n", 1292 | "[TEST] Acc: 51.020\n", 1293 | "Epoch: 106\n", 1294 | "[TRAIN] Acc: 99.968\n", 1295 | "[TEST] Acc: 50.810\n", 1296 | "Epoch: 107\n", 1297 | "[TRAIN] Acc: 99.972\n", 1298 | "[TEST] Acc: 50.970\n", 1299 | "Epoch: 108\n", 1300 | "[TRAIN] Acc: 99.972\n", 1301 | "[TEST] Acc: 50.980\n", 1302 | "Epoch: 109\n", 1303 | "[TRAIN] Acc: 99.964\n", 1304 | "[TEST] Acc: 51.010\n", 1305 | "Epoch: 110\n", 1306 | "[TRAIN] Acc: 99.968\n", 1307 | "[TEST] Acc: 51.080\n", 1308 | "Epoch: 111\n", 1309 | "[TRAIN] Acc: 99.976\n", 1310 | "[TEST] Acc: 51.020\n", 1311 | "Epoch: 112\n", 1312 | "[TRAIN] Acc: 99.964\n", 1313 | "[TEST] Acc: 51.080\n", 1314 | "Epoch: 113\n", 1315 | "[TRAIN] Acc: 99.978\n", 1316 | "[TEST] Acc: 51.150\n", 1317 | "Epoch: 114\n", 1318 | "[TRAIN] Acc: 99.976\n", 1319 | "[TEST] Acc: 51.130\n", 1320 | "Epoch: 115\n", 1321 | "[TRAIN] Acc: 99.978\n", 1322 | "[TEST] Acc: 51.260\n", 1323 | "Epoch: 116\n", 1324 | "[TRAIN] Acc: 99.974\n", 1325 | "[TEST] Acc: 51.030\n", 1326 | "Epoch: 117\n", 1327 | "[TRAIN] Acc: 99.978\n", 1328 | "[TEST] Acc: 51.200\n", 1329 | "Epoch: 118\n", 1330 | "[TRAIN] Acc: 99.970\n", 1331 | "[TEST] Acc: 51.080\n", 1332 | "Epoch: 119\n", 1333 | "[TRAIN] Acc: 99.984\n", 1334 | "[TEST] Acc: 51.120\n", 1335 | "Epoch: 120\n", 1336 | "[TRAIN] Acc: 99.978\n", 1337 | "[TEST] Acc: 51.050\n", 1338 | "Epoch: 121\n", 1339 | "[TRAIN] Acc: 99.980\n", 1340 | "[TEST] Acc: 51.080\n", 1341 | "Epoch: 122\n", 1342 | "[TRAIN] Acc: 99.968\n", 1343 | "[TEST] Acc: 51.220\n", 1344 | "Epoch: 123\n", 1345 | "[TRAIN] Acc: 99.980\n", 1346 | "[TEST] Acc: 51.090\n", 1347 | "Epoch: 124\n", 1348 | "[TRAIN] Acc: 99.984\n", 1349 | "[TEST] Acc: 51.050\n", 1350 | "Epoch: 125\n", 1351 | "[TRAIN] Acc: 99.982\n", 1352 | "[TEST] Acc: 50.940\n", 1353 | "Epoch: 126\n", 1354 | "[TRAIN] Acc: 99.980\n", 1355 | "[TEST] Acc: 51.250\n", 1356 | "Epoch: 127\n", 1357 | "[TRAIN] Acc: 99.978\n", 1358 | "[TEST] Acc: 51.180\n", 1359 | "Epoch: 128\n", 1360 | "[TRAIN] Acc: 99.976\n", 1361 | "[TEST] Acc: 51.360\n", 1362 | "Epoch: 129\n", 1363 | "[TRAIN] Acc: 99.980\n", 1364 | "[TEST] Acc: 51.150\n", 1365 | "Epoch: 130\n", 1366 | "[TRAIN] Acc: 99.982\n", 1367 | "[TEST] Acc: 51.290\n", 1368 | "Epoch: 131\n", 1369 | "[TRAIN] Acc: 99.974\n", 1370 | "[TEST] Acc: 51.450\n", 1371 | "Epoch: 132\n", 1372 | "[TRAIN] Acc: 99.980\n", 1373 | "[TEST] Acc: 51.360\n", 1374 | "Epoch: 133\n", 1375 | "[TRAIN] Acc: 99.982\n", 1376 | "[TEST] Acc: 51.370\n", 1377 | "Epoch: 134\n", 1378 | "[TRAIN] Acc: 99.984\n", 1379 | "[TEST] Acc: 51.200\n", 1380 | "Epoch: 135\n", 1381 | "[TRAIN] Acc: 99.970\n", 1382 | "[TEST] Acc: 51.250\n", 1383 | "Epoch: 136\n", 1384 | "[TRAIN] Acc: 99.982\n", 1385 | "[TEST] Acc: 51.380\n", 1386 | "Epoch: 137\n", 1387 | "[TRAIN] Acc: 99.976\n", 1388 | "[TEST] Acc: 51.190\n", 1389 | "Epoch: 138\n", 1390 | "[TRAIN] Acc: 99.976\n", 1391 | "[TEST] Acc: 51.290\n", 1392 | "Epoch: 139\n", 1393 | "[TRAIN] Acc: 99.976\n", 1394 | "[TEST] Acc: 51.000\n", 1395 | "Epoch: 140\n", 1396 | "[TRAIN] Acc: 99.982\n", 1397 | "[TEST] Acc: 51.200\n", 1398 | "Epoch: 141\n", 1399 | "[TRAIN] Acc: 99.978\n", 1400 | "[TEST] Acc: 51.150\n", 1401 | "Epoch: 142\n", 1402 | "[TRAIN] Acc: 99.972\n", 1403 | "[TEST] Acc: 51.020\n", 1404 | "Epoch: 143\n", 1405 | "[TRAIN] Acc: 99.982\n", 1406 | "[TEST] Acc: 51.230\n", 1407 | "Epoch: 144\n", 1408 | "[TRAIN] Acc: 99.978\n", 1409 | "[TEST] Acc: 51.260\n", 1410 | "Epoch: 145\n", 1411 | "[TRAIN] Acc: 99.984\n", 1412 | "[TEST] Acc: 51.180\n", 1413 | "Epoch: 146\n", 1414 | "[TRAIN] Acc: 99.986\n", 1415 | "[TEST] Acc: 51.330\n", 1416 | "Epoch: 147\n", 1417 | "[TRAIN] Acc: 99.978\n", 1418 | "[TEST] Acc: 51.280\n", 1419 | "Epoch: 148\n", 1420 | "[TRAIN] Acc: 99.964\n", 1421 | "[TEST] Acc: 51.010\n", 1422 | "Epoch: 149\n", 1423 | "[TRAIN] Acc: 99.976\n", 1424 | "[TEST] Acc: 51.390\n", 1425 | "Epoch: 150\n", 1426 | "[TRAIN] Acc: 99.984\n", 1427 | "[TEST] Acc: 51.160\n", 1428 | "Epoch: 151\n", 1429 | "[TRAIN] Acc: 99.984\n", 1430 | "[TEST] Acc: 51.100\n", 1431 | "Epoch: 152\n", 1432 | "[TRAIN] Acc: 99.986\n", 1433 | "[TEST] Acc: 51.280\n", 1434 | "Epoch: 153\n", 1435 | "[TRAIN] Acc: 99.970\n", 1436 | "[TEST] Acc: 51.310\n", 1437 | "Epoch: 154\n", 1438 | "[TRAIN] Acc: 99.982\n", 1439 | "[TEST] Acc: 51.290\n", 1440 | "Epoch: 155\n", 1441 | "[TRAIN] Acc: 99.988\n", 1442 | "[TEST] Acc: 51.410\n", 1443 | "Epoch: 156\n", 1444 | "[TRAIN] Acc: 99.978\n", 1445 | "[TEST] Acc: 51.300\n", 1446 | "Epoch: 157\n", 1447 | "[TRAIN] Acc: 99.978\n", 1448 | "[TEST] Acc: 51.180\n", 1449 | "Epoch: 158\n", 1450 | "[TRAIN] Acc: 99.980\n", 1451 | "[TEST] Acc: 51.330\n", 1452 | "Epoch: 159\n", 1453 | "[TRAIN] Acc: 99.974\n", 1454 | "[TEST] Acc: 51.290\n", 1455 | "Epoch: 160\n", 1456 | "[TRAIN] Acc: 99.978\n", 1457 | "[TEST] Acc: 51.250\n", 1458 | "Epoch: 161\n", 1459 | "[TRAIN] Acc: 99.980\n", 1460 | "[TEST] Acc: 51.250\n", 1461 | "Epoch: 162\n", 1462 | "[TRAIN] Acc: 99.976\n", 1463 | "[TEST] Acc: 51.310\n", 1464 | "Epoch: 163\n", 1465 | "[TRAIN] Acc: 99.978\n", 1466 | "[TEST] Acc: 51.410\n", 1467 | "Epoch: 164\n", 1468 | "[TRAIN] Acc: 99.984\n", 1469 | "[TEST] Acc: 51.280\n", 1470 | "Epoch: 165\n", 1471 | "[TRAIN] Acc: 99.978\n", 1472 | "[TEST] Acc: 51.400\n", 1473 | "Epoch: 166\n" 1474 | ] 1475 | }, 1476 | { 1477 | "name": "stdout", 1478 | "output_type": "stream", 1479 | "text": [ 1480 | "[TRAIN] Acc: 99.984\n", 1481 | "[TEST] Acc: 51.500\n", 1482 | "Epoch: 167\n", 1483 | "[TRAIN] Acc: 99.984\n", 1484 | "[TEST] Acc: 51.160\n", 1485 | "Epoch: 168\n", 1486 | "[TRAIN] Acc: 99.984\n", 1487 | "[TEST] Acc: 51.310\n", 1488 | "Epoch: 169\n", 1489 | "[TRAIN] Acc: 99.980\n", 1490 | "[TEST] Acc: 51.300\n", 1491 | "Epoch: 170\n", 1492 | "[TRAIN] Acc: 99.982\n", 1493 | "[TEST] Acc: 51.410\n", 1494 | "Epoch: 171\n", 1495 | "[TRAIN] Acc: 99.984\n", 1496 | "[TEST] Acc: 51.350\n", 1497 | "Epoch: 172\n", 1498 | "[TRAIN] Acc: 99.978\n", 1499 | "[TEST] Acc: 51.240\n", 1500 | "Epoch: 173\n", 1501 | "[TRAIN] Acc: 99.988\n", 1502 | "[TEST] Acc: 51.270\n", 1503 | "Epoch: 174\n", 1504 | "[TRAIN] Acc: 99.978\n", 1505 | "[TEST] Acc: 51.310\n", 1506 | "Epoch: 175\n", 1507 | "[TRAIN] Acc: 99.982\n", 1508 | "[TEST] Acc: 51.350\n", 1509 | "Epoch: 176\n", 1510 | "[TRAIN] Acc: 99.974\n", 1511 | "[TEST] Acc: 51.310\n", 1512 | "Epoch: 177\n", 1513 | "[TRAIN] Acc: 99.978\n", 1514 | "[TEST] Acc: 51.400\n", 1515 | "Epoch: 178\n", 1516 | "[TRAIN] Acc: 99.982\n", 1517 | "[TEST] Acc: 51.120\n", 1518 | "Epoch: 179\n", 1519 | "[TRAIN] Acc: 99.970\n", 1520 | "[TEST] Acc: 51.280\n", 1521 | "Epoch: 180\n", 1522 | "[TRAIN] Acc: 99.988\n", 1523 | "[TEST] Acc: 51.110\n", 1524 | "Epoch: 181\n", 1525 | "[TRAIN] Acc: 99.980\n", 1526 | "[TEST] Acc: 51.390\n", 1527 | "Epoch: 182\n", 1528 | "[TRAIN] Acc: 99.980\n", 1529 | "[TEST] Acc: 51.380\n", 1530 | "Epoch: 183\n", 1531 | "[TRAIN] Acc: 99.976\n", 1532 | "[TEST] Acc: 51.250\n", 1533 | "Epoch: 184\n", 1534 | "[TRAIN] Acc: 99.982\n", 1535 | "[TEST] Acc: 51.360\n", 1536 | "Epoch: 185\n", 1537 | "[TRAIN] Acc: 99.974\n", 1538 | "[TEST] Acc: 51.350\n", 1539 | "Epoch: 186\n", 1540 | "[TRAIN] Acc: 99.980\n", 1541 | "[TEST] Acc: 51.400\n", 1542 | "Epoch: 187\n", 1543 | "[TRAIN] Acc: 99.980\n", 1544 | "[TEST] Acc: 51.390\n", 1545 | "Epoch: 188\n", 1546 | "[TRAIN] Acc: 99.986\n", 1547 | "[TEST] Acc: 51.410\n", 1548 | "Epoch: 189\n", 1549 | "[TRAIN] Acc: 99.976\n", 1550 | "[TEST] Acc: 51.330\n", 1551 | "Epoch: 190\n", 1552 | "[TRAIN] Acc: 99.980\n", 1553 | "[TEST] Acc: 51.290\n", 1554 | "Epoch: 191\n", 1555 | "[TRAIN] Acc: 99.982\n", 1556 | "[TEST] Acc: 51.290\n", 1557 | "Epoch: 192\n", 1558 | "[TRAIN] Acc: 99.988\n", 1559 | "[TEST] Acc: 51.210\n", 1560 | "Epoch: 193\n", 1561 | "[TRAIN] Acc: 99.974\n", 1562 | "[TEST] Acc: 51.490\n", 1563 | "Epoch: 194\n", 1564 | "[TRAIN] Acc: 99.978\n", 1565 | "[TEST] Acc: 51.220\n", 1566 | "Epoch: 195\n", 1567 | "[TRAIN] Acc: 99.972\n", 1568 | "[TEST] Acc: 51.370\n", 1569 | "Epoch: 196\n", 1570 | "[TRAIN] Acc: 99.982\n", 1571 | "[TEST] Acc: 51.360\n", 1572 | "Epoch: 197\n", 1573 | "[TRAIN] Acc: 99.980\n", 1574 | "[TEST] Acc: 51.360\n", 1575 | "Epoch: 198\n", 1576 | "[TRAIN] Acc: 99.986\n", 1577 | "[TEST] Acc: 51.430\n", 1578 | "Epoch: 199\n", 1579 | "[TRAIN] Acc: 99.974\n", 1580 | "[TEST] Acc: 51.090\n" 1581 | ] 1582 | } 1583 | ], 1584 | "source": [ 1585 | "for epoch in epochs:\n", 1586 | " optimizer = Adam(net.parameters(), lr=lr)\n", 1587 | " for _ in range(epoch):\n", 1588 | " train(count)\n", 1589 | " test(count)\n", 1590 | " count += 1\n", 1591 | " lr /= 10" 1592 | ] 1593 | }, 1594 | { 1595 | "cell_type": "code", 1596 | "execution_count": null, 1597 | "metadata": {}, 1598 | "outputs": [], 1599 | "source": [] 1600 | } 1601 | ], 1602 | "metadata": { 1603 | "kernelspec": { 1604 | "display_name": "Python (bayesian)", 1605 | "language": "python", 1606 | "name": "bayesian" 1607 | }, 1608 | "language_info": { 1609 | "codemirror_mode": { 1610 | "name": "ipython", 1611 | "version": 3 1612 | }, 1613 | "file_extension": ".py", 1614 | "mimetype": "text/x-python", 1615 | "name": "python", 1616 | "nbconvert_exporter": "python", 1617 | "pygments_lexer": "ipython3", 1618 | "version": "3.7.3" 1619 | } 1620 | }, 1621 | "nbformat": 4, 1622 | "nbformat_minor": 2 1623 | } 1624 | -------------------------------------------------------------------------------- /VGG_Sigmoid.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.backends.cudnn as cudnn\n", 13 | "from torch.optim import Adam, SGD\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "\n", 17 | "import sys, os, math\n", 18 | "import argparse" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "lr=0.01\n", 28 | "data='cifar10'\n", 29 | "root='./data/'\n", 30 | "model='vgg'\n", 31 | "model_out='./checkpoint/cifar10_vgg_Sigmoid.pth'\n", 32 | "resume = False" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Files already downloaded and verified\n", 45 | "Files already downloaded and verified\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "if data == 'cifar10':\n", 51 | " nclass = 10\n", 52 | " img_width = 32\n", 53 | " transform_train = transforms.Compose([\n", 54 | " transforms.RandomCrop(32, padding=4),\n", 55 | " transforms.RandomHorizontalFlip(),\n", 56 | " transforms.ToTensor(),\n", 57 | " ])\n", 58 | " transform_test = transforms.Compose([\n", 59 | " transforms.ToTensor(),\n", 60 | " ])\n", 61 | " trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)\n", 62 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n", 63 | " testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)\n", 64 | " testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "DataParallel(\n", 76 | " (module): VGG_Sigmoid(\n", 77 | " (features): Sequential(\n", 78 | " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 79 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 80 | " (2): Sigmoid()\n", 81 | " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 82 | " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (5): Sigmoid()\n", 84 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 85 | " (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 86 | " (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 87 | " (9): Sigmoid()\n", 88 | " (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 89 | " (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (12): Sigmoid()\n", 91 | " (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 92 | " (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 93 | " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 94 | " (16): Sigmoid()\n", 95 | " (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 96 | " (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 97 | " (19): Sigmoid()\n", 98 | " (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 99 | " (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 100 | " (22): Sigmoid()\n", 101 | " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 102 | " (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 103 | " (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 104 | " (26): Sigmoid()\n", 105 | " (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 106 | " (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 107 | " (29): Sigmoid()\n", 108 | " (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 109 | " (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 110 | " (32): Sigmoid()\n", 111 | " (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 112 | " (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 113 | " (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 114 | " (36): Sigmoid()\n", 115 | " (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 116 | " (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 117 | " (39): Sigmoid()\n", 118 | " (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 119 | " (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 120 | " (42): Sigmoid()\n", 121 | " (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 122 | " (44): AvgPool2d(kernel_size=1, stride=1, padding=0)\n", 123 | " )\n", 124 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n", 125 | " )\n", 126 | ")" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "if model == 'vgg':\n", 136 | " from models.vgg import VGG_Sigmoid\n", 137 | " net = nn.DataParallel(VGG_Sigmoid('VGG16', nclass, img_width=img_width).cuda())\n", 138 | " \n", 139 | "net" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "if resume:\n", 149 | " print(f'==> Resuming from {model_out}')\n", 150 | " net.load_state_dict(torch.load(model_out))" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 6, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "cudnn.benchmark = True" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "criterion = nn.CrossEntropyLoss()" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 8, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "def train(epoch):\n", 178 | " print('Epoch: %d' % epoch)\n", 179 | " net.train()\n", 180 | " train_loss = 0\n", 181 | " correct = 0\n", 182 | " total = 0\n", 183 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 184 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 185 | " optimizer.zero_grad()\n", 186 | " outputs, _ = net(inputs)\n", 187 | " loss = criterion(outputs, targets)\n", 188 | " loss.backward()\n", 189 | " optimizer.step()\n", 190 | " pred = torch.max(outputs, dim=1)[1]\n", 191 | " correct += torch.sum(pred.eq(targets)).item()\n", 192 | " total += targets.numel()\n", 193 | " print(f'[TRAIN] Acc: {100.*correct/total:.3f}')" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 9, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "def test(epoch):\n", 203 | " net.eval()\n", 204 | " test_loss = 0\n", 205 | " correct = 0\n", 206 | " total = 0\n", 207 | " with torch.no_grad():\n", 208 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 209 | " inputs, targets = inputs.cuda(), targets.cuda()\n", 210 | " outputs, _ = net(inputs)\n", 211 | " loss = criterion(outputs, targets)\n", 212 | " test_loss += loss.item()\n", 213 | " _, predicted = outputs.max(1)\n", 214 | " total += targets.size(0)\n", 215 | " correct += predicted.eq(targets).sum().item()\n", 216 | " print(f'[TEST] Acc: {100.*correct/total:.3f}')\n", 217 | "\n", 218 | " # Save checkpoint after each epoch\n", 219 | " torch.save(net.state_dict(), model_out)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 10, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "if data == 'cifar10':\n", 229 | " epochs = [30, 20, 10]" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 11, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "count = 0" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 12, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "Epoch: 0\n", 251 | "[TRAIN] Acc: 10.134\n", 252 | "[TEST] Acc: 10.000\n", 253 | "Epoch: 1\n", 254 | "[TRAIN] Acc: 9.892\n", 255 | "[TEST] Acc: 10.000\n", 256 | "Epoch: 2\n", 257 | "[TRAIN] Acc: 10.144\n", 258 | "[TEST] Acc: 10.000\n", 259 | "Epoch: 3\n", 260 | "[TRAIN] Acc: 10.064\n", 261 | "[TEST] Acc: 10.000\n", 262 | "Epoch: 4\n", 263 | "[TRAIN] Acc: 9.952\n", 264 | "[TEST] Acc: 10.000\n", 265 | "Epoch: 5\n", 266 | "[TRAIN] Acc: 9.806\n", 267 | "[TEST] Acc: 10.000\n", 268 | "Epoch: 6\n", 269 | "[TRAIN] Acc: 9.954\n", 270 | "[TEST] Acc: 10.000\n", 271 | "Epoch: 7\n", 272 | "[TRAIN] Acc: 10.056\n", 273 | "[TEST] Acc: 10.000\n", 274 | "Epoch: 8\n", 275 | "[TRAIN] Acc: 10.066\n", 276 | "[TEST] Acc: 10.000\n", 277 | "Epoch: 9\n", 278 | "[TRAIN] Acc: 9.846\n", 279 | "[TEST] Acc: 10.000\n", 280 | "Epoch: 10\n", 281 | "[TRAIN] Acc: 10.130\n", 282 | "[TEST] Acc: 10.000\n", 283 | "Epoch: 11\n", 284 | "[TRAIN] Acc: 9.782\n", 285 | "[TEST] Acc: 10.000\n", 286 | "Epoch: 12\n", 287 | "[TRAIN] Acc: 9.916\n", 288 | "[TEST] Acc: 10.000\n", 289 | "Epoch: 13\n", 290 | "[TRAIN] Acc: 9.894\n", 291 | "[TEST] Acc: 10.000\n", 292 | "Epoch: 14\n", 293 | "[TRAIN] Acc: 9.910\n", 294 | "[TEST] Acc: 10.000\n", 295 | "Epoch: 15\n", 296 | "[TRAIN] Acc: 9.946\n", 297 | "[TEST] Acc: 10.000\n", 298 | "Epoch: 16\n", 299 | "[TRAIN] Acc: 9.912\n", 300 | "[TEST] Acc: 10.000\n", 301 | "Epoch: 17\n", 302 | "[TRAIN] Acc: 9.986\n", 303 | "[TEST] Acc: 10.000\n", 304 | "Epoch: 18\n", 305 | "[TRAIN] Acc: 9.800\n", 306 | "[TEST] Acc: 10.000\n", 307 | "Epoch: 19\n", 308 | "[TRAIN] Acc: 9.830\n", 309 | "[TEST] Acc: 10.000\n", 310 | "Epoch: 20\n", 311 | "[TRAIN] Acc: 10.036\n", 312 | "[TEST] Acc: 10.000\n", 313 | "Epoch: 21\n", 314 | "[TRAIN] Acc: 9.836\n", 315 | "[TEST] Acc: 10.000\n", 316 | "Epoch: 22\n", 317 | "[TRAIN] Acc: 10.176\n", 318 | "[TEST] Acc: 10.000\n", 319 | "Epoch: 23\n", 320 | "[TRAIN] Acc: 9.904\n", 321 | "[TEST] Acc: 10.000\n", 322 | "Epoch: 24\n", 323 | "[TRAIN] Acc: 10.136\n", 324 | "[TEST] Acc: 10.000\n", 325 | "Epoch: 25\n", 326 | "[TRAIN] Acc: 9.880\n", 327 | "[TEST] Acc: 10.000\n", 328 | "Epoch: 26\n", 329 | "[TRAIN] Acc: 9.950\n", 330 | "[TEST] Acc: 10.000\n", 331 | "Epoch: 27\n", 332 | "[TRAIN] Acc: 9.626\n", 333 | "[TEST] Acc: 10.000\n", 334 | "Epoch: 28\n", 335 | "[TRAIN] Acc: 9.986\n", 336 | "[TEST] Acc: 10.000\n", 337 | "Epoch: 29\n", 338 | "[TRAIN] Acc: 9.922\n", 339 | "[TEST] Acc: 10.000\n", 340 | "Epoch: 30\n", 341 | "[TRAIN] Acc: 9.978\n", 342 | "[TEST] Acc: 10.000\n", 343 | "Epoch: 31\n", 344 | "[TRAIN] Acc: 9.856\n", 345 | "[TEST] Acc: 10.000\n", 346 | "Epoch: 32\n", 347 | "[TRAIN] Acc: 9.920\n", 348 | "[TEST] Acc: 10.000\n", 349 | "Epoch: 33\n", 350 | "[TRAIN] Acc: 10.054\n", 351 | "[TEST] Acc: 10.000\n", 352 | "Epoch: 34\n", 353 | "[TRAIN] Acc: 9.746\n", 354 | "[TEST] Acc: 10.000\n", 355 | "Epoch: 35\n", 356 | "[TRAIN] Acc: 9.708\n", 357 | "[TEST] Acc: 10.000\n", 358 | "Epoch: 36\n", 359 | "[TRAIN] Acc: 9.814\n", 360 | "[TEST] Acc: 10.000\n", 361 | "Epoch: 37\n", 362 | "[TRAIN] Acc: 9.636\n", 363 | "[TEST] Acc: 10.000\n", 364 | "Epoch: 38\n", 365 | "[TRAIN] Acc: 9.928\n", 366 | "[TEST] Acc: 10.000\n", 367 | "Epoch: 39\n", 368 | "[TRAIN] Acc: 9.980\n", 369 | "[TEST] Acc: 10.000\n", 370 | "Epoch: 40\n", 371 | "[TRAIN] Acc: 9.930\n", 372 | "[TEST] Acc: 10.000\n", 373 | "Epoch: 41\n", 374 | "[TRAIN] Acc: 9.902\n", 375 | "[TEST] Acc: 10.000\n", 376 | "Epoch: 42\n", 377 | "[TRAIN] Acc: 9.766\n", 378 | "[TEST] Acc: 10.000\n", 379 | "Epoch: 43\n", 380 | "[TRAIN] Acc: 9.844\n", 381 | "[TEST] Acc: 10.000\n", 382 | "Epoch: 44\n", 383 | "[TRAIN] Acc: 9.758\n", 384 | "[TEST] Acc: 10.000\n", 385 | "Epoch: 45\n", 386 | "[TRAIN] Acc: 9.728\n", 387 | "[TEST] Acc: 10.000\n", 388 | "Epoch: 46\n", 389 | "[TRAIN] Acc: 9.646\n", 390 | "[TEST] Acc: 10.000\n", 391 | "Epoch: 47\n", 392 | "[TRAIN] Acc: 9.922\n", 393 | "[TEST] Acc: 10.000\n", 394 | "Epoch: 48\n", 395 | "[TRAIN] Acc: 9.704\n", 396 | "[TEST] Acc: 10.000\n", 397 | "Epoch: 49\n", 398 | "[TRAIN] Acc: 10.010\n", 399 | "[TEST] Acc: 10.000\n", 400 | "Epoch: 50\n", 401 | "[TRAIN] Acc: 10.000\n", 402 | "[TEST] Acc: 10.000\n", 403 | "Epoch: 51\n", 404 | "[TRAIN] Acc: 10.000\n", 405 | "[TEST] Acc: 10.000\n", 406 | "Epoch: 52\n", 407 | "[TRAIN] Acc: 10.000\n", 408 | "[TEST] Acc: 10.000\n", 409 | "Epoch: 53\n", 410 | "[TRAIN] Acc: 10.000\n", 411 | "[TEST] Acc: 10.000\n", 412 | "Epoch: 54\n", 413 | "[TRAIN] Acc: 10.000\n", 414 | "[TEST] Acc: 10.000\n", 415 | "Epoch: 55\n", 416 | "[TRAIN] Acc: 10.000\n", 417 | "[TEST] Acc: 10.000\n", 418 | "Epoch: 56\n", 419 | "[TRAIN] Acc: 10.000\n", 420 | "[TEST] Acc: 10.000\n", 421 | "Epoch: 57\n", 422 | "[TRAIN] Acc: 10.000\n", 423 | "[TEST] Acc: 10.000\n", 424 | "Epoch: 58\n", 425 | "[TRAIN] Acc: 9.836\n", 426 | "[TEST] Acc: 10.000\n", 427 | "Epoch: 59\n", 428 | "[TRAIN] Acc: 9.944\n", 429 | "[TEST] Acc: 10.000\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "for epoch in epochs:\n", 435 | " optimizer = Adam(net.parameters(), lr=lr)\n", 436 | " for _ in range(epoch):\n", 437 | " train(count)\n", 438 | " test(count)\n", 439 | " count += 1\n", 440 | " lr /= 10" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [] 449 | } 450 | ], 451 | "metadata": { 452 | "kernelspec": { 453 | "display_name": "Python (bayesian)", 454 | "language": "python", 455 | "name": "bayesian" 456 | }, 457 | "language_info": { 458 | "codemirror_mode": { 459 | "name": "ipython", 460 | "version": 3 461 | }, 462 | "file_extension": ".py", 463 | "mimetype": "text/x-python", 464 | "name": "python", 465 | "nbconvert_exporter": "python", 466 | "pygments_lexer": "ipython3", 467 | "version": "3.7.3" 468 | } 469 | }, 470 | "nbformat": 4, 471 | "nbformat_minor": 2 472 | } 473 | -------------------------------------------------------------------------------- /Visualization/ComparisonResultsProbAct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/ComparisonResultsProbAct.png -------------------------------------------------------------------------------- /Visualization/LayerWiseSigma.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/LayerWiseSigma.pdf -------------------------------------------------------------------------------- /Visualization/OneTrainableSigma.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/OneTrainableSigma.pdf -------------------------------------------------------------------------------- /Visualization/OverfittingCIFAR100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/OverfittingCIFAR100.png -------------------------------------------------------------------------------- /Visualization/ProbAct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/ProbAct.png -------------------------------------------------------------------------------- /Visualization/TestAcc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/TestAcc.pdf -------------------------------------------------------------------------------- /Visualization/TestAccCIFAR10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/TestAccCIFAR10.pdf -------------------------------------------------------------------------------- /Visualization/TrainAcc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/TrainAcc.pdf -------------------------------------------------------------------------------- /Visualization/TrainAccCIFAR10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumar-shridhar/ProbAct-Probabilistic-Activation-Function/aa5e2759f7d500aae0d93da3b6638492d07b6623/Visualization/TrainAccCIFAR10.pdf -------------------------------------------------------------------------------- /models/ProbAct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | 6 | 7 | device = torch.device("cuda:0") 8 | 9 | class TrainableSigma(nn.Module): 10 | 11 | 12 | def __init__(self, num_parameters=1, init=0): 13 | self.num_parameters = num_parameters 14 | super(TrainableSigma, self).__init__() 15 | self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) 16 | 17 | def forward(self, input): 18 | 19 | mu = input 20 | 21 | if mu.is_cuda: 22 | eps = torch.cuda.FloatTensor(mu.size()).normal_(mean = 0, std = 1) 23 | else: 24 | eps = torch.FloatTensor(mu.size()).normal_(mean = 0, std = 1) 25 | 26 | return F.relu(mu) + self.weight * eps 27 | 28 | def extra_repr(self): 29 | return 'num_parameters={}'.format(self.num_parameters) 30 | -------------------------------------------------------------------------------- /models/swish.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class Swish(nn.Module): 5 | 6 | def __init__(self, inplace=False): 7 | super().__init__() 8 | 9 | self.inplace = True 10 | 11 | def forward(self, x): 12 | if self.inplace: 13 | x.mul_(F.sigmoid(x)) 14 | return x 15 | else: 16 | return x * F.sigmoid(x) -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG_ReLU(nn.Module): 15 | def __init__(self, vgg_name, nclass, img_width=32): 16 | super(VGG_ReLU, self).__init__() 17 | self.img_width = img_width 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, nclass) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out, None # return None, to make it compatible with VGG_noise 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | width = self.img_width 31 | for x in cfg: 32 | if x == 'M': 33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 34 | width = width // 2 35 | else: 36 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(x), 38 | nn.ReLU(inplace=True)] 39 | in_channels = x 40 | layers += [nn.Dropout(0.5)] 41 | return nn.Sequential(*layers) 42 | 43 | 44 | class VGG_Sigmoid(nn.Module): 45 | def __init__(self, vgg_name, nclass, img_width=32): 46 | super(VGG_Sigmoid, self).__init__() 47 | self.img_width = img_width 48 | self.features = self._make_layers(cfg[vgg_name]) 49 | self.classifier = nn.Linear(512, nclass) 50 | 51 | def forward(self, x): 52 | out = self.features(x) 53 | out = out.view(out.size(0), -1) 54 | out = self.classifier(out) 55 | return out, None # return None, to make it compatible with VGG_noise 56 | 57 | def _make_layers(self, cfg): 58 | layers = [] 59 | in_channels = 3 60 | width = self.img_width 61 | for x in cfg: 62 | if x == 'M': 63 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 64 | width = width // 2 65 | else: 66 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 67 | nn.BatchNorm2d(x), 68 | nn.Sigmoid()] 69 | in_channels = x 70 | layers += [nn.AvgPool2d(kernel_size=width, stride=1)] 71 | return nn.Sequential(*layers) 72 | 73 | 74 | 75 | class VGG_PReLU(nn.Module): 76 | def __init__(self, vgg_name, nclass, img_width=32): 77 | super(VGG_PReLU, self).__init__() 78 | self.img_width = img_width 79 | self.features = self._make_layers(cfg[vgg_name]) 80 | self.classifier = nn.Linear(512, nclass) 81 | 82 | def forward(self, x): 83 | out = self.features(x) 84 | out = out.view(out.size(0), -1) 85 | out = self.classifier(out) 86 | return out, None # return None, to make it compatible with VGG_noise 87 | 88 | def _make_layers(self, cfg): 89 | layers = [] 90 | in_channels = 3 91 | width = self.img_width 92 | for x in cfg: 93 | if x == 'M': 94 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 95 | width = width // 2 96 | else: 97 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 98 | nn.BatchNorm2d(x), 99 | nn.PReLU(1)] 100 | in_channels = x 101 | layers += [nn.AvgPool2d(kernel_size=width, stride=1)] 102 | return nn.Sequential(*layers) 103 | 104 | 105 | 106 | class VGG_LeakyReLU(nn.Module): 107 | def __init__(self, vgg_name, nclass, img_width=32): 108 | super(VGG_LeakyReLU, self).__init__() 109 | self.img_width = img_width 110 | self.features = self._make_layers(cfg[vgg_name]) 111 | self.classifier = nn.Linear(512, nclass) 112 | 113 | def forward(self, x): 114 | out = self.features(x) 115 | out = out.view(out.size(0), -1) 116 | out = self.classifier(out) 117 | return out, None # return None, to make it compatible with VGG_noise 118 | 119 | def _make_layers(self, cfg): 120 | layers = [] 121 | in_channels = 3 122 | width = self.img_width 123 | for x in cfg: 124 | if x == 'M': 125 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 126 | width = width // 2 127 | else: 128 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 129 | nn.BatchNorm2d(x), 130 | nn.LeakyReLU()] 131 | in_channels = x 132 | layers += [nn.AvgPool2d(kernel_size=width, stride=1)] 133 | return nn.Sequential(*layers) 134 | 135 | -------------------------------------------------------------------------------- /models/vgg_probact_non-trainable.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import normal 7 | import numpy as np 8 | from torch.nn import Parameter 9 | 10 | device = torch.device("cuda:0") 11 | 12 | 13 | class VGG_NonTrainable_ProbAct(nn.Module): 14 | 15 | def __init__(self, nclass, img_width=32, sigma): 16 | super(VGG_NonTrainable_ProbAct, self).__init__() 17 | 18 | self.sigma = sigma 19 | self.conv1 = nn.Sequential( 20 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(64)) 22 | 23 | self.conv2 = nn.Sequential( 24 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(64)) 26 | 27 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 28 | 29 | self.conv3 = nn.Sequential( 30 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 31 | nn.BatchNorm2d(128)) 32 | 33 | self.conv4 = nn.Sequential( 34 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(128)) 36 | 37 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 38 | 39 | self.conv5 = nn.Sequential( 40 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(256)) 42 | 43 | self.conv6 = nn.Sequential( 44 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(256)) 46 | 47 | self.conv7 = nn.Sequential( 48 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 49 | nn.BatchNorm2d(256)) 50 | 51 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 52 | 53 | self.conv8 = nn.Sequential( 54 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(512)) 56 | 57 | self.conv9 = nn.Sequential( 58 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 59 | nn.BatchNorm2d(512)) 60 | 61 | self.conv10 = nn.Sequential( 62 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(512)) 64 | 65 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 66 | 67 | self.conv11 = nn.Sequential( 68 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 69 | nn.BatchNorm2d(512)) 70 | 71 | self.conv12 = nn.Sequential( 72 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 73 | nn.BatchNorm2d(512)) 74 | 75 | self.conv13 = nn.Sequential( 76 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(512)) 78 | 79 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 80 | 81 | # self.dropout = nn.Dropout(0.5) 82 | 83 | self.ProbActAF = TrainableSigma() 84 | 85 | # self.SwishAF = Swish() 86 | 87 | self.img_width = img_width 88 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 89 | self.globalAvgpool = nn.AvgPool2d(kernel_size=1, stride=1) 90 | 91 | self.classifier = nn.Sequential( 92 | # nn.Linear(512 * 7 * 7, 4096), 93 | # nn.ReLU(True), 94 | # nn.Dropout(), 95 | # nn.Linear(4096, 512), 96 | # nn.ReLU(True), 97 | # nn.Dropout(), 98 | nn.Linear(512, nclass), 99 | ) 100 | 101 | 102 | def non_trainable_sigma(self, x): 103 | 104 | mu = x 105 | shape = mu.size() 106 | 107 | if mu.is_cuda: 108 | eps = torch.cuda.FloatTensor(shape).normal_(mean = 0, std =self.sigma) 109 | else: 110 | eps = torch.FloatTensor(shape).normal_(mean = 0, std = self.sigma) 111 | 112 | return F.relu(mu) + eps 113 | 114 | 115 | def forward(self, x): 116 | 117 | out = self.conv1(x) 118 | out = self.non_trainable_sigma(out) 119 | 120 | out = self.conv2(out) 121 | out = self.non_trainable_sigma(out) 122 | out = self.pool1(out) 123 | 124 | 125 | out = self.conv3(out) 126 | out = self.non_trainable_sigma(out) 127 | 128 | out = self.conv4(out) 129 | out = self.non_trainable_sigma(out) 130 | out = self.pool2(out) 131 | 132 | 133 | out = self.conv5(out) 134 | out = self.non_trainable_sigma(out) 135 | 136 | out = self.conv6(out) 137 | out = self.non_trainable_sigma(out) 138 | 139 | out = self.conv7(out) 140 | out = self.non_trainable_sigma(out) 141 | out = self.pool3(out) 142 | 143 | 144 | out = self.conv8(out) 145 | out = self.non_trainable_sigma(out) 146 | 147 | out = self.conv9(out) 148 | out = self.non_trainable_sigma(out) 149 | 150 | out = self.conv10(out) 151 | out = self.non_trainable_sigma(out) 152 | out = self.pool4(out) 153 | 154 | 155 | out = self.conv11(out) 156 | out = self.non_trainable_sigma(out) 157 | 158 | out = self.conv12(out) 159 | out = self.non_trainable_sigma(out) 160 | 161 | out = self.conv13(out) 162 | out = self.non_trainable_sigma(out) 163 | out = self.pool5(out) 164 | 165 | 166 | out = self.globalAvgpool(out) 167 | out = out.view(out.size(0), -1) 168 | out = self.classifier(out) 169 | return out 170 | -------------------------------------------------------------------------------- /models/vgg_probact_trainable.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import normal 7 | import numpy as np 8 | from torch.nn import Parameter 9 | from .ProbAct import TrainableSigma 10 | 11 | device = torch.device("cuda:0") 12 | 13 | 14 | class VGG_ProbAct_Trainable(nn.Module): 15 | 16 | def __init__(self, nclass, img_width=32): 17 | super(VGG_ProbAct_Trainable, self).__init__() 18 | 19 | 20 | self.conv1 = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(64)) 23 | 24 | self.conv2 = nn.Sequential( 25 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 26 | nn.BatchNorm2d(64)) 27 | 28 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 29 | 30 | self.conv3 = nn.Sequential( 31 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(128)) 33 | 34 | self.conv4 = nn.Sequential( 35 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(128)) 37 | 38 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 39 | 40 | self.conv5 = nn.Sequential( 41 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(256)) 43 | 44 | self.conv6 = nn.Sequential( 45 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 46 | nn.BatchNorm2d(256)) 47 | 48 | self.conv7 = nn.Sequential( 49 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 50 | nn.BatchNorm2d(256)) 51 | 52 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 53 | 54 | self.conv8 = nn.Sequential( 55 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 56 | nn.BatchNorm2d(512)) 57 | 58 | self.conv9 = nn.Sequential( 59 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(512)) 61 | 62 | self.conv10 = nn.Sequential( 63 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 64 | nn.BatchNorm2d(512)) 65 | 66 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 67 | 68 | self.conv11 = nn.Sequential( 69 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 70 | nn.BatchNorm2d(512)) 71 | 72 | self.conv12 = nn.Sequential( 73 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 74 | nn.BatchNorm2d(512)) 75 | 76 | self.conv13 = nn.Sequential( 77 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 78 | nn.BatchNorm2d(512)) 79 | 80 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 81 | 82 | # self.dropout = nn.Dropout(0.5) 83 | 84 | self.ProbActAF = TrainableSigma() 85 | 86 | 87 | self.img_width = img_width 88 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 89 | self.globalAvgpool = nn.AvgPool2d(kernel_size=1, stride=1) 90 | 91 | self.classifier = nn.Sequential( 92 | # nn.Linear(512 * 7 * 7, 4096), 93 | # nn.ReLU(True), 94 | # nn.Dropout(), 95 | # nn.Linear(4096, 512), 96 | # nn.ReLU(True), 97 | # nn.Dropout(), 98 | nn.Linear(512, nclass), 99 | ) 100 | 101 | 102 | def forward(self, x): 103 | 104 | out = self.conv1(x) 105 | out = self.ProbActAF(out) 106 | 107 | out = self.conv2(out) 108 | out = self.ProbActAF(out) 109 | out = self.pool1(out) 110 | 111 | 112 | out = self.conv3(out) 113 | out = self.ProbActAF(out) 114 | 115 | out = self.conv4(out) 116 | out = self.ProbActAF(out) 117 | out = self.pool2(out) 118 | 119 | 120 | out = self.conv5(out) 121 | out = self.ProbActAF(out) 122 | 123 | out = self.conv6(out) 124 | out = self.ProbActAF(out) 125 | 126 | out = self.conv7(out) 127 | out = self.ProbActAF(out) 128 | out = self.pool3(out) 129 | 130 | 131 | out = self.conv8(out) 132 | out = self.ProbActAF(out) 133 | 134 | out = self.conv9(out) 135 | out = self.ProbActAF(out) 136 | 137 | out = self.conv10(out) 138 | out = self.ProbActAF(out) 139 | out = self.pool4(out) 140 | 141 | 142 | out = self.conv11(out) 143 | out = self.ProbActAF(out) 144 | 145 | out = self.conv12(out) 146 | out = self.ProbActAF(out) 147 | 148 | out = self.conv13(out) 149 | out = self.ProbActAF(out) 150 | out = self.pool5(out) 151 | 152 | 153 | out = self.globalAvgpool(out) 154 | out = out.view(out.size(0), -1) 155 | out = self.classifier(out) 156 | return out 157 | -------------------------------------------------------------------------------- /models/vgg_swish.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import normal 7 | import numpy as np 8 | from torch.nn import Parameter 9 | from .swish import Swish 10 | 11 | device = torch.device("cuda:0") 12 | 13 | class VGG_Swish(nn.Module): 14 | 15 | def __init__(self, nclass, img_width=32): 16 | super(VGG_Swish, self).__init__() 17 | 18 | 19 | self.conv1 = nn.Sequential( 20 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(64)) 22 | 23 | self.conv2 = nn.Sequential( 24 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(64)) 26 | 27 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 28 | 29 | self.conv3 = nn.Sequential( 30 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 31 | nn.BatchNorm2d(128)) 32 | 33 | self.conv4 = nn.Sequential( 34 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(128)) 36 | 37 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 38 | 39 | self.conv5 = nn.Sequential( 40 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(256)) 42 | 43 | self.conv6 = nn.Sequential( 44 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(256)) 46 | 47 | self.conv7 = nn.Sequential( 48 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 49 | nn.BatchNorm2d(256)) 50 | 51 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 52 | 53 | self.conv8 = nn.Sequential( 54 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(512)) 56 | 57 | self.conv9 = nn.Sequential( 58 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 59 | nn.BatchNorm2d(512)) 60 | 61 | self.conv10 = nn.Sequential( 62 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(512)) 64 | 65 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 66 | 67 | self.conv11 = nn.Sequential( 68 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 69 | nn.BatchNorm2d(512)) 70 | 71 | self.conv12 = nn.Sequential( 72 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 73 | nn.BatchNorm2d(512)) 74 | 75 | self.conv13 = nn.Sequential( 76 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(512)) 78 | 79 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 80 | 81 | # self.dropout = nn.Dropout(0.5) 82 | 83 | self.SwishAF = Swish() 84 | 85 | self.img_width = img_width 86 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 87 | self.globalAvgpool = nn.AvgPool2d(kernel_size=1, stride=1) 88 | 89 | self.classifier = nn.Sequential( 90 | # nn.Linear(512 * 7 * 7, 4096), 91 | # nn.ReLU(True), 92 | # nn.Dropout(), 93 | # nn.Linear(4096, 512), 94 | # nn.ReLU(True), 95 | # nn.Dropout(), 96 | nn.Linear(512, nclass), 97 | ) 98 | 99 | 100 | def forward(self, x): 101 | 102 | out = self.conv1(x) 103 | out = self.SwishAF(out) 104 | 105 | out = self.conv2(out) 106 | out = self.SwishAF(out) 107 | out = self.pool1(out) 108 | 109 | 110 | out = self.conv3(out) 111 | out = self.SwishAF(out) 112 | 113 | out = self.conv4(out) 114 | out = self.SwishAF(out) 115 | out = self.pool2(out) 116 | 117 | 118 | out = self.conv5(out) 119 | out = self.SwishAF(out) 120 | 121 | out = self.conv6(out) 122 | out = self.SwishAF(out) 123 | 124 | out = self.conv7(out) 125 | out = self.SwishAF(out) 126 | out = self.pool3(out) 127 | 128 | 129 | out = self.conv8(out) 130 | out = self.SwishAF(out) 131 | 132 | out = self.conv9(out) 133 | out = self.SwishAF(out) 134 | 135 | out = self.conv10(out) 136 | out = self.SwishAF(out) 137 | out = self.pool4(out) 138 | 139 | 140 | out = self.conv11(out) 141 | out = self.SwishAF(out) 142 | 143 | out = self.conv12(out) 144 | out = self.SwishAF(out) 145 | 146 | out = self.conv13(out) 147 | out = self.SwishAF(out) 148 | out = self.pool5(out) 149 | 150 | 151 | out = self.globalAvgpool(out) 152 | out = out.view(out.size(0), -1) 153 | out = self.classifier(out) 154 | return out 155 | --------------------------------------------------------------------------------