├── README.md ├── requirements.txt └── pytorch_model.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Image Classification 2 | 3 | Classifies an image as containing either a dog or a cat (using Kaggle's public dataset), but could easily be extended to other image classification problems. 4 | 5 | ### Dependencies: 6 | - PyTorch / Torchvision 7 | - Numpy 8 | - PIL 9 | - CUDA 10 | 11 | ## Data 12 | 13 | The data directory structure I used was: 14 | 15 | * project 16 | * data 17 | * train 18 | * dogs 19 | * cats 20 | * validation 21 | * dogs 22 | * cats 23 | * test 24 | * test 25 | 26 | ## Performance 27 | The result of the notebook in this repo produced a log loss score on Kaggle's hidden dataset of 0.04988 -- further gains can probably be achieved by creating an ensemble of classifiers using this approach. 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ### Conda reqs 2 | #conda install cudatoolkit==9.0 cudnn=7.6.0=cuda9.0_0 3 | #conda install -c anaconda nb_conda 4 | 5 | attrs==19.1.0 6 | backcall==0.1.0 7 | bleach==3.3.0 8 | certifi==2024.7.4 9 | cffi==1.12.3 10 | decorator==4.4.0 11 | defusedxml==0.6.0 12 | entrypoints==0.3 13 | ipykernel==5.1.1 14 | ipython==8.10.0 15 | ipython-genutils==0.2.0 16 | jedi==0.13.3 17 | Jinja2==3.1.6 18 | jsonschema==3.0.1 19 | jupyter-client==5.3.1 20 | jupyter-core==4.11.2 21 | MarkupSafe==1.1.1 22 | mistune==2.0.3 23 | mkl-fft==1.0.12 24 | mkl-random==1.0.2 25 | nb-conda==2.2.1 26 | nb-conda-kernels==2.2.2 27 | nbconvert==6.5.1 28 | nbformat==4.4.0 29 | notebook==6.4.12 30 | numpy==1.22.0 31 | olefile==0.46 32 | pandocfilters==1.4.2 33 | parso==0.5.0 34 | pexpect==4.7.0 35 | pickleshare==0.7.5 36 | Pillow==10.3.0 37 | prometheus-client==0.7.1 38 | prompt-toolkit==2.0.9 39 | ptyprocess==0.6.0 40 | pycparser==2.19 41 | Pygments==2.15.0 42 | pyrsistent==0.14.11 43 | python-dateutil==2.8.0 44 | pyzmq==18.0.0 45 | Send2Trash==1.5.0 46 | six==1.12.0 47 | terminado==0.8.2 48 | testpath==0.4.2 49 | torch==2.7.1 50 | torchvision==0.3.0 51 | tornado==6.5.1 52 | traitlets==4.3.2 53 | wcwidth==0.1.7 54 | webencodings==0.5.1 55 | -------------------------------------------------------------------------------- /pytorch_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Image Classification of Dogs vs. Cats With PyTorch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Imports & Environment" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "import random\n", 25 | "import collections\n", 26 | "import shutil\n", 27 | "import time\n", 28 | "import glob\n", 29 | "import csv\n", 30 | "import numpy as np\n", 31 | "\n", 32 | "import torch\n", 33 | "import torch.backends.cudnn as cudnn\n", 34 | "import torch.nn as nn\n", 35 | "import torch.optim as optim\n", 36 | "import torch.utils.data as data\n", 37 | "import torchvision.datasets as datasets\n", 38 | "import torchvision.models as models\n", 39 | "import torchvision.transforms as transforms\n", 40 | "\n", 41 | "from PIL import Image\n", 42 | "\n", 43 | "ROOT_DIR = os.getcwd()\n", 44 | "DATA_HOME_DIR = ROOT_DIR + '/data'" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "### Config & Hyperparameters" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# paths\n", 61 | "data_path = DATA_HOME_DIR + '/' \n", 62 | "split_train_path = data_path + '/train/'\n", 63 | "full_train_path = data_path + '/train_full/'\n", 64 | "valid_path = data_path + '/valid/'\n", 65 | "test_path = DATA_HOME_DIR + '/test/test/'\n", 66 | "saved_model_path = ROOT_DIR + '/models/'\n", 67 | "submission_path = ROOT_DIR + '/submissions/'\n", 68 | "\n", 69 | "# data\n", 70 | "batch_size = 8\n", 71 | "\n", 72 | "# model\n", 73 | "nb_runs = 1\n", 74 | "nb_aug = 3\n", 75 | "epochs = 5\n", 76 | "lr = 1e-4\n", 77 | "clip = 0.001\n", 78 | "archs = [\"resnet152\"]\n", 79 | "\n", 80 | "model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith(\"__\"))\n", 81 | "best_prec1 = 0" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "['_utils',\n", 93 | " 'alexnet',\n", 94 | " 'densenet',\n", 95 | " 'densenet121',\n", 96 | " 'densenet161',\n", 97 | " 'densenet169',\n", 98 | " 'densenet201',\n", 99 | " 'detection',\n", 100 | " 'googlenet',\n", 101 | " 'inception',\n", 102 | " 'inception_v3',\n", 103 | " 'mobilenet',\n", 104 | " 'mobilenet_v2',\n", 105 | " 'resnet',\n", 106 | " 'resnet101',\n", 107 | " 'resnet152',\n", 108 | " 'resnet18',\n", 109 | " 'resnet34',\n", 110 | " 'resnet50',\n", 111 | " 'resnext101_32x8d',\n", 112 | " 'resnext50_32x4d',\n", 113 | " 'segmentation',\n", 114 | " 'shufflenet_v2_x0_5',\n", 115 | " 'shufflenet_v2_x1_0',\n", 116 | " 'shufflenet_v2_x1_5',\n", 117 | " 'shufflenet_v2_x2_0',\n", 118 | " 'shufflenetv2',\n", 119 | " 'squeezenet',\n", 120 | " 'squeezenet1_0',\n", 121 | " 'squeezenet1_1',\n", 122 | " 'utils',\n", 123 | " 'vgg',\n", 124 | " 'vgg11',\n", 125 | " 'vgg11_bn',\n", 126 | " 'vgg13',\n", 127 | " 'vgg13_bn',\n", 128 | " 'vgg16',\n", 129 | " 'vgg16_bn',\n", 130 | " 'vgg19',\n", 131 | " 'vgg19_bn']" 132 | ] 133 | }, 134 | "execution_count": 3, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "model_names" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "### Helper Functions for Training" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 4, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "def train(train_loader, model, criterion, optimizer, epoch):\n", 157 | " batch_time = AverageMeter()\n", 158 | " data_time = AverageMeter()\n", 159 | " losses = AverageMeter()\n", 160 | " acc = AverageMeter()\n", 161 | " end = time.time()\n", 162 | " \n", 163 | " # switch to train mode\n", 164 | " model.train()\n", 165 | " \n", 166 | " for i, (images, target) in enumerate(train_loader):\n", 167 | " # measure data loading time\n", 168 | " data_time.update(time.time() - end)\n", 169 | "\n", 170 | " target = target.cuda()\n", 171 | " image_var = torch.autograd.Variable(images)\n", 172 | " label_var = torch.autograd.Variable(target)\n", 173 | "\n", 174 | " # compute y_pred\n", 175 | " y_pred = model(image_var)\n", 176 | " loss = criterion(y_pred, label_var)\n", 177 | "\n", 178 | " # measure accuracy and record loss\n", 179 | " prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))\n", 180 | " losses.update(loss.data, images.size(0))\n", 181 | " acc.update(prec1, images.size(0))\n", 182 | "\n", 183 | " # compute gradient and do SGD step\n", 184 | " optimizer.zero_grad()\n", 185 | " loss.backward()\n", 186 | " optimizer.step()\n", 187 | "\n", 188 | " # measure elapsed time\n", 189 | " batch_time.update(time.time() - end)\n", 190 | " end = time.time()" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 5, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "def validate(val_loader, model, criterion, epoch):\n", 200 | " batch_time = AverageMeter()\n", 201 | " losses = AverageMeter()\n", 202 | " acc = AverageMeter()\n", 203 | "\n", 204 | " # switch to evaluate mode\n", 205 | " model.eval()\n", 206 | "\n", 207 | " end = time.time()\n", 208 | " for i, (images, labels) in enumerate(val_loader):\n", 209 | " labels = labels.cuda()\n", 210 | " image_var = torch.autograd.Variable(images)\n", 211 | " label_var = torch.autograd.Variable(labels)\n", 212 | "\n", 213 | " # compute y_pred\n", 214 | " y_pred = model(image_var)\n", 215 | " loss = criterion(y_pred, label_var)\n", 216 | "\n", 217 | " # measure accuracy and record loss\n", 218 | " prec1, temp_var = accuracy(y_pred.data, labels, topk=(1, 1))\n", 219 | " losses.update(loss.data, images.size(0))\n", 220 | " acc.update(prec1, images.size(0))\n", 221 | "\n", 222 | " # measure elapsed time\n", 223 | " batch_time.update(time.time() - end)\n", 224 | " end = time.time()\n", 225 | "\n", 226 | " print(' * EPOCH {epoch} | Accuracy: {acc.avg:.3f} | Loss: {losses.avg:.3f}'.format(epoch=epoch,\n", 227 | " acc=acc,\n", 228 | " losses=losses))\n", 229 | "\n", 230 | " return acc.avg" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 18, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "def test(test_loader, model):\n", 240 | " csv_map = collections.defaultdict(float)\n", 241 | " \n", 242 | " # switch to evaluate mode\n", 243 | " model.eval()\n", 244 | " \n", 245 | " for aug in range(nb_aug):\n", 246 | " print(\" * Predicting on test augmentation {}\".format(aug + 1))\n", 247 | " \n", 248 | " for i, (images, filepath) in enumerate(test_loader):\n", 249 | " # pop extension, treat as id to map\n", 250 | " filepath = os.path.splitext(os.path.basename(filepath[0]))[0]\n", 251 | " filepath = int(filepath)\n", 252 | "\n", 253 | " image_var = torch.autograd.Variable(images)\n", 254 | " y_pred = model(image_var)\n", 255 | " # get the index of the max log-probability\n", 256 | " smax = nn.Softmax()\n", 257 | " smax_out = smax(y_pred)[0]\n", 258 | " cat_prob = smax_out.data[0]\n", 259 | " dog_prob = smax_out.data[1]\n", 260 | " prob = dog_prob\n", 261 | " if cat_prob > dog_prob:\n", 262 | " prob = 1 - cat_prob\n", 263 | " prob = np.around(prob.cpu(), decimals=4)\n", 264 | " prob = np.clip(prob, clip, 1-clip)\n", 265 | " csv_map[filepath] += (prob / nb_aug)\n", 266 | "\n", 267 | " sub_fn = submission_path + '{0}epoch_{1}clip_{2}runs'.format(epochs, clip, nb_runs)\n", 268 | " \n", 269 | " for arch in archs:\n", 270 | " sub_fn += \"_{}\".format(arch)\n", 271 | " \n", 272 | " print(\"Writing Predictions to CSV...\")\n", 273 | " with open(sub_fn + '.csv', 'w') as csvfile:\n", 274 | " fieldnames = ['id', 'label']\n", 275 | " csv_w = csv.writer(csvfile)\n", 276 | " csv_w.writerow(('id', 'label'))\n", 277 | " for row in sorted(csv_map.items()):\n", 278 | " csv_w.writerow(row)\n", 279 | " print(\"Done.\")" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 7, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):\n", 289 | " torch.save(state, filename)\n", 290 | " if is_best:\n", 291 | " shutil.copyfile(filename, 'model_best.pth.tar') " 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 8, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "class AverageMeter(object):\n", 301 | " \"\"\"Computes and stores the average and current value\"\"\"\n", 302 | "\n", 303 | " def __init__(self):\n", 304 | " self.reset()\n", 305 | "\n", 306 | " def reset(self):\n", 307 | " self.val = 0\n", 308 | " self.avg = 0\n", 309 | " self.sum = 0\n", 310 | " self.count = 0\n", 311 | "\n", 312 | " def update(self, val, n=1):\n", 313 | " self.val = val\n", 314 | " self.sum += val * n\n", 315 | " self.count += n\n", 316 | " self.avg = self.sum / self.count" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 9, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "def adjust_learning_rate(optimizer, epoch):\n", 326 | " \"\"\"Sets the learning rate to the initial LR decayed by 10 every 30 epochs\"\"\"\n", 327 | " global lr\n", 328 | " lr = lr * (0.1**(epoch // 30))\n", 329 | " for param_group in optimizer.state_dict()['param_groups']:\n", 330 | " param_group['lr'] = lr\n", 331 | "\n", 332 | "\n", 333 | "def accuracy(y_pred, y_actual, topk=(1, )):\n", 334 | " \"\"\"Computes the precision@k for the specified values of k\"\"\"\n", 335 | " maxk = max(topk)\n", 336 | " batch_size = y_actual.size(0)\n", 337 | "\n", 338 | " _, pred = y_pred.topk(maxk, 1, True, True)\n", 339 | " pred = pred.t()\n", 340 | " correct = pred.eq(y_actual.view(1, -1).expand_as(pred))\n", 341 | "\n", 342 | " res = []\n", 343 | " for k in topk:\n", 344 | " correct_k = correct[:k].view(-1).float().sum(0)\n", 345 | " res.append(correct_k.mul_(100.0 / batch_size))\n", 346 | "\n", 347 | " return res" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 10, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "class TestImageFolder(data.Dataset):\n", 357 | " def __init__(self, root, transform=None):\n", 358 | " images = []\n", 359 | " for filename in sorted(glob.glob(test_path + \"*.jpg\")):\n", 360 | " images.append('{}'.format(filename))\n", 361 | "\n", 362 | " self.root = root\n", 363 | " self.imgs = images\n", 364 | " self.transform = transform\n", 365 | "\n", 366 | " def __getitem__(self, index):\n", 367 | " filename = self.imgs[index]\n", 368 | " img = Image.open(os.path.join(self.root, filename))\n", 369 | " if self.transform is not None:\n", 370 | " img = self.transform(img)\n", 371 | " return img, filename\n", 372 | "\n", 373 | " def __len__(self):\n", 374 | " return len(self.imgs)" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 11, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "def shear(img):\n", 384 | " width, height = img.size\n", 385 | " m = random.uniform(-0.05, 0.05)\n", 386 | " xshift = abs(m) * width\n", 387 | " new_width = width + int(round(xshift))\n", 388 | " img = img.transform((new_width, height), Image.AFFINE,\n", 389 | " (1, m, -xshift if m > 0 else 0, 0, 1, 0),\n", 390 | " Image.BICUBIC)\n", 391 | " return img" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "metadata": {}, 397 | "source": [ 398 | "### Main Training Loop" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 12, 404 | "metadata": { 405 | "scrolled": true 406 | }, 407 | "outputs": [], 408 | "source": [ 409 | "def main(mode=\"train\", resume=False):\n", 410 | " \n", 411 | " global best_prec1\n", 412 | " \n", 413 | " for arch in archs:\n", 414 | "\n", 415 | " # create model\n", 416 | " print(\"=> Starting {0} on '{1}' model\".format(mode, arch))\n", 417 | " model = models.__dict__[arch](pretrained=True)\n", 418 | " # Don't update non-classifier learned features in the pretrained networks\n", 419 | " for param in model.parameters():\n", 420 | " param.requires_grad = False\n", 421 | " # Replace the last fully-connected layer\n", 422 | " # Parameters of newly constructed modules have requires_grad=True by default\n", 423 | " # Final dense layer needs to replaced with the previous out chans, and number of classes\n", 424 | " # in this case -- resnet 101 - it's 2048 with two classes (cats and dogs)\n", 425 | " model.fc = nn.Linear(2048, 2)\n", 426 | "\n", 427 | " if arch.startswith('alexnet') or arch.startswith('vgg'):\n", 428 | " model.features = torch.nn.DataParallel(model.features)\n", 429 | " model.cuda()\n", 430 | " else:\n", 431 | " model = torch.nn.DataParallel(model).cuda()\n", 432 | " \n", 433 | " # optionally resume from a checkpoint\n", 434 | " if resume:\n", 435 | " if os.path.isfile(resume):\n", 436 | " print(\"=> Loading checkpoint '{}'\".format(resume))\n", 437 | " checkpoint = torch.load(resume)\n", 438 | " start_epoch = checkpoint['epoch']\n", 439 | " best_prec1 = checkpoint['best_prec1']\n", 440 | " model.load_state_dict(checkpoint['state_dict'])\n", 441 | " print(\"=> Loaded checkpoint (epoch {})\".format(checkpoint['epoch']))\n", 442 | " else:\n", 443 | " print(\"=> No checkpoint found at '{}'\".format(args.resume))\n", 444 | "\n", 445 | " cudnn.benchmark = True\n", 446 | "\n", 447 | " # Data loading code\n", 448 | " traindir = split_train_path\n", 449 | " valdir = valid_path\n", 450 | " testdir = test_path\n", 451 | "\n", 452 | " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 453 | "\n", 454 | " train_loader = data.DataLoader(\n", 455 | " datasets.ImageFolder(traindir,\n", 456 | " transforms.Compose([\n", 457 | " # transforms.Lambda(shear),\n", 458 | " transforms.RandomResizedCrop(224),\n", 459 | " transforms.RandomHorizontalFlip(),\n", 460 | " transforms.ToTensor(),\n", 461 | " normalize,\n", 462 | " ])),\n", 463 | " batch_size=batch_size,\n", 464 | " shuffle=True,\n", 465 | " num_workers=4,\n", 466 | " pin_memory=True)\n", 467 | "\n", 468 | " val_loader = data.DataLoader(\n", 469 | " datasets.ImageFolder(valdir,\n", 470 | " transforms.Compose([\n", 471 | " transforms.Resize(256),\n", 472 | " transforms.CenterCrop(224),\n", 473 | " transforms.ToTensor(),\n", 474 | " normalize,\n", 475 | " ])),\n", 476 | " batch_size=batch_size,\n", 477 | " shuffle=True,\n", 478 | " num_workers=4,\n", 479 | " pin_memory=True)\n", 480 | "\n", 481 | " test_loader = data.DataLoader(\n", 482 | " TestImageFolder(testdir,\n", 483 | " transforms.Compose([\n", 484 | " # transforms.Lambda(shear),\n", 485 | " transforms.Resize(256),\n", 486 | " transforms.CenterCrop(224),\n", 487 | " transforms.RandomHorizontalFlip(),\n", 488 | " transforms.ToTensor(),\n", 489 | " normalize,\n", 490 | " ])),\n", 491 | " batch_size=1,\n", 492 | " shuffle=False,\n", 493 | " num_workers=1,\n", 494 | " pin_memory=False)\n", 495 | " \n", 496 | " \n", 497 | " if mode == \"test\":\n", 498 | " test(test_loader, model)\n", 499 | " return\n", 500 | " \n", 501 | " # define loss function (criterion) and pptimizer\n", 502 | " criterion = nn.CrossEntropyLoss().cuda()\n", 503 | " \n", 504 | " if mode == \"validate\":\n", 505 | " validate(val_loader, model, criterion, 0)\n", 506 | " return\n", 507 | "\n", 508 | " optimizer = optim.Adam(model.module.fc.parameters(), lr, weight_decay=1e-4)\n", 509 | "\n", 510 | " for epoch in range(epochs):\n", 511 | " adjust_learning_rate(optimizer, epoch)\n", 512 | "\n", 513 | " # train for one epoch\n", 514 | " train(train_loader, model, criterion, optimizer, epoch)\n", 515 | "\n", 516 | " # evaluate on validation set\n", 517 | " prec1 = validate(val_loader, model, criterion, epoch)\n", 518 | "\n", 519 | " # remember best Accuracy and save checkpoint\n", 520 | " is_best = prec1 > best_prec1\n", 521 | " best_prec1 = max(prec1, best_prec1)\n", 522 | " save_checkpoint({\n", 523 | " 'epoch': epoch + 1,\n", 524 | " 'arch': arch,\n", 525 | " 'state_dict': model.state_dict(),\n", 526 | " 'best_prec1': best_prec1,\n", 527 | " }, is_best)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": {}, 533 | "source": [ 534 | "### Run Train" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 13, 540 | "metadata": {}, 541 | "outputs": [ 542 | { 543 | "name": "stdout", 544 | "output_type": "stream", 545 | "text": [ 546 | "=> Starting train on 'resnet152' model\n", 547 | " * EPOCH 0 | Accuracy: 99.005 | Loss: 0.045\n", 548 | " * EPOCH 1 | Accuracy: 99.502 | Loss: 0.032\n", 549 | " * EPOCH 2 | Accuracy: 99.254 | Loss: 0.035\n", 550 | " * EPOCH 3 | Accuracy: 99.254 | Loss: 0.035\n", 551 | " * EPOCH 4 | Accuracy: 99.502 | Loss: 0.027\n" 552 | ] 553 | } 554 | ], 555 | "source": [ 556 | "main(mode=\"train\")" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 14, 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "name": "stdout", 566 | "output_type": "stream", 567 | "text": [ 568 | "=> Starting validate on 'resnet152' model\n", 569 | "=> Loading checkpoint 'model_best.pth.tar'\n", 570 | "=> Loaded checkpoint (epoch 2)\n", 571 | " * EPOCH 0 | Accuracy: 99.502 | Loss: 0.032\n" 572 | ] 573 | } 574 | ], 575 | "source": [ 576 | "main(mode=\"validate\", resume='model_best.pth.tar')" 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": {}, 582 | "source": [ 583 | "### Run Test" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": {}, 590 | "outputs": [ 591 | { 592 | "name": "stdout", 593 | "output_type": "stream", 594 | "text": [ 595 | "=> Starting test on 'resnet152' model\n", 596 | "=> Loading checkpoint 'model_best.pth.tar'\n", 597 | "=> Loaded checkpoint (epoch 2)\n", 598 | " * Predicting on test augmentation 1\n" 599 | ] 600 | }, 601 | { 602 | "name": "stderr", 603 | "output_type": "stream", 604 | "text": [ 605 | "/home/robert/anaconda3/envs/pytorch-image-classification/lib/python3.7/site-packages/ipykernel_launcher.py:19: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 606 | ] 607 | }, 608 | { 609 | "name": "stdout", 610 | "output_type": "stream", 611 | "text": [ 612 | " * Predicting on test augmentation 2\n" 613 | ] 614 | } 615 | ], 616 | "source": [ 617 | "main(mode=\"test\", resume='model_best.pth.tar')" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [] 626 | } 627 | ], 628 | "metadata": { 629 | "anaconda-cloud": {}, 630 | "kernelspec": { 631 | "display_name": "Python [conda env:pytorch-image-classification] *", 632 | "language": "python", 633 | "name": "conda-env-pytorch-image-classification-py" 634 | }, 635 | "language_info": { 636 | "codemirror_mode": { 637 | "name": "ipython", 638 | "version": 3 639 | }, 640 | "file_extension": ".py", 641 | "mimetype": "text/x-python", 642 | "name": "python", 643 | "nbconvert_exporter": "python", 644 | "pygments_lexer": "ipython3", 645 | "version": "3.7.4" 646 | } 647 | }, 648 | "nbformat": 4, 649 | "nbformat_minor": 1 650 | } 651 | --------------------------------------------------------------------------------