├── README.md ├── Self-Supervised Learning.pdf └── Image_Colorization.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # SS2021-19-08-2021 2 | Repository for Summer School 2021 3 | -------------------------------------------------------------------------------- /Self-Supervised Learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rudrabha/SS2021-19-08-2021/HEAD/Self-Supervised Learning.pdf -------------------------------------------------------------------------------- /Image_Colorization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Image_Colorization.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "kA26ItOODFRL" 35 | }, 36 | "source": [ 37 | "**Import Headers**" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "H3AHwHkD0MJa" 44 | }, 45 | "source": [ 46 | "General Headers" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "colab": { 53 | "base_uri": "https://localhost:8080/" 54 | }, 55 | "id": "UON2uQnNAuDz", 56 | "outputId": "bbb6de39-2550-4b07-c839-c4e8aaace5d7" 57 | }, 58 | "source": [ 59 | "import os\n", 60 | "!pip install wget\n", 61 | "import wget\n", 62 | "import shutil\n", 63 | "import glob\n", 64 | "import cv2\n", 65 | "import numpy as np\n", 66 | "import random\n", 67 | "from tqdm import tqdm\n", 68 | "import matplotlib.pyplot as plt" 69 | ], 70 | "execution_count": 1, 71 | "outputs": [ 72 | { 73 | "output_type": "stream", 74 | "text": [ 75 | "Collecting wget\n", 76 | " Downloading wget-3.2.zip (10 kB)\n", 77 | "Building wheels for collected packages: wget\n", 78 | " Building wheel for wget (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 79 | " Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9672 sha256=158faad761740dd33e4a2c1c7daaf4d2fe12c818894db0a4f8e438b87490c2f4\n", 80 | " Stored in directory: /root/.cache/pip/wheels/a1/b6/7c/0e63e34eb06634181c63adacca38b79ff8f35c37e3c13e3c02\n", 81 | "Successfully built wget\n", 82 | "Installing collected packages: wget\n", 83 | "Successfully installed wget-3.2\n" 84 | ], 85 | "name": "stdout" 86 | } 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": { 92 | "id": "knrDk5Y_0RaV" 93 | }, 94 | "source": [ 95 | "PyTorch based Headers" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "6iqW-Yjhk5OL" 102 | }, 103 | "source": [ 104 | "from torch.utils.data import Dataset, DataLoader\n", 105 | "import torch \n", 106 | "from torch import nn\n", 107 | "from torch.nn import functional as F\n", 108 | "from torch import optim" 109 | ], 110 | "execution_count": 2, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "5PoEZpF_0WQk" 117 | }, 118 | "source": [ 119 | "**SETTING UP THE DEVICE FOR GPU COMPUTATION**" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "metadata": { 125 | "colab": { 126 | "base_uri": "https://localhost:8080/" 127 | }, 128 | "id": "xkrsI3rEl4A7", 129 | "outputId": "22238fa4-66de-40f6-f5a4-95cde6b4302a" 130 | }, 131 | "source": [ 132 | "use_cuda = torch.cuda.is_available()\n", 133 | "print('use_cuda: {}'.format(use_cuda))\n", 134 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 135 | "print(\"Device to be used : \",device)\n", 136 | "!nvidia-smi" 137 | ], 138 | "execution_count": 3, 139 | "outputs": [ 140 | { 141 | "output_type": "stream", 142 | "text": [ 143 | "use_cuda: True\n", 144 | "Device to be used : cuda\n", 145 | "Thu Aug 19 06:03:07 2021 \n", 146 | "+-----------------------------------------------------------------------------+\n", 147 | "| NVIDIA-SMI 470.57.02 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 148 | "|-------------------------------+----------------------+----------------------+\n", 149 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 150 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 151 | "| | | MIG M. |\n", 152 | "|===============================+======================+======================|\n", 153 | "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", 154 | "| N/A 55C P8 29W / 149W | 3MiB / 11441MiB | 0% Default |\n", 155 | "| | | N/A |\n", 156 | "+-------------------------------+----------------------+----------------------+\n", 157 | " \n", 158 | "+-----------------------------------------------------------------------------+\n", 159 | "| Processes: |\n", 160 | "| GPU GI CI PID Type Process name GPU Memory |\n", 161 | "| ID ID Usage |\n", 162 | "|=============================================================================|\n", 163 | "| No running processes found |\n", 164 | "+-----------------------------------------------------------------------------+\n" 165 | ], 166 | "name": "stdout" 167 | } 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": { 173 | "id": "A4D25A8gDK4r" 174 | }, 175 | "source": [ 176 | "**Setting up Data Path**" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "id": "qM68zdO2IRUH" 183 | }, 184 | "source": [ 185 | "#shutil.rmtree(\"/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data\")" 186 | ], 187 | "execution_count": null, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "metadata": { 193 | "id": "jd4gLw6uDV-E" 194 | }, 195 | "source": [ 196 | "parent_folder = \"/content/IMAGE_SUPER_RESOLVE_DATA\"\n", 197 | "\n", 198 | "if os.path.isdir(parent_folder):\n", 199 | " shutil.rmtree(parent_folder)\n", 200 | "os.mkdir(parent_folder)\n", 201 | "\n", 202 | "#Create Folder to download Raw Data\n", 203 | "raw_data_folder = os.path.join(parent_folder,\"raw_data\")\n", 204 | "extracted_data_folder = os.path.join(parent_folder,\"extracted_data\")\n", 205 | "\n", 206 | "if not os.path.isdir(raw_data_folder):\n", 207 | " os.mkdir(raw_data_folder)\n", 208 | "\n", 209 | "if not os.path.isdir(extracted_data_folder):\n", 210 | " os.mkdir(extracted_data_folder)\n", 211 | "\n", 212 | "image_data_folder = os.path.join(extracted_data_folder, \"images\")" 213 | ], 214 | "execution_count": 4, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "id": "0n-gyvN7EXMi" 221 | }, 222 | "source": [ 223 | "**Downloading Data**" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "colab": { 230 | "base_uri": "https://localhost:8080/" 231 | }, 232 | "id": "TBnxcgceFT84", 233 | "outputId": "96b86e32-709d-4da4-f6a1-9d571e1c7efc" 234 | }, 235 | "source": [ 236 | "dataset_link = \"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz\"\n", 237 | "raw_data = os.path.join(raw_data_folder, \"images.tar.gz\")\n", 238 | "print(\"Downloading Data\")\n", 239 | "wget.download(dataset_link, raw_data)\n", 240 | "print(\"Downloading Done\")" 241 | ], 242 | "execution_count": 5, 243 | "outputs": [ 244 | { 245 | "output_type": "stream", 246 | "text": [ 247 | "Downloading Data\n", 248 | "Downloading Done\n" 249 | ], 250 | "name": "stdout" 251 | } 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": { 257 | "id": "pndT1J33JG1H" 258 | }, 259 | "source": [ 260 | "**Extracting the Data**" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "metadata": { 266 | "id": "dJgY_Nd3JGCQ" 267 | }, 268 | "source": [ 269 | "shutil.unpack_archive(raw_data, extracted_data_folder)" 270 | ], 271 | "execution_count": 6, 272 | "outputs": [] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": { 277 | "id": "9ys5YOmXK2ZZ" 278 | }, 279 | "source": [ 280 | "**Listing the Dataset Features**" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "id": "x9_u9N2EKuEd" 287 | }, 288 | "source": [ 289 | "def get_image_address(image_data_folder):\n", 290 | " image_address_list = []\n", 291 | " image_address_list = glob.glob(os.path.join(image_data_folder, \"*.jpg\"))\n", 292 | " print(\"Number of Files : \", len(image_address_list))\n", 293 | " for img_addr in image_address_list:\n", 294 | " try :\n", 295 | " img = cv2.imread(img_addr)\n", 296 | " x = img.shape\n", 297 | " except :\n", 298 | " image_address_list.remove(img_addr)\n", 299 | " os.remove(img_addr)\n", 300 | " \n", 301 | " print(\"Number of Files after removing : \", len(image_address_list))\n", 302 | "\n", 303 | " return image_address_list" 304 | ], 305 | "execution_count": 7, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": { 311 | "id": "0MJylShMkZQf" 312 | }, 313 | "source": [ 314 | "**MODULE_1 : Data Loader**" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "metadata": { 320 | "id": "iLBq5_zjkd9u" 321 | }, 322 | "source": [ 323 | "class DataGenerator(Dataset):\n", 324 | "\t\n", 325 | " def __init__(self, image_list):\n", 326 | " self.files = image_list[:10]\n", 327 | " \n", 328 | "\n", 329 | " def __len__(self):\n", 330 | " return len(self.files)\n", 331 | " \n", 332 | "\n", 333 | " def __getitem__(self,idx):\n", 334 | "\n", 335 | " #print(files[idx])\n", 336 | " img = cv2.imread(self.files[idx])\n", 337 | " high_res_img = cv2.resize(img,(512,512))\n", 338 | " high_res_img = np.transpose(high_res_img, (2, 0, 1))\n", 339 | " low_res_img = cv2.resize(img,(512, 512))\n", 340 | " low_res_img = cv2.cvtColor(low_res_img, cv2.COLOR_BGR2GRAY)\n", 341 | " low_res_img = np.reshape(low_res_img, (512, 512, 1))\n", 342 | " low_res_img = np.transpose(low_res_img, (2, 0, 1))\n", 343 | " return torch.FloatTensor(high_res_img/255.), torch.FloatTensor(low_res_img/255.)\n", 344 | "\t\t\n", 345 | "\t\n", 346 | "def load_data(image_list, batch_size=32, num_workers=10, shuffle=True):\n", 347 | "\n", 348 | " dataset = DataGenerator(image_list)\n", 349 | " data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)\n", 350 | "\n", 351 | " return data_loader" 352 | ], 353 | "execution_count": 12, 354 | "outputs": [] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": { 359 | "id": "dFkfUgl1amEk" 360 | }, 361 | "source": [ 362 | "**Checking the dataloader**" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "colab": { 369 | "base_uri": "https://localhost:8080/" 370 | }, 371 | "id": "2xN3ua4CZh1l", 372 | "outputId": "74058035-4898-48f6-d0e7-d9c88e942375" 373 | }, 374 | "source": [ 375 | "parent_folder = \"/content/IMAGE_SUPER_RESOLVE_DATA\"\n", 376 | "extracted_data_folder = os.path.join(parent_folder, \"extracted_data\")\n", 377 | "image_data_folder = os.path.join(extracted_data_folder, \"images\")\n", 378 | "image_address_list = get_image_address(image_data_folder)\n", 379 | "random.shuffle(image_address_list)\n", 380 | "\n", 381 | "train_img_addr_list = image_address_list[:int(0.7*len(image_address_list))]\n", 382 | "train_loader = load_data(train_img_addr_list, batch_size=1, num_workers=2, shuffle=True)\n", 383 | "check = iter(train_loader)\n", 384 | "\n" 385 | ], 386 | "execution_count": 13, 387 | "outputs": [ 388 | { 389 | "output_type": "stream", 390 | "text": [ 391 | "Number of Files : 7384\n", 392 | "Number of Files after removing : 7384\n" 393 | ], 394 | "name": "stdout" 395 | } 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "metadata": { 401 | "colab": { 402 | "base_uri": "https://localhost:8080/", 403 | "height": 221 404 | }, 405 | "id": "Eg9j4kc-aPYj", 406 | "outputId": "1d9257d2-6e0b-4cfc-cde6-7eba87418bbb" 407 | }, 408 | "source": [ 409 | " GT, input_img = next(check)\n", 410 | "input_img = input_img.numpy()[0]\n", 411 | "GT = GT.numpy()[0]\n", 412 | "\n", 413 | "input_img = np.transpose(input_img, (1, 2, 0))\n", 414 | "GT = np.transpose(GT, (1, 2, 0))\n", 415 | "\n", 416 | "input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)\n", 417 | "GT = cv2.cvtColor(GT, cv2.COLOR_BGR2RGB)\n", 418 | "\n", 419 | "f, axarr = plt.subplots(1,2)\n", 420 | "axarr[0].imshow(input_img)\n", 421 | "axarr[1].imshow(GT)" 422 | ], 423 | "execution_count": 15, 424 | "outputs": [ 425 | { 426 | "output_type": "execute_result", 427 | "data": { 428 | "text/plain": [ 429 | "" 430 | ] 431 | }, 432 | "metadata": {}, 433 | "execution_count": 15 434 | }, 435 | { 436 | "output_type": "display_data", 437 | "data": { 438 | "image/png": "\n", 439 | "text/plain": [ 440 | "
" 441 | ] 442 | }, 443 | "metadata": { 444 | "needs_background": "light" 445 | } 446 | } 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": { 452 | "id": "_Tf2iVeJvUDh" 453 | }, 454 | "source": [ 455 | "**MODULE 2 : Model Creation**" 456 | ] 457 | }, 458 | { 459 | "cell_type": "markdown", 460 | "metadata": { 461 | "id": "-E8p0uaux9w_" 462 | }, 463 | "source": [ 464 | "**Conv2D**" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "metadata": { 470 | "id": "zEYELaVgvg5g" 471 | }, 472 | "source": [ 473 | "class Conv2d(nn.Module):\n", 474 | " def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):\n", 475 | " super().__init__(*args, **kwargs)\n", 476 | " self.conv_block = nn.Sequential(\n", 477 | " nn.Conv2d(cin, cout, kernel_size, stride, padding),\n", 478 | " nn.BatchNorm2d(cout)\n", 479 | " )\n", 480 | " self.act = nn.ReLU()\n", 481 | " self.residual = residual\n", 482 | "\n", 483 | " def forward(self, x):\n", 484 | " out = self.conv_block(x)\n", 485 | " if self.residual:\n", 486 | " out += x\n", 487 | " return self.act(out)" 488 | ], 489 | "execution_count": 16, 490 | "outputs": [] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": { 495 | "id": "6bCywTOIyCIo" 496 | }, 497 | "source": [ 498 | "**Conv2D-T**" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "metadata": { 504 | "id": "h8a1mNGRx3lM" 505 | }, 506 | "source": [ 507 | "class Conv2dTranspose(nn.Module):\n", 508 | " def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):\n", 509 | " super().__init__(*args, **kwargs)\n", 510 | " self.conv_block = nn.Sequential(\n", 511 | " nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),\n", 512 | " nn.BatchNorm2d(cout)\n", 513 | " )\n", 514 | " self.act = nn.ReLU()\n", 515 | "\n", 516 | " def forward(self, x):\n", 517 | " out = self.conv_block(x)\n", 518 | " return self.act(out)" 519 | ], 520 | "execution_count": 17, 521 | "outputs": [] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": { 526 | "id": "Eahg7WyeyHla" 527 | }, 528 | "source": [ 529 | "**Model**" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "metadata": { 535 | "id": "nAYmhm0RyMJy" 536 | }, 537 | "source": [ 538 | "class Image_Super_Resolve(nn.Module):\n", 539 | " def __init__(self):\n", 540 | " super(Image_Super_Resolve, self).__init__()\n", 541 | "\n", 542 | " self.image_encoder = nn.Sequential(\n", 543 | " Conv2d(1, 4, kernel_size=3, stride=1, padding=1),\n", 544 | " \n", 545 | " Conv2d(4, 8, kernel_size=3, stride=1, padding=1),\n", 546 | " Conv2d(8, 8, kernel_size=3, stride=1, padding=1, residual=True),\n", 547 | " Conv2d(8, 8, kernel_size=3, stride=1, padding=1, residual=True),\n", 548 | " \n", 549 | " Conv2d(8, 16, kernel_size=3, stride=1, padding=1),\n", 550 | " Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),\n", 551 | " Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),\n", 552 | " \n", 553 | " Conv2d(16, 32, kernel_size=3, stride=1, padding=1),\n", 554 | " Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n", 555 | " Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), \n", 556 | " )\n", 557 | " self.image_decoder = nn.Sequential(\n", 558 | "\n", 559 | " # Conv2dTranspose(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n", 560 | "\n", 561 | " # Conv2dTranspose(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n", 562 | " Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n", 563 | " Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n", 564 | " Conv2d(32, 3, kernel_size=3, stride=1, padding=1),\n", 565 | " nn.Conv2d(3, 3, 1, 1, 0)\n", 566 | " )\n", 567 | " \n", 568 | " def forward(self, face_image):\n", 569 | "\n", 570 | " #print(\"Shape : \",face_image.shape)\n", 571 | " face_embedding = self.image_encoder(face_image)\n", 572 | " # print(\"Shape : \",face_embedding.shape)\n", 573 | " decoded_face = self.image_decoder(face_embedding)\n", 574 | " decoded_face = decoded_face + face_image\n", 575 | "\n", 576 | " decoded_face = torch.sigmoid(decoded_face)\n", 577 | " # print(\"Shape : \",decoded_face.shape)\n", 578 | " return decoded_face\n" 579 | ], 580 | "execution_count": 18, 581 | "outputs": [] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": { 586 | "id": "lkhAMdpk9Wmz" 587 | }, 588 | "source": [ 589 | "**Code to check the model shape**" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "metadata": { 595 | "id": "adkY3N__yHGs", 596 | "colab": { 597 | "base_uri": "https://localhost:8080/" 598 | }, 599 | "outputId": "de661625-4e39-41ba-ffe5-9766ff696398" 600 | }, 601 | "source": [ 602 | "model = Image_Super_Resolve()\n", 603 | "data = torch.rand(2, 1, 512, 512)\n", 604 | "print(data.shape)\n", 605 | "decoded_data = model.forward(data)\n", 606 | "print (decoded_data.shape)" 607 | ], 608 | "execution_count": 19, 609 | "outputs": [ 610 | { 611 | "output_type": "stream", 612 | "text": [ 613 | "torch.Size([2, 1, 512, 512])\n", 614 | "torch.Size([2, 3, 512, 512])\n" 615 | ], 616 | "name": "stdout" 617 | } 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "metadata": { 623 | "id": "iOBpnH7VnA4f" 624 | }, 625 | "source": [ 626 | "class PSNR:\n", 627 | " \"\"\"Peak Signal to Noise Ratio\n", 628 | " img1 and img2 have range [0, 255]\"\"\"\n", 629 | "\n", 630 | " def __init__(self):\n", 631 | " self.name = \"PSNR\"\n", 632 | "\n", 633 | " @staticmethod\n", 634 | " def __call__(img1, img2):\n", 635 | " mse = torch.mean((img1 - img2) ** 2)\n", 636 | " return 20 * torch.log10(255.0 / torch.sqrt(mse))\n" 637 | ], 638 | "execution_count": 20, 639 | "outputs": [] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "metadata": { 644 | "id": "0wg7mYyzll_A" 645 | }, 646 | "source": [ 647 | "**MODULE 3 : Training**" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": { 653 | "id": "jzpzfp66WdvW" 654 | }, 655 | "source": [ 656 | "**SAVE CHECKPOINT**" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "metadata": { 662 | "id": "Bx72WwzNWhC7" 663 | }, 664 | "source": [ 665 | "def save_ckp(checkpoint, checkpoint_path):\n", 666 | " torch.save(checkpoint, checkpoint_path)" 667 | ], 668 | "execution_count": 21, 669 | "outputs": [] 670 | }, 671 | { 672 | "cell_type": "markdown", 673 | "metadata": { 674 | "id": "4nsJKr9QWhrT" 675 | }, 676 | "source": [ 677 | "**LOAD CHECKPOINT**" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "metadata": { 683 | "id": "QYwO7XeWWmOJ" 684 | }, 685 | "source": [ 686 | "def load_ckp(checkpoint_path, model, model_opt):\n", 687 | " checkpoint = torch.load(checkpoint_path)\n", 688 | " model.load_state_dict(checkpoint['state_dict'])\n", 689 | " model_opt.load_state_dict(checkpoint['optimizer'])\n", 690 | " return model, model_opt, checkpoint['epoch']" 691 | ], 692 | "execution_count": 22, 693 | "outputs": [] 694 | }, 695 | { 696 | "cell_type": "markdown", 697 | "metadata": { 698 | "id": "sncpj_TLGqAO" 699 | }, 700 | "source": [ 701 | "**TRAIN EPOCH**" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "metadata": { 707 | "id": "MFbFHbMuBz6S" 708 | }, 709 | "source": [ 710 | "def train_epoch(train_loader, model, optimizer, epoch):\n", 711 | "\n", 712 | " progress_bar = tqdm(enumerate(train_loader))\n", 713 | " total_loss = 0.0\n", 714 | " for step, (high_res_img, low_res_img) in progress_bar:\n", 715 | " # if high_res_img is None and low_res_img is None:\n", 716 | " # continue\n", 717 | " model.train()\n", 718 | " high_res_img = high_res_img.to(device)\n", 719 | " low_res_img = low_res_img.to(device)\n", 720 | "\n", 721 | " optimizer.zero_grad()\n", 722 | "\n", 723 | " pred_img = model.forward(low_res_img)\n", 724 | "\n", 725 | " mse = nn.MSELoss()\n", 726 | " psnr = PSNR()\n", 727 | "\n", 728 | " mse_loss = mse(pred_img, high_res_img)\n", 729 | " psnr_loss = psnr(pred_img*255.0, high_res_img*255.0)\n", 730 | "\n", 731 | " loss = mse_loss\n", 732 | "\n", 733 | " # print(loss)\n", 734 | " loss.backward()\n", 735 | " optimizer.step()\n", 736 | "\n", 737 | " progress_bar.set_description(\n", 738 | " \"Epoch : {} Training Loss : {} \".format(epoch, loss))\n", 739 | "\n", 740 | "\n", 741 | " return model, optimizer" 742 | ], 743 | "execution_count": 23, 744 | "outputs": [] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "metadata": { 749 | "id": "-MI5PZkyGx5J" 750 | }, 751 | "source": [ 752 | "**VAL EPOCH**" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "metadata": { 758 | "id": "Im6okWc4BvJn" 759 | }, 760 | "source": [ 761 | "def val_epoch(val_loader, model, optimizer, epoch):\n", 762 | "\n", 763 | " progress_bar = tqdm(enumerate(val_loader))\n", 764 | " total_loss = 0.0\n", 765 | " for step, (high_res_img, low_res_img) in progress_bar:\n", 766 | "\n", 767 | " try :\n", 768 | " if high_res_img is None and low_res_img is None:\n", 769 | " continue\n", 770 | "\n", 771 | " high_res_img = high_res_img.to(device)\n", 772 | " low_res_img = low_res_img.to(device)\n", 773 | "\n", 774 | " mse = nn.MSELoss()\n", 775 | " psnr = PSNR()\n", 776 | "\n", 777 | " model.eval()\n", 778 | " pred_img = model.forward(low_res_img)\n", 779 | "\n", 780 | " mse_loss = mse(pred_img, high_res_img)\n", 781 | " psnr_loss = psnr(pred_img*255.0, high_res_img*255.0)\n", 782 | "\n", 783 | " loss = mse_loss\n", 784 | "\n", 785 | " progress_bar.set_description(\n", 786 | " \"Epoch : {} Validation Loss : {} \".format(epoch-1, loss))\n", 787 | " except :\n", 788 | " continue\n" 789 | ], 790 | "execution_count": 24, 791 | "outputs": [] 792 | }, 793 | { 794 | "cell_type": "markdown", 795 | "metadata": { 796 | "id": "HNpkA_8fG2cH" 797 | }, 798 | "source": [ 799 | "**TEST EPOCH**" 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "Mr6f0cRmBtXi" 806 | }, 807 | "source": [ 808 | "def test_epoch(test_loader, model, optimizer, epoch):\n", 809 | "\n", 810 | " progress_bar = tqdm(enumerate(test_loader))\n", 811 | " total_loss = 0.0\n", 812 | "\n", 813 | " no_img_to_write = 10\n", 814 | " inference_folder = \"/content/IMAGE_SUPER_RESOLVE_DATA/inference_data\"\n", 815 | " if not os.path.isdir(inference_folder):\n", 816 | " os.mkdir(inference_folder)\n", 817 | "\n", 818 | " if not os.path.isdir(os.path.join(inference_folder, str(epoch))):\n", 819 | " os.mkdir(os.path.join(inference_folder, str(epoch)))\n", 820 | "\n", 821 | " for step, (high_res_img, low_res_img) in progress_bar:\n", 822 | "\n", 823 | " try:\n", 824 | " if high_res_img is None and low_res_img is None:\n", 825 | " continue\n", 826 | "\n", 827 | " high_res_img = high_res_img.to(device)\n", 828 | " low_res_img = low_res_img.to(device)\n", 829 | "\n", 830 | " mse = nn.MSELoss()\n", 831 | " l1 = nn.L1Loss()\n", 832 | " psnr = PSNR()\n", 833 | "\n", 834 | " model.eval()\n", 835 | " pred_img = model.forward(low_res_img)\n", 836 | "\n", 837 | " #mse_loss = mse(pred_img, high_res_img)\n", 838 | " #psnr_loss = psnr(pred_img*255.0, high_res_img*255.0)\n", 839 | " l1_loss = l1(pred_img, high_res_img)\n", 840 | "\n", 841 | " loss = l1_loss\n", 842 | "\n", 843 | " progress_bar.set_description(\n", 844 | " \"Epoch : {} Test Loss : {} \".format(epoch-1, loss))\n", 845 | "\n", 846 | " if(step < no_img_to_write):\n", 847 | "\n", 848 | " p_img = pred_img.cpu().numpy().transpose(0, 2, 3, 1) * 255\n", 849 | " gt_img = high_res_img.cpu().numpy().transpose(0, 2, 3, 1) * 255\n", 850 | " inp_img = low_res_img.cpu().numpy().transpose(0, 2, 3, 1) * 255\n", 851 | "\n", 852 | " # cv2.imwrite(os.path.join(inference_folder, str(epoch),\n", 853 | " # \"img_\"+str(step)+\"_pred.jpg\"), p_img[0])\n", 854 | " # cv2.imwrite(os.path.join(inference_folder, str(epoch),\n", 855 | " # \"img_\"+str(step)+\"_gt.jpg\"), gt_img[0])\n", 856 | " # cv2.imwrite(os.path.join(inference_folder, str(epoch),\n", 857 | " # \"img_\"+str(step)+\"_inp.jpg\"), inp_img[0])\n", 858 | " except :\n", 859 | " continue" 860 | ], 861 | "execution_count": 27, 862 | "outputs": [] 863 | }, 864 | { 865 | "cell_type": "markdown", 866 | "metadata": { 867 | "id": "1t3jslWRG6sR" 868 | }, 869 | "source": [ 870 | "**Code to control the Train, Test & Val**" 871 | ] 872 | }, 873 | { 874 | "cell_type": "code", 875 | "metadata": { 876 | "id": "WOj06fhkBfON" 877 | }, 878 | "source": [ 879 | "def train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume):\n", 880 | "\n", 881 | " checkpoint_path = \"/content/IMAGE_SUPER_RESOLVE_DATA/checkpoint.pt\"\n", 882 | "\n", 883 | " epoch = 0\n", 884 | " if resume:\n", 885 | " model, optimizer, epoch = load_ckp(\n", 886 | " checkpoint_path, model, optimizer)\n", 887 | "\n", 888 | " while 1:\n", 889 | " model, optimizer = train_epoch(train_loader, model, optimizer, epoch)\n", 890 | " checkpoint = {'epoch': epoch+1, 'state_dict': model.state_dict(),\n", 891 | " 'optimizer': optimizer.state_dict()}\n", 892 | " save_ckp(checkpoint, checkpoint_path)\n", 893 | " print(\"Checkpoint Saved\")\n", 894 | " # model, optimizer, epoch = load_ckp(checkpoint_path, model, optimizer)\n", 895 | " # print(\"Checkpoint Loaded\")\n", 896 | " with torch.no_grad():\n", 897 | " val_epoch(val_loader, model, optimizer, epoch)\n", 898 | " test_epoch(test_loader, model, optimizer, epoch)" 899 | ], 900 | "execution_count": 26, 901 | "outputs": [] 902 | }, 903 | { 904 | "cell_type": "markdown", 905 | "metadata": { 906 | "id": "DjZ7fZkUWOO4" 907 | }, 908 | "source": [ 909 | "**MAIN FUNCTION**" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "metadata": { 915 | "id": "d80vuRX7nDFM" 916 | }, 917 | "source": [ 918 | "def main():\n", 919 | "\n", 920 | " parent_folder = \"/content/IMAGE_SUPER_RESOLVE_DATA\"\n", 921 | " extracted_data_folder = os.path.join(parent_folder, \"extracted_data\")\n", 922 | " image_data_folder = os.path.join(extracted_data_folder, \"images\")\n", 923 | " image_address_list = get_image_address(image_data_folder)\n", 924 | " random.shuffle(image_address_list)\n", 925 | "\n", 926 | " train_img_addr_list = image_address_list[:int(0.7*len(image_address_list))]\n", 927 | " val_img_addr_list = image_address_list[len(train_img_addr_list):int(\n", 928 | " len(train_img_addr_list) + 0.2*len(image_address_list))]\n", 929 | " test_img_addr_list = image_address_list[len(\n", 930 | " train_img_addr_list) + len(val_img_addr_list):]\n", 931 | "\n", 932 | " print(\"Total Number of Images : \", len(image_address_list))\n", 933 | " print(\"Train : {} Val : {} Test : {}\".format(\n", 934 | " len(train_img_addr_list), len(val_img_addr_list), len(test_img_addr_list)))\n", 935 | "\n", 936 | " train_loader = load_data(\n", 937 | " train_img_addr_list, batch_size=2, num_workers=2, shuffle=True)\n", 938 | " val_loader = load_data(val_img_addr_list, batch_size=2,\n", 939 | " num_workers=2, shuffle=True)\n", 940 | " test_loader = load_data(\n", 941 | " test_img_addr_list, batch_size=1, num_workers=2, shuffle=False)\n", 942 | "\n", 943 | " model = Image_Super_Resolve()\n", 944 | " model = model.to(device)\n", 945 | " optimizer = optim.Adam(\n", 946 | " [p for p in model.parameters() if p.requires_grad], lr=0.01)\n", 947 | " n_epoch = 100\n", 948 | " resume = False\n", 949 | " train_val_test(train_loader, val_loader, test_loader,\n", 950 | " model, optimizer, n_epoch, resume)" 951 | ], 952 | "execution_count": 28, 953 | "outputs": [] 954 | }, 955 | { 956 | "cell_type": "markdown", 957 | "metadata": { 958 | "id": "vwxERyJxWTFz" 959 | }, 960 | "source": [ 961 | "**CALLING THE MAIN FUNCTION**" 962 | ] 963 | }, 964 | { 965 | "cell_type": "code", 966 | "metadata": { 967 | "id": "R8TmBvZEDTp6", 968 | "colab": { 969 | "base_uri": "https://localhost:8080/", 970 | "height": 1000 971 | }, 972 | "outputId": "6d66510d-0a6f-450d-fa5b-b5177bfdffb4" 973 | }, 974 | "source": [ 975 | "main()" 976 | ], 977 | "execution_count": 29, 978 | "outputs": [ 979 | { 980 | "output_type": "stream", 981 | "text": [ 982 | "Number of Files : 7384\n", 983 | "Number of Files after removing : 7384\n", 984 | "Total Number of Images : 7384\n", 985 | "Train : 5168 Val : 1476 Test : 740\n" 986 | ], 987 | "name": "stdout" 988 | }, 989 | { 990 | "output_type": "stream", 991 | "text": [ 992 | "Epoch : 0 Training Loss : 0.07829584926366806 : : 5it [00:02, 2.43it/s]" 993 | ], 994 | "name": "stderr" 995 | }, 996 | { 997 | "output_type": "stream", 998 | "text": [ 999 | "Checkpoint Saved\n" 1000 | ], 1001 | "name": "stdout" 1002 | }, 1003 | { 1004 | "output_type": "stream", 1005 | "text": [ 1006 | "\n", 1007 | "Epoch : -1 Validation Loss : 0.3999069631099701 : : 5it [00:00, 6.73it/s]\n", 1008 | "Epoch : -1 Test Loss : 0.4609988331794739 : : 10it [00:00, 12.86it/s]\n", 1009 | "Epoch : 0 Training Loss : 0.05630511790513992 : : 5it [00:01, 2.62it/s]" 1010 | ], 1011 | "name": "stderr" 1012 | }, 1013 | { 1014 | "output_type": "stream", 1015 | "text": [ 1016 | "Checkpoint Saved\n" 1017 | ], 1018 | "name": "stdout" 1019 | }, 1020 | { 1021 | "output_type": "stream", 1022 | "text": [ 1023 | "\n", 1024 | "Epoch : -1 Validation Loss : 0.15192380547523499 : : 5it [00:00, 6.82it/s]\n", 1025 | "Epoch : -1 Test Loss : 0.3490328788757324 : : 10it [00:00, 13.25it/s]\n", 1026 | "Epoch : 0 Training Loss : 0.047594211995601654 : : 5it [00:01, 2.60it/s]" 1027 | ], 1028 | "name": "stderr" 1029 | }, 1030 | { 1031 | "output_type": "stream", 1032 | "text": [ 1033 | "Checkpoint Saved\n" 1034 | ], 1035 | "name": "stdout" 1036 | }, 1037 | { 1038 | "output_type": "stream", 1039 | "text": [ 1040 | "\n", 1041 | "Epoch : -1 Validation Loss : 0.06959714740514755 : : 5it [00:00, 7.03it/s]\n", 1042 | "Epoch : -1 Test Loss : 0.22591321170330048 : : 10it [00:00, 13.05it/s]\n", 1043 | "Epoch : 0 Training Loss : 0.04837007075548172 : : 5it [00:01, 2.57it/s]" 1044 | ], 1045 | "name": "stderr" 1046 | }, 1047 | { 1048 | "output_type": "stream", 1049 | "text": [ 1050 | "Checkpoint Saved\n" 1051 | ], 1052 | "name": "stdout" 1053 | }, 1054 | { 1055 | "output_type": "stream", 1056 | "text": [ 1057 | "\n", 1058 | "Epoch : -1 Validation Loss : 0.06871086359024048 : : 5it [00:00, 6.72it/s]\n", 1059 | "Epoch : -1 Test Loss : 0.18924197554588318 : : 10it [00:00, 13.37it/s]\n", 1060 | "Epoch : 0 Training Loss : 0.04076924920082092 : : 5it [00:01, 2.57it/s]" 1061 | ], 1062 | "name": "stderr" 1063 | }, 1064 | { 1065 | "output_type": "stream", 1066 | "text": [ 1067 | "Checkpoint Saved\n" 1068 | ], 1069 | "name": "stdout" 1070 | }, 1071 | { 1072 | "output_type": "stream", 1073 | "text": [ 1074 | "\n", 1075 | "Epoch : -1 Validation Loss : 0.037644024938344955 : : 5it [00:00, 6.30it/s]\n", 1076 | "Epoch : -1 Test Loss : 0.17227113246917725 : : 10it [00:00, 13.03it/s]\n", 1077 | "Epoch : 0 Training Loss : 0.034246839582920074 : : 5it [00:01, 2.60it/s]" 1078 | ], 1079 | "name": "stderr" 1080 | }, 1081 | { 1082 | "output_type": "stream", 1083 | "text": [ 1084 | "Checkpoint Saved\n" 1085 | ], 1086 | "name": "stdout" 1087 | }, 1088 | { 1089 | "output_type": "stream", 1090 | "text": [ 1091 | "\n", 1092 | "Epoch : -1 Validation Loss : 0.043083518743515015 : : 5it [00:00, 6.88it/s]\n", 1093 | "Epoch : -1 Test Loss : 0.159893199801445 : : 10it [00:00, 13.27it/s]\n", 1094 | "Epoch : 0 Training Loss : 0.02342107892036438 : : 5it [00:01, 2.61it/s]" 1095 | ], 1096 | "name": "stderr" 1097 | }, 1098 | { 1099 | "output_type": "stream", 1100 | "text": [ 1101 | "Checkpoint Saved\n" 1102 | ], 1103 | "name": "stdout" 1104 | }, 1105 | { 1106 | "output_type": "stream", 1107 | "text": [ 1108 | "\n", 1109 | "Epoch : -1 Validation Loss : 0.03414853662252426 : : 5it [00:00, 6.63it/s]\n", 1110 | "Epoch : -1 Test Loss : 0.14629341661930084 : : 10it [00:00, 13.22it/s]\n", 1111 | "Epoch : 0 Training Loss : 0.03479604423046112 : : 5it [00:01, 2.58it/s]" 1112 | ], 1113 | "name": "stderr" 1114 | }, 1115 | { 1116 | "output_type": "stream", 1117 | "text": [ 1118 | "Checkpoint Saved\n" 1119 | ], 1120 | "name": "stdout" 1121 | }, 1122 | { 1123 | "output_type": "stream", 1124 | "text": [ 1125 | "\n", 1126 | "Epoch : -1 Validation Loss : 0.03475822135806084 : : 5it [00:00, 7.00it/s]\n", 1127 | "Epoch : -1 Test Loss : 0.1355130523443222 : : 10it [00:00, 13.24it/s]\n", 1128 | "Epoch : 0 Training Loss : 0.021552659571170807 : : 5it [00:01, 2.57it/s]" 1129 | ], 1130 | "name": "stderr" 1131 | }, 1132 | { 1133 | "output_type": "stream", 1134 | "text": [ 1135 | "Checkpoint Saved\n" 1136 | ], 1137 | "name": "stdout" 1138 | }, 1139 | { 1140 | "output_type": "stream", 1141 | "text": [ 1142 | "\n", 1143 | "Epoch : -1 Validation Loss : 0.02842155657708645 : : 5it [00:00, 7.17it/s]\n", 1144 | "Epoch : -1 Test Loss : 0.11862193793058395 : : 10it [00:00, 13.04it/s]\n", 1145 | "Epoch : 0 Training Loss : 0.019530709832906723 : : 5it [00:01, 2.60it/s]" 1146 | ], 1147 | "name": "stderr" 1148 | }, 1149 | { 1150 | "output_type": "stream", 1151 | "text": [ 1152 | "Checkpoint Saved\n" 1153 | ], 1154 | "name": "stdout" 1155 | }, 1156 | { 1157 | "output_type": "stream", 1158 | "text": [ 1159 | "\n", 1160 | "Epoch : -1 Validation Loss : 0.024070754647254944 : : 5it [00:00, 6.79it/s]\n", 1161 | "Epoch : -1 Test Loss : 0.09877555072307587 : : 10it [00:00, 12.99it/s]\n", 1162 | "Epoch : 0 Training Loss : 0.017261836677789688 : : 5it [00:01, 2.57it/s]" 1163 | ], 1164 | "name": "stderr" 1165 | }, 1166 | { 1167 | "output_type": "stream", 1168 | "text": [ 1169 | "Checkpoint Saved\n" 1170 | ], 1171 | "name": "stdout" 1172 | }, 1173 | { 1174 | "output_type": "stream", 1175 | "text": [ 1176 | "\n", 1177 | "Epoch : -1 Validation Loss : 0.0100233880802989 : : 5it [00:00, 7.07it/s]\n", 1178 | "Epoch : -1 Test Loss : 0.09105490148067474 : : 10it [00:00, 13.22it/s]\n", 1179 | "Epoch : 0 Training Loss : 0.014956777915358543 : : 3it [00:01, 1.98it/s]\n" 1180 | ], 1181 | "name": "stderr" 1182 | }, 1183 | { 1184 | "output_type": "error", 1185 | "ename": "KeyboardInterrupt", 1186 | "evalue": "ignored", 1187 | "traceback": [ 1188 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 1189 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 1190 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 1191 | "\u001b[0;32m\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mresume\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m train_val_test(train_loader, val_loader, test_loader,\n\u001b[0;32m---> 33\u001b[0;31m model, optimizer, n_epoch, resume)\n\u001b[0m", 1192 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_val_test\u001b[0;34m(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m checkpoint = {'epoch': epoch+1, 'state_dict': model.state_dict(),\n\u001b[1;32m 13\u001b[0m 'optimizer': optimizer.state_dict()}\n", 1193 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_epoch\u001b[0;34m(train_loader, model, optimizer, epoch)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m progress_bar.set_description(\n\u001b[0;32m---> 29\u001b[0;31m \"Epoch : {} Training Loss : {} \".format(epoch, loss))\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 1194 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36m__format__\u001b[0;34m(self, format_spec)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__format__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__format__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mformat_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 561\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__format__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 1195 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 1196 | ] 1197 | } 1198 | ] 1199 | }, 1200 | { 1201 | "cell_type": "markdown", 1202 | "metadata": { 1203 | "id": "0nb3p4nMZcue" 1204 | }, 1205 | "source": [ 1206 | "**INFERENCE**" 1207 | ] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "metadata": { 1212 | "colab": { 1213 | "base_uri": "https://localhost:8080/", 1214 | "height": 167 1215 | }, 1216 | "id": "7cGdz-wLZf0e", 1217 | "outputId": "c59e892f-7840-4e6c-ad3d-ca912353a13a" 1218 | }, 1219 | "source": [ 1220 | "image_folder = \"/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images\"\n", 1221 | "checkpoint_path = \"/content/IMAGE_SUPER_RESOLVE_DATA/checkpoint.pt\"\n", 1222 | "\n", 1223 | "img_addr = os.path.join(image_folder, \"Abyssinian_2.jpg\")\n", 1224 | "img = cv2.imread(img_addr)\n", 1225 | "high_res_img = cv2.resize(img,(512,512))\n", 1226 | "low_res_img = cv2.resize(img,(512, 512))\n", 1227 | "low_res_img = cv2.cvtColor(low_res_img, cv2.COLOR_BGR2GRAY)\n", 1228 | "low_res_img = np.reshape(low_res_img, (512, 512, 1))\n", 1229 | "\n", 1230 | "low_res_img = np.transpose(low_res_img, (2, 0, 1))\n", 1231 | "\n", 1232 | "low_res_img = torch.FloatTensor(low_res_img/255.).unsqueeze(0)\n", 1233 | "\n", 1234 | "\n", 1235 | "model = Image_Super_Resolve()\n", 1236 | "model = model.to(device)\n", 1237 | "optimizer = optim.Adam(\n", 1238 | " [p for p in model.parameters() if p.requires_grad], lr=0.01)\n", 1239 | "model, optimizer, epoch = load_ckp(checkpoint_path, model, optimizer)\n", 1240 | "low_res_img = low_res_img.to(device)\n", 1241 | "pred_img = model.forward(low_res_img)\n", 1242 | "\n", 1243 | "p_img = pred_img.detach().cpu().numpy().transpose(0, 2, 3, 1)[0]\n", 1244 | "\n", 1245 | "# p_img = cv2.cvtColor(p_img, )\n", 1246 | "inp_img = low_res_img.cpu().numpy().transpose(0, 2, 3, 1)[0]\n", 1247 | "\n", 1248 | "inp_img = np.reshape(inp_img, (512, 512))\n", 1249 | "#p_img = np.reshape(p_img, (512, 512, 3))\n", 1250 | "\n", 1251 | "high_res_img = cv2.cvtColor(high_res_img, cv2.COLOR_BGR2RGB)\n", 1252 | "p_img = cv2.cvtColor(p_img, cv2.COLOR_BGR2RGB)\n", 1253 | "\n", 1254 | "f, axarr = plt.subplots(1,3)\n", 1255 | "axarr[0].imshow(inp_img, cmap='gray')\n", 1256 | "axarr[1].imshow(p_img)\n", 1257 | "axarr[2].imshow(high_res_img)\n", 1258 | "\n", 1259 | "# print(\"Ground Truth\")\n", 1260 | "# plt.imshow(high_res_img[:,:,::-1])\n", 1261 | "# plt.show()\n", 1262 | "\n", 1263 | "# print(\"Input Image\")\n", 1264 | "# plt.imshow(inp_img[0, :,:,::-1])\n", 1265 | "# plt.show()\n", 1266 | "\n", 1267 | "# print(\"Predicted Image\")\n", 1268 | "# plt.imshow(p_img[0, :,:,::-1])\n", 1269 | "# plt.show()" 1270 | ], 1271 | "execution_count": 46, 1272 | "outputs": [ 1273 | { 1274 | "output_type": "execute_result", 1275 | "data": { 1276 | "text/plain": [ 1277 | "" 1278 | ] 1279 | }, 1280 | "metadata": {}, 1281 | "execution_count": 46 1282 | }, 1283 | { 1284 | "output_type": "display_data", 1285 | "data": { 1286 | "image/png": "\n", 1287 | "text/plain": [ 1288 | "
" 1289 | ] 1290 | }, 1291 | "metadata": { 1292 | "needs_background": "light" 1293 | } 1294 | } 1295 | ] 1296 | } 1297 | ] 1298 | } --------------------------------------------------------------------------------