├── .gitignore ├── NNTI_Project_Vision_Kairanda_Mohanta.pdf ├── README.md ├── Vision_task_1.ipynb ├── dataloader.py ├── dataset.py ├── deeplabv3.py ├── eval.py ├── expt_logs ├── deeplab_cityscapes │ ├── slurm-4164877.out │ └── slurm-4176446.out ├── fcn_cityscapes │ ├── loss_200.png │ ├── metric_train_200.png │ ├── metric_val_200.png │ ├── roc_train_200.png │ ├── roc_val_200.png │ ├── seg_train_200.png │ ├── seg_val_200.png │ └── slurm-4164875.out ├── fcn_pascal │ ├── log.out │ ├── loss_290.png │ ├── metric_train_290.png │ ├── metric_val_290.png │ ├── seg_train_290.png │ └── seg_val_290.png ├── r2unet2_cityscapes │ └── slurm-4164873.out ├── recnet_cityscapes │ └── slurm-4164907.out ├── resunet_cityscapes │ └── slurm-4165984.out ├── slurm-4164885.out ├── slurm-4176447.out ├── unet_cityscapes │ └── slurm-4164874.out └── unet_pascal │ └── slurm-4176449.out ├── fcn.py ├── main.py ├── metrics.py ├── r2unet.py ├── resnet.py ├── scripts ├── slurm_run.sh └── slurm_setup.sh └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | slurm-output/* 2 | logs/* 3 | cityscapes/* 4 | __pycache__/* -------------------------------------------------------------------------------- /NNTI_Project_Vision_Kairanda_Mohanta.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/NNTI_Project_Vision_Kairanda_Mohanta.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R2U-net 2 | Pytorch Implementation of "Fully Convolutional Network", "Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net)" and "DeepLabV3" on PascalVOC and Cityscapes dataset. 3 | 4 | ## Contributors 5 | 6 | Navami Kairanda 7 | Priyanka Mohanta 8 | 9 | ## Requirements 10 | 11 | Following packages are used 12 | 13 | * python 3.8 14 | * pytorch 1.7 15 | * torchvision 0.8.1 16 | * pytorch-lightning 1.2.3 17 | 18 | ## Prerequisites 19 | 20 | 21 | For tasks 2 and 3, 22 | ### Dataset preparation 23 | Download and unzip gtFine_trainvaltest.zip (241MB) and leftImg8bit_trainvaltest.zip (11GB) from cityscapes site 24 | https://www.cityscapes-dataset.com/downloads/ 25 | 26 | Generate trainId labels for the dataset, using the scripts provided by Cityscape authors https://github.com/mcordts/cityscapesScripts 27 | ``` 28 | git clone https://github.com/mcordts/cityscapesScripts.git 29 | pip install cityscapesScripts 30 | CITYSCAPES_DATASET_PATH=/HPS/Navami/work/code/nnti/R2U-Net/cityscapes/ 31 | export CITYSCAPES_DATASET=$CITYSCAPES_DATASET_PATH 32 | python /HPS/Navami/work/code/nnti/cityscapesScripts/cityscapesscripts/preparation/createTrainIdLabelImgs.py 33 | ``` 34 | Download resnet pretraineed model from https://download.pytorch.org/models/resnet50-19c8e357.pth and update corresponding path in resnet.py 35 | 36 | ## Train and Test 37 | 38 | For task 1, run Vision_task_1.ipynb jupyter notebook 39 | 40 | For tasks 2 and 3, 41 | ``` 42 | python main.py /path/to/expt/logdir 43 | ``` 44 | 45 | ## Test 46 | 47 | For tasks 2 and 3, download model from Microsoft Teams 48 | 49 | ``` 50 | python eval.py /path/to/expt/logdir {model_name}.tar 51 | ``` 52 | 53 | 54 | 55 | ## References 56 | 57 | 58 | Task 1: 59 | Jonathan Long, Evan Shelhamer, and Trevor Darrell. Fully Convolutional Networks for 60 | Semantic Segmentation. arXiv e-prints, page arXiv:1411.4038, November 2014. 61 | 62 | Task 2: 63 | Md Zahangir Alom, Mahmudul Hasan, Chris Yakopcic, Tarek M Taha, and Vijayan K Asari. 64 | Recurrent residual convolutional neural network based on u-net (r2u-net) for medical image 65 | segmentation. arXiv preprint arXiv:1802.06955, 2018. 66 | 67 | Task 3: 68 | Liang-Chieh Chen, George Papandreou, Florian Schroff, and Hartwig Adam. Rethinking atrous 69 | convolution for semantic image segmentation. arXiv preprint arXiv:1706.05587, 2017. 70 | -------------------------------------------------------------------------------- /Vision_task_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "Vision_task_1.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "language": "python", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "codemirror_mode": { 18 | "name": "ipython", 19 | "version": 3 20 | }, 21 | "file_extension": ".py", 22 | "mimetype": "text/x-python", 23 | "name": "python", 24 | "nbconvert_exporter": "python", 25 | "pygments_lexer": "ipython3", 26 | "version": "3.6.9" 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "8xnrc3Of4LC9" 34 | }, 35 | "source": [ 36 | "# Image Segmentation Task 1\n", 37 | "#### Welcome to the first task of Image Segmentation. Image segmentation is the process of partitioning the image into a set of pixels representing an object. In this task, you will be introduced to the problem of image segmentation and programming pipeline involved in image segmentation." 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "7eY-YQmU4LDB" 44 | }, 45 | "source": [ 46 | "For the purpose of this task we will be using PASCAL VOC datset. The dataset contains a total of 2913 images with segmentation annotations. Code in the cell below will download the code and extract the dataset." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "zM_t4c-S3k31" 53 | }, 54 | "source": [ 55 | "!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar\n", 56 | "!tar -xvf VOCtrainval_11-May-2012.tar" 57 | ], 58 | "execution_count": null, 59 | "outputs": [] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "6lvs9XIpBaI0" 65 | }, 66 | "source": [ 67 | "!pip install scipy==1.1.0" 68 | ], 69 | "execution_count": null, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "s-A7VhYD4LDD" 76 | }, 77 | "source": [ 78 | "### 1.1 Loading the dataset" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "qunDv45j24Mg" 85 | }, 86 | "source": [ 87 | "import os\n", 88 | "from os.path import join as pjoin\n", 89 | "import collections\n", 90 | "import json\n", 91 | "import torch\n", 92 | "import imageio\n", 93 | "import numpy as np\n", 94 | "import scipy.misc as m\n", 95 | "import scipy.io as io\n", 96 | "import matplotlib.pyplot as plt\n", 97 | "import glob\n", 98 | "\n", 99 | "from PIL import Image\n", 100 | "from tqdm import tqdm\n", 101 | "from torch.utils import data\n", 102 | "from torchvision import transforms\n", 103 | "\n", 104 | "import pdb\n", 105 | "import time\n", 106 | "import torch.nn as nn\n", 107 | "import torchvision.models.vgg as vgg\n", 108 | "import torch.optim as optim\n", 109 | "import matplotlib.pyplot as plt\n", 110 | "import sys\n", 111 | "\n", 112 | "class pascalVOCDataset(data.Dataset):\n", 113 | " \"\"\"Data loader for the Pascal VOC semantic segmentation dataset.\n", 114 | "\n", 115 | " Annotations from both the original VOC data (which consist of RGB images\n", 116 | " in which colours map to specific classes) and the SBD (Berkely) dataset\n", 117 | " (where annotations are stored as .mat files) are converted into a common\n", 118 | " `label_mask` format. Under this format, each mask is an (M,N) array of\n", 119 | " integer values from 0 to 21, where 0 represents the background class.\n", 120 | "\n", 121 | " The label masks are stored in a new folder, called `pre_encoded`, which\n", 122 | " is added as a subdirectory of the `SegmentationClass` folder in the\n", 123 | " original Pascal VOC data layout.\n", 124 | "\n", 125 | " A total of five data splits are provided for working with the VOC data:\n", 126 | " train: The original VOC 2012 training data - 1464 images\n", 127 | " val: The original VOC 2012 validation data - 1449 images\n", 128 | " trainval: The combination of `train` and `val` - 2913 images\n", 129 | " train_aug: The unique images present in both the train split and\n", 130 | " training images from SBD: - 8829 images (the unique members\n", 131 | " of the result of combining lists of length 1464 and 8498)\n", 132 | " train_aug_val: The original VOC 2012 validation data minus the images\n", 133 | " present in `train_aug` (This is done with the same logic as\n", 134 | " the validation set used in FCN PAMI paper, but with VOC 2012\n", 135 | " rather than VOC 2011) - 904 images\n", 136 | " \"\"\"\n", 137 | "\n", 138 | " def __init__(\n", 139 | " self,\n", 140 | " root,\n", 141 | " sbd_path=None,\n", 142 | " split=\"train_aug\",\n", 143 | " is_transform=False,\n", 144 | " img_size=512,\n", 145 | " augmentations=None,\n", 146 | " img_norm=True,\n", 147 | " test_mode=False,\n", 148 | " ):\n", 149 | " self.root = root\n", 150 | " self.sbd_path = sbd_path\n", 151 | " self.split = split\n", 152 | " self.is_transform = is_transform\n", 153 | " self.augmentations = augmentations\n", 154 | " self.img_norm = img_norm\n", 155 | " self.test_mode = test_mode\n", 156 | " self.n_classes = 21\n", 157 | " #self.mean = np.array([104.00699, 116.66877, 122.67892])\n", 158 | " self.mean = torch.tensor([0.485, 0.456, 0.406])\n", 159 | " self.std = torch.tensor([0.229, 0.224, 0.225])\n", 160 | " self.files = collections.defaultdict(list)\n", 161 | " self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)\n", 162 | "\n", 163 | " if not self.test_mode:\n", 164 | " for split in [\"train\", \"val\", \"trainval\"]:\n", 165 | " path = pjoin(self.root, \"ImageSets/Segmentation\", split + \".txt\")\n", 166 | " file_list = tuple(open(path, \"r\"))\n", 167 | " file_list = [id_.rstrip() for id_ in file_list]\n", 168 | " self.files[split] = file_list\n", 169 | " self.setup_annotations()\n", 170 | "\n", 171 | " self.tf = transforms.Compose(\n", 172 | " [\n", 173 | " # add more trasnformations as you see fit \n", 174 | " #transforms.Resize(256),\n", 175 | " #transforms.CenterCrop(224),\n", 176 | " transforms.ToTensor(),\n", 177 | " transforms.Normalize(self.mean, self.std),\n", 178 | " ]\n", 179 | " )\n", 180 | "\n", 181 | " def __len__(self):\n", 182 | " return len(self.files[self.split])\n", 183 | "\n", 184 | " def __getitem__(self, index):\n", 185 | " im_name = self.files[self.split][index]\n", 186 | " im_path = pjoin(self.root, \"JPEGImages\", im_name + \".jpg\")\n", 187 | " lbl_path = pjoin(self.root, \"SegmentationClass/pre_encoded\", im_name + \".png\")\n", 188 | " im = Image.open(im_path) \n", 189 | " lbl = Image.open(lbl_path) \n", 190 | " if self.augmentations is not None:\n", 191 | " im, lbl = self.augmentations(im, lbl)\n", 192 | " if self.is_transform:\n", 193 | " im, lbl = self.transform(im, lbl)\n", 194 | " return im, torch.clamp(lbl, max=20)\n", 195 | "\n", 196 | " def transform(self, img, lbl):\n", 197 | " if self.img_size == (\"same\", \"same\"):\n", 198 | " pass\n", 199 | " else:\n", 200 | " img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode\n", 201 | " lbl = lbl.resize((self.img_size[0], self.img_size[1]), resample=Image.NEAREST) \n", 202 | " img = self.tf(img)\n", 203 | " lbl = torch.from_numpy(np.array(lbl)).long()\n", 204 | " lbl[lbl == 255] = 0\n", 205 | " return img, lbl\n", 206 | "\n", 207 | " def get_pascal_labels(self):\n", 208 | " \"\"\"Load the mapping that associates pascal classes with label colors\n", 209 | "\n", 210 | " Returns:\n", 211 | " np.ndarray with dimensions (21, 3)\n", 212 | " \"\"\"\n", 213 | " return np.asarray(\n", 214 | " [\n", 215 | " [0, 0, 0],\n", 216 | " [128, 0, 0],\n", 217 | " [0, 128, 0],\n", 218 | " [128, 128, 0],\n", 219 | " [0, 0, 128],\n", 220 | " [128, 0, 128],\n", 221 | " [0, 128, 128],\n", 222 | " [128, 128, 128],\n", 223 | " [64, 0, 0],\n", 224 | " [192, 0, 0],\n", 225 | " [64, 128, 0],\n", 226 | " [192, 128, 0],\n", 227 | " [64, 0, 128],\n", 228 | " [192, 0, 128],\n", 229 | " [64, 128, 128],\n", 230 | " [192, 128, 128],\n", 231 | " [0, 64, 0],\n", 232 | " [128, 64, 0],\n", 233 | " [0, 192, 0],\n", 234 | " [128, 192, 0],\n", 235 | " [0, 64, 128],\n", 236 | " ]\n", 237 | " )\n", 238 | "\n", 239 | " def encode_segmap(self, mask):\n", 240 | " \"\"\"Encode segmentation label images as pascal classes\n", 241 | "\n", 242 | " Args:\n", 243 | " mask (np.ndarray): raw segmentation label image of dimension\n", 244 | " (M, N, 3), in which the Pascal classes are encoded as colours.\n", 245 | "\n", 246 | " Returns:\n", 247 | " (np.ndarray): class map with dimensions (M,N), where the value at\n", 248 | " a given location is the integer denoting the class index.\n", 249 | " \"\"\"\n", 250 | " mask = mask.astype(int)\n", 251 | " label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)\n", 252 | " for ii, label in enumerate(self.get_pascal_labels()):\n", 253 | " label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii\n", 254 | " label_mask = label_mask.astype(int)\n", 255 | " # print(np.unique(label_mask))\n", 256 | " return label_mask\n", 257 | "\n", 258 | " def decode_segmap(self, label_mask, plot=False):\n", 259 | " \"\"\"Decode segmentation class labels into a color image\n", 260 | "\n", 261 | " Args:\n", 262 | " label_mask (np.ndarray): an (M,N) array of integer values denoting\n", 263 | " the class label at each spatial location.\n", 264 | " plot (bool, optional): whether to show the resulting color image\n", 265 | " in a figure.\n", 266 | "\n", 267 | " Returns:\n", 268 | " (np.ndarray, optional): the resulting decoded color image.\n", 269 | " \"\"\"\n", 270 | " label_colours = self.get_pascal_labels()\n", 271 | " r = label_mask.copy()\n", 272 | " g = label_mask.copy()\n", 273 | " b = label_mask.copy()\n", 274 | " for ll in range(0, self.n_classes):\n", 275 | " r[label_mask == ll] = label_colours[ll, 0]\n", 276 | " g[label_mask == ll] = label_colours[ll, 1]\n", 277 | " b[label_mask == ll] = label_colours[ll, 2]\n", 278 | " rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))\n", 279 | " rgb[:, :, 0] = r / 255.0\n", 280 | " rgb[:, :, 1] = g / 255.0\n", 281 | " rgb[:, :, 2] = b / 255.0\n", 282 | " if plot:\n", 283 | " plt.imshow(rgb)\n", 284 | " plt.show()\n", 285 | " else:\n", 286 | " return rgb\n", 287 | "\n", 288 | " def setup_annotations(self):\n", 289 | " \"\"\"Sets up Berkley annotations by adding image indices to the\n", 290 | " `train_aug` split and pre-encode all segmentation labels into the\n", 291 | " common label_mask format (if this has not already been done). This\n", 292 | " function also defines the `train_aug` and `train_aug_val` data splits\n", 293 | " according to the description in the class docstring\n", 294 | " \"\"\"\n", 295 | " sbd_path = self.sbd_path\n", 296 | " target_path = pjoin(self.root, \"SegmentationClass/pre_encoded\")\n", 297 | " if not os.path.exists(target_path):\n", 298 | " os.makedirs(target_path)\n", 299 | " train_aug = self.files[\"train\"]\n", 300 | "\n", 301 | " # keep unique elements (stable)\n", 302 | " train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])]\n", 303 | " self.files[\"train_aug\"] = train_aug\n", 304 | " set_diff = set(self.files[\"val\"]) - set(train_aug) # remove overlap\n", 305 | " self.files[\"train_aug_val\"] = list(set_diff)\n", 306 | "\n", 307 | " pre_encoded = glob.glob(pjoin(target_path, \"*.png\"))\n", 308 | " expected = np.unique(self.files[\"train_aug\"] + self.files[\"val\"]).size\n", 309 | "\n", 310 | " if len(pre_encoded) != expected:\n", 311 | " print(\"Pre-encoding segmentation masks...\")\n", 312 | "\n", 313 | " for ii in tqdm(self.files[\"trainval\"]):\n", 314 | " fname = ii + \".png\"\n", 315 | " lbl_path = pjoin(self.root, \"SegmentationClass\", fname)\n", 316 | " lbl = self.encode_segmap(m.imread(lbl_path))\n", 317 | " lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min())\n", 318 | " m.imsave(pjoin(target_path, fname), lbl)\n", 319 | "\n", 320 | " assert expected == 2913, \"unexpected dataset sizes\"" 321 | ], 322 | "execution_count": null, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "ZwcyE41Q4LDI" 329 | }, 330 | "source": [ 331 | "### 1.2 Define the model architecture(2.0 point)\n", 332 | "In this section you have the freedom to decide your own model. Keep in mind though, to perform image segmentation, you need to implement an architecture that does pixel level classification i.e. for each pixel in the image you need to predict the probability of it belonging to one of the 21 categories." 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "metadata": { 338 | "id": "CatAsvH3GTXs" 339 | }, 340 | "source": [ 341 | "import torch.nn as nn\n", 342 | "\n", 343 | "class Segnet(nn.Module):\n", 344 | " \n", 345 | " def __init__(self, n_classes):\n", 346 | " super(Segnet, self).__init__()\n", 347 | " #define the layers for your model\n", 348 | " self.vgg_model = vgg.vgg16(pretrained=True, progress=True).to(device)\n", 349 | " #del self.vgg_model.classifier\n", 350 | " self.relu = nn.ReLU(inplace=True)\n", 351 | " self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)\n", 352 | " self.bn1 = nn.BatchNorm2d(512) #TODO BN not mentioned in paper\n", 353 | " self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)\n", 354 | " self.bn2 = nn.BatchNorm2d(256)\n", 355 | " self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)\n", 356 | " self.bn3 = nn.BatchNorm2d(128)\n", 357 | " self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)\n", 358 | " self.bn4 = nn.BatchNorm2d(64)\n", 359 | " self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)\n", 360 | " self.bn5 = nn.BatchNorm2d(32)\n", 361 | " self.classifier = nn.Conv2d(32, n_classes, kernel_size=1)\n", 362 | "\n", 363 | " def forward(self, x):\n", 364 | " #define the forward pass\n", 365 | " x = self.vgg_model.features(x) # B, \n", 366 | " output = self.vgg_model.avgpool(x) # B, 512, 512, 7\n", 367 | " output_zero = torch.zeros([4, 21, 512, 512], requires_grad=True) #always background\n", 368 | " score = self.bn1(self.relu(self.deconv1(x))) # size=(N, 512, x.H/16, x.W/16)\n", 369 | " score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8)\n", 370 | " score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4)\n", 371 | " score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2)\n", 372 | " score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W)\n", 373 | " score = self.classifier(score) # size=(N, n_classes, x.H/1, x.W/1)\n", 374 | " return score # size=(N, n_class, x.H/1, x.W/1)\n" 375 | ], 376 | "execution_count": null, 377 | "outputs": [] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "metadata": { 382 | "id": "QfQiOnEkGZat" 383 | }, 384 | "source": [ 385 | "# Creating an instance of the model defined above. \n", 386 | "# You can modify it incase you need to pass paratemers to the constructor.\n", 387 | "device = torch.device(\"cuda\") #if torch.cuda.is_available() else \"cpu\")\n", 388 | "num_gpu = list(range(torch.cuda.device_count())) \n", 389 | "\n", 390 | "n_classes = 21\n", 391 | "\n", 392 | "model = nn.DataParallel(Segnet(n_classes), device_ids=num_gpu).to(device)\n", 393 | "\n", 394 | "\n" 395 | ], 396 | "execution_count": null, 397 | "outputs": [] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": { 402 | "id": "05k1AY_f4LDL" 403 | }, 404 | "source": [ 405 | "### 1.3 Hyperparameters(0.5 points)\n", 406 | "Define all the hyperparameters(not restricted to the three given below) that you find useful here." 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "metadata": { 412 | "id": "ykVrbCUw4LDL" 413 | }, 414 | "source": [ 415 | "\n", 416 | "# Hyper-parameters\n", 417 | "\n", 418 | "# Setup experiment log folder\n", 419 | "expt_logdir = sys.argv[1]\n", 420 | "os.makedirs(expt_logdir, exist_ok=True)\n", 421 | "local_path = 'VOCdevkit/VOC2012/' \n", 422 | "bs = 32 \n", 423 | "num_workers = 8 \n", 424 | "n_classes = 21\n", 425 | "img_size = 224 #'same'\n", 426 | "\n", 427 | "# Training parameters\n", 428 | "epochs = 300 \n", 429 | "lr = 0.001\n", 430 | "\n", 431 | "# Logging options\n", 432 | "i_save = 50 #save model after every i_save epochs\n", 433 | "i_vis = 10\n", 434 | "rows, cols = 5, 2 #Show 10 images in the dataset along with target and predicted masks\n", 435 | "\n" 436 | ], 437 | "execution_count": null, 438 | "outputs": [] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "id": "7tqO-3LH4LDL" 444 | }, 445 | "source": [ 446 | "### 1.4 Dataset and Dataloader(0.5 points)\n", 447 | "Create the dataset using pascalVOCDataset class defined above. Use local_path defined in the cell above as root. " 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "metadata": { 453 | "id": "kKZyzK7x4LDM" 454 | }, 455 | "source": [ 456 | "# dataset variable\n", 457 | "test_split = 'val'\n", 458 | "train_dst = pascalVOCDataset(local_path, split=\"train\", is_transform=True, img_size=img_size)\n", 459 | "test_dst = pascalVOCDataset(local_path, split=test_split, is_transform=True, img_size=img_size)\n", 460 | "\n", 461 | "# dataloader variable\n", 462 | "trainloader = torch.utils.data.DataLoader(train_dst, batch_size=bs, num_workers=num_workers, pin_memory=True, shuffle=True) \n", 463 | "testloader = torch.utils.data.DataLoader(test_dst, batch_size=bs, num_workers=num_workers, pin_memory=True, shuffle=True) \n" 464 | ], 465 | "execution_count": null, 466 | "outputs": [] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": { 471 | "id": "uH5MFgv64LDM" 472 | }, 473 | "source": [ 474 | "### 1.5 Loss fuction and Optimizer(1.0 point)\n", 475 | "Define below with the loss function you think would be most suitable for segmentation task. You are free to choose any optimizer to train the network." 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "metadata": { 481 | "id": "FP6cXGZb4LDM" 482 | }, 483 | "source": [ 484 | "# Loss fuction and Optimizer\n", 485 | "# loss function\n", 486 | "loss_f = nn.CrossEntropyLoss() \n", 487 | "\n", 488 | "# optimizer variable\n", 489 | "opt = optim.Adam(model.parameters(), lr=lr) \n", 490 | "\n" 491 | ], 492 | "execution_count": null, 493 | "outputs": [] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": { 498 | "id": "rpIuY3_s4LDM" 499 | }, 500 | "source": [ 501 | "### 1.6 Training the model(3.0 points)\n", 502 | "Your task here is to complete the code below to perform a training loop and save the model weights after each epoch of training." 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "metadata": { 508 | "id": "Xz08hSdPKODm" 509 | }, 510 | "source": [ 511 | "train_vis = Vis(train_dst, expt_logdir, rows, cols)\n", 512 | "test_vis = Vis(test_dst, expt_logdir, rows, cols)\n", 513 | "\n", 514 | "train_metrics = Metrics(n_classes, trainloader, 'train', device, expt_logdir)\n", 515 | "test_metrics = Metrics(n_classes, testloader, test_split, device, expt_logdir)\n", 516 | "\n", 517 | "epoch = -1\n", 518 | "train_metrics.compute(epoch, model)\n", 519 | "#train_metrics.plot_scalar_metrics(epoch)\n", 520 | "train_vis.visualize(epoch, model)\n", 521 | "\n", 522 | "test_metrics.compute(epoch, model)\n", 523 | "#test_metrics.plot_scalar_metrics(epoch)\n", 524 | "test_vis.visualize(epoch, model)\n", 525 | "\n", 526 | "losses = []\n", 527 | "for epoch in range(epochs):\n", 528 | " st = time.time()\n", 529 | " model.train()\n", 530 | " for i, (inputs, labels) in enumerate(trainloader):\n", 531 | " opt.zero_grad()\n", 532 | " inputs = inputs.to(device)\n", 533 | " labels = labels.to(device)\n", 534 | " predictions = model(inputs)\n", 535 | " loss = loss_f(predictions, labels)\n", 536 | " loss.backward()\n", 537 | " opt.step()\n", 538 | " if i % 20 == 0:\n", 539 | " print(\"Finish iter: {}, loss {}\".format(i, loss.data))\n", 540 | " losses.append(loss)\n", 541 | " print(\"Training epoch: {}, loss: {}, time elapsed: {},\".format(epoch, loss, time.time() - st))\n", 542 | " \n", 543 | " train_metrics.compute(epoch, model)\n", 544 | " test_metrics.compute(epoch, model)\n", 545 | " \n", 546 | " if epoch % i_save == 0:\n", 547 | " torch.save(model.state_dict(), os.path.join(expt_logdir, \"{}.tar\".format(epoch)))\n", 548 | " if epoch % i_vis == 0:\n", 549 | " test_metrics.plot_scalar_metrics(epoch) #section 1.8\n", 550 | " test_vis.visualize(epoch, model) #section1.9\n", 551 | " \n", 552 | " train_metrics.plot_scalar_metrics(epoch) #section 1.8\n", 553 | " train_vis.visualize(epoch, model) #section 1.9\n", 554 | " \n", 555 | " train_metrics.plot_loss(epoch, losses) # section 1.8" 556 | ], 557 | "execution_count": null, 558 | "outputs": [] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": { 563 | "id": "zRdCADPG4LDN" 564 | }, 565 | "source": [ 566 | "### 1.7 Evaluate your model(1.5 points)\n", 567 | "In this section you have to implement the evaluation metrics for your model. Calculate the values of F1-score, dice coefficient and AUC-ROC score on the data you used for training. You can use external packages like scikit-learn to compute above metrics." 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "metadata": { 573 | "id": "XkdpzZFF4LDN" 574 | }, 575 | "source": [ 576 | "\n", 577 | "# Plot the evaluation metrics against epochs\n", 578 | "\n", 579 | "from pytorch_lightning import metrics\n", 580 | "\n", 581 | "class Dice(metrics.Metric): \n", 582 | " \n", 583 | " def __init__(self): \n", 584 | " super().__init__()\n", 585 | " self.add_state(\"dice_score\", default=[])\n", 586 | " \n", 587 | " def update(self, pred, target):\n", 588 | " dice_score_val = metrics.functional.classification.dice_score(pred, target, bg=True)\n", 589 | " self.dice_score.append(dice_score_val.item())\n", 590 | " \n", 591 | " def compute(self):\n", 592 | " self.dice_score = torch.tensor(self.dice_score)\n", 593 | " return torch.mean(self.dice_score)\n", 594 | "\n", 595 | " \n", 596 | "class Metrics():\n", 597 | " def __init__(self, n_classes, dataloader, split, device, expt_logdir):\n", 598 | " self.dataloader = dataloader\n", 599 | " self.device = device\n", 600 | " accuracy = metrics.Accuracy().to(self.device) \n", 601 | " iou = metrics.IoU(num_classes=n_classes).to(self.device)\n", 602 | " dice = Dice().to(self.device)\n", 603 | " recall = metrics.Recall(num_classes=n_classes,average='macro', mdmc_average='global').to(self.device)\n", 604 | " roc = metrics.ROC(num_classes=n_classes,dist_sync_on_step=True).to(self.device)\n", 605 | " \n", 606 | " self.eval_metrics = {'accuracy': {'module': accuracy, 'values': []}, \n", 607 | " 'iou': {'module': iou, 'values': []}, \n", 608 | " 'dice': {'module': dice, 'values': []},\n", 609 | " 'sensitivity': {'module': recall, 'values': []},\n", 610 | " 'auroc': {'module': roc, 'values': []}\n", 611 | " }\n", 612 | " self.softmax = nn.Softmax(dim=1)\n", 613 | " self.expt_logdir = expt_logdir\n", 614 | " self.split = split\n", 615 | " \n", 616 | " def compute_auroc(self, value): \n", 617 | " fpr, tpr, _ = value\n", 618 | " auc_scores = [torch.trapz(y, x) for x, y in zip(fpr, tpr)]\n", 619 | " return torch.mean(torch.stack(auc_scores))\n", 620 | " \n", 621 | " def compute(self, epoch, model): \n", 622 | " model.eval()\n", 623 | " with torch.no_grad():\n", 624 | " for i, (inputs, labels) in enumerate(self.dataloader):\n", 625 | " inputs = inputs.to(self.device)#N, H, W\n", 626 | " labels = labels.to(self.device) #N, H, W\n", 627 | "\n", 628 | " predictions = model(inputs) #N, C, H, W\n", 629 | " predictions = self.softmax(predictions)\n", 630 | "\n", 631 | " for key in self.eval_metrics: \n", 632 | " #Evaluate AUC/ROC on subset of the training data, otherwise leads to OOM errors on GPU\n", 633 | " #Full evaluation on validation/test data\n", 634 | " if key == 'auroc' and i > 20: \n", 635 | " continue\n", 636 | " self.eval_metrics[key]['module'].update(predictions, labels)\n", 637 | " \n", 638 | " for key in self.eval_metrics: \n", 639 | " value = self.eval_metrics[key]['module'].compute()\n", 640 | " if key == 'auroc':\n", 641 | " value = self.compute_auroc(value)\n", 642 | " self.eval_metrics[key]['values'].append(value.item())\n", 643 | " self.eval_metrics[key]['module'].reset()\n", 644 | " \n", 645 | " metrics_string = \" ; \".join(\"{}: {:05.3f}\".format(key, self.eval_metrics[key]['values'][-1]) for key in self.eval_metrics)\n", 646 | " print(\"Split: {}, epoch: {}, metrics: \".format(self.split, epoch) + metrics_string) \n", 647 | "\n", 648 | " def plot_scalar_metrics(self, epoch): \n", 649 | " fig = plt.figure(figsize=(13, 5))\n", 650 | " ax = fig.gca()\n", 651 | " for key, metric in self.eval_metrics.items():\n", 652 | " ax.plot(metric['values'], label=key)\n", 653 | " ax.legend(fontsize=\"16\")\n", 654 | " ax.set_xlabel(\"Epochs\", fontsize=\"16\")\n", 655 | " ax.set_ylabel(\"Metric\", fontsize=\"16\")\n", 656 | " ax.set_title(\"Evaluation metric vs epochs\", fontsize=\"16\")\n", 657 | " plt.savefig(os.path.join(self.expt_logdir, 'metric_{}_{}.png'.format(self.split, epoch)))\n", 658 | " plt.clf()\n", 659 | " \n", 660 | " def plot_loss(self, epoch, losses): \n", 661 | " fig = plt.figure(figsize=(13, 5))\n", 662 | " ax = fig.gca()\n", 663 | " ax.plot(losses) \n", 664 | " ax.set_xlabel(\"Epochs\", fontsize=\"16\")\n", 665 | " ax.set_ylabel(\"Loss\", fontsize=\"16\")\n", 666 | " ax.set_title(\"Training loss vs. epochs\", fontsize=\"16\")\n", 667 | " plt.savefig(os.path.join(self.expt_logdir, 'loss_{}.png'.format(epoch)))\n", 668 | " plt.clf() \n", 669 | "\n" 670 | ], 671 | "execution_count": null, 672 | "outputs": [] 673 | }, 674 | { 675 | "cell_type": "markdown", 676 | "metadata": { 677 | "id": "s4D-bF384LDN" 678 | }, 679 | "source": [ 680 | "### 1.8 Plot the evaluation metrics against epochs(1.0)\n", 681 | "In section 1.6 we saved the weights of the model after each epoch. In this section, you have to calculate the evaluation metrics after each epoch of training by loading the weights for each epoch. Once you have calculated the evaluation metrics for each epoch, plot them against the epochs." 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "metadata": { 687 | "id": "4zSvY9H-4LDN" 688 | }, 689 | "source": [], 690 | "execution_count": null, 691 | "outputs": [] 692 | }, 693 | { 694 | "cell_type": "markdown", 695 | "metadata": { 696 | "id": "o5JBjJij4LDN" 697 | }, 698 | "source": [ 699 | "### 1.9 Visualize results(0.5 points)\n", 700 | "For any 10 images in the dataset, show the images along the with their segmentation mask." 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "metadata": { 706 | "id": "3NS50IL_c7Mf" 707 | }, 708 | "source": [ 709 | "def image_grid(images, rows=None, cols=None, fill=True, show_axes=False):\n", 710 | " \"\"\"\n", 711 | " A util function for plotting a grid of images.\n", 712 | "\n", 713 | " Args:\n", 714 | " images: (N, H, W, 4) array of RGBA images\n", 715 | " rows: number of rows in the grid\n", 716 | " cols: number of columns in the grid\n", 717 | " fill: boolean indicating if the space between images should be filled\n", 718 | " show_axes: boolean indicating if the axes of the plots should be visible\n", 719 | " rgb: boolean, If True, only RGB channels are plotted.\n", 720 | " If False, only the alpha channel is plotted.\n", 721 | "\n", 722 | " Returns:\n", 723 | " None\n", 724 | " \"\"\"\n", 725 | " if (rows is None) != (cols is None):\n", 726 | " raise ValueError(\"Specify either both rows and cols or neither.\")\n", 727 | "\n", 728 | " if rows is None:\n", 729 | " rows = len(images)\n", 730 | " cols = 1\n", 731 | "\n", 732 | " gridspec_kw = {\"wspace\": 0.0, \"hspace\": 0.0} if fill else {}\n", 733 | " fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))\n", 734 | "\n", 735 | " for ax, im in zip(axarr.ravel(), images):\n", 736 | " # only render RGB channels\n", 737 | " ax.imshow(im[..., :3])\n", 738 | " if not show_axes:\n", 739 | " ax.set_axis_off()\n", 740 | "\n", 741 | "class Vis():\n", 742 | " \n", 743 | " def __init__(self, dst, expt_logdir, rows, cols):\n", 744 | " \n", 745 | " self.dst = dst\n", 746 | " self.expt_logdir = expt_logdir\n", 747 | " self.rows = rows\n", 748 | " self.cols = cols\n", 749 | " self.images = []\n", 750 | " self.images_vis = []\n", 751 | " self.labels_vis = []\n", 752 | " image_ids = np.random.randint(len(dst), size=rows*cols)\n", 753 | " \n", 754 | " for image_id in image_ids: \n", 755 | " image, label = dst[image_id][0], dst[image_id][1]\n", 756 | " image = image[None, ...]\n", 757 | " self.images.append(image)\n", 758 | " \n", 759 | " image = torch.squeeze(image) \n", 760 | " image = image * self.dst.std[:, None, None] + self.dst.mean[:, None, None]\n", 761 | " image = torch.movedim(image, 0, -1) # (3,H,W) to (H,W,3) \n", 762 | " image = image.cpu().numpy()\n", 763 | " self.images_vis.append(image)\n", 764 | " \n", 765 | " label = label.cpu().numpy()\n", 766 | " label = dst.decode_segmap(label) \n", 767 | " self.labels_vis.append(label)\n", 768 | " \n", 769 | " self.images = torch.cat(self.images, axis=0)\n", 770 | " \n", 771 | " def visualize(self, epoch, model): \n", 772 | "\n", 773 | " prediction = model(self.images) \n", 774 | " prediction = torch.argmax(prediction, dim=1)\n", 775 | " prediction = prediction.cpu().numpy()\n", 776 | " \n", 777 | " rgb_vis = []\n", 778 | " for image, label, pred in zip(self.images_vis, self.labels_vis, prediction):\n", 779 | " pred = self.dst.decode_segmap(pred)\n", 780 | " rgb_vis.extend([image, label, pred])\n", 781 | " rgb_vis = np.array(rgb_vis)\n", 782 | " \n", 783 | " image_grid(rgb_vis, rows=self.rows, cols=3*self.cols) \n", 784 | " plt.savefig(os.path.join(self.expt_logdir, 'seg_{}_{}.png'.format(self.dst.split, epoch)))\n" 785 | ], 786 | "execution_count": null, 787 | "outputs": [] 788 | } 789 | ] 790 | } -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #from torchvision import datasets 3 | 4 | import torchvision.transforms as transforms 5 | import pdb 6 | import numpy as np 7 | 8 | from dataset import Cityscapes 9 | from dataset import ignoreClassId 10 | 11 | img_size = 256 12 | 13 | def targetToTensor(target): 14 | """ 15 | A util function for transforming target segmentation masks 16 | Args: 17 | target: (N, H, W) PIL images 18 | Returns: 19 | torch tensor of dimensions (N, H, W) 20 | """ 21 | target = np.array(target) 22 | target = np.where(target == 255, ignoreClassId, target) 23 | target = torch.as_tensor(target, dtype=torch.int64) 24 | return target 25 | 26 | image_transform = transforms.Compose([ 27 | # you can add other transformations in this list 28 | transforms.Resize([img_size,]), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | ]) 33 | 34 | target_transform = transforms.Compose([ 35 | transforms.Resize([img_size,], interpolation=0),#0:InterpolationMode.NEAREST 36 | targetToTensor 37 | ]) 38 | 39 | # Load Cityscapes train and test datasets 40 | def load_dataset(batch_size, num_workers, split='train'): 41 | """ 42 | A util function for loading dataset and dataloader 43 | Args: 44 | batch_size: batch size (hyperparameters) 45 | num_workers: num_workers (hyperparameters) 46 | split: Takes input as string. Can be any split allowed by the dataloader. 47 | Returns: 48 | data_loader: An iterable element of the dataset 49 | data_set: Loaded with processed dataset 50 | """ 51 | 52 | data_set = Cityscapes(root='cityscapes', split=split, mode='fine', target_type='semantic_basic', transform=image_transform, target_transform=target_transform) 53 | 54 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=True) 55 | return data_loader, data_set 56 | 57 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple 4 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 5 | 6 | from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str 7 | from torchvision.datasets.vision import VisionDataset 8 | from PIL import Image 9 | import numpy as np 10 | import pdb 11 | ignoreClassId = 19 12 | 13 | class Cityscapes(VisionDataset): 14 | """`Cityscapes `_ Dataset. 15 | 16 | Args: 17 | root (string): Root directory of dataset where directory ``leftImg8bit`` 18 | and ``gtFine`` or ``gtCoarse`` are located. 19 | split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine" 20 | otherwise ``train``, ``train_extra`` or ``val`` 21 | mode (string, optional): The quality mode to use, ``fine`` or ``coarse`` 22 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` 23 | or ``color``. Can also be a list to output a tuple with all specified target types. 24 | transform (callable, optional): A function/transform that takes in a PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 29 | and returns a transformed version. 30 | 31 | Examples: 32 | 33 | Get semantic segmentation target 34 | 35 | .. code-block:: python 36 | 37 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 38 | target_type='semantic') 39 | 40 | img, smnt = dataset[0] 41 | 42 | Get multiple targets 43 | 44 | .. code-block:: python 45 | 46 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 47 | target_type=['instance', 'color', 'polygon']) 48 | 49 | img, (inst, col, poly) = dataset[0] 50 | 51 | Validate on the "coarse" set 52 | 53 | .. code-block:: python 54 | 55 | dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', 56 | target_type='semantic') 57 | 58 | img, smnt = dataset[0] 59 | """ 60 | 61 | # Based on https://github.com/mcordts/cityscapesScripts 62 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', 63 | 'has_instances', 'ignore_in_eval', 'color']) 64 | 65 | classes = [ 66 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 67 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 68 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 69 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 70 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 71 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 72 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 73 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 74 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 75 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 76 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 77 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 78 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 79 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 80 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 81 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 82 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 83 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 84 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 85 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 86 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 87 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 88 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 89 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 90 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 91 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 92 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 93 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 94 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 95 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 96 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 97 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 98 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 99 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 100 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), 101 | ] 102 | 103 | def __init__( 104 | self, 105 | root: str, 106 | split: str = "train", 107 | mode: str = "fine", 108 | target_type: Union[List[str], str] = "instance", 109 | transform: Optional[Callable] = None, 110 | target_transform: Optional[Callable] = None, 111 | transforms: Optional[Callable] = None, 112 | ) -> None: 113 | super(Cityscapes, self).__init__(root, transforms, transform, target_transform) 114 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 115 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 116 | self.targets_dir = os.path.join(self.root, self.mode, split) 117 | self.target_type = target_type 118 | self.split = split 119 | self.images = [] 120 | self.targets = [] 121 | verify_str_arg(mode, "mode", ("fine", "coarse")) 122 | if mode == "fine": 123 | valid_modes = ("train", "test", "val") 124 | else: 125 | valid_modes = ("train", "train_extra", "val") 126 | msg = ("Unknown value '{}' for argument split if mode is '{}'. " 127 | "Valid values are {{{}}}.") 128 | msg = msg.format(split, mode, iterable_to_str(valid_modes)) 129 | verify_str_arg(split, "split", valid_modes, msg) 130 | 131 | if not isinstance(target_type, list): 132 | self.target_type = [target_type] 133 | [verify_str_arg(value, "target_type", 134 | ("instance", "semantic", "polygon", "color", "semantic_basic")) 135 | for value in self.target_type] 136 | 137 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): 138 | 139 | if split == 'train_extra': 140 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) 141 | else: 142 | image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) 143 | 144 | if self.mode == 'gtFine': 145 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) 146 | elif self.mode == 'gtCoarse': 147 | target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) 148 | 149 | if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): 150 | extract_archive(from_path=image_dir_zip, to_path=self.root) 151 | extract_archive(from_path=target_dir_zip, to_path=self.root) 152 | else: 153 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 154 | ' specified "split" and "mode" are inside the "root" directory') 155 | 156 | for city in os.listdir(self.images_dir): 157 | img_dir = os.path.join(self.images_dir, city) 158 | target_dir = os.path.join(self.targets_dir, city) 159 | for file_name in os.listdir(img_dir): 160 | target_types = [] 161 | for t in self.target_type: 162 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 163 | self._get_target_suffix(self.mode, t)) 164 | target_types.append(os.path.join(target_dir, target_name)) 165 | 166 | self.images.append(os.path.join(img_dir, file_name)) 167 | self.targets.append(target_types) 168 | 169 | self.trainId2Color = {label.train_id : label.color for label in self.classes} 170 | self.trainId2Name = {label.train_id : label.name for label in self.classes} 171 | 172 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 173 | """ 174 | Args: 175 | index (int): Index 176 | Returns: 177 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 178 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 179 | """ 180 | 181 | image = Image.open(self.images[index]).convert('RGB') 182 | 183 | targets: Any = [] 184 | for i, t in enumerate(self.target_type): 185 | if t == 'polygon': 186 | target = self._load_json(self.targets[index][i]) 187 | else: 188 | target = Image.open(self.targets[index][i]) 189 | 190 | targets.append(target) 191 | 192 | target = tuple(targets) if len(targets) > 1 else targets[0] 193 | 194 | if self.transforms is not None: 195 | image, target = self.transforms(image, target) 196 | 197 | return image, target 198 | 199 | 200 | def __len__(self) -> int: 201 | return len(self.images) 202 | 203 | def extra_repr(self) -> str: 204 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] 205 | return '\n'.join(lines).format(**self.__dict__) 206 | 207 | def _load_json(self, path: str) -> Dict[str, Any]: 208 | with open(path, 'r') as file: 209 | data = json.load(file) 210 | return data 211 | 212 | def _get_target_suffix(self, mode: str, target_type: str) -> str: 213 | if target_type == 'instance': 214 | return '{}_instanceIds.png'.format(mode) 215 | elif target_type == 'semantic': 216 | return '{}_labelIds.png'.format(mode) 217 | elif target_type == 'semantic_basic': 218 | return '{}_labelTrainIds.png'.format(mode) 219 | elif target_type == 'color': 220 | return '{}_color.png'.format(mode) 221 | else: 222 | return '{}_polygons.json'.format(mode) 223 | 224 | def decode_segmap(self, label_mask): 225 | """Decode segmentation class labels into a color image 226 | 227 | Args: 228 | label_mask (np.ndarray): an (M,N) array of integer values denoting 229 | the class label at each spatial location. 230 | 231 | Returns: 232 | (np.ndarray, optional): the resulting decoded color image. 233 | """ 234 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 235 | for trainId, color in self.trainId2Color.items(): 236 | if trainId == 255: 237 | trainId = ignoreClassId 238 | rgb[label_mask == trainId] = color 239 | rgb = rgb /255.0 240 | return rgb -------------------------------------------------------------------------------- /deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torchvision.models.vgg as vgg 5 | from resnet import ResNet50_OS16 6 | 7 | class ASPP(nn.Module): 8 | ''' 9 | This module implements Atrous spatial pyramid pooling(ASPP) on the DeepLab net with VGG backbone. 10 | 11 | Args: 12 | num_classes: number of class to be predicted 13 | feature_map: feature map produced from the backbone net. 14 | 15 | Returns: 16 | feature map after performing ASPP of shape (batch_size, num_classes, h/16, w/16) 17 | ''' 18 | def __init__(self, num_classes): 19 | super(ASPP, self).__init__() 20 | 21 | self.conv_1x1_1 = nn.Conv2d(512, 256, kernel_size=1) 22 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256) 23 | 24 | self.conv_3x3_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=6, dilation=6) 25 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256) 26 | 27 | self.conv_3x3_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12) 28 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256) 29 | 30 | self.conv_3x3_3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=18, dilation=18) 31 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256) 32 | 33 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 34 | 35 | self.conv_1x1_2 = nn.Conv2d(512, 256, kernel_size=1) 36 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256) 37 | 38 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) 39 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256) 40 | 41 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) 42 | 43 | def forward(self, feature_map): 44 | # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8)) 45 | feature_map_h = feature_map.size()[2] # (== h/16) 46 | feature_map_w = feature_map.size()[3] # (== w/16) 47 | 48 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 49 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 50 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 51 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 52 | 53 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) 54 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) 55 | #out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16)) 56 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), scale_factor=None, mode="bilinear", align_corners=True, recompute_scale_factor=None) 57 | 58 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) 59 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) 60 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) 61 | 62 | return out 63 | 64 | class ASPP_Bottleneck(nn.Module): 65 | ''' 66 | This module implements Atrous spatial pyramid pooling(ASPP) on the DeepLab net with ResNet50 backbone. 67 | 68 | Args: 69 | num_classes: number of class to be predicted 70 | feature_map: feature map produced from the backbone net. 71 | 72 | Returns: 73 | feature map after performing ASPP of shape (batch_size, num_classes, h/16, w/16) 74 | ''' 75 | def __init__(self, num_classes): 76 | super(ASPP_Bottleneck, self).__init__() 77 | 78 | self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1) 79 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256) 80 | 81 | self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6) 82 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256) 83 | 84 | self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12) 85 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256) 86 | 87 | self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18) 88 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256) 89 | 90 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 91 | 92 | self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1) 93 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256) 94 | 95 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) 96 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256) 97 | 98 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) 99 | 100 | def forward(self, feature_map): 101 | # (feature_map has shape (batch_size, 4*512, h/16, w/16)) 102 | 103 | feature_map_h = feature_map.size()[2] # (== h/16) 104 | feature_map_w = feature_map.size()[3] # (== w/16) 105 | 106 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 107 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 108 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 109 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 110 | 111 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) 112 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) 113 | #out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16)) 114 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), scale_factor=None, mode="bilinear", align_corners=True, recompute_scale_factor=None) 115 | 116 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) 117 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) 118 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) 119 | 120 | return out 121 | 122 | class DeepLabV3(nn.Module): 123 | ''' 124 | DeepLabV3 net framework 125 | 126 | Args: 127 | n_class: number of class to be predicted 128 | backbone: takes either 'vgg' or 'resnet'. This decides the pretrianed backbone selected. 129 | 130 | Returns: 131 | feature map after extracted by DeepLabV3 net of shape (batch_size, num_classes, h, w) 132 | 133 | ''' 134 | def __init__(self, n_class, backbone): 135 | super(DeepLabV3, self).__init__() 136 | 137 | self.num_classes = n_class 138 | 139 | if backbone == 'vgg': 140 | self.features = vgg.vgg16(pretrained=True).features 141 | self.aspp = ASPP(num_classes=self.num_classes) 142 | elif backbone == 'resnet': 143 | self.features = ResNet50_OS16() 144 | self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) 145 | 146 | def forward(self, x): 147 | # (x has shape (batch_size, 3, h, w)) 148 | 149 | h = x.size()[2] 150 | w = x.size()[3] 151 | 152 | #feature_map = self.pretrained.features(x) 153 | feature_map = self.features(x) #If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16)) 154 | 155 | output = self.aspp(feature_map) # (shape: (batch_size, num_classes, h/16, w/16)) 156 | 157 | #output = F.upsample(output, size=(h, w), mode="bilinear",align_corners=True) # (shape: (batch_size, num_classes, h, w)) 158 | output=F.interpolate(output, size=(h, w), scale_factor=None, mode="bilinear", align_corners=True, recompute_scale_factor=None) 159 | return output 160 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import time 7 | import pdb 8 | import sys 9 | import os 10 | 11 | from fcn import Segnet 12 | from r2unet import U_Net, R2U_Net, RecU_Net, ResU_Net 13 | from deeplabv3 import DeepLabV3 14 | from dataloader import load_dataset 15 | from metrics import Metrics 16 | from vis import Vis 17 | 18 | expt_logdir = sys.argv[1] 19 | ckpt_name = sys.argv[2] 20 | 21 | #Dataset parameters 22 | num_workers = 8 23 | batch_size = 16 24 | n_classes = 20 25 | img_size = 224 26 | test_split = 'val' 27 | 28 | # Logging options 29 | rows, cols = 5, 2 #Show 10 images in the dataset along with target and predicted masks 30 | 31 | device = torch.device("cuda")# if torch.cuda.is_available() else "cpu") 32 | num_gpu = list(range(torch.cuda.device_count())) 33 | 34 | testloader, test_dst = load_dataset(batch_size, num_workers, split=test_split) 35 | 36 | # Creating an instance of the model 37 | #model = Segnet(n_classes) #Fully Convolutional Networks 38 | #model = U_Net(img_ch=3,output_ch=n_classes) #U Network 39 | #model = R2U_Net(img_ch=3,output_ch=n_classes,t=2) #Residual Recurrent U Network, R2Unet (t=2) 40 | #model = R2U_Net(img_ch=3,output_ch=n_classes,t=3) #Residual Recurrent U Network, R2Unet (t=3) 41 | #model = RecU_Net(img_ch=3,output_ch=n_classes,t=2) #Recurrent U Network, RecUnet (t=2) 42 | #model = ResU_Net(img_ch=3,output_ch=n_classes) #Residual U Network, ResUnet 43 | #model = DeepLabV3(n_classes, 'vgg') #DeepLabV3 VGG backbone 44 | model = DeepLabV3(n_classes, 'resnet') #DeepLabV3 Resnet backbone 45 | 46 | print('Evaluation logs for model: {}'.format(model.__class__.__name__)) 47 | 48 | model = nn.DataParallel(model, device_ids=num_gpu).to(device) 49 | model_params = torch.load(os.path.join(expt_logdir, "{}".format(ckpt_name))) 50 | model.load_state_dict(model_params) 51 | 52 | #Visualization of test data 53 | test_vis = Vis(test_dst, expt_logdir, rows, cols) 54 | #Metrics calculator for test data 55 | test_metrics = Metrics(n_classes, testloader, test_split, device, expt_logdir) 56 | 57 | model.eval() 58 | for i, (inputs, labels) in enumerate(testloader): 59 | inputs = inputs.to(device) 60 | labels = labels.to(device) 61 | predictions = model(inputs) 62 | 63 | epoch = ckpt_name 64 | 65 | test_metrics.compute(epoch, model) 66 | test_metrics.plot_roc(epoch) 67 | test_vis.visualize(epoch, model) -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/loss_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/loss_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/metric_train_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/metric_train_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/metric_val_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/metric_val_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/roc_train_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/roc_train_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/roc_val_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/roc_val_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/seg_train_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/seg_train_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_cityscapes/seg_val_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_cityscapes/seg_val_200.png -------------------------------------------------------------------------------- /expt_logs/fcn_pascal/loss_290.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_pascal/loss_290.png -------------------------------------------------------------------------------- /expt_logs/fcn_pascal/metric_train_290.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_pascal/metric_train_290.png -------------------------------------------------------------------------------- /expt_logs/fcn_pascal/metric_val_290.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_pascal/metric_val_290.png -------------------------------------------------------------------------------- /expt_logs/fcn_pascal/seg_train_290.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_pascal/seg_train_290.png -------------------------------------------------------------------------------- /expt_logs/fcn_pascal/seg_val_290.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navamikairanda/R2U-Net/e8c7fe61554e703662a18f86b76a0629db81ffea/expt_logs/fcn_pascal/seg_val_290.png -------------------------------------------------------------------------------- /expt_logs/recnet_cityscapes/slurm-4164907.out: -------------------------------------------------------------------------------- 1 | Experiment logs for model: RecU_Net 2 | /HPS/Navami/static00/anaconda3/envs/nnti/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:50: UserWarning: Metric `ROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint. 3 | warnings.warn(*args, **kwargs) 4 | Split: train, epoch: -1, metrics: accuracy: 0.001 ; iou: 0.000 ; dice: 0.000 ; sensitivity: 0.050 ; auroc: 0.478 5 | Split: val, epoch: -1, metrics: accuracy: 0.001 ; iou: 0.000 ; dice: 0.000 ; sensitivity: 0.050 ; auroc: 0.477 6 | Finish iter: 0, loss 3.096630096435547 7 | Finish iter: 20, loss 1.734170913696289 8 | Finish iter: 40, loss 1.3681607246398926 9 | Finish iter: 60, loss 1.0925949811935425 10 | Finish iter: 80, loss 1.1361868381500244 11 | Finish iter: 100, loss 0.9968492984771729 12 | Finish iter: 120, loss 0.9619495868682861 13 | Finish iter: 140, loss 0.9718397259712219 14 | Finish iter: 160, loss 0.8027595281600952 15 | Finish iter: 180, loss 0.8099526762962341 16 | Training epoch: 0, loss: 0.8703349828720093, time elapsed: 222.9545829296112, 17 | Split: train, epoch: 0, metrics: accuracy: 0.658 ; iou: 0.150 ; dice: 0.195 ; sensitivity: 0.229 ; auroc: 0.836 18 | Split: val, epoch: 0, metrics: accuracy: 0.655 ; iou: 0.142 ; dice: 0.188 ; sensitivity: 0.223 ; auroc: 0.828 19 | Finish iter: 0, loss 0.7533780932426453 20 | Finish iter: 20, loss 0.7803657650947571 21 | Finish iter: 40, loss 0.7961416840553284 22 | Finish iter: 60, loss 0.7296813130378723 23 | Finish iter: 80, loss 0.8249067068099976 24 | Finish iter: 100, loss 0.7133967876434326 25 | Finish iter: 120, loss 0.7341651320457458 26 | Finish iter: 140, loss 0.9754429459571838 27 | Finish iter: 160, loss 0.8659511208534241 28 | Finish iter: 180, loss 0.9135880470275879 29 | Training epoch: 1, loss: 0.7714760303497314, time elapsed: 222.1364574432373, 30 | Split: train, epoch: 1, metrics: accuracy: 0.457 ; iou: 0.109 ; dice: 0.149 ; sensitivity: 0.162 ; auroc: 0.705 31 | Split: val, epoch: 1, metrics: accuracy: 0.437 ; iou: 0.100 ; dice: 0.142 ; sensitivity: 0.158 ; auroc: 0.684 32 | Finish iter: 0, loss 0.6544751524925232 33 | Finish iter: 20, loss 0.591164231300354 34 | Finish iter: 40, loss 0.6446842551231384 35 | Finish iter: 60, loss 0.5980395674705505 36 | Finish iter: 80, loss 0.7326205372810364 37 | Finish iter: 100, loss 0.5569791793823242 38 | Finish iter: 120, loss 0.6157584190368652 39 | Finish iter: 140, loss 0.7069805264472961 40 | Finish iter: 160, loss 0.6000252366065979 41 | Finish iter: 180, loss 0.6038381457328796 42 | Training epoch: 2, loss: 0.7230420708656311, time elapsed: 224.21626806259155, 43 | Split: train, epoch: 2, metrics: accuracy: 0.658 ; iou: 0.182 ; dice: 0.233 ; sensitivity: 0.273 ; auroc: 0.897 44 | Split: val, epoch: 2, metrics: accuracy: 0.652 ; iou: 0.170 ; dice: 0.223 ; sensitivity: 0.266 ; auroc: 0.896 45 | Finish iter: 0, loss 0.6346749663352966 46 | Finish iter: 20, loss 0.6947142481803894 47 | Finish iter: 40, loss 0.6705856919288635 48 | Finish iter: 60, loss 0.5365504622459412 49 | Finish iter: 80, loss 0.528434693813324 50 | Finish iter: 100, loss 0.5324239730834961 51 | Finish iter: 120, loss 0.6468367576599121 52 | Finish iter: 140, loss 0.7753591537475586 53 | Finish iter: 160, loss 0.5400938987731934 54 | Finish iter: 180, loss 0.5998868942260742 55 | Training epoch: 3, loss: 0.7295322418212891, time elapsed: 224.61195755004883, 56 | Split: train, epoch: 3, metrics: accuracy: 0.641 ; iou: 0.179 ; dice: 0.224 ; sensitivity: 0.258 ; auroc: 0.876 57 | Split: val, epoch: 3, metrics: accuracy: 0.641 ; iou: 0.173 ; dice: 0.219 ; sensitivity: 0.253 ; auroc: 0.867 58 | Finish iter: 0, loss 0.5563568472862244 59 | Finish iter: 20, loss 0.6322904229164124 60 | Finish iter: 40, loss 0.5547595620155334 61 | Finish iter: 60, loss 0.5865216255187988 62 | Finish iter: 80, loss 0.486945778131485 63 | Finish iter: 100, loss 0.5746210217475891 64 | Finish iter: 120, loss 0.5670445561408997 65 | Finish iter: 140, loss 0.5102142691612244 66 | Finish iter: 160, loss 0.5183233022689819 67 | Finish iter: 180, loss 0.5215664505958557 68 | Training epoch: 4, loss: 0.5192659497261047, time elapsed: 224.38132619857788, 69 | Split: train, epoch: 4, metrics: accuracy: 0.576 ; iou: 0.163 ; dice: 0.213 ; sensitivity: 0.238 ; auroc: 0.849 70 | Split: val, epoch: 4, metrics: accuracy: 0.594 ; iou: 0.166 ; dice: 0.217 ; sensitivity: 0.241 ; auroc: 0.853 71 | Finish iter: 0, loss 0.4707641303539276 72 | Finish iter: 20, loss 0.5613135695457458 73 | Finish iter: 40, loss 0.5993478894233704 74 | Finish iter: 60, loss 0.5015940070152283 75 | Finish iter: 80, loss 0.7493917942047119 76 | Finish iter: 100, loss 0.5558897256851196 77 | Finish iter: 120, loss 0.6096292734146118 78 | Finish iter: 140, loss 0.4526526629924774 79 | Finish iter: 160, loss 0.5142977237701416 80 | Finish iter: 180, loss 0.4457816183567047 81 | Training epoch: 5, loss: 0.5669042468070984, time elapsed: 224.48723697662354, 82 | Split: train, epoch: 5, metrics: accuracy: 0.408 ; iou: 0.128 ; dice: 0.162 ; sensitivity: 0.189 ; auroc: 0.754 83 | Split: val, epoch: 5, metrics: accuracy: 0.407 ; iou: 0.126 ; dice: 0.161 ; sensitivity: 0.189 ; auroc: 0.737 84 | Finish iter: 0, loss 0.606767475605011 85 | Finish iter: 20, loss 0.4652656316757202 86 | Finish iter: 40, loss 0.582310676574707 87 | Finish iter: 60, loss 0.5417382717132568 88 | Finish iter: 80, loss 0.4760960638523102 89 | Finish iter: 100, loss 0.46125590801239014 90 | Finish iter: 120, loss 0.4065217971801758 91 | Finish iter: 140, loss 0.4881526827812195 92 | Finish iter: 160, loss 0.5201624035835266 93 | Finish iter: 180, loss 0.5114322304725647 94 | Training epoch: 6, loss: 0.47913604974746704, time elapsed: 224.44510698318481, 95 | Split: train, epoch: 6, metrics: accuracy: 0.697 ; iou: 0.200 ; dice: 0.263 ; sensitivity: 0.265 ; auroc: 0.908 96 | Split: val, epoch: 6, metrics: accuracy: 0.714 ; iou: 0.204 ; dice: 0.268 ; sensitivity: 0.268 ; auroc: 0.914 97 | Finish iter: 0, loss 0.49509552121162415 98 | Finish iter: 20, loss 0.3823366165161133 99 | Finish iter: 40, loss 0.43715178966522217 100 | Finish iter: 60, loss 0.5023162961006165 101 | Finish iter: 80, loss 0.578927755355835 102 | Finish iter: 100, loss 0.5502452254295349 103 | Finish iter: 120, loss 0.48745614290237427 104 | Finish iter: 140, loss 0.5085086822509766 105 | Finish iter: 160, loss 0.42567503452301025 106 | Finish iter: 180, loss 0.4138880670070648 107 | Training epoch: 7, loss: 0.4474240243434906, time elapsed: 224.62291145324707, 108 | Split: train, epoch: 7, metrics: accuracy: 0.384 ; iou: 0.127 ; dice: 0.174 ; sensitivity: 0.204 ; auroc: 0.814 109 | Split: val, epoch: 7, metrics: accuracy: 0.395 ; iou: 0.125 ; dice: 0.174 ; sensitivity: 0.205 ; auroc: 0.803 110 | Finish iter: 0, loss 0.500510036945343 111 | Finish iter: 20, loss 0.63545823097229 112 | Finish iter: 40, loss 0.4919646382331848 113 | Finish iter: 60, loss 0.4169360399246216 114 | Finish iter: 80, loss 0.42249441146850586 115 | Finish iter: 100, loss 0.4113334119319916 116 | Finish iter: 120, loss 0.5824847221374512 117 | Finish iter: 140, loss 0.5847809314727783 118 | Finish iter: 160, loss 0.48813849687576294 119 | Finish iter: 180, loss 0.37656962871551514 120 | Training epoch: 8, loss: 0.4572499990463257, time elapsed: 224.76057362556458, 121 | Split: train, epoch: 8, metrics: accuracy: 0.457 ; iou: 0.149 ; dice: 0.189 ; sensitivity: 0.218 ; auroc: 0.850 122 | Split: val, epoch: 8, metrics: accuracy: 0.458 ; iou: 0.146 ; dice: 0.187 ; sensitivity: 0.214 ; auroc: 0.844 123 | Finish iter: 0, loss 0.41113680601119995 124 | Finish iter: 20, loss 0.37345001101493835 125 | Finish iter: 40, loss 0.4271112084388733 126 | Finish iter: 60, loss 0.4669220447540283 127 | Finish iter: 80, loss 0.40964239835739136 128 | Finish iter: 100, loss 0.46367865800857544 129 | Finish iter: 120, loss 0.4244126081466675 130 | Finish iter: 140, loss 0.3661944568157196 131 | Finish iter: 160, loss 0.4107145071029663 132 | Finish iter: 180, loss 0.3749409317970276 133 | Training epoch: 9, loss: 0.5286173224449158, time elapsed: 224.63939237594604, 134 | Split: train, epoch: 9, metrics: accuracy: 0.634 ; iou: 0.199 ; dice: 0.260 ; sensitivity: 0.286 ; auroc: 0.923 135 | Split: val, epoch: 9, metrics: accuracy: 0.648 ; iou: 0.200 ; dice: 0.261 ; sensitivity: 0.286 ; auroc: 0.922 136 | Finish iter: 0, loss 0.3928765654563904 137 | Finish iter: 20, loss 0.4250600039958954 138 | Finish iter: 40, loss 0.46629735827445984 139 | Finish iter: 60, loss 0.4707067906856537 140 | Finish iter: 80, loss 0.44829922914505005 141 | Finish iter: 100, loss 0.6314753293991089 142 | Finish iter: 120, loss 0.4258241355419159 143 | Finish iter: 140, loss 0.3183186948299408 144 | Finish iter: 160, loss 0.47284257411956787 145 | Finish iter: 180, loss 0.5213043093681335 146 | Training epoch: 10, loss: 0.45791250467300415, time elapsed: 224.82197999954224, 147 | Split: train, epoch: 10, metrics: accuracy: 0.438 ; iou: 0.146 ; dice: 0.193 ; sensitivity: 0.222 ; auroc: 0.833 148 | Split: val, epoch: 10, metrics: accuracy: 0.440 ; iou: 0.144 ; dice: 0.191 ; sensitivity: 0.219 ; auroc: 0.819 149 | Finish iter: 0, loss 0.4494038224220276 150 | Finish iter: 20, loss 0.40498948097229004 151 | Finish iter: 40, loss 0.43616026639938354 152 | Finish iter: 60, loss 0.3866649568080902 153 | Finish iter: 80, loss 0.4006049633026123 154 | Finish iter: 100, loss 0.4519660770893097 155 | Finish iter: 120, loss 0.4253271222114563 156 | Finish iter: 140, loss 0.397969126701355 157 | Finish iter: 160, loss 0.4158984124660492 158 | Finish iter: 180, loss 0.35578396916389465 159 | Training epoch: 11, loss: 0.3956805467605591, time elapsed: 222.0421645641327, 160 | Split: train, epoch: 11, metrics: accuracy: 0.532 ; iou: 0.176 ; dice: 0.231 ; sensitivity: 0.255 ; auroc: 0.875 161 | Split: val, epoch: 11, metrics: accuracy: 0.517 ; iou: 0.169 ; dice: 0.220 ; sensitivity: 0.245 ; auroc: 0.864 162 | Finish iter: 0, loss 0.38147592544555664 163 | Finish iter: 20, loss 0.42995136976242065 164 | Finish iter: 40, loss 0.4450961947441101 165 | Finish iter: 60, loss 0.4024580717086792 166 | Finish iter: 80, loss 0.4237862229347229 167 | Finish iter: 100, loss 0.450382798910141 168 | Finish iter: 120, loss 0.45022130012512207 169 | Finish iter: 140, loss 0.39389777183532715 170 | Finish iter: 160, loss 0.44292181730270386 171 | Finish iter: 180, loss 0.3599338233470917 172 | Training epoch: 12, loss: 0.4369013011455536, time elapsed: 224.34075570106506, 173 | Split: train, epoch: 12, metrics: accuracy: 0.462 ; iou: 0.154 ; dice: 0.213 ; sensitivity: 0.247 ; auroc: 0.868 174 | Split: val, epoch: 12, metrics: accuracy: 0.456 ; iou: 0.147 ; dice: 0.206 ; sensitivity: 0.240 ; auroc: 0.862 175 | Finish iter: 0, loss 0.3809870183467865 176 | Finish iter: 20, loss 0.414207398891449 177 | Finish iter: 40, loss 0.3973848223686218 178 | Finish iter: 60, loss 0.427240788936615 179 | Finish iter: 80, loss 0.3929983079433441 180 | Finish iter: 100, loss 0.44583600759506226 181 | Finish iter: 120, loss 0.4124112129211426 182 | Finish iter: 140, loss 0.35344088077545166 183 | Finish iter: 160, loss 0.4166591763496399 184 | Finish iter: 180, loss 0.5352100133895874 185 | Training epoch: 13, loss: 0.4193961024284363, time elapsed: 224.40169143676758, 186 | Split: train, epoch: 13, metrics: accuracy: 0.410 ; iou: 0.141 ; dice: 0.185 ; sensitivity: 0.240 ; auroc: 0.816 187 | Split: val, epoch: 13, metrics: accuracy: 0.400 ; iou: 0.136 ; dice: 0.177 ; sensitivity: 0.234 ; auroc: 0.808 188 | Finish iter: 0, loss 0.3833897113800049 189 | Finish iter: 20, loss 0.3788790702819824 190 | Finish iter: 40, loss 0.4196552634239197 191 | Finish iter: 60, loss 0.4937152862548828 192 | Finish iter: 80, loss 0.34513193368911743 193 | Finish iter: 100, loss 0.4234275221824646 194 | Finish iter: 120, loss 0.38009604811668396 195 | Finish iter: 140, loss 0.4472934305667877 196 | Finish iter: 160, loss 0.41466259956359863 197 | Finish iter: 180, loss 0.38428962230682373 198 | Training epoch: 14, loss: 0.27795225381851196, time elapsed: 224.69092988967896, 199 | Split: train, epoch: 14, metrics: accuracy: 0.410 ; iou: 0.145 ; dice: 0.198 ; sensitivity: 0.256 ; auroc: 0.809 200 | Split: val, epoch: 14, metrics: accuracy: 0.416 ; iou: 0.143 ; dice: 0.196 ; sensitivity: 0.248 ; auroc: 0.801 201 | Finish iter: 0, loss 0.36740925908088684 202 | Finish iter: 20, loss 0.38066917657852173 203 | Finish iter: 40, loss 0.31385648250579834 204 | Finish iter: 60, loss 0.34194064140319824 205 | Finish iter: 80, loss 0.4044250547885895 206 | Finish iter: 100, loss 0.3796727955341339 207 | Finish iter: 120, loss 0.4355350434780121 208 | Finish iter: 140, loss 0.37146759033203125 209 | Finish iter: 160, loss 0.4842931628227234 210 | Finish iter: 180, loss 0.3284326493740082 211 | Training epoch: 15, loss: 0.4855916202068329, time elapsed: 224.57464861869812, 212 | Split: train, epoch: 15, metrics: accuracy: 0.358 ; iou: 0.129 ; dice: 0.169 ; sensitivity: 0.213 ; auroc: 0.773 213 | Split: val, epoch: 15, metrics: accuracy: 0.356 ; iou: 0.128 ; dice: 0.167 ; sensitivity: 0.210 ; auroc: 0.746 214 | Finish iter: 0, loss 0.3064555823802948 215 | Finish iter: 20, loss 0.2943767309188843 216 | Finish iter: 40, loss 0.4294513165950775 217 | Finish iter: 60, loss 0.4359145760536194 218 | Finish iter: 80, loss 0.4915398955345154 219 | Finish iter: 100, loss 0.2923773527145386 220 | Finish iter: 120, loss 0.44726112484931946 221 | Finish iter: 140, loss 0.34521031379699707 222 | Finish iter: 160, loss 0.30557551980018616 223 | Finish iter: 180, loss 0.4416017234325409 224 | Training epoch: 16, loss: 0.33157557249069214, time elapsed: 224.68105149269104, 225 | Split: train, epoch: 16, metrics: accuracy: 0.434 ; iou: 0.160 ; dice: 0.219 ; sensitivity: 0.277 ; auroc: 0.859 226 | Split: val, epoch: 16, metrics: accuracy: 0.413 ; iou: 0.151 ; dice: 0.207 ; sensitivity: 0.255 ; auroc: 0.834 227 | Finish iter: 0, loss 0.41087082028388977 228 | Finish iter: 20, loss 0.38660091161727905 229 | Finish iter: 40, loss 0.3508269190788269 230 | Finish iter: 60, loss 0.3339795470237732 231 | Finish iter: 80, loss 0.32900574803352356 232 | Finish iter: 100, loss 0.5039730668067932 233 | Finish iter: 120, loss 0.4514555335044861 234 | Finish iter: 140, loss 0.3932616710662842 235 | Finish iter: 160, loss 0.4209910035133362 236 | Finish iter: 180, loss 0.3018573224544525 237 | Training epoch: 17, loss: 0.3991266191005707, time elapsed: 224.54220461845398, 238 | Split: train, epoch: 17, metrics: accuracy: 0.353 ; iou: 0.133 ; dice: 0.177 ; sensitivity: 0.205 ; auroc: 0.799 239 | Split: val, epoch: 17, metrics: accuracy: 0.352 ; iou: 0.132 ; dice: 0.175 ; sensitivity: 0.204 ; auroc: 0.796 240 | Finish iter: 0, loss 0.36665427684783936 241 | Finish iter: 20, loss 0.3940759599208832 242 | Finish iter: 40, loss 0.39053159952163696 243 | Finish iter: 60, loss 0.4100719392299652 244 | Finish iter: 80, loss 0.4159899055957794 245 | Finish iter: 100, loss 0.3502296507358551 246 | Finish iter: 120, loss 0.3434220850467682 247 | Finish iter: 140, loss 0.3621370196342468 248 | Finish iter: 160, loss 0.40795382857322693 249 | Finish iter: 180, loss 0.34858500957489014 250 | Training epoch: 18, loss: 0.34872114658355713, time elapsed: 224.5522906780243, 251 | Split: train, epoch: 18, metrics: accuracy: 0.340 ; iou: 0.138 ; dice: 0.185 ; sensitivity: 0.233 ; auroc: 0.782 252 | Split: val, epoch: 18, metrics: accuracy: 0.322 ; iou: 0.132 ; dice: 0.176 ; sensitivity: 0.224 ; auroc: 0.775 253 | Finish iter: 0, loss 0.3808671534061432 254 | Finish iter: 20, loss 0.4460797607898712 255 | Finish iter: 40, loss 0.3672599494457245 256 | Finish iter: 60, loss 0.4194376468658447 257 | Finish iter: 80, loss 0.3973221480846405 258 | Finish iter: 100, loss 0.30685436725616455 259 | Finish iter: 120, loss 0.3381350040435791 260 | Finish iter: 140, loss 0.328183114528656 261 | Finish iter: 160, loss 0.2665712833404541 262 | Finish iter: 180, loss 0.34319230914115906 263 | Training epoch: 19, loss: 0.3452523648738861, time elapsed: 224.63177466392517, 264 | Split: train, epoch: 19, metrics: accuracy: 0.281 ; iou: 0.109 ; dice: 0.152 ; sensitivity: 0.204 ; auroc: 0.747 265 | Split: val, epoch: 19, metrics: accuracy: 0.267 ; iou: 0.105 ; dice: 0.145 ; sensitivity: 0.193 ; auroc: 0.714 266 | Finish iter: 0, loss 0.3419574499130249 267 | Finish iter: 20, loss 0.3091698884963989 268 | Finish iter: 40, loss 0.3453525900840759 269 | Finish iter: 60, loss 0.3618669807910919 270 | Finish iter: 80, loss 0.3083277642726898 271 | Finish iter: 100, loss 0.35186922550201416 272 | Finish iter: 120, loss 0.3879014551639557 273 | Finish iter: 140, loss 0.33003848791122437 274 | Finish iter: 160, loss 0.3634563982486725 275 | Finish iter: 180, loss 0.388264536857605 276 | Training epoch: 20, loss: 0.3991178274154663, time elapsed: 224.45906805992126, 277 | Split: train, epoch: 20, metrics: accuracy: 0.341 ; iou: 0.123 ; dice: 0.169 ; sensitivity: 0.213 ; auroc: 0.747 278 | Split: val, epoch: 20, metrics: accuracy: 0.350 ; iou: 0.124 ; dice: 0.170 ; sensitivity: 0.214 ; auroc: 0.735 279 | /HPS/Navami/work/code/nnti/R2U-Net/metrics.py:76: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). 280 | fig = plt.figure(figsize=(13, 5)) 281 | Finish iter: 0, loss 0.37978535890579224 282 | Finish iter: 20, loss 0.3246116638183594 283 | Finish iter: 40, loss 0.33002081513404846 284 | Finish iter: 60, loss 0.35457926988601685 285 | Finish iter: 80, loss 0.36316782236099243 286 | Finish iter: 100, loss 0.3146795928478241 287 | Finish iter: 120, loss 0.298304945230484 288 | Finish iter: 140, loss 0.38142672181129456 289 | Finish iter: 160, loss 0.28994330763816833 290 | Finish iter: 180, loss 0.38127878308296204 291 | Training epoch: 21, loss: 0.285248726606369, time elapsed: 222.40826106071472, 292 | Split: train, epoch: 21, metrics: accuracy: 0.190 ; iou: 0.082 ; dice: 0.124 ; sensitivity: 0.153 ; auroc: 0.655 293 | Split: val, epoch: 21, metrics: accuracy: 0.197 ; iou: 0.083 ; dice: 0.124 ; sensitivity: 0.153 ; auroc: 0.643 294 | Finish iter: 0, loss 0.35551750659942627 295 | Finish iter: 20, loss 0.39345964789390564 296 | Finish iter: 40, loss 0.41399240493774414 297 | Finish iter: 60, loss 0.28399330377578735 298 | Finish iter: 80, loss 0.31827837228775024 299 | Finish iter: 100, loss 0.32892072200775146 300 | Finish iter: 120, loss 0.26426270604133606 301 | Finish iter: 140, loss 0.3769584894180298 302 | Finish iter: 160, loss 0.3680155873298645 303 | Finish iter: 180, loss 0.3657911717891693 304 | Training epoch: 22, loss: 0.4583113491535187, time elapsed: 224.56580138206482, 305 | Split: train, epoch: 22, metrics: accuracy: 0.289 ; iou: 0.106 ; dice: 0.156 ; sensitivity: 0.181 ; auroc: 0.737 306 | Split: val, epoch: 22, metrics: accuracy: 0.292 ; iou: 0.107 ; dice: 0.156 ; sensitivity: 0.184 ; auroc: 0.732 307 | Finish iter: 0, loss 0.3023519217967987 308 | Finish iter: 20, loss 0.2918838858604431 309 | Finish iter: 40, loss 0.31473496556282043 310 | Finish iter: 60, loss 0.3169395625591278 311 | Finish iter: 80, loss 0.3990042507648468 312 | Finish iter: 100, loss 0.2976300120353699 313 | Finish iter: 120, loss 0.3199124336242676 314 | Finish iter: 140, loss 0.2753380835056305 315 | Finish iter: 160, loss 0.313218891620636 316 | Finish iter: 180, loss 0.3559046983718872 317 | Training epoch: 23, loss: 0.2921752333641052, time elapsed: 224.6181263923645, 318 | Split: train, epoch: 23, metrics: accuracy: 0.209 ; iou: 0.095 ; dice: 0.131 ; sensitivity: 0.188 ; auroc: 0.715 319 | Split: val, epoch: 23, metrics: accuracy: 0.206 ; iou: 0.092 ; dice: 0.127 ; sensitivity: 0.186 ; auroc: 0.717 320 | Finish iter: 0, loss 0.406324177980423 321 | Finish iter: 20, loss 0.28551095724105835 322 | Finish iter: 40, loss 0.2799364924430847 323 | Finish iter: 60, loss 0.3632131814956665 324 | Finish iter: 80, loss 0.4061095714569092 325 | Finish iter: 100, loss 0.3283922076225281 326 | Finish iter: 120, loss 0.3069410026073456 327 | Finish iter: 140, loss 0.3122752010822296 328 | Finish iter: 160, loss 0.412832111120224 329 | Finish iter: 180, loss 0.3134084641933441 330 | Training epoch: 24, loss: 0.2899927496910095, time elapsed: 224.51668286323547, 331 | Split: train, epoch: 24, metrics: accuracy: 0.260 ; iou: 0.088 ; dice: 0.132 ; sensitivity: 0.147 ; auroc: 0.722 332 | Split: val, epoch: 24, metrics: accuracy: 0.239 ; iou: 0.081 ; dice: 0.122 ; sensitivity: 0.135 ; auroc: 0.711 333 | Finish iter: 0, loss 0.3531290292739868 334 | Finish iter: 20, loss 0.3644753694534302 335 | Finish iter: 40, loss 0.2889963984489441 336 | Finish iter: 60, loss 0.24203309416770935 337 | Finish iter: 80, loss 0.2590658664703369 338 | Finish iter: 100, loss 0.2770468294620514 339 | Finish iter: 120, loss 0.2934582233428955 340 | Finish iter: 140, loss 0.3050212562084198 341 | Finish iter: 160, loss 0.39288514852523804 342 | Finish iter: 180, loss 0.2981954514980316 343 | Training epoch: 25, loss: 0.3111414909362793, time elapsed: 224.64652729034424, 344 | Split: train, epoch: 25, metrics: accuracy: 0.179 ; iou: 0.067 ; dice: 0.098 ; sensitivity: 0.123 ; auroc: 0.654 345 | Split: val, epoch: 25, metrics: accuracy: 0.169 ; iou: 0.064 ; dice: 0.094 ; sensitivity: 0.114 ; auroc: 0.629 346 | Finish iter: 0, loss 0.27898475527763367 347 | Finish iter: 20, loss 0.32310914993286133 348 | Finish iter: 40, loss 0.21607059240341187 349 | Finish iter: 60, loss 0.29836222529411316 350 | Finish iter: 80, loss 0.35458576679229736 351 | Finish iter: 100, loss 0.2929541766643524 352 | Finish iter: 120, loss 0.286101371049881 353 | Finish iter: 140, loss 0.30629226565361023 354 | Finish iter: 160, loss 0.26518699526786804 355 | Finish iter: 180, loss 0.45952025055885315 356 | Training epoch: 26, loss: 0.32927560806274414, time elapsed: 225.18393516540527, 357 | Split: train, epoch: 26, metrics: accuracy: 0.166 ; iou: 0.070 ; dice: 0.107 ; sensitivity: 0.152 ; auroc: 0.675 358 | Split: val, epoch: 26, metrics: accuracy: 0.162 ; iou: 0.069 ; dice: 0.105 ; sensitivity: 0.140 ; auroc: 0.637 359 | Finish iter: 0, loss 0.31298893690109253 360 | Finish iter: 20, loss 0.2934369742870331 361 | Finish iter: 40, loss 0.31999602913856506 362 | Finish iter: 60, loss 0.3501349091529846 363 | Finish iter: 80, loss 0.2627689838409424 364 | Finish iter: 100, loss 0.31145307421684265 365 | Finish iter: 120, loss 0.31612956523895264 366 | Finish iter: 140, loss 0.2735452950000763 367 | Finish iter: 160, loss 0.22983306646347046 368 | Finish iter: 180, loss 0.2922663390636444 369 | Training epoch: 27, loss: 0.37907907366752625, time elapsed: 224.99953508377075, 370 | Split: train, epoch: 27, metrics: accuracy: 0.209 ; iou: 0.089 ; dice: 0.125 ; sensitivity: 0.170 ; auroc: 0.718 371 | Split: val, epoch: 27, metrics: accuracy: 0.197 ; iou: 0.086 ; dice: 0.120 ; sensitivity: 0.165 ; auroc: 0.696 372 | Finish iter: 0, loss 0.274125874042511 373 | Finish iter: 20, loss 0.3254241645336151 374 | Finish iter: 40, loss 0.27571600675582886 375 | Finish iter: 60, loss 0.3376780152320862 376 | Finish iter: 80, loss 0.2994401156902313 377 | Finish iter: 100, loss 0.279597669839859 378 | Finish iter: 120, loss 0.29254019260406494 379 | Finish iter: 140, loss 0.3131135404109955 380 | Finish iter: 160, loss 0.29822972416877747 381 | Finish iter: 180, loss 0.26995110511779785 382 | Training epoch: 28, loss: 0.4325577914714813, time elapsed: 225.03701901435852, 383 | Split: train, epoch: 28, metrics: accuracy: 0.147 ; iou: 0.052 ; dice: 0.082 ; sensitivity: 0.135 ; auroc: 0.645 384 | Split: val, epoch: 28, metrics: accuracy: 0.143 ; iou: 0.051 ; dice: 0.080 ; sensitivity: 0.126 ; auroc: 0.623 385 | Finish iter: 0, loss 0.2827097177505493 386 | Finish iter: 20, loss 0.3212285041809082 387 | Finish iter: 40, loss 0.305167555809021 388 | Finish iter: 60, loss 0.3265172243118286 389 | Finish iter: 80, loss 0.3406413197517395 390 | Finish iter: 100, loss 0.29336410760879517 391 | Finish iter: 120, loss 0.2871347963809967 392 | Finish iter: 140, loss 0.3297807276248932 393 | Finish iter: 160, loss 0.3518953025341034 394 | Finish iter: 180, loss 0.29808467626571655 395 | Training epoch: 29, loss: 0.2874864339828491, time elapsed: 225.00453519821167, 396 | Split: train, epoch: 29, metrics: accuracy: 0.177 ; iou: 0.077 ; dice: 0.109 ; sensitivity: 0.161 ; auroc: 0.740 397 | Split: val, epoch: 29, metrics: accuracy: 0.175 ; iou: 0.077 ; dice: 0.108 ; sensitivity: 0.158 ; auroc: 0.728 398 | Finish iter: 0, loss 0.2453809231519699 399 | Finish iter: 20, loss 0.27719762921333313 400 | Finish iter: 40, loss 0.3016561269760132 401 | Finish iter: 60, loss 0.3434916138648987 402 | Finish iter: 80, loss 0.22985142469406128 403 | Finish iter: 100, loss 0.26230889558792114 404 | Finish iter: 120, loss 0.2923572063446045 405 | Finish iter: 140, loss 0.2886323928833008 406 | Finish iter: 160, loss 0.35096442699432373 407 | Finish iter: 180, loss 0.28475844860076904 408 | Training epoch: 30, loss: 0.2563552260398865, time elapsed: 224.95156741142273, 409 | Split: train, epoch: 30, metrics: accuracy: 0.205 ; iou: 0.072 ; dice: 0.109 ; sensitivity: 0.170 ; auroc: 0.730 410 | Split: val, epoch: 30, metrics: accuracy: 0.200 ; iou: 0.073 ; dice: 0.108 ; sensitivity: 0.155 ; auroc: 0.722 411 | Finish iter: 0, loss 0.2743076980113983 412 | Finish iter: 20, loss 0.27766740322113037 413 | Finish iter: 40, loss 0.3013075888156891 414 | Finish iter: 60, loss 0.28021439909935 415 | Finish iter: 80, loss 0.38058996200561523 416 | Finish iter: 100, loss 0.2762143611907959 417 | Finish iter: 120, loss 0.26985424757003784 418 | Finish iter: 140, loss 0.2425134778022766 419 | Finish iter: 160, loss 0.2543611228466034 420 | Finish iter: 180, loss 0.2683083713054657 421 | Training epoch: 31, loss: 0.3686622679233551, time elapsed: 222.39211130142212, 422 | Split: train, epoch: 31, metrics: accuracy: 0.161 ; iou: 0.051 ; dice: 0.085 ; sensitivity: 0.143 ; auroc: 0.742 423 | Split: val, epoch: 31, metrics: accuracy: 0.153 ; iou: 0.050 ; dice: 0.083 ; sensitivity: 0.136 ; auroc: 0.736 424 | Finish iter: 0, loss 0.2516822814941406 425 | Finish iter: 20, loss 0.3445367217063904 426 | Finish iter: 40, loss 0.3368116021156311 427 | Finish iter: 60, loss 0.33008840680122375 428 | Finish iter: 80, loss 0.2853473722934723 429 | Finish iter: 100, loss 0.33237770199775696 430 | Finish iter: 120, loss 0.2574792802333832 431 | Finish iter: 140, loss 0.2718927562236786 432 | Finish iter: 160, loss 0.23351307213306427 433 | Finish iter: 180, loss 0.2738317549228668 434 | Training epoch: 32, loss: 0.260929137468338, time elapsed: 224.53262901306152, 435 | Split: train, epoch: 32, metrics: accuracy: 0.083 ; iou: 0.024 ; dice: 0.041 ; sensitivity: 0.105 ; auroc: 0.687 436 | Split: val, epoch: 32, metrics: accuracy: 0.078 ; iou: 0.022 ; dice: 0.038 ; sensitivity: 0.095 ; auroc: 0.676 437 | Finish iter: 0, loss 0.26160529255867004 438 | Finish iter: 20, loss 0.2686249911785126 439 | Finish iter: 40, loss 0.20337703824043274 440 | Finish iter: 60, loss 0.3191201388835907 441 | Finish iter: 80, loss 0.3112674355506897 442 | Finish iter: 100, loss 0.26892751455307007 443 | Finish iter: 120, loss 0.24695800244808197 444 | Finish iter: 140, loss 0.23338539898395538 445 | Finish iter: 160, loss 0.2973196804523468 446 | Finish iter: 180, loss 0.2834741175174713 447 | Training epoch: 33, loss: 0.2615412473678589, time elapsed: 224.5028417110443, 448 | Split: train, epoch: 33, metrics: accuracy: 0.168 ; iou: 0.065 ; dice: 0.097 ; sensitivity: 0.143 ; auroc: 0.663 449 | Split: val, epoch: 33, metrics: accuracy: 0.163 ; iou: 0.063 ; dice: 0.094 ; sensitivity: 0.137 ; auroc: 0.649 450 | Finish iter: 0, loss 0.29246729612350464 451 | Finish iter: 20, loss 0.25836461782455444 452 | Finish iter: 40, loss 0.25087815523147583 453 | Finish iter: 60, loss 0.3124537765979767 454 | Finish iter: 80, loss 0.29675087332725525 455 | Finish iter: 100, loss 0.21701502799987793 456 | Finish iter: 120, loss 0.22084668278694153 457 | Finish iter: 140, loss 0.24752181768417358 458 | Finish iter: 160, loss 0.22644315659999847 459 | Finish iter: 180, loss 0.29453057050704956 460 | Training epoch: 34, loss: 0.22412998974323273, time elapsed: 224.57934999465942, 461 | Split: train, epoch: 34, metrics: accuracy: 0.140 ; iou: 0.067 ; dice: 0.098 ; sensitivity: 0.146 ; auroc: 0.669 462 | Split: val, epoch: 34, metrics: accuracy: 0.134 ; iou: 0.064 ; dice: 0.094 ; sensitivity: 0.137 ; auroc: 0.652 463 | Finish iter: 0, loss 0.2676863670349121 464 | Finish iter: 20, loss 0.2569328546524048 465 | Finish iter: 40, loss 0.24016353487968445 466 | Finish iter: 60, loss 0.2988206446170807 467 | Finish iter: 80, loss 0.2553076446056366 468 | Finish iter: 100, loss 0.26859259605407715 469 | Finish iter: 120, loss 0.2834089994430542 470 | Finish iter: 140, loss 0.2752552628517151 471 | Finish iter: 160, loss 0.2848752737045288 472 | Finish iter: 180, loss 0.2506951689720154 473 | Training epoch: 35, loss: 0.2483908236026764, time elapsed: 224.5297200679779, 474 | Split: train, epoch: 35, metrics: accuracy: 0.272 ; iou: 0.074 ; dice: 0.114 ; sensitivity: 0.178 ; auroc: 0.838 475 | Split: val, epoch: 35, metrics: accuracy: 0.256 ; iou: 0.070 ; dice: 0.108 ; sensitivity: 0.162 ; auroc: 0.813 476 | Finish iter: 0, loss 0.31354382634162903 477 | Finish iter: 20, loss 0.24957585334777832 478 | Finish iter: 40, loss 0.27353864908218384 479 | Finish iter: 60, loss 0.2660386264324188 480 | Finish iter: 80, loss 0.23612399399280548 481 | Finish iter: 100, loss 0.21762485802173615 482 | Finish iter: 120, loss 0.26515960693359375 483 | Finish iter: 140, loss 0.20159664750099182 484 | Finish iter: 160, loss 0.26149803400039673 485 | Finish iter: 180, loss 0.24584227800369263 486 | Training epoch: 36, loss: 0.35677823424339294, time elapsed: 224.63166570663452, 487 | Split: train, epoch: 36, metrics: accuracy: 0.122 ; iou: 0.036 ; dice: 0.062 ; sensitivity: 0.131 ; auroc: 0.764 488 | Split: val, epoch: 36, metrics: accuracy: 0.116 ; iou: 0.034 ; dice: 0.059 ; sensitivity: 0.124 ; auroc: 0.754 489 | Finish iter: 0, loss 0.2086993157863617 490 | Finish iter: 20, loss 0.25042128562927246 491 | Finish iter: 40, loss 0.27439257502555847 492 | Finish iter: 60, loss 0.27353277802467346 493 | Finish iter: 80, loss 0.2510380148887634 494 | Finish iter: 100, loss 0.24373498558998108 495 | Finish iter: 120, loss 0.24634885787963867 496 | Finish iter: 140, loss 0.22443893551826477 497 | Finish iter: 160, loss 0.27249130606651306 498 | Finish iter: 180, loss 0.2668968439102173 499 | Training epoch: 37, loss: 0.2307763248682022, time elapsed: 224.54947638511658, 500 | Split: train, epoch: 37, metrics: accuracy: 0.109 ; iou: 0.026 ; dice: 0.042 ; sensitivity: 0.111 ; auroc: 0.730 501 | Split: val, epoch: 37, metrics: accuracy: 0.104 ; iou: 0.024 ; dice: 0.039 ; sensitivity: 0.101 ; auroc: 0.719 502 | Finish iter: 0, loss 0.2041214555501938 503 | Finish iter: 20, loss 0.26884686946868896 504 | Finish iter: 40, loss 0.24988485872745514 505 | Finish iter: 60, loss 0.2689952552318573 506 | Finish iter: 80, loss 0.26950401067733765 507 | Finish iter: 100, loss 0.23611576855182648 508 | Finish iter: 120, loss 0.21867437660694122 509 | Finish iter: 140, loss 0.241507425904274 510 | Finish iter: 160, loss 0.24259521067142487 511 | Finish iter: 180, loss 0.2915859818458557 512 | Training epoch: 38, loss: 0.24499215185642242, time elapsed: 224.67239665985107, 513 | Split: train, epoch: 38, metrics: accuracy: 0.118 ; iou: 0.032 ; dice: 0.056 ; sensitivity: 0.115 ; auroc: 0.741 514 | Split: val, epoch: 38, metrics: accuracy: 0.115 ; iou: 0.032 ; dice: 0.055 ; sensitivity: 0.107 ; auroc: 0.734 515 | Finish iter: 0, loss 0.20849041640758514 516 | Finish iter: 20, loss 0.2347530573606491 517 | Finish iter: 40, loss 0.24605247378349304 518 | Finish iter: 60, loss 0.26445212960243225 519 | Finish iter: 80, loss 0.26354604959487915 520 | Finish iter: 100, loss 0.2515641152858734 521 | Finish iter: 120, loss 0.2469959855079651 522 | Finish iter: 140, loss 0.2658264636993408 523 | Finish iter: 160, loss 0.2823251783847809 524 | Finish iter: 180, loss 0.22824491560459137 525 | Training epoch: 39, loss: 0.2730962038040161, time elapsed: 224.63642048835754, 526 | Split: train, epoch: 39, metrics: accuracy: 0.224 ; iou: 0.063 ; dice: 0.099 ; sensitivity: 0.152 ; auroc: 0.780 527 | Split: val, epoch: 39, metrics: accuracy: 0.219 ; iou: 0.061 ; dice: 0.095 ; sensitivity: 0.146 ; auroc: 0.773 528 | Finish iter: 0, loss 0.2262418270111084 529 | Finish iter: 20, loss 0.2660892903804779 530 | Finish iter: 40, loss 0.22499984502792358 531 | Finish iter: 60, loss 0.28787514567375183 532 | Finish iter: 80, loss 0.2684687077999115 533 | Finish iter: 100, loss 0.27320003509521484 534 | Finish iter: 120, loss 0.26300981640815735 535 | Finish iter: 140, loss 0.2274027317762375 536 | Finish iter: 160, loss 0.22937822341918945 537 | Finish iter: 180, loss 0.2617340385913849 538 | Training epoch: 40, loss: 0.2003132849931717, time elapsed: 225.11757969856262, 539 | Split: train, epoch: 40, metrics: accuracy: 0.096 ; iou: 0.019 ; dice: 0.033 ; sensitivity: 0.074 ; auroc: 0.713 540 | Split: val, epoch: 40, metrics: accuracy: 0.094 ; iou: 0.018 ; dice: 0.032 ; sensitivity: 0.070 ; auroc: 0.698 541 | Finish iter: 0, loss 0.276131808757782 542 | Finish iter: 20, loss 0.21457217633724213 543 | Finish iter: 40, loss 0.2082192450761795 544 | Finish iter: 60, loss 0.24439769983291626 545 | Finish iter: 80, loss 0.2207152545452118 546 | Finish iter: 100, loss 0.22443707287311554 547 | Finish iter: 120, loss 0.2518899738788605 548 | Finish iter: 140, loss 0.23937612771987915 549 | Finish iter: 160, loss 0.2498416304588318 550 | Finish iter: 180, loss 0.21088139712810516 551 | Training epoch: 41, loss: 0.2440328598022461, time elapsed: 222.85695791244507, 552 | Split: train, epoch: 41, metrics: accuracy: 0.278 ; iou: 0.107 ; dice: 0.161 ; sensitivity: 0.203 ; auroc: 0.871 553 | Split: val, epoch: 41, metrics: accuracy: 0.272 ; iou: 0.099 ; dice: 0.150 ; sensitivity: 0.193 ; auroc: 0.861 554 | Finish iter: 0, loss 0.24968548119068146 555 | Finish iter: 20, loss 0.3039020597934723 556 | Finish iter: 40, loss 0.22005993127822876 557 | Finish iter: 60, loss 0.2158166915178299 558 | Finish iter: 80, loss 0.21177855134010315 559 | Finish iter: 100, loss 0.2700449824333191 560 | Finish iter: 120, loss 0.23322193324565887 561 | Finish iter: 140, loss 0.26149266958236694 562 | Finish iter: 160, loss 0.1921999156475067 563 | Finish iter: 180, loss 0.258588045835495 564 | Training epoch: 42, loss: 0.23209309577941895, time elapsed: 224.65499424934387, 565 | Split: train, epoch: 42, metrics: accuracy: 0.189 ; iou: 0.051 ; dice: 0.084 ; sensitivity: 0.134 ; auroc: 0.817 566 | Split: val, epoch: 42, metrics: accuracy: 0.200 ; iou: 0.048 ; dice: 0.079 ; sensitivity: 0.130 ; auroc: 0.818 567 | Finish iter: 0, loss 0.23476412892341614 568 | Finish iter: 20, loss 0.16938193142414093 569 | Finish iter: 40, loss 0.20973698794841766 570 | Finish iter: 60, loss 0.21492667496204376 571 | Finish iter: 80, loss 0.22675982117652893 572 | Finish iter: 100, loss 0.19863682985305786 573 | Finish iter: 120, loss 0.26829057931900024 574 | Finish iter: 140, loss 0.21482831239700317 575 | Finish iter: 160, loss 0.27618080377578735 576 | Finish iter: 180, loss 0.18258145451545715 577 | Training epoch: 43, loss: 0.255708247423172, time elapsed: 224.70654726028442, 578 | Split: train, epoch: 43, metrics: accuracy: 0.123 ; iou: 0.034 ; dice: 0.056 ; sensitivity: 0.112 ; auroc: 0.731 579 | Split: val, epoch: 43, metrics: accuracy: 0.121 ; iou: 0.033 ; dice: 0.054 ; sensitivity: 0.106 ; auroc: 0.722 580 | Finish iter: 0, loss 0.22796718776226044 581 | Finish iter: 20, loss 0.174246683716774 582 | Finish iter: 40, loss 0.2072448581457138 583 | Finish iter: 60, loss 0.24479351937770844 584 | Finish iter: 80, loss 0.22781144082546234 585 | Finish iter: 100, loss 0.19067372381687164 586 | Finish iter: 120, loss 0.20479290187358856 587 | Finish iter: 140, loss 0.2452375292778015 588 | Finish iter: 160, loss 0.34680038690567017 589 | Finish iter: 180, loss 0.25671693682670593 590 | Training epoch: 44, loss: 0.24830737709999084, time elapsed: 224.56873083114624, 591 | Split: train, epoch: 44, metrics: accuracy: 0.111 ; iou: 0.027 ; dice: 0.043 ; sensitivity: 0.104 ; auroc: 0.672 592 | Split: val, epoch: 44, metrics: accuracy: 0.107 ; iou: 0.026 ; dice: 0.041 ; sensitivity: 0.097 ; auroc: 0.665 593 | Finish iter: 0, loss 0.22861337661743164 594 | Finish iter: 20, loss 0.19176740944385529 595 | Finish iter: 40, loss 0.2522750496864319 596 | Finish iter: 60, loss 0.19462254643440247 597 | Finish iter: 80, loss 0.2658678889274597 598 | Finish iter: 100, loss 0.18857507407665253 599 | Finish iter: 120, loss 0.19897720217704773 600 | Finish iter: 140, loss 0.2385285496711731 601 | Finish iter: 160, loss 0.2237362414598465 602 | Finish iter: 180, loss 0.24283240735530853 603 | Training epoch: 45, loss: 0.22083404660224915, time elapsed: 224.747811794281, 604 | Split: train, epoch: 45, metrics: accuracy: 0.224 ; iou: 0.088 ; dice: 0.139 ; sensitivity: 0.169 ; auroc: 0.826 605 | Split: val, epoch: 45, metrics: accuracy: 0.219 ; iou: 0.087 ; dice: 0.137 ; sensitivity: 0.164 ; auroc: 0.823 606 | Finish iter: 0, loss 0.18895769119262695 607 | Finish iter: 20, loss 0.22417007386684418 608 | Finish iter: 40, loss 0.27337685227394104 609 | Finish iter: 60, loss 0.20547997951507568 610 | Finish iter: 80, loss 0.19572879374027252 611 | Finish iter: 100, loss 0.2076464295387268 612 | Finish iter: 120, loss 0.2055884152650833 613 | Finish iter: 140, loss 0.23983660340309143 614 | Finish iter: 160, loss 0.2617143392562866 615 | Finish iter: 180, loss 0.2056431621313095 616 | Training epoch: 46, loss: 0.22824415564537048, time elapsed: 224.8711371421814, 617 | Split: train, epoch: 46, metrics: accuracy: 0.126 ; iou: 0.033 ; dice: 0.056 ; sensitivity: 0.115 ; auroc: 0.782 618 | Split: val, epoch: 46, metrics: accuracy: 0.121 ; iou: 0.031 ; dice: 0.052 ; sensitivity: 0.103 ; auroc: 0.770 619 | Finish iter: 0, loss 0.20441094040870667 620 | Finish iter: 20, loss 0.1864471286535263 621 | Finish iter: 40, loss 0.25640669465065 622 | Finish iter: 60, loss 0.21895313262939453 623 | Finish iter: 80, loss 0.19303825497627258 624 | Finish iter: 100, loss 0.22693473100662231 625 | Finish iter: 120, loss 0.20328420400619507 626 | Finish iter: 140, loss 0.1979500651359558 627 | Finish iter: 160, loss 0.2012885957956314 628 | Finish iter: 180, loss 0.22707578539848328 629 | Training epoch: 47, loss: 0.2392863631248474, time elapsed: 224.65988302230835, 630 | Split: train, epoch: 47, metrics: accuracy: 0.199 ; iou: 0.057 ; dice: 0.097 ; sensitivity: 0.162 ; auroc: 0.817 631 | Split: val, epoch: 47, metrics: accuracy: 0.191 ; iou: 0.054 ; dice: 0.093 ; sensitivity: 0.152 ; auroc: 0.796 632 | Finish iter: 0, loss 0.19313274323940277 633 | Finish iter: 20, loss 0.20693014562129974 634 | Finish iter: 40, loss 0.2045283317565918 635 | Finish iter: 60, loss 0.24371977150440216 636 | Finish iter: 80, loss 0.24768412113189697 637 | Finish iter: 100, loss 0.1750010848045349 638 | Finish iter: 120, loss 0.14992469549179077 639 | Finish iter: 140, loss 0.21040092408657074 640 | Finish iter: 160, loss 0.19208942353725433 641 | Finish iter: 180, loss 0.2414512038230896 642 | Training epoch: 48, loss: 0.18421702086925507, time elapsed: 224.84717082977295, 643 | Split: train, epoch: 48, metrics: accuracy: 0.248 ; iou: 0.061 ; dice: 0.099 ; sensitivity: 0.148 ; auroc: 0.831 644 | Split: val, epoch: 48, metrics: accuracy: 0.241 ; iou: 0.058 ; dice: 0.095 ; sensitivity: 0.134 ; auroc: 0.828 645 | Finish iter: 0, loss 0.20802651345729828 646 | Finish iter: 20, loss 0.18077684938907623 647 | Finish iter: 40, loss 0.2073429375886917 648 | Finish iter: 60, loss 0.19641906023025513 649 | Finish iter: 80, loss 0.1908656656742096 650 | Finish iter: 100, loss 0.18482336401939392 651 | Finish iter: 120, loss 0.20296573638916016 652 | Finish iter: 140, loss 0.1988840252161026 653 | Finish iter: 160, loss 0.1744031459093094 654 | Finish iter: 180, loss 0.20435886085033417 655 | Training epoch: 49, loss: 0.19845493137836456, time elapsed: 224.81350111961365, 656 | Split: train, epoch: 49, metrics: accuracy: 0.137 ; iou: 0.034 ; dice: 0.060 ; sensitivity: 0.126 ; auroc: 0.765 657 | Split: val, epoch: 49, metrics: accuracy: 0.130 ; iou: 0.032 ; dice: 0.057 ; sensitivity: 0.115 ; auroc: 0.753 658 | Finish iter: 0, loss 0.18818572163581848 659 | Finish iter: 20, loss 0.20390978455543518 660 | Finish iter: 40, loss 0.22899888455867767 661 | Finish iter: 60, loss 0.19940170645713806 662 | Finish iter: 80, loss 0.17929142713546753 663 | Finish iter: 100, loss 0.265813946723938 664 | Finish iter: 120, loss 0.23961570858955383 665 | Finish iter: 140, loss 0.20194853842258453 666 | Finish iter: 160, loss 0.16585294902324677 667 | Finish iter: 180, loss 0.23924243450164795 668 | Training epoch: 50, loss: 0.24049393832683563, time elapsed: 224.7023241519928, 669 | Split: train, epoch: 50, metrics: accuracy: 0.119 ; iou: 0.029 ; dice: 0.052 ; sensitivity: 0.112 ; auroc: 0.721 670 | Split: val, epoch: 50, metrics: accuracy: 0.113 ; iou: 0.027 ; dice: 0.048 ; sensitivity: 0.102 ; auroc: 0.721 671 | Finish iter: 0, loss 0.20156119763851166 672 | Finish iter: 20, loss 0.20586538314819336 673 | Finish iter: 40, loss 0.2080424726009369 674 | Finish iter: 60, loss 0.25534194707870483 675 | Finish iter: 80, loss 0.1869318187236786 676 | Finish iter: 100, loss 0.2203545868396759 677 | Finish iter: 120, loss 0.20290730893611908 678 | Finish iter: 140, loss 0.21353371441364288 679 | Finish iter: 160, loss 0.19977974891662598 680 | Finish iter: 180, loss 0.16397738456726074 681 | Training epoch: 51, loss: 0.20044006407260895, time elapsed: 222.37581372261047, 682 | Split: train, epoch: 51, metrics: accuracy: 0.081 ; iou: 0.018 ; dice: 0.033 ; sensitivity: 0.101 ; auroc: 0.733 683 | Split: val, epoch: 51, metrics: accuracy: 0.078 ; iou: 0.017 ; dice: 0.031 ; sensitivity: 0.099 ; auroc: 0.732 684 | Finish iter: 0, loss 0.20715473592281342 685 | Finish iter: 20, loss 0.21102556586265564 686 | Finish iter: 40, loss 0.2078854739665985 687 | Finish iter: 60, loss 0.18801815807819366 688 | Finish iter: 80, loss 0.1606634557247162 689 | Finish iter: 100, loss 0.20689304172992706 690 | Finish iter: 120, loss 0.1755586862564087 691 | Finish iter: 140, loss 0.17115460336208344 692 | Finish iter: 160, loss 0.22802618145942688 693 | Finish iter: 180, loss 0.1878541260957718 694 | Training epoch: 52, loss: 0.2041531503200531, time elapsed: 224.56232213974, 695 | Split: train, epoch: 52, metrics: accuracy: 0.140 ; iou: 0.042 ; dice: 0.074 ; sensitivity: 0.138 ; auroc: 0.763 696 | Split: val, epoch: 52, metrics: accuracy: 0.142 ; iou: 0.044 ; dice: 0.076 ; sensitivity: 0.137 ; auroc: 0.753 697 | Finish iter: 0, loss 0.17199796438217163 698 | Finish iter: 20, loss 0.18570876121520996 699 | Finish iter: 40, loss 0.22601495683193207 700 | Finish iter: 60, loss 0.2064376026391983 701 | Finish iter: 80, loss 0.18670842051506042 702 | Finish iter: 100, loss 0.1703394055366516 703 | Finish iter: 120, loss 0.1971365362405777 704 | Finish iter: 140, loss 0.15915559232234955 705 | Finish iter: 160, loss 0.16433703899383545 706 | Finish iter: 180, loss 0.22027051448822021 707 | Training epoch: 53, loss: 0.15492978692054749, time elapsed: 224.71779918670654, 708 | Split: train, epoch: 53, metrics: accuracy: 0.137 ; iou: 0.028 ; dice: 0.044 ; sensitivity: 0.105 ; auroc: 0.773 709 | Split: val, epoch: 53, metrics: accuracy: 0.140 ; iou: 0.029 ; dice: 0.044 ; sensitivity: 0.100 ; auroc: 0.766 710 | Finish iter: 0, loss 0.16071194410324097 711 | Finish iter: 20, loss 0.16300790011882782 712 | Finish iter: 40, loss 0.15992717444896698 713 | Finish iter: 60, loss 0.15727530419826508 714 | Finish iter: 80, loss 0.13757923245429993 715 | Finish iter: 100, loss 0.22198988497257233 716 | Finish iter: 120, loss 0.19321446120738983 717 | Finish iter: 140, loss 0.17388254404067993 718 | Finish iter: 160, loss 0.19370543956756592 719 | Finish iter: 180, loss 0.20812121033668518 720 | Training epoch: 54, loss: 0.18142195045948029, time elapsed: 224.57340335845947, 721 | Split: train, epoch: 54, metrics: accuracy: 0.209 ; iou: 0.056 ; dice: 0.092 ; sensitivity: 0.119 ; auroc: 0.816 722 | Split: val, epoch: 54, metrics: accuracy: 0.211 ; iou: 0.055 ; dice: 0.089 ; sensitivity: 0.121 ; auroc: 0.814 723 | Finish iter: 0, loss 0.17574791610240936 724 | Finish iter: 20, loss 0.1758182942867279 725 | Finish iter: 40, loss 0.19372285902500153 726 | Finish iter: 60, loss 0.16941040754318237 727 | Finish iter: 80, loss 0.1921256184577942 728 | Finish iter: 100, loss 0.18103018403053284 729 | Finish iter: 120, loss 0.16482532024383545 730 | Finish iter: 140, loss 0.13739457726478577 731 | Finish iter: 160, loss 0.1869034767150879 732 | Finish iter: 180, loss 0.19975391030311584 733 | Training epoch: 55, loss: 0.2116287797689438, time elapsed: 224.63432049751282, 734 | Split: train, epoch: 55, metrics: accuracy: 0.129 ; iou: 0.024 ; dice: 0.041 ; sensitivity: 0.112 ; auroc: 0.783 735 | Split: val, epoch: 55, metrics: accuracy: 0.130 ; iou: 0.024 ; dice: 0.040 ; sensitivity: 0.114 ; auroc: 0.775 736 | Finish iter: 0, loss 0.16425248980522156 737 | Finish iter: 20, loss 0.1595265120267868 738 | Finish iter: 40, loss 0.20804363489151 739 | Finish iter: 60, loss 0.17259660363197327 740 | Finish iter: 80, loss 0.15504199266433716 741 | Finish iter: 100, loss 0.1854289174079895 742 | Finish iter: 120, loss 0.1598692387342453 743 | Finish iter: 140, loss 0.188557431101799 744 | Finish iter: 160, loss 0.16851204633712769 745 | Finish iter: 180, loss 0.2127394825220108 746 | Training epoch: 56, loss: 0.20321017503738403, time elapsed: 224.61475729942322, 747 | Split: train, epoch: 56, metrics: accuracy: 0.129 ; iou: 0.031 ; dice: 0.053 ; sensitivity: 0.125 ; auroc: 0.760 748 | Split: val, epoch: 56, metrics: accuracy: 0.126 ; iou: 0.028 ; dice: 0.050 ; sensitivity: 0.120 ; auroc: 0.749 749 | Finish iter: 0, loss 0.15569813549518585 750 | Finish iter: 20, loss 0.20177629590034485 751 | Finish iter: 40, loss 0.17251890897750854 752 | Finish iter: 60, loss 0.18890252709388733 753 | Finish iter: 80, loss 0.16125354170799255 754 | Finish iter: 100, loss 0.17161643505096436 755 | Finish iter: 120, loss 0.15792596340179443 756 | Finish iter: 140, loss 0.16054876148700714 757 | Finish iter: 160, loss 0.17157451808452606 758 | Finish iter: 180, loss 0.20921923220157623 759 | Training epoch: 57, loss: 0.15304303169250488, time elapsed: 224.5606620311737, 760 | Split: train, epoch: 57, metrics: accuracy: 0.085 ; iou: 0.018 ; dice: 0.031 ; sensitivity: 0.099 ; auroc: 0.766 761 | Split: val, epoch: 57, metrics: accuracy: 0.083 ; iou: 0.018 ; dice: 0.031 ; sensitivity: 0.099 ; auroc: 0.745 762 | Finish iter: 0, loss 0.15643085539340973 763 | Finish iter: 20, loss 0.16552896797657013 764 | Finish iter: 40, loss 0.1680721640586853 765 | Finish iter: 60, loss 0.16013319790363312 766 | Finish iter: 80, loss 0.13167522847652435 767 | Finish iter: 100, loss 0.15066173672676086 768 | Finish iter: 120, loss 0.16933146119117737 769 | Finish iter: 140, loss 0.21060186624526978 770 | Finish iter: 160, loss 0.24676264822483063 771 | Finish iter: 180, loss 0.27221715450286865 772 | Training epoch: 58, loss: 0.2811741530895233, time elapsed: 224.4683792591095, 773 | Split: train, epoch: 58, metrics: accuracy: 0.296 ; iou: 0.074 ; dice: 0.122 ; sensitivity: 0.164 ; auroc: 0.847 774 | Split: val, epoch: 58, metrics: accuracy: 0.304 ; iou: 0.071 ; dice: 0.118 ; sensitivity: 0.161 ; auroc: 0.848 775 | Finish iter: 0, loss 0.20801261067390442 776 | Finish iter: 20, loss 0.24710941314697266 777 | Finish iter: 40, loss 0.2207755744457245 778 | Finish iter: 60, loss 0.24301058053970337 779 | Finish iter: 80, loss 0.17701943218708038 780 | Finish iter: 100, loss 0.228857159614563 781 | Finish iter: 120, loss 0.15228557586669922 782 | Finish iter: 140, loss 0.15933704376220703 783 | Finish iter: 160, loss 0.14372657239437103 784 | Finish iter: 180, loss 0.18796630203723907 785 | Training epoch: 59, loss: 0.16767290234565735, time elapsed: 224.69456934928894, 786 | Split: train, epoch: 59, metrics: accuracy: 0.224 ; iou: 0.045 ; dice: 0.068 ; sensitivity: 0.098 ; auroc: 0.764 787 | Split: val, epoch: 59, metrics: accuracy: 0.226 ; iou: 0.044 ; dice: 0.066 ; sensitivity: 0.098 ; auroc: 0.760 788 | Finish iter: 0, loss 0.2117057740688324 789 | Finish iter: 20, loss 0.1887184977531433 790 | Finish iter: 40, loss 0.19980323314666748 791 | Finish iter: 60, loss 0.1562756448984146 792 | Finish iter: 80, loss 0.1728668510913849 793 | Finish iter: 100, loss 0.1731467992067337 794 | Finish iter: 120, loss 0.17320315539836884 795 | Finish iter: 140, loss 0.14773474633693695 796 | Finish iter: 160, loss 0.17663611471652985 797 | Finish iter: 180, loss 0.18231201171875 798 | Training epoch: 60, loss: 0.15470512211322784, time elapsed: 224.72885537147522, 799 | Split: train, epoch: 60, metrics: accuracy: 0.181 ; iou: 0.040 ; dice: 0.070 ; sensitivity: 0.131 ; auroc: 0.822 800 | Split: val, epoch: 60, metrics: accuracy: 0.184 ; iou: 0.039 ; dice: 0.068 ; sensitivity: 0.129 ; auroc: 0.824 801 | Finish iter: 0, loss 0.20250962674617767 802 | Finish iter: 20, loss 0.18445345759391785 803 | Finish iter: 40, loss 0.16857336461544037 804 | Finish iter: 60, loss 0.17456099390983582 805 | Finish iter: 80, loss 0.18322400748729706 806 | Finish iter: 100, loss 0.16350042819976807 807 | Finish iter: 120, loss 0.16045399010181427 808 | Finish iter: 140, loss 0.148673415184021 809 | Finish iter: 160, loss 0.16373884677886963 810 | Finish iter: 180, loss 0.1800476759672165 811 | Training epoch: 61, loss: 0.13524653017520905, time elapsed: 222.93676567077637, 812 | Split: train, epoch: 61, metrics: accuracy: 0.138 ; iou: 0.039 ; dice: 0.065 ; sensitivity: 0.118 ; auroc: 0.773 813 | Split: val, epoch: 61, metrics: accuracy: 0.138 ; iou: 0.039 ; dice: 0.064 ; sensitivity: 0.114 ; auroc: 0.772 814 | Finish iter: 0, loss 0.13618117570877075 815 | Finish iter: 20, loss 0.1619071662425995 816 | Finish iter: 40, loss 0.14374476671218872 817 | Finish iter: 60, loss 0.14664283394813538 818 | Finish iter: 80, loss 0.13551212847232819 819 | Finish iter: 100, loss 0.12011497467756271 820 | Finish iter: 120, loss 0.1501600742340088 821 | Finish iter: 140, loss 0.1488906443119049 822 | Finish iter: 160, loss 0.13774967193603516 823 | Finish iter: 180, loss 0.15357476472854614 824 | Training epoch: 62, loss: 0.14078806340694427, time elapsed: 223.99112057685852, 825 | Split: train, epoch: 62, metrics: accuracy: 0.149 ; iou: 0.023 ; dice: 0.039 ; sensitivity: 0.119 ; auroc: 0.829 826 | Split: val, epoch: 62, metrics: accuracy: 0.152 ; iou: 0.021 ; dice: 0.036 ; sensitivity: 0.113 ; auroc: 0.818 827 | Finish iter: 0, loss 0.16627013683319092 828 | Finish iter: 20, loss 0.14743277430534363 829 | Finish iter: 40, loss 0.15496790409088135 830 | Finish iter: 60, loss 0.12184068560600281 831 | Finish iter: 80, loss 0.1463831663131714 832 | Finish iter: 100, loss 0.14590947329998016 833 | Finish iter: 120, loss 0.12285211682319641 834 | Finish iter: 140, loss 0.16452917456626892 835 | Finish iter: 160, loss 0.13205642998218536 836 | Finish iter: 180, loss 0.147676482796669 837 | Training epoch: 63, loss: 0.1370786875486374, time elapsed: 224.73858428001404, 838 | Split: train, epoch: 63, metrics: accuracy: 0.161 ; iou: 0.028 ; dice: 0.045 ; sensitivity: 0.123 ; auroc: 0.840 839 | Split: val, epoch: 63, metrics: accuracy: 0.166 ; iou: 0.025 ; dice: 0.041 ; sensitivity: 0.116 ; auroc: 0.840 840 | Finish iter: 0, loss 0.16074803471565247 841 | Finish iter: 20, loss 0.1554383635520935 842 | Finish iter: 40, loss 0.1684873104095459 843 | Finish iter: 60, loss 0.13460831344127655 844 | Finish iter: 80, loss 0.1305561661720276 845 | Finish iter: 100, loss 0.13526758551597595 846 | Finish iter: 120, loss 0.15422232449054718 847 | Finish iter: 140, loss 0.14360299706459045 848 | Finish iter: 160, loss 0.17748554050922394 849 | Finish iter: 180, loss 0.2410663366317749 850 | Training epoch: 64, loss: 0.19946010410785675, time elapsed: 224.59910464286804, 851 | Split: train, epoch: 64, metrics: accuracy: 0.083 ; iou: 0.025 ; dice: 0.043 ; sensitivity: 0.108 ; auroc: 0.739 852 | Split: val, epoch: 64, metrics: accuracy: 0.091 ; iou: 0.027 ; dice: 0.045 ; sensitivity: 0.101 ; auroc: 0.734 853 | Finish iter: 0, loss 0.21289128065109253 854 | Finish iter: 20, loss 0.1630212664604187 855 | Finish iter: 40, loss 0.2430124580860138 856 | Finish iter: 60, loss 0.18679164350032806 857 | Finish iter: 80, loss 0.18657450377941132 858 | Finish iter: 100, loss 0.1931016743183136 859 | Finish iter: 120, loss 0.16185976564884186 860 | Finish iter: 140, loss 0.1366070806980133 861 | Finish iter: 160, loss 0.20697863399982452 862 | Finish iter: 180, loss 0.20821061730384827 863 | Training epoch: 65, loss: 0.18065457046031952, time elapsed: 224.62199711799622, 864 | Split: train, epoch: 65, metrics: accuracy: 0.154 ; iou: 0.026 ; dice: 0.046 ; sensitivity: 0.132 ; auroc: 0.853 865 | Split: val, epoch: 65, metrics: accuracy: 0.156 ; iou: 0.025 ; dice: 0.044 ; sensitivity: 0.127 ; auroc: 0.839 866 | Finish iter: 0, loss 0.1492009460926056 867 | Finish iter: 20, loss 0.1482774317264557 868 | Finish iter: 40, loss 0.14694887399673462 869 | Finish iter: 60, loss 0.13432322442531586 870 | Finish iter: 80, loss 0.1840825229883194 871 | Finish iter: 100, loss 0.1972145438194275 872 | Finish iter: 120, loss 0.18640029430389404 873 | Finish iter: 140, loss 0.15168479084968567 874 | Finish iter: 160, loss 0.1633639931678772 875 | Finish iter: 180, loss 0.1555173397064209 876 | Training epoch: 66, loss: 0.16145040094852448, time elapsed: 224.6425986289978, 877 | Split: train, epoch: 66, metrics: accuracy: 0.148 ; iou: 0.022 ; dice: 0.037 ; sensitivity: 0.117 ; auroc: 0.819 878 | Split: val, epoch: 66, metrics: accuracy: 0.152 ; iou: 0.021 ; dice: 0.036 ; sensitivity: 0.117 ; auroc: 0.811 879 | Finish iter: 0, loss 0.17294126749038696 880 | Finish iter: 20, loss 0.1815544068813324 881 | Finish iter: 40, loss 0.1495041698217392 882 | Finish iter: 60, loss 0.15845559537410736 883 | Finish iter: 80, loss 0.14723458886146545 884 | Finish iter: 100, loss 0.1624229997396469 885 | Finish iter: 120, loss 0.1302926391363144 886 | Finish iter: 140, loss 0.14835157990455627 887 | Finish iter: 160, loss 0.14395540952682495 888 | Finish iter: 180, loss 0.12003740668296814 889 | Training epoch: 67, loss: 0.14434009790420532, time elapsed: 225.73910403251648, 890 | Split: train, epoch: 67, metrics: accuracy: 0.142 ; iou: 0.027 ; dice: 0.044 ; sensitivity: 0.125 ; auroc: 0.826 891 | Split: val, epoch: 67, metrics: accuracy: 0.146 ; iou: 0.025 ; dice: 0.042 ; sensitivity: 0.123 ; auroc: 0.820 892 | Finish iter: 0, loss 0.15502163767814636 893 | Finish iter: 20, loss 0.14835070073604584 894 | Finish iter: 40, loss 0.1358965039253235 895 | Finish iter: 60, loss 0.1485985666513443 896 | Finish iter: 80, loss 0.13990865647792816 897 | Finish iter: 100, loss 0.1345319300889969 898 | Finish iter: 120, loss 0.12352322041988373 899 | Finish iter: 140, loss 0.14095570147037506 900 | Finish iter: 160, loss 0.15028104186058044 901 | Finish iter: 180, loss 0.14921221137046814 902 | Training epoch: 68, loss: 0.13025711476802826, time elapsed: 225.08654403686523, 903 | Split: train, epoch: 68, metrics: accuracy: 0.132 ; iou: 0.025 ; dice: 0.041 ; sensitivity: 0.108 ; auroc: 0.791 904 | Split: val, epoch: 68, metrics: accuracy: 0.138 ; iou: 0.024 ; dice: 0.040 ; sensitivity: 0.106 ; auroc: 0.787 905 | Finish iter: 0, loss 0.13117606937885284 906 | Finish iter: 20, loss 0.13748764991760254 907 | Finish iter: 40, loss 0.10333605855703354 908 | Finish iter: 60, loss 0.13875937461853027 909 | Finish iter: 80, loss 0.16018806397914886 910 | Finish iter: 100, loss 0.14252416789531708 911 | Finish iter: 120, loss 0.13364091515541077 912 | Finish iter: 140, loss 0.13652962446212769 913 | Finish iter: 160, loss 0.11081434786319733 914 | Finish iter: 180, loss 0.15426649153232574 915 | Training epoch: 69, loss: 0.14111877977848053, time elapsed: 224.74978733062744, 916 | Split: train, epoch: 69, metrics: accuracy: 0.151 ; iou: 0.033 ; dice: 0.051 ; sensitivity: 0.103 ; auroc: 0.793 917 | Split: val, epoch: 69, metrics: accuracy: 0.157 ; iou: 0.032 ; dice: 0.051 ; sensitivity: 0.103 ; auroc: 0.798 918 | Finish iter: 0, loss 0.12499471008777618 919 | Finish iter: 20, loss 0.1354910433292389 920 | Finish iter: 40, loss 0.13276401162147522 921 | Finish iter: 60, loss 0.11486832797527313 922 | Finish iter: 80, loss 0.13470035791397095 923 | Finish iter: 100, loss 0.139710932970047 924 | Finish iter: 120, loss 0.14417123794555664 925 | Finish iter: 140, loss 0.14873750507831573 926 | Finish iter: 160, loss 0.14743997156620026 927 | Finish iter: 180, loss 0.1265813708305359 928 | Training epoch: 70, loss: 0.14727549254894257, time elapsed: 224.65670561790466, 929 | Split: train, epoch: 70, metrics: accuracy: 0.148 ; iou: 0.035 ; dice: 0.056 ; sensitivity: 0.115 ; auroc: 0.820 930 | Split: val, epoch: 70, metrics: accuracy: 0.154 ; iou: 0.035 ; dice: 0.057 ; sensitivity: 0.117 ; auroc: 0.816 931 | Finish iter: 0, loss 0.13123080134391785 932 | Finish iter: 20, loss 0.1246054545044899 933 | Finish iter: 40, loss 0.11533865332603455 934 | Finish iter: 60, loss 0.1541678011417389 935 | Finish iter: 80, loss 0.12383051216602325 936 | Finish iter: 100, loss 0.12983080744743347 937 | Finish iter: 120, loss 0.1175856813788414 938 | Finish iter: 140, loss 0.12782561779022217 939 | Finish iter: 160, loss 0.15854589641094208 940 | Finish iter: 180, loss 0.12385441362857819 941 | Training epoch: 71, loss: 0.10180258005857468, time elapsed: 222.18698287010193, 942 | Split: train, epoch: 71, metrics: accuracy: 0.121 ; iou: 0.018 ; dice: 0.031 ; sensitivity: 0.090 ; auroc: 0.811 943 | Split: val, epoch: 71, metrics: accuracy: 0.129 ; iou: 0.018 ; dice: 0.032 ; sensitivity: 0.090 ; auroc: 0.802 944 | Finish iter: 0, loss 0.14828996360301971 945 | Finish iter: 20, loss 0.123553566634655 946 | Finish iter: 40, loss 0.12963415682315826 947 | Finish iter: 60, loss 0.13532724976539612 948 | Finish iter: 80, loss 0.12922139465808868 949 | Finish iter: 100, loss 0.11579883098602295 950 | Finish iter: 120, loss 0.1396297961473465 951 | Finish iter: 140, loss 0.11340813338756561 952 | Finish iter: 160, loss 0.1416081190109253 953 | Finish iter: 180, loss 0.1241280660033226 954 | Training epoch: 72, loss: 0.13368354737758636, time elapsed: 224.34346199035645, 955 | Split: train, epoch: 72, metrics: accuracy: 0.143 ; iou: 0.031 ; dice: 0.049 ; sensitivity: 0.108 ; auroc: 0.827 956 | Split: val, epoch: 72, metrics: accuracy: 0.150 ; iou: 0.031 ; dice: 0.050 ; sensitivity: 0.107 ; auroc: 0.819 957 | Finish iter: 0, loss 0.12597514688968658 958 | Finish iter: 20, loss 0.14176532626152039 959 | Finish iter: 40, loss 0.16965964436531067 960 | Finish iter: 60, loss 0.1529451608657837 961 | Finish iter: 80, loss 0.17004938423633575 962 | Finish iter: 100, loss 0.1650736927986145 963 | Finish iter: 120, loss 0.2201787829399109 964 | Finish iter: 140, loss 0.20582321286201477 965 | Finish iter: 160, loss 0.19539301097393036 966 | Finish iter: 180, loss 0.2044394314289093 967 | Training epoch: 73, loss: 0.2070743590593338, time elapsed: 224.43892455101013, 968 | Split: train, epoch: 73, metrics: accuracy: 0.154 ; iou: 0.036 ; dice: 0.061 ; sensitivity: 0.107 ; auroc: 0.819 969 | Split: val, epoch: 73, metrics: accuracy: 0.156 ; iou: 0.033 ; dice: 0.056 ; sensitivity: 0.105 ; auroc: 0.816 970 | Finish iter: 0, loss 0.16165009140968323 971 | Finish iter: 20, loss 0.16729722917079926 972 | Finish iter: 40, loss 0.14208349585533142 973 | Finish iter: 60, loss 0.15254954993724823 974 | Finish iter: 80, loss 0.13930568099021912 975 | Finish iter: 100, loss 0.15182146430015564 976 | Finish iter: 120, loss 0.1461336463689804 977 | Finish iter: 140, loss 0.12471267580986023 978 | Finish iter: 160, loss 0.15794888138771057 979 | Finish iter: 180, loss 0.10100105404853821 980 | Training epoch: 74, loss: 0.13216209411621094, time elapsed: 224.63199138641357, 981 | Split: train, epoch: 74, metrics: accuracy: 0.152 ; iou: 0.033 ; dice: 0.052 ; sensitivity: 0.098 ; auroc: 0.822 982 | Split: val, epoch: 74, metrics: accuracy: 0.159 ; iou: 0.031 ; dice: 0.050 ; sensitivity: 0.099 ; auroc: 0.819 983 | Finish iter: 0, loss 0.13213811814785004 984 | Finish iter: 20, loss 0.13186275959014893 985 | Finish iter: 40, loss 0.12212240695953369 986 | Finish iter: 60, loss 0.11558917164802551 987 | Finish iter: 80, loss 0.14122635126113892 988 | Finish iter: 100, loss 0.14273181557655334 989 | Finish iter: 120, loss 0.14282900094985962 990 | Finish iter: 140, loss 0.11955077201128006 991 | Finish iter: 160, loss 0.1448877900838852 992 | Finish iter: 180, loss 0.1555509716272354 993 | Training epoch: 75, loss: 0.1484820693731308, time elapsed: 224.60657453536987, 994 | Split: train, epoch: 75, metrics: accuracy: 0.117 ; iou: 0.022 ; dice: 0.039 ; sensitivity: 0.105 ; auroc: 0.806 995 | Split: val, epoch: 75, metrics: accuracy: 0.129 ; iou: 0.023 ; dice: 0.041 ; sensitivity: 0.105 ; auroc: 0.803 996 | Finish iter: 0, loss 0.14583608508110046 997 | Finish iter: 20, loss 0.12327469885349274 998 | Finish iter: 40, loss 0.1467132866382599 999 | Finish iter: 60, loss 0.1362418383359909 1000 | Finish iter: 80, loss 0.13265839219093323 1001 | Finish iter: 100, loss 0.151235893368721 1002 | Finish iter: 120, loss 0.1506972759962082 1003 | Finish iter: 140, loss 0.1086643785238266 1004 | Finish iter: 160, loss 0.1265842765569687 1005 | Finish iter: 180, loss 0.12410146743059158 1006 | Training epoch: 76, loss: 0.12957511842250824, time elapsed: 224.7165503501892, 1007 | Split: train, epoch: 76, metrics: accuracy: 0.119 ; iou: 0.012 ; dice: 0.022 ; sensitivity: 0.062 ; auroc: 0.766 1008 | Split: val, epoch: 76, metrics: accuracy: 0.130 ; iou: 0.013 ; dice: 0.023 ; sensitivity: 0.065 ; auroc: 0.769 1009 | Finish iter: 0, loss 0.11293905228376389 1010 | Finish iter: 20, loss 0.11115451157093048 1011 | Finish iter: 40, loss 0.10779604315757751 1012 | Finish iter: 60, loss 0.11546360701322556 1013 | Finish iter: 80, loss 0.10654239356517792 1014 | Finish iter: 100, loss 0.11994752287864685 1015 | Finish iter: 120, loss 0.11114854365587234 1016 | Finish iter: 140, loss 0.12721028923988342 1017 | Finish iter: 160, loss 0.1331745684146881 1018 | Finish iter: 180, loss 0.12457864731550217 1019 | Training epoch: 77, loss: 0.09588564932346344, time elapsed: 224.4464886188507, 1020 | Split: train, epoch: 77, metrics: accuracy: 0.125 ; iou: 0.019 ; dice: 0.033 ; sensitivity: 0.071 ; auroc: 0.757 1021 | Split: val, epoch: 77, metrics: accuracy: 0.135 ; iou: 0.020 ; dice: 0.034 ; sensitivity: 0.074 ; auroc: 0.772 1022 | Finish iter: 0, loss 0.08874652534723282 1023 | Finish iter: 20, loss 0.11783718317747116 1024 | Finish iter: 40, loss 0.11235391348600388 1025 | Finish iter: 60, loss 0.13793990015983582 1026 | Finish iter: 80, loss 0.11358150094747543 1027 | Finish iter: 100, loss 0.10812084376811981 1028 | Finish iter: 120, loss 0.10794121772050858 1029 | Finish iter: 140, loss 0.1319747418165207 1030 | Finish iter: 160, loss 0.11351863294839859 1031 | Finish iter: 180, loss 0.10848184674978256 1032 | Training epoch: 78, loss: 0.11304108798503876, time elapsed: 224.46401166915894, 1033 | Split: train, epoch: 78, metrics: accuracy: 0.136 ; iou: 0.029 ; dice: 0.047 ; sensitivity: 0.082 ; auroc: 0.793 1034 | Split: val, epoch: 78, metrics: accuracy: 0.146 ; iou: 0.031 ; dice: 0.049 ; sensitivity: 0.085 ; auroc: 0.803 1035 | Finish iter: 0, loss 0.12216303497552872 1036 | Finish iter: 20, loss 0.11272578686475754 1037 | Finish iter: 40, loss 0.08755368739366531 1038 | Finish iter: 60, loss 0.09321939945220947 1039 | Finish iter: 80, loss 0.11285552382469177 1040 | Finish iter: 100, loss 0.12223965674638748 1041 | Finish iter: 120, loss 0.12580323219299316 1042 | Finish iter: 140, loss 0.1302492320537567 1043 | Finish iter: 160, loss 0.10205391049385071 1044 | Finish iter: 180, loss 0.12916462123394012 1045 | Training epoch: 79, loss: 0.11535406857728958, time elapsed: 224.73864769935608, 1046 | Split: train, epoch: 79, metrics: accuracy: 0.146 ; iou: 0.027 ; dice: 0.045 ; sensitivity: 0.100 ; auroc: 0.811 1047 | Split: val, epoch: 79, metrics: accuracy: 0.150 ; iou: 0.025 ; dice: 0.043 ; sensitivity: 0.099 ; auroc: 0.804 1048 | Finish iter: 0, loss 0.11765128374099731 1049 | Finish iter: 20, loss 0.0998358204960823 1050 | Finish iter: 40, loss 0.12034642696380615 1051 | Finish iter: 60, loss 0.12326247245073318 1052 | Finish iter: 80, loss 0.12309829890727997 1053 | Finish iter: 100, loss 0.10554520040750504 1054 | Finish iter: 120, loss 0.11247579753398895 1055 | Finish iter: 140, loss 0.11241700500249863 1056 | Finish iter: 160, loss 0.11756744235754013 1057 | Finish iter: 180, loss 0.10487847030162811 1058 | Training epoch: 80, loss: 0.09332511574029922, time elapsed: 224.6728298664093, 1059 | Split: train, epoch: 80, metrics: accuracy: 0.132 ; iou: 0.027 ; dice: 0.045 ; sensitivity: 0.079 ; auroc: 0.830 1060 | Split: val, epoch: 80, metrics: accuracy: 0.142 ; iou: 0.027 ; dice: 0.046 ; sensitivity: 0.081 ; auroc: 0.827 1061 | Finish iter: 0, loss 0.10732047259807587 1062 | slurmstepd-gpu20-09: error: *** JOB 4164907 ON gpu20-09 CANCELLED AT 2021-03-29T22:20:17 DUE TO TIME LIMIT *** 1063 | -------------------------------------------------------------------------------- /fcn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models.vgg as vgg 3 | 4 | class Segnet(nn.Module): 5 | ''' 6 | Fully Convolutional Network (FCN) 7 | 8 | Performs fcn on the inputs and returns the feature map. This section is a decoder of the fcn network which starts from the 7th layer of pretrained vgg model. 9 | Args: 10 | n_classes: number of classes to be predicted 11 | 12 | Returns: 13 | feature map size=(N, n_class, x.H/1, x.W/1) 14 | 15 | 16 | ''' 17 | def __init__(self, n_classes): 18 | super(Segnet, self).__init__() 19 | self.vgg_model = vgg.vgg16(pretrained=True, progress=True)#.to(device) 20 | #del self.vgg_model.classifier 21 | self.relu = nn.ReLU(inplace=True) 22 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 23 | self.bn1 = nn.BatchNorm2d(512) 24 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 25 | self.bn2 = nn.BatchNorm2d(256) 26 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 27 | self.bn3 = nn.BatchNorm2d(128) 28 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 29 | self.bn4 = nn.BatchNorm2d(64) 30 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 31 | self.bn5 = nn.BatchNorm2d(32) 32 | self.classifier = nn.Conv2d(32, n_classes, kernel_size=1) 33 | 34 | def forward(self, x): 35 | x = self.vgg_model.features(x) # B, 36 | output = self.vgg_model.avgpool(x) # B, 512, 512, 7 37 | score = self.bn1(self.relu(self.deconv1(x))) # size=(N, 512, x.H/16, x.W/16) 38 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 39 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 40 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 41 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 42 | score = self.classifier(score) # size=(N, n_classes, x.H/1, x.W/1) 43 | return score # size=(N, n_class, x.H/1, x.W/1) 44 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import time 7 | import pdb 8 | import sys 9 | import os 10 | 11 | from fcn import Segnet 12 | from r2unet import U_Net, R2U_Net, RecU_Net, ResU_Net 13 | #from deeplabv3_torchvision import DeepLabHead 14 | from deeplabv3 import DeepLabV3 15 | from dataloader import load_dataset 16 | from metrics import Metrics 17 | from vis import Vis 18 | 19 | from torchvision.models.segmentation.segmentation import deeplabv3_resnet50 20 | 21 | expt_logdir = sys.argv[1] 22 | os.makedirs(expt_logdir, exist_ok=True) 23 | 24 | #Dataset parameters 25 | num_workers = 8 26 | batch_size = 16 27 | n_classes = 20 28 | img_size = 224 29 | test_split = 'val' 30 | 31 | # Training parameters 32 | epochs = 300 #use 200 33 | lr = 0.001 34 | decayRate = 0.96 35 | #TODO weight decay, plot results for validation data 36 | 37 | # Logging options 38 | i_save = 50#save model after every i_save epochs 39 | i_vis = 10 40 | rows, cols = 5, 2 #Show 10 images in the dataset along with target and predicted masks 41 | 42 | # Setting up the device 43 | device = torch.device("cuda")# if torch.cuda.is_available() else "cpu") 44 | num_gpu = list(range(torch.cuda.device_count())) 45 | 46 | #Loading training and testing data 47 | trainloader, train_dst = load_dataset(batch_size, num_workers, split='train') 48 | testloader, test_dst = load_dataset(batch_size, num_workers, split=test_split) 49 | 50 | # Creating an instance of the model 51 | #model = Segnet(n_classes) #Fully Convolutional Networks 52 | #model = U_Net(img_ch=3,output_ch=n_classes) #U Network 53 | #model = R2U_Net(img_ch=3,output_ch=n_classes,t=2) #Residual Recurrent U Network, R2Unet (t=2) 54 | #model = R2U_Net(img_ch=3,output_ch=n_classes,t=3) #Residual Recurrent U Network, R2Unet (t=3) 55 | #model = RecU_Net(img_ch=3,output_ch=n_classes,t=2) #Recurrent U Network, RecUnet (t=2) 56 | #model = ResU_Net(img_ch=3,output_ch=n_classes) #Residual U Network, ResUnet 57 | #model = DeepLabV3(n_classes, 'vgg') #DeepLabV3 VGG backbone 58 | model = DeepLabV3(n_classes, 'resnet') #DeepLabV3 Resnet backbone 59 | 60 | print('Experiment logs for model: {}'.format(model.__class__.__name__)) 61 | 62 | model = nn.DataParallel(model, device_ids=num_gpu).to(device) 63 | # loss function 64 | loss_f = nn.CrossEntropyLoss() #TODO s ignore_index required? ignore_index=19 65 | 66 | # optimizer variable 67 | opt = optim.Adam(model.parameters(), lr=lr) 68 | lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=decayRate) 69 | #torch.optim.lr_scheduler.StepLR(optimizer,step_size=3, gamma=0.1) 70 | 71 | #TODO random seed 72 | #Visualization of train and test data 73 | train_vis = Vis(train_dst, expt_logdir, rows, cols) 74 | test_vis = Vis(test_dst, expt_logdir, rows, cols) 75 | 76 | #Metrics calculator for train and test data 77 | train_metrics = Metrics(n_classes, trainloader, 'train', device, expt_logdir) 78 | test_metrics = Metrics(n_classes, testloader, test_split, device, expt_logdir) 79 | 80 | epoch = -1 81 | train_metrics.compute(epoch, model) 82 | train_metrics.plot_scalar_metrics(epoch) 83 | train_metrics.plot_roc(epoch) 84 | train_vis.visualize(epoch, model) 85 | 86 | test_metrics.compute(epoch, model) 87 | test_metrics.plot_scalar_metrics(epoch) 88 | test_metrics.plot_roc(epoch) 89 | test_vis.visualize(epoch, model) 90 | 91 | #Training 92 | losses = [] 93 | for epoch in range(epochs): 94 | st = time.time() 95 | model.train() 96 | for i, (inputs, labels) in enumerate(trainloader): 97 | opt.zero_grad() 98 | inputs = inputs.to(device) 99 | labels = labels.to(device) 100 | predictions = model(inputs) 101 | loss = loss_f(predictions, labels) 102 | loss.backward() 103 | opt.step() 104 | if i % 20 == 0: 105 | print("Finish iter: {}, loss {}".format(i, loss.data)) 106 | lr_scheduler.step() 107 | losses.append(loss) 108 | print("Training epoch: {}, loss: {}, time elapsed: {},".format(epoch, loss, time.time() - st)) 109 | 110 | train_metrics.compute(epoch, model) 111 | test_metrics.compute(epoch, model) 112 | 113 | if epoch % i_save == 0: 114 | torch.save(model.state_dict(), os.path.join(expt_logdir, "{}.tar".format(epoch))) #file name example: '0.tar' 115 | if epoch % i_vis == 0: # Metric calculation and visualization 116 | test_metrics.plot_scalar_metrics(epoch) 117 | test_metrics.plot_roc(epoch) 118 | test_vis.visualize(epoch, model) 119 | 120 | train_metrics.plot_scalar_metrics(epoch) 121 | train_metrics.plot_roc(epoch) 122 | train_vis.visualize(epoch, model) 123 | 124 | train_metrics.plot_loss(epoch, losses) 125 | 126 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_lightning import metrics 4 | import matplotlib.pyplot as plt 5 | import os 6 | import pdb 7 | 8 | class Dice(metrics.Metric): 9 | ''' 10 | Module to calcuate Dice metric 11 | 12 | Args: 13 | None 14 | 15 | Returns: 16 | None 17 | 18 | ''' 19 | def __init__(self): 20 | super().__init__() 21 | self.add_state("dice_score", default=[]) 22 | 23 | def update(self, pred, target): 24 | ''' 25 | Updates the parameters of dice coefficient 26 | Args: 27 | pred: The predicted value from the net 28 | target: The target value given 29 | 30 | Returns: 31 | None 32 | ''' 33 | dice_score_val = metrics.functional.classification.dice_score(pred, target, bg=True) 34 | self.dice_score.append(dice_score_val.item()) 35 | 36 | def compute(self): 37 | ''' 38 | Computes the dice coefficient for the given parameters 39 | Args: 40 | None 41 | Returns: 42 | dice_score 43 | ''' 44 | self.dice_score = torch.tensor(self.dice_score) 45 | return torch.mean(self.dice_score) 46 | 47 | 48 | class Metrics(): 49 | ''' 50 | Metrics Calculator 51 | 52 | Calculates the required metrics for the given dataset and model. 53 | The metrics calculated are accuracy, iou, dice score, sensitivity, aucroc. 54 | Args: 55 | n_classes: Number of classes to predict 56 | dataloader: Dataloader of the dataset for which metric calculation is performed. 57 | split: Takes string input of 'train' or 'val' or any split provided by dataloader for training or validation data respectively 58 | device: Device value that contains the model. 59 | expt_logdir: Path to store the plots 60 | 61 | Returns: 62 | None 63 | ''' 64 | def __init__(self, n_classes, dataloader, split, device, expt_logdir): 65 | self.dataloader = dataloader 66 | self.device = device 67 | accuracy = metrics.Accuracy().to(self.device) 68 | iou = metrics.IoU(num_classes=n_classes).to(self.device) 69 | dice = Dice().to(self.device) 70 | recall = metrics.Recall(num_classes=n_classes,average='macro', mdmc_average='global').to(self.device) 71 | roc = metrics.ROC(num_classes=n_classes,dist_sync_on_step=True).to(self.device) 72 | 73 | self.eval_metrics = {'accuracy': {'module': accuracy, 'values': []}, 74 | 'iou': {'module': iou, 'values': []}, 75 | 'dice': {'module': dice, 'values': []}, 76 | 'sensitivity': {'module': recall, 'values': []}, 77 | 'auroc': {'module': roc, 'values': []} 78 | } 79 | self.softmax = nn.Softmax(dim=1) 80 | self.expt_logdir = expt_logdir 81 | self.split = split 82 | 83 | def compute_auroc(self, value): #computes aucroc 84 | self.fpr, self.tpr, _ = value 85 | auc_scores = [torch.trapz(y, x) for x, y in zip(self.fpr, self.tpr)] 86 | return torch.mean(torch.stack(auc_scores)) 87 | 88 | def compute(self, epoch, model): #computes the metrics 89 | model.eval() 90 | with torch.no_grad(): 91 | for i, (inputs, labels) in enumerate(self.dataloader): 92 | inputs = inputs.to(self.device)#N, H, W 93 | labels = labels.to(self.device) #N, H, W 94 | 95 | predictions = model(inputs) #N, C, H, W 96 | predictions = self.softmax(predictions) 97 | 98 | for key in self.eval_metrics: 99 | #Evaluate AUC/ROC on subset of the training data, otherwise leads to OOM errors on GPU 100 | #Full evaluation on validation/test data 101 | if key == 'auroc' and i > 20: 102 | continue 103 | self.eval_metrics[key]['module'].update(predictions, labels) 104 | 105 | for key in self.eval_metrics: 106 | value = self.eval_metrics[key]['module'].compute() 107 | if key == 'auroc': 108 | value = self.compute_auroc(value) 109 | self.eval_metrics[key]['values'].append(value.item()) 110 | self.eval_metrics[key]['module'].reset() 111 | 112 | metrics_string = " ; ".join("{}: {:05.3f}".format(key, self.eval_metrics[key]['values'][-1]) for key in self.eval_metrics) 113 | print("Split: {}, epoch: {}, metrics: ".format(self.split, epoch) + metrics_string) 114 | 115 | def plot_scalar_metrics(self, epoch): #for ploting the scalar metrics against epochs 116 | fig = plt.figure(figsize=(13, 5)) 117 | ax = fig.gca() 118 | for key, metric in self.eval_metrics.items(): 119 | ax.plot(metric['values'], label=key) 120 | ax.legend(fontsize="16") 121 | ax.set_xlabel("Epochs", fontsize="16") 122 | ax.set_ylabel("Metric", fontsize="16") 123 | ax.set_title("Evaluation metric vs epochs", fontsize="16") 124 | plt.savefig(os.path.join(self.expt_logdir, 'metric_{}_{}.png'.format(self.split, epoch))) #example file name: 'metric_seg_val_100.png' 125 | plt.clf() 126 | 127 | def plot_roc(self, epoch): #for plotting roc 128 | fig = plt.figure(figsize=(13, 5)) 129 | ax = fig.gca() 130 | trainId2Name = self.dataloader.dataset.trainId2Name 131 | for class_idx, (x, y) in enumerate(zip(self.fpr, self.tpr)): 132 | class_idx = 255 if class_idx == 19 else class_idx 133 | ax.plot(x.cpu().numpy(), y.cpu().numpy(), label=trainId2Name[class_idx]) 134 | ax.legend(fontsize="8", ncol=2, loc='lower right') 135 | ax.set_xlabel("FPR (False Positive Rate)", fontsize="16") 136 | ax.set_ylabel("TPR (True Positive Rate)", fontsize="16") 137 | ax.set_title("ROC Curve", fontsize="16") 138 | plt.savefig(os.path.join(self.expt_logdir, 'roc_{}_{}.png'.format(self.split, epoch))) #example file name: 'roc_seg_val_100.png' 139 | plt.clf() 140 | 141 | def plot_loss(self, epoch, losses): #for plotting losses against epochs 142 | fig = plt.figure(figsize=(13, 5)) 143 | ax = fig.gca() 144 | ax.plot(losses) 145 | ax.set_xlabel("Epochs", fontsize="16") 146 | ax.set_ylabel("Loss", fontsize="16") 147 | ax.set_title("Training loss vs. epochs", fontsize="16") 148 | plt.savefig(os.path.join(self.expt_logdir, 'loss_{}.png'.format(epoch))) #example file name: 'loss_100.png' 149 | plt.clf() 150 | 151 | -------------------------------------------------------------------------------- /r2unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class conv_block(nn.Module): 5 | ''' 6 | Block for convolutional layer of U-Net at the encoder end. 7 | Args: 8 | ch_in : number of input channels 9 | ch_out : number of outut channels 10 | Returns: 11 | feature map of the giv 12 | ''' 13 | def __init__(self,ch_in,ch_out): 14 | super(conv_block,self).__init__() 15 | self.conv = nn.Sequential( 16 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 17 | nn.BatchNorm2d(ch_out), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 20 | nn.BatchNorm2d(ch_out), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | 25 | def forward(self,x): 26 | x = self.conv(x) 27 | return x 28 | 29 | class up_conv(nn.Module): 30 | ''' 31 | Block for deconvolutional layer of U-Net at the decoder end 32 | Args: 33 | ch_in : number of input channels 34 | ch_out : number of outut channels 35 | Returns: 36 | feature map of the given input 37 | ''' 38 | def __init__(self,ch_in,ch_out): 39 | super(up_conv,self).__init__() 40 | self.up = nn.Sequential( 41 | nn.Upsample(scale_factor=2), 42 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 43 | nn.BatchNorm2d(ch_out), 44 | nn.ReLU(inplace=True) 45 | ) 46 | 47 | def forward(self,x): 48 | x = self.up(x) 49 | return x 50 | 51 | class Recurrent_block(nn.Module): 52 | ''' 53 | Recurrent convolution block for RU-Net and R2U-Net 54 | Args: 55 | ch_out : number of outut channels 56 | t: the number of recurrent convolution block to be used 57 | Returns: 58 | feature map of the given input 59 | ''' 60 | def __init__(self,ch_out,t=2): 61 | super(Recurrent_block,self).__init__() 62 | self.t = t 63 | self.ch_out = ch_out 64 | self.conv = nn.Sequential( 65 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 66 | nn.BatchNorm2d(ch_out), 67 | nn.ReLU(inplace=True) 68 | ) 69 | 70 | def forward(self,x): 71 | for i in range(self.t): 72 | 73 | if i==0: 74 | x1 = self.conv(x) 75 | 76 | x1 = self.conv(x+x1) 77 | return x1 78 | 79 | class RRCNN_block(nn.Module): 80 | ''' 81 | Recurrent Residual convolution block for R2U-Net 82 | Args: 83 | ch_in : number of input channels 84 | ch_out : number of outut channels 85 | t : the number of recurrent residual convolution block to be used 86 | Returns: 87 | feature map of the given input 88 | ''' 89 | def __init__(self,ch_in,ch_out,t=2): 90 | super(RRCNN_block,self).__init__() 91 | self.RCNN = nn.Sequential( 92 | Recurrent_block(ch_out,t=t), 93 | Recurrent_block(ch_out,t=t) 94 | ) 95 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 96 | 97 | def forward(self,x): 98 | x = self.Conv_1x1(x) 99 | x1 = self.RCNN(x) 100 | return x+x1 #residual learning 101 | 102 | class RCNN_block(nn.Module): 103 | ''' 104 | Recurrent convolution block for RU-Net 105 | Args: 106 | ch_in : number of input channels 107 | ch_out : number of outut channels 108 | t : the number of recurrent residual convolution block to be used 109 | Returns: 110 | feature map of the given input 111 | ''' 112 | def __init__(self,ch_in,ch_out,t=2): 113 | super(RCNN_block,self).__init__() 114 | self.RCNN = nn.Sequential( 115 | Recurrent_block(ch_out,t=t), 116 | Recurrent_block(ch_out,t=t) 117 | ) 118 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 119 | 120 | def forward(self,x): 121 | x = self.Conv_1x1(x) 122 | x = self.RCNN(x) 123 | return x 124 | 125 | class ResCNN_block(nn.Module): 126 | ''' 127 | Residual convolution block 128 | Args: 129 | ch_in : number of input channels 130 | ch_out : number of outut channels 131 | 132 | Returns: 133 | feature map of the given input 134 | ''' 135 | def __init__(self,ch_in,ch_out): 136 | super(ResCNN_block,self).__init__() 137 | self.Conv = conv_block(ch_in, ch_out) 138 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 139 | 140 | def forward(self,x): 141 | x1 = self.Conv_1x1(x) 142 | x = self.Conv(x) 143 | return x+x1 144 | 145 | class U_Net(nn.Module): 146 | ''' 147 | U-Net Network. 148 | Implements traditional U-Net with a compressive encoder and an expanding decoder 149 | 150 | Args: 151 | img_ch: Input image channels 152 | output_ch: Number of channels expected in the output 153 | 154 | Returns: 155 | Feature map of input (batch_size, output_ch=1,h,w) 156 | ''' 157 | def __init__(self,img_ch=3,output_ch=1): 158 | super(U_Net,self).__init__() 159 | 160 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 161 | 162 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 163 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 164 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 165 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 166 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 167 | 168 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 169 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 170 | 171 | self.Up4 = up_conv(ch_in=512,ch_out=256) 172 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 173 | 174 | self.Up3 = up_conv(ch_in=256,ch_out=128) 175 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 176 | 177 | self.Up2 = up_conv(ch_in=128,ch_out=64) 178 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 179 | 180 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 181 | 182 | 183 | def forward(self,x): 184 | # encoding path 185 | x1 = self.Conv1(x) 186 | 187 | x2 = self.Maxpool(x1) 188 | x2 = self.Conv2(x2) 189 | 190 | x3 = self.Maxpool(x2) 191 | x3 = self.Conv3(x3) 192 | 193 | x4 = self.Maxpool(x3) 194 | x4 = self.Conv4(x4) 195 | 196 | x5 = self.Maxpool(x4) 197 | x5 = self.Conv5(x5) 198 | 199 | # decoding + concat path 200 | d5 = self.Up5(x5) 201 | d5 = torch.cat((x4,d5),dim=1) 202 | 203 | d5 = self.Up_conv5(d5) 204 | 205 | d4 = self.Up4(d5) 206 | d4 = torch.cat((x3,d4),dim=1) 207 | d4 = self.Up_conv4(d4) 208 | 209 | d3 = self.Up3(d4) 210 | d3 = torch.cat((x2,d3),dim=1) 211 | d3 = self.Up_conv3(d3) 212 | 213 | d2 = self.Up2(d3) 214 | d2 = torch.cat((x1,d2),dim=1) 215 | d2 = self.Up_conv2(d2) 216 | 217 | d1 = self.Conv_1x1(d2) 218 | 219 | return d1 220 | 221 | 222 | class R2U_Net(nn.Module): 223 | ''' 224 | R2U-Net Network. 225 | Implements U-Net with a RRCNN block. 226 | 227 | Args: 228 | img_ch: Input image channels 229 | output_ch: Number of channels expected in the output 230 | t: number of recurrent blocks expected 231 | 232 | Returns: 233 | Feature map of input (batch_size, output_ch=1,h,w) 234 | ''' 235 | def __init__(self,img_ch=3,output_ch=1,t=2): 236 | super(R2U_Net,self).__init__() 237 | 238 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 239 | self.Upsample = nn.Upsample(scale_factor=2) 240 | 241 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 242 | 243 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 244 | 245 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 246 | 247 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 248 | 249 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 250 | 251 | 252 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 253 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 254 | 255 | self.Up4 = up_conv(ch_in=512,ch_out=256) 256 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 257 | 258 | self.Up3 = up_conv(ch_in=256,ch_out=128) 259 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 260 | 261 | self.Up2 = up_conv(ch_in=128,ch_out=64) 262 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 263 | 264 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 265 | 266 | 267 | def forward(self,x): 268 | # encoding path 269 | x1 = self.RRCNN1(x) 270 | 271 | x2 = self.Maxpool(x1) 272 | x2 = self.RRCNN2(x2) 273 | 274 | x3 = self.Maxpool(x2) 275 | x3 = self.RRCNN3(x3) 276 | 277 | x4 = self.Maxpool(x3) 278 | x4 = self.RRCNN4(x4) 279 | 280 | x5 = self.Maxpool(x4) 281 | x5 = self.RRCNN5(x5) 282 | 283 | # decoding + concat path 284 | d5 = self.Up5(x5) 285 | d5 = torch.cat((x4,d5),dim=1) 286 | d5 = self.Up_RRCNN5(d5) 287 | 288 | d4 = self.Up4(d5) 289 | d4 = torch.cat((x3,d4),dim=1) 290 | d4 = self.Up_RRCNN4(d4) 291 | 292 | d3 = self.Up3(d4) 293 | d3 = torch.cat((x2,d3),dim=1) 294 | d3 = self.Up_RRCNN3(d3) 295 | 296 | d2 = self.Up2(d3) 297 | d2 = torch.cat((x1,d2),dim=1) 298 | d2 = self.Up_RRCNN2(d2) 299 | 300 | d1 = self.Conv_1x1(d2) 301 | 302 | return d1 303 | 304 | class RecU_Net(nn.Module): 305 | ''' 306 | RU-Net Network. 307 | Implements U-Net with a RCNN block. 308 | 309 | Args: 310 | img_ch: Input image channels 311 | output_ch: Number of channels expected in the output 312 | t: number of recurrent blocks expected 313 | 314 | Returns: 315 | Feature map of input (batch_size, output_ch=1,h,w) 316 | ''' 317 | def __init__(self,img_ch=3,output_ch=1,t=2): 318 | super(RecU_Net,self).__init__() 319 | 320 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 321 | self.Upsample = nn.Upsample(scale_factor=2) 322 | 323 | self.RCNN1 = RCNN_block(ch_in=img_ch,ch_out=64,t=t) 324 | 325 | self.RCNN2 = RCNN_block(ch_in=64,ch_out=128,t=t) 326 | 327 | self.RCNN3 = RCNN_block(ch_in=128,ch_out=256,t=t) 328 | 329 | self.RCNN4 = RCNN_block(ch_in=256,ch_out=512,t=t) 330 | 331 | self.RCNN5 = RCNN_block(ch_in=512,ch_out=1024,t=t) 332 | 333 | 334 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 335 | self.Up_RCNN5 = RCNN_block(ch_in=1024, ch_out=512,t=t) 336 | 337 | self.Up4 = up_conv(ch_in=512,ch_out=256) 338 | self.Up_RCNN4 = RCNN_block(ch_in=512, ch_out=256,t=t) 339 | 340 | self.Up3 = up_conv(ch_in=256,ch_out=128) 341 | self.Up_RCNN3 = RCNN_block(ch_in=256, ch_out=128,t=t) 342 | 343 | self.Up2 = up_conv(ch_in=128,ch_out=64) 344 | self.Up_RCNN2 = RCNN_block(ch_in=128, ch_out=64,t=t) 345 | 346 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 347 | 348 | 349 | def forward(self,x): 350 | # encoding path 351 | x1 = self.RCNN1(x) 352 | 353 | x2 = self.Maxpool(x1) 354 | x2 = self.RCNN2(x2) 355 | 356 | x3 = self.Maxpool(x2) 357 | x3 = self.RCNN3(x3) 358 | 359 | x4 = self.Maxpool(x3) 360 | x4 = self.RCNN4(x4) 361 | 362 | x5 = self.Maxpool(x4) 363 | x5 = self.RCNN5(x5) 364 | 365 | # decoding + concat path 366 | d5 = self.Up5(x5) 367 | d5 = torch.cat((x4,d5),dim=1) 368 | d5 = self.Up_RCNN5(d5) 369 | 370 | d4 = self.Up4(d5) 371 | d4 = torch.cat((x3,d4),dim=1) 372 | d4 = self.Up_RCNN4(d4) 373 | 374 | d3 = self.Up3(d4) 375 | d3 = torch.cat((x2,d3),dim=1) 376 | d3 = self.Up_RCNN3(d3) 377 | 378 | d2 = self.Up2(d3) 379 | d2 = torch.cat((x1,d2),dim=1) 380 | d2 = self.Up_RCNN2(d2) 381 | 382 | d1 = self.Conv_1x1(d2) 383 | 384 | return d1 385 | 386 | class ResU_Net(nn.Module): 387 | ''' 388 | Residual U-Net Network. 389 | Implements U-Net with a ResCNN block. 390 | 391 | Args: 392 | img_ch: Input image channels 393 | output_ch: Number of channels expected in the output 394 | 395 | Returns: 396 | Feature map of size (batch_size, output_ch,h,w) 397 | ''' 398 | def __init__(self,img_ch=3,output_ch=1): 399 | super(ResU_Net,self).__init__() 400 | 401 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 402 | self.Upsample = nn.Upsample(scale_factor=2) 403 | 404 | self.ResCNN1 = ResCNN_block(ch_in=img_ch,ch_out=64) 405 | 406 | self.ResCNN2 = ResCNN_block(ch_in=64,ch_out=128) 407 | 408 | self.ResCNN3 = ResCNN_block(ch_in=128,ch_out=256) 409 | 410 | self.ResCNN4 = ResCNN_block(ch_in=256,ch_out=512) 411 | 412 | self.ResCNN5 = ResCNN_block(ch_in=512,ch_out=1024) 413 | 414 | 415 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 416 | self.Up_ResCNN5 = ResCNN_block(ch_in=1024, ch_out=512) 417 | 418 | self.Up4 = up_conv(ch_in=512,ch_out=256) 419 | self.Up_ResCNN4 = ResCNN_block(ch_in=512, ch_out=256) 420 | 421 | self.Up3 = up_conv(ch_in=256,ch_out=128) 422 | self.Up_ResCNN3 = ResCNN_block(ch_in=256, ch_out=128) 423 | 424 | self.Up2 = up_conv(ch_in=128,ch_out=64) 425 | self.Up_ResCNN2 = ResCNN_block(ch_in=128, ch_out=64) 426 | 427 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 428 | 429 | 430 | def forward(self,x): 431 | # encoding path 432 | x1 = self.ResCNN1(x) 433 | 434 | x2 = self.Maxpool(x1) 435 | x2 = self.ResCNN2(x2) 436 | 437 | x3 = self.Maxpool(x2) 438 | x3 = self.ResCNN3(x3) 439 | 440 | x4 = self.Maxpool(x3) 441 | x4 = self.ResCNN4(x4) 442 | 443 | x5 = self.Maxpool(x4) 444 | x5 = self.ResCNN5(x5) 445 | 446 | # decoding + concat path 447 | d5 = self.Up5(x5) 448 | d5 = torch.cat((x4,d5),dim=1) 449 | d5 = self.Up_ResCNN5(d5) 450 | 451 | d4 = self.Up4(d5) 452 | d4 = torch.cat((x3,d4),dim=1) 453 | d4 = self.Up_ResCNN4(d4) 454 | 455 | d3 = self.Up3(d4) 456 | d3 = torch.cat((x2,d3),dim=1) 457 | d3 = self.Up_ResCNN3(d3) 458 | 459 | d2 = self.Up2(d3) 460 | d2 = torch.cat((x1,d2),dim=1) 461 | d2 = self.Up_ResCNN2(d2) 462 | 463 | d1 = self.Conv_1x1(d2) 464 | 465 | return d1 466 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # NOTE! OS: output stride, the ratio of input image resolution to final output resolution (OS16: output size is (img_h/16, img_w/16)) 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1): 9 | strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1]) 10 | 11 | blocks = [] 12 | for stride in strides: 13 | blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation)) 14 | in_channels = block.expansion*channels 15 | 16 | layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments) 17 | 18 | return layer 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, in_channels, channels, stride=1, dilation=1): 24 | super(Bottleneck, self).__init__() 25 | 26 | out_channels = self.expansion*channels 27 | 28 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False) 29 | self.bn1 = nn.BatchNorm2d(channels) 30 | 31 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 32 | self.bn2 = nn.BatchNorm2d(channels) 33 | 34 | self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_channels) 36 | 37 | if (stride != 1) or (in_channels != out_channels): 38 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 39 | bn = nn.BatchNorm2d(out_channels) 40 | self.downsample = nn.Sequential(conv, bn) 41 | else: 42 | self.downsample = nn.Sequential() 43 | 44 | def forward(self, x): 45 | # (x has shape: (batch_size, in_channels, h, w)) 46 | 47 | out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w)) 48 | out = F.relu(self.bn2(self.conv2(out))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) 49 | out = self.bn3(self.conv3(out)) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) 50 | 51 | out = out + self.downsample(x) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) 52 | 53 | out = F.relu(out) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) 54 | 55 | return out 56 | 57 | class ResNet_Bottleneck_OS16(nn.Module): 58 | def __init__(self, num_layers): 59 | super(ResNet_Bottleneck_OS16, self).__init__() 60 | 61 | if num_layers == 50: 62 | resnet = models.resnet50() 63 | # load pretrained model: 64 | resnet.load_state_dict(torch.load("pretrained_models/resnet50-19c8e357.pth")) 65 | # remove fully connected layer, avg pool and layer5: 66 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 67 | 68 | print ("pretrained resnet, 50") 69 | else: 70 | raise Exception("num_layers must be in {50}!") 71 | 72 | self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2) 73 | 74 | def forward(self, x): 75 | # (x has shape (batch_size, 3, h, w)) 76 | 77 | # pass x through (parts of) the pretrained ResNet: 78 | c4 = self.resnet(x) # (shape: (batch_size, 4*256, h/16, w/16)) (it's called c4 since 16 == 2^4) 79 | 80 | output = self.layer5(c4) # (shape: (batch_size, 4*512, h/16, w/16)) 81 | 82 | return output 83 | 84 | def ResNet50_OS16(): 85 | return ResNet_Bottleneck_OS16(num_layers=50) 86 | 87 | -------------------------------------------------------------------------------- /scripts/slurm_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu20 3 | #SBATCH -o /HPS/Navami/work/code/nnti/R2U-Net/logs/slurm-output/slurm-%j.out 4 | #SBATCH -t 0-08:00:00 5 | #SBATCH --gres gpu:2 6 | 7 | cd /HPS/Navami/work/code/nnti/R2U-Net 8 | #sbatch scripts/slurm_run.sh 9 | 10 | ## RUN 11 | # Make conda available: 12 | eval "$(conda shell.bash hook)" 13 | # Activate a conda environment: 14 | conda activate nnti 15 | 16 | #python -u main.py logs/expt1_0 #FCN, bs=16 17 | #python -u main.py logs/expt2_0 #U-Net, bs=16 18 | #python -u main.py logs/expt3_0 #R2U-Net (t=2), bs=16 19 | #python -u main.py logs/expt4_0 #R2U-Net (t=3), bs=8 20 | #python -u main.py logs/expt5_0 #Recurrent U-Net 21 | #python -u main.py logs/expt6_0 #Residual U-Net 22 | #python -u main.py logs/expt7_0 #DeepLab V3 VGG backbone, bs=16 23 | #python -u main.py logs/expt8_0 #DeepLab V3 Resnet backbone, bs=16 24 | python -u main.py logs/expt8_1 #DeepLab V3 Resnet backbone, bs=16, LR decay -------------------------------------------------------------------------------- /scripts/slurm_setup.sh: -------------------------------------------------------------------------------- 1 | # Python modules remain same as task 1 2 | 3 | # Dataset preparation 4 | # Download and unzip gtFine_trainvaltest.zip (241MB) and leftImg8bit_trainvaltest.zip (11GB) from cityscapes site 5 | https://www.cityscapes-dataset.com/downloads/ 6 | 7 | # Generate trainId labels for the dataset, using the scripts provided by Cityscape authors https://github.com/mcordts/cityscapesScripts 8 | git clone https://github.com/mcordts/cityscapesScripts.git 9 | pip install cityscapesScripts 10 | CITYSCAPES_DATASET_PATH=/HPS/Navami/work/code/nnti/R2U-Net/cityscapes/ 11 | export CITYSCAPES_DATASET=$CITYSCAPES_DATASET_PATH 12 | python cityscapesScripts/cityscapesscripts/preparation/createTrainIdLabelImgs.py -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | import pdb 6 | 7 | 8 | mean = torch.tensor([0.485, 0.456, 0.406]) 9 | std = torch.tensor([0.229, 0.224, 0.225]) 10 | 11 | def image_grid(images, rows=None, cols=None, fill=True, show_axes=False): 12 | """ 13 | A util function for plotting a grid of images. 14 | 15 | Args: 16 | images: (N, H, W, 4) array of RGBA images 17 | rows: number of rows in the grid 18 | cols: number of columns in the grid 19 | fill: boolean indicating if the space between images should be filled 20 | show_axes: boolean indicating if the axes of the plots should be visible 21 | rgb: boolean, If True, only RGB channels are plotted. 22 | If False, only the alpha channel is plotted. 23 | 24 | Returns: 25 | None 26 | """ 27 | if (rows is None) != (cols is None): 28 | raise ValueError("Specify either both rows and cols or neither.") 29 | 30 | if rows is None: 31 | rows = len(images) 32 | cols = 1 33 | 34 | gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} 35 | fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)) 36 | 37 | for ax, im in zip(axarr.ravel(), images): 38 | # only render RGB channels 39 | ax.imshow(im[..., :3]) 40 | if not show_axes: 41 | ax.set_axis_off() 42 | 43 | class Vis(): 44 | """ 45 | Visualization module 46 | Saves the visualized segmentation images of dataset provided. 47 | 48 | Args: 49 | dst: train or validation dataset 50 | expt_logdir: number of rows in the grid 51 | rows: number of rows of the image 52 | cols: number of columns of the image 53 | 54 | 55 | Returns: 56 | None 57 | """ 58 | def __init__(self, dst, expt_logdir, rows, cols): 59 | 60 | self.dst = dst 61 | self.expt_logdir = expt_logdir 62 | self.rows = rows 63 | self.cols = cols 64 | self.images = [] 65 | self.images_vis = [] 66 | self.labels_vis = [] 67 | image_ids = np.random.randint(len(dst), size=rows*cols) 68 | 69 | for image_id in image_ids: 70 | image, label = dst[image_id][0], dst[image_id][1] 71 | image = image[None, ...] 72 | self.images.append(image) 73 | 74 | image = torch.squeeze(image) 75 | image = image * std[:, None, None] + mean[:, None, None] 76 | image = torch.movedim(image, 0, -1) # (3,H,W) to (H,W,3) 77 | image = image.cpu().numpy() 78 | self.images_vis.append(image) 79 | 80 | label = label.cpu().numpy() 81 | label = dst.decode_segmap(label) 82 | self.labels_vis.append(label) 83 | 84 | self.images = torch.cat(self.images, axis=0) 85 | 86 | def visualize(self, epoch, model): 87 | 88 | prediction = model(self.images) #TODO move to device? 89 | prediction = torch.argmax(prediction, dim=1) 90 | prediction = prediction.cpu().numpy() 91 | 92 | rgb_vis = [] 93 | for image, label, pred in zip(self.images_vis, self.labels_vis, prediction): 94 | pred = self.dst.decode_segmap(pred) 95 | rgb_vis.extend([image, label, pred]) 96 | rgb_vis = np.array(rgb_vis) 97 | 98 | image_grid(rgb_vis, rows=self.rows, cols=3*self.cols) 99 | plt.savefig(os.path.join(self.expt_logdir, 'seg_{}_{}.png'.format(self.dst.split, epoch))) #example file name: seg_val_0.png 100 | --------------------------------------------------------------------------------