├── README.md ├── Unet-Resnet34-Open Solution-FastAI.ipynb ├── lovasz_losses.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # TGS-Open-Solution-Fastai 2 | 3 | ## Kaggle Score: 0.829 4 | 5 | Main Architecture: Unet with ResNet34 encoder. 6 | 7 | Training: 8 | 9 | - RandomFlip augmentation 10 | - Batch sizes: 64, 128, 256 with cycle length = 30 11 | 12 | Cross Validation: 13 | 14 | - 5 folds 15 | - Test time augmentation with horizontal flip 16 | 17 | Final Submission: 18 | 19 | - Model Averaging of 5 models 20 | 21 | Getting Started? 22 | You can check [fast.ai lesson 14](http://course.fast.ai/lessons/lesson14.html) where Jeremy Howard an amazing mentor and teacher shows how to solve a similar problem. 23 | -------------------------------------------------------------------------------- /Unet-Resnet34-Open Solution-FastAI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "heading_collapsed": true 7 | }, 8 | "source": [ 9 | "## Imports" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "hidden": true, 17 | "scrolled": true 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "%matplotlib inline\n", 22 | "%reload_ext autoreload\n", 23 | "%autoreload 2\n", 24 | "from fastai.conv_learner import *\n", 25 | "from fastai.dataset import *\n", 26 | "from fastai.models.resnet import vgg_resnet50\n", 27 | "from fastai.models.senet import *\n", 28 | "from skimage.transform import resize\n", 29 | "import json\n", 30 | "from sklearn.model_selection import train_test_split, StratifiedKFold , KFold\n", 31 | "from sklearn.metrics import jaccard_similarity_score\n", 32 | "from pycocotools import mask as cocomask\n", 33 | "from utils import my_eval,intersection_over_union_thresholds,RLenc\n", 34 | "from lovasz_losses import lovasz_hinge\n", 35 | "print(torch.__version__)\n", 36 | "torch.cuda.is_available()\n", 37 | "torch.backends.cudnn.benchmark=True" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "heading_collapsed": true 44 | }, 45 | "source": [ 46 | "## Paths" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": { 53 | "hidden": true 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "PATH = Path('data/tgs/')\n", 58 | "TRN_MASKS = 'trn_masks'\n", 59 | "TRN_IMG = 'trn_images'\n", 60 | "TRN_MSK = 'trn_masks'\n", 61 | "TST_IMG = 'tst_images'\n", 62 | "trn = pd.read_csv(PATH/'train.csv')\n", 63 | "dpth = pd.read_csv(PATH/'depths.csv')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": { 70 | "hidden": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "def show_img(im, figsize=None, ax=None, alpha=None):\n", 75 | " if not ax: fig,ax = plt.subplots(figsize=figsize)\n", 76 | " ax.imshow(im, alpha=alpha)\n", 77 | " ax.set_axis_off()\n", 78 | " return ax" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "heading_collapsed": true 85 | }, 86 | "source": [ 87 | "## Datasets" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "metadata": { 94 | "hidden": true 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "class DepthDataset(Dataset):\n", 99 | " def __init__(self,ds,dpth_dict):\n", 100 | " self.dpth = dpth_dict\n", 101 | " self.ds = ds\n", 102 | " \n", 103 | " def __getitem__(self,i):\n", 104 | " val = self.ds[i]\n", 105 | " return val[0],self.dpth[self.ds.fnames[i].split('/')[1][:-4]],val[1]\n", 106 | " \n", 107 | "class MatchedFilesDataset(FilesDataset):\n", 108 | " def __init__(self, fnames, y, transform, path):\n", 109 | " self.y=y\n", 110 | " assert(len(fnames)==len(y))\n", 111 | " super().__init__(fnames, transform, path)\n", 112 | " \n", 113 | " def get_x(self, i): \n", 114 | " return open_image(os.path.join(self.path, self.fnames[i]))\n", 115 | " \n", 116 | " def get_y(self, i):\n", 117 | " return open_image(os.path.join(str(self.path), str(self.y[i])))\n", 118 | "\n", 119 | " def get_c(self): return 0\n", 120 | " \n", 121 | "class TestFilesDataset(FilesDataset):\n", 122 | " def __init__(self, fnames, y, transform,flip, path):\n", 123 | " self.y=y\n", 124 | " self.flip = flip\n", 125 | " super().__init__(fnames, transform, path)\n", 126 | " \n", 127 | " def get_x(self, i): \n", 128 | " im = open_image(os.path.join(self.path, self.fnames[i]))\n", 129 | " return np.fliplr(im) if self.flip else im\n", 130 | " \n", 131 | " def get_y(self, i):\n", 132 | " im = open_image(os.path.join(str(self.path), str(self.y[i])))\n", 133 | " return np.fliplr(im) if self.flip else im\n", 134 | " def get_c(self): return 0" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "heading_collapsed": true 141 | }, 142 | "source": [ 143 | "## Creating K-Fold" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 5, 149 | "metadata": { 150 | "hidden": true 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "x_names = np.array([f'{TRN_IMG}/{o.name}' for o in (PATH/TRN_MASKS).iterdir()])\n", 155 | "y_names = np.array([f'{TRN_MASKS}/{o.name}' for o in (PATH/TRN_MASKS).iterdir()])\n", 156 | "tst_x = np.array([f'{TST_IMG}/{o.name}' for o in (PATH/TST_IMG).iterdir()])\n", 157 | "f_name = [o.split('/')[-1] for o in x_names]\n", 158 | "\n", 159 | "c = dpth.set_index('id')\n", 160 | "dpth_dict = c['z'].to_dict()\n", 161 | "\n", 162 | "kf = 5\n", 163 | "kfold = KFold(n_splits=kf, shuffle=True, random_state=42)\n", 164 | "\n", 165 | "train_folds = []\n", 166 | "val_folds = []\n", 167 | "for idxs in kfold.split(f_name):\n", 168 | " train_folds.append([f_name[idx] for idx in idxs[0]])\n", 169 | " val_folds.append([f_name[idx] for idx in idxs[1]])\n", 170 | " " 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 6, 176 | "metadata": { 177 | "hidden": true 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "train_folds = pickle.load(open('train_folds.pkl',mode='rb'))\n", 182 | "val_folds = pickle.load(open('val_folds.pkl',mode='rb'))\n", 183 | "tst_x = pickle.load(open('tst_x.pkl',mode='rb'))\n" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": { 189 | "heading_collapsed": true 190 | }, 191 | "source": [ 192 | "## Unet Model" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "metadata": { 199 | "hidden": true 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "class SaveFeatures():\n", 204 | " features=None\n", 205 | " def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)\n", 206 | " def hook_fn(self, module, input, output): self.features = output\n", 207 | " def remove(self): self.hook.remove()\n", 208 | " \n", 209 | "class UnetBlock(nn.Module):\n", 210 | " def __init__(self, up_in, x_in, n_out):\n", 211 | " super().__init__()\n", 212 | " up_out = x_out = n_out//2\n", 213 | " self.x_conv = nn.Conv2d(x_in, x_out, 1)\n", 214 | " self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)\n", 215 | " self.bn = nn.BatchNorm2d(n_out)\n", 216 | " \n", 217 | " def forward(self, up_p, x_p):\n", 218 | " up_p = self.tr_conv(up_p)\n", 219 | " x_p = self.x_conv(x_p)\n", 220 | " cat_p = torch.cat([up_p,x_p], dim=1)\n", 221 | " return self.bn(F.relu(cat_p))\n", 222 | " \n", 223 | "class Unet34(nn.Module):\n", 224 | " def __init__(self, rn):\n", 225 | " super().__init__()\n", 226 | " self.rn = rn\n", 227 | " self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]\n", 228 | " self.up1 = UnetBlock(512,256,128)\n", 229 | " self.up2 = UnetBlock(128,128,128)\n", 230 | " self.up3 = UnetBlock(128,64,128)\n", 231 | " self.up4 = UnetBlock(128,64,128)\n", 232 | " self.up5 = nn.ConvTranspose2d(128, 1, 2, stride=2)\n", 233 | " \n", 234 | " def forward(self,img,depth):\n", 235 | " x = F.relu(self.rn(img))\n", 236 | " x = self.up1(x, self.sfs[3].features)\n", 237 | " x = self.up2(x, self.sfs[2].features)\n", 238 | " x = self.up3(x, self.sfs[1].features)\n", 239 | " x = self.up4(x, self.sfs[0].features)\n", 240 | " x = self.up5(x)\n", 241 | " return x[:,0]\n", 242 | " \n", 243 | " def close(self):\n", 244 | " for sf in self.sfs: sf.remove()\n", 245 | "\n", 246 | "\n", 247 | "class UnetModel():\n", 248 | " def __init__(self,model,lr_cut,name='unet'):\n", 249 | " self.model,self.name = model,name\n", 250 | " self.lr_cut = lr_cut\n", 251 | "\n", 252 | " def get_layer_groups(self, precompute):\n", 253 | " lgs = list(split_by_idxs(children(self.model.rn), [self.lr_cut]))\n", 254 | " return lgs + [children(self.model)[1:]]\n", 255 | " \n", 256 | " " 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 8, 262 | "metadata": { 263 | "hidden": true 264 | }, 265 | "outputs": [], 266 | "source": [ 267 | "def get_tgs_model():\n", 268 | " f = resnet34\n", 269 | " cut,lr_cut = model_meta[f]\n", 270 | " m_base = get_base(f,cut)\n", 271 | " m = to_gpu(Unet34(m_base))\n", 272 | " models = UnetModel(m,lr_cut)\n", 273 | " learn = ConvLearner(md, models)\n", 274 | " return learn\n", 275 | "\n", 276 | "def get_base(f,cut):\n", 277 | " layers = cut_model(f(True), cut)\n", 278 | " return nn.Sequential(*layers) \n" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": { 284 | "heading_collapsed": true 285 | }, 286 | "source": [ 287 | "## Training loop" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 12, 293 | "metadata": { 294 | "hidden": true, 295 | "scrolled": true 296 | }, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "[256, 64]\n", 303 | "fold_id0\n", 304 | "5_fold_simple_resnet34_0\n" 305 | ] 306 | }, 307 | { 308 | "data": { 309 | "application/vnd.jupyter.widget-view+json": { 310 | "model_id": "73bf451b359c4808be097145564a500b", 311 | "version_major": 2, 312 | "version_minor": 0 313 | }, 314 | "text/plain": [ 315 | "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" 316 | ] 317 | }, 318 | "metadata": {}, 319 | "output_type": "display_data" 320 | }, 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "epoch trn_loss val_loss my_eval \n", 326 | " 0 1.325157 0.93006 0.765038 \n", 327 | " 1 1.03828 0.913555 0.778321 \n", 328 | " 2 0.939755 0.824341 0.801629 \n", 329 | " 3 0.88104 0.808754 0.807644 \n", 330 | " 4 0.852355 0.849393 0.798747 \n", 331 | " 5 0.825028 0.805423 0.806266 \n", 332 | " 6 0.799593 0.807329 0.808271 \n", 333 | " 7 0.753976 0.780195 0.813283 \n", 334 | " 8 0.731826 0.771001 0.813033 \n", 335 | " 9 0.702984 0.773757 0.815539 \n", 336 | " 10 0.702384 0.875103 0.803258 \n", 337 | " 11 0.799662 0.888119 0.784085 \n", 338 | " 12 0.833936 0.828985 0.801754 \n", 339 | " 13 0.817738 0.80451 0.80802 \n", 340 | " 14 0.780242 0.826695 0.796491 \n", 341 | " 15 0.720219 0.791626 0.816541 \n", 342 | " 16 0.694303 0.796775 0.809398 \n", 343 | " 17 0.671596 0.772202 0.817544 \n", 344 | " 18 0.639161 0.788361 0.821053 \n", 345 | " 19 0.621297 0.773167 0.819799 \n", 346 | " 20 0.620868 0.905281 0.780952 \n", 347 | " 21 0.702979 0.856898 0.80401 \n", 348 | " 22 0.737761 0.812372 0.809524 \n", 349 | " 23 0.736216 0.825973 0.808271 \n", 350 | " 24 0.735905 0.814299 0.815915 \n", 351 | " 25 0.680798 0.77566 0.813033 \n", 352 | " 26 0.637793 0.781759 0.820426 \n", 353 | " 27 0.604662 0.776221 0.817544 \n", 354 | " 28 0.575798 0.775428 0.817794 \n", 355 | " 29 0.54858 0.78509 0.82005 \n", 356 | "\n", 357 | "0.8210526315789473\n", 358 | "fold_id1\n", 359 | "5_fold_simple_resnet34_1\n" 360 | ] 361 | }, 362 | { 363 | "data": { 364 | "application/vnd.jupyter.widget-view+json": { 365 | "model_id": "f5bad59687504bd8ba5ce0dae7b6b194", 366 | "version_major": 2, 367 | "version_minor": 0 368 | }, 369 | "text/plain": [ 370 | "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" 371 | ] 372 | }, 373 | "metadata": {}, 374 | "output_type": "display_data" 375 | }, 376 | { 377 | "name": "stdout", 378 | "output_type": "stream", 379 | "text": [ 380 | "epoch trn_loss val_loss my_eval \n", 381 | " 0 1.348665 1.033372 0.747494 \n", 382 | " 1 1.078376 0.959368 0.769674 \n", 383 | " 2 0.967243 0.929143 0.750376 \n", 384 | " 3 0.907108 0.909645 0.76391 \n", 385 | " 4 0.845982 0.840461 0.79198 \n", 386 | " 5 0.822131 0.858117 0.775689 \n", 387 | " 6 0.80265 0.841035 0.784712 \n", 388 | " 7 0.736551 0.830294 0.793233 \n", 389 | " 8 0.71209 0.826889 0.788722 \n", 390 | " 9 0.68642 0.818437 0.792732 \n", 391 | " 10 0.67733 0.87536 0.769298 \n", 392 | " 11 0.768952 0.91329 0.77619 \n", 393 | " 12 0.772509 0.873983 0.787845 \n", 394 | " 13 0.76916 0.835945 0.802256 \n", 395 | " 14 0.733515 0.822953 0.781579 \n", 396 | " 15 0.687887 0.840586 0.768045 \n", 397 | " 16 0.653878 0.812393 0.787845 \n", 398 | " 17 0.622058 0.802763 0.799123 \n", 399 | " 18 0.590854 0.782723 0.800627 \n", 400 | " 19 0.564545 0.780446 0.805138 \n", 401 | " 20 0.584671 0.92863 0.771679 \n", 402 | " 21 0.711783 0.856759 0.791729 \n", 403 | " 22 0.704852 0.854031 0.782456 \n", 404 | " 23 0.687806 0.917681 0.760526 \n", 405 | " 24 0.681609 0.808088 0.796115 \n", 406 | " 25 0.621207 0.820036 0.793108 \n", 407 | " 26 0.59361 0.843446 0.791855 \n", 408 | " 27 0.558614 0.78089 0.805388 \n", 409 | " 28 0.526056 0.795237 0.813409 \n", 410 | " 29 0.487891 0.759982 0.806767 \n", 411 | "\n", 412 | "0.813408521303258\n", 413 | "fold_id2\n", 414 | "5_fold_simple_resnet34_2\n" 415 | ] 416 | }, 417 | { 418 | "data": { 419 | "application/vnd.jupyter.widget-view+json": { 420 | "model_id": "9fea58c3fd9a4c46ad2201f7173c0250", 421 | "version_major": 2, 422 | "version_minor": 0 423 | }, 424 | "text/plain": [ 425 | "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" 426 | ] 427 | }, 428 | "metadata": {}, 429 | "output_type": "display_data" 430 | }, 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "epoch trn_loss val_loss my_eval \n", 436 | " 0 1.284321 1.039837 0.740226 \n", 437 | " 1 1.036117 0.917657 0.779073 \n", 438 | " 2 0.939302 0.862901 0.787469 \n", 439 | " 3 0.888223 0.843559 0.79198 \n", 440 | " 4 0.832163 0.804011 0.803383 \n", 441 | " 5 0.787491 0.81721 0.802256 \n", 442 | " 6 0.726067 0.760629 0.810902 \n", 443 | " 7 0.690015 0.747631 0.817794 \n", 444 | " 8 0.643923 0.724834 0.822306 \n", 445 | " 9 0.597543 0.736826 0.812281 \n", 446 | " 10 0.610031 0.85177 0.80614 \n", 447 | " 11 0.684985 0.956533 0.777569 \n", 448 | " 12 0.749024 0.841364 0.79787 \n", 449 | " 13 0.745144 0.886766 0.792857 \n", 450 | " 14 0.702938 0.766046 0.815539 \n", 451 | " 15 0.645297 0.744435 0.816416 \n", 452 | " 16 0.591229 0.78739 0.818922 \n", 453 | " 17 0.551063 0.742451 0.820677 \n", 454 | " 18 0.500171 0.755703 0.8099 \n", 455 | " 19 0.462674 0.748881 0.813659 \n", 456 | " 20 0.488585 0.876356 0.803133 \n", 457 | " 21 0.56171 0.868642 0.809398 \n", 458 | " 22 0.602444 0.842901 0.802632 \n", 459 | " 23 0.589505 0.927599 0.77193 \n", 460 | " 24 0.580468 0.85071 0.809649 \n", 461 | " 25 0.539967 0.775132 0.819048 \n", 462 | " 26 0.489004 0.795378 0.816667 \n", 463 | " 27 0.456576 0.778352 0.819674 \n", 464 | " 28 0.40038 0.766835 0.813534 \n", 465 | " 29 0.368988 0.795965 0.818421 \n", 466 | "\n", 467 | "0.8223057644110277\n", 468 | "fold_id3\n", 469 | "5_fold_simple_resnet34_3\n" 470 | ] 471 | }, 472 | { 473 | "data": { 474 | "application/vnd.jupyter.widget-view+json": { 475 | "model_id": "3229e9889bac41528d18ac2bcfdfcd81", 476 | "version_major": 2, 477 | "version_minor": 0 478 | }, 479 | "text/plain": [ 480 | "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" 481 | ] 482 | }, 483 | "metadata": {}, 484 | "output_type": "display_data" 485 | }, 486 | { 487 | "name": "stdout", 488 | "output_type": "stream", 489 | "text": [ 490 | "epoch trn_loss val_loss my_eval \n", 491 | " 0 1.372224 1.006186 0.755013 \n", 492 | " 1 1.05301 0.937049 0.776817 \n", 493 | " 2 0.92111 0.897983 0.779323 \n", 494 | " 3 0.855703 0.91295 0.77594 \n", 495 | " 4 0.807139 0.875474 0.789474 \n", 496 | " 5 0.757861 0.825887 0.796491 \n", 497 | " 6 0.715815 0.825675 0.802005 \n", 498 | " 7 0.668804 0.803978 0.805138 \n", 499 | " 8 0.633585 0.779532 0.806266 \n", 500 | " 9 0.598439 0.775525 0.812406 \n", 501 | " 10 0.595149 0.944765 0.779699 \n", 502 | " 11 0.708947 1.008999 0.752005 \n", 503 | " 12 0.73316 0.865063 0.790727 \n", 504 | " 13 0.717747 0.815308 0.809273 \n", 505 | " 14 0.681317 0.86026 0.795489 \n", 506 | " 15 0.622366 0.784242 0.808647 \n", 507 | " 16 0.590551 0.795878 0.811404 \n", 508 | " 17 0.547718 0.788322 0.814411 \n", 509 | " 18 0.510185 0.791338 0.812155 \n", 510 | " 19 0.481769 0.784741 0.820175 \n", 511 | " 20 0.482904 0.940439 0.796366 \n", 512 | " 21 0.541447 0.916361 0.790476 \n", 513 | " 22 0.592643 0.869368 0.799248 \n", 514 | " 23 0.604125 0.884436 0.795363 \n", 515 | " 24 0.569964 0.86229 0.805514 \n", 516 | " 25 0.531787 0.899627 0.805013 \n", 517 | " 26 0.50264 0.862456 0.806767 \n", 518 | " 27 0.470732 0.825488 0.806015 \n", 519 | " 28 0.438868 0.843719 0.824185 \n", 520 | " 29 0.406581 0.835985 0.816667 \n", 521 | "\n", 522 | "0.8241854636591479\n", 523 | "fold_id4\n", 524 | "5_fold_simple_resnet34_4\n" 525 | ] 526 | }, 527 | { 528 | "data": { 529 | "application/vnd.jupyter.widget-view+json": { 530 | "model_id": "0b56e1d8fde146138f4f3b16098fbb4a", 531 | "version_major": 2, 532 | "version_minor": 0 533 | }, 534 | "text/plain": [ 535 | "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" 536 | ] 537 | }, 538 | "metadata": {}, 539 | "output_type": "display_data" 540 | }, 541 | { 542 | "name": "stdout", 543 | "output_type": "stream", 544 | "text": [ 545 | "epoch trn_loss val_loss my_eval \n", 546 | " 0 1.361298 1.10476 0.728356 \n", 547 | " 1 1.036111 1.038042 0.748934 \n", 548 | " 2 0.891879 0.983283 0.768632 \n", 549 | " 3 0.834489 1.03786 0.754078 \n", 550 | " 4 0.785943 0.919298 0.777039 \n", 551 | " 5 0.729437 0.951818 0.775784 \n", 552 | " 6 0.687682 0.937895 0.784191 \n", 553 | " 7 0.658635 0.902859 0.78143 \n", 554 | " 8 0.601494 0.886 0.796738 \n", 555 | " 9 0.571334 0.881849 0.792974 \n", 556 | " 10 0.575541 0.916668 0.784065 \n", 557 | " 11 0.667784 1.034864 0.754078 \n", 558 | " 12 0.690661 1.014947 0.768507 \n", 559 | " 13 0.674954 0.988458 0.767378 \n", 560 | " 14 0.627315 0.892071 0.783438 \n", 561 | " 15 0.58735 0.920014 0.783689 \n", 562 | " 16 0.553768 0.931009 0.787202 \n", 563 | " 17 0.498617 0.918791 0.794103 \n", 564 | " 18 0.470062 0.88031 0.799373 \n", 565 | " 19 0.435687 0.879542 0.804141 \n", 566 | " 20 0.444049 0.994929 0.784191 \n", 567 | " 21 0.557501 1.043981 0.758093 \n", 568 | " 22 0.586135 1.058152 0.76675 \n", 569 | " 23 0.566919 0.973012 0.78005 \n", 570 | " 24 0.53908 0.920541 0.791217 \n", 571 | " 25 0.479652 0.948885 0.784316 \n", 572 | " 26 0.433641 1.033326 0.789335 \n", 573 | " 27 0.402444 0.959513 0.793601 \n", 574 | " 28 0.374302 0.965276 0.793099 \n", 575 | " 29 0.346213 0.96585 0.793852 \n", 576 | "\n", 577 | "0.8041405269761606\n" 578 | ] 579 | } 580 | ], 581 | "source": [ 582 | "model = 'simple_resnet34'\n", 583 | "bst_acc=[]\n", 584 | "use_clr_min=20\n", 585 | "use_clr_div=10\n", 586 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 587 | "\n", 588 | "szs = [(256,64)]\n", 589 | "for sz,bs in szs:\n", 590 | " print([sz,bs])\n", 591 | " for i in range(kf) :\n", 592 | " print(f'fold_id{i}')\n", 593 | " \n", 594 | " trn_x = np.array([f'trn_images/{o}' for o in train_folds[i]])\n", 595 | " trn_y = np.array([f'trn_masks/{o}' for o in train_folds[i]])\n", 596 | " val_x = [f'trn_images/{o}' for o in val_folds[i]]\n", 597 | " val_y = [f'trn_masks/{o}' for o in val_folds[i]]\n", 598 | " \n", 599 | " tfms = tfms_from_model(resnet34, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 600 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=tst_x,path=PATH)\n", 601 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 602 | " denorm = md.trn_ds.denorm\n", 603 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 604 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 605 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 606 | " learn = get_tgs_model() \n", 607 | " learn.opt_fn = optim.Adam\n", 608 | " learn.metrics=[my_eval]\n", 609 | " pa = f'{kf}_fold_{model}_{i}'\n", 610 | " print(pa)\n", 611 | "\n", 612 | " lr=1e-2\n", 613 | " wd=1e-7\n", 614 | " lrs = np.array([lr/100,lr/10,lr])\n", 615 | "\n", 616 | " learn.unfreeze() \n", 617 | " learn.crit = lovasz_hinge\n", 618 | " learn.load(pa)\n", 619 | " learn.fit(lrs/2,3, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 620 | "\n", 621 | "\n", 622 | " \n", 623 | " learn.load(pa) \n", 624 | " #Calcuating mean iou score\n", 625 | " v_targ = md.val_ds.ds[:][1]\n", 626 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 627 | " v_pred = learn.predict()\n", 628 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 629 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 630 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 631 | " print(bst_acc[-1])" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 13, 637 | "metadata": { 638 | "hidden": true 639 | }, 640 | "outputs": [ 641 | { 642 | "data": { 643 | "text/plain": [ 644 | "([0.8210526315789473,\n", 645 | " 0.813408521303258,\n", 646 | " 0.8223057644110277,\n", 647 | " 0.8241854636591479,\n", 648 | " 0.8041405269761606],\n", 649 | " 0.8170185815857083)" 650 | ] 651 | }, 652 | "execution_count": 13, 653 | "metadata": {}, 654 | "output_type": "execute_result" 655 | } 656 | ], 657 | "source": [ 658 | "bst_acc,np.mean(bst_acc)#With 256" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 11, 664 | "metadata": { 665 | "hidden": true, 666 | "scrolled": true 667 | }, 668 | "outputs": [ 669 | { 670 | "data": { 671 | "text/plain": [ 672 | "([0.818922305764411,\n", 673 | " 0.793734335839599,\n", 674 | " 0.8147869674185464,\n", 675 | " 0.8031328320802005,\n", 676 | " 0.7912170639899623],\n", 677 | " 0.8043587010185437)" 678 | ] 679 | }, 680 | "execution_count": 11, 681 | "metadata": {}, 682 | "output_type": "execute_result" 683 | } 684 | ], 685 | "source": [ 686 | "bst_acc,np.mean(bst_acc) #With 128" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "metadata": { 692 | "heading_collapsed": true 693 | }, 694 | "source": [ 695 | "## Submission - TTA" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": 14, 701 | "metadata": { 702 | "hidden": true 703 | }, 704 | "outputs": [ 705 | { 706 | "data": { 707 | "application/vnd.jupyter.widget-view+json": { 708 | "model_id": "45d11eaa2f124312897f36ec705e9fe6", 709 | "version_major": 2, 710 | "version_minor": 0 711 | }, 712 | "text/plain": [ 713 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 714 | ] 715 | }, 716 | "metadata": {}, 717 | "output_type": "display_data" 718 | }, 719 | { 720 | "name": "stdout", 721 | "output_type": "stream", 722 | "text": [ 723 | "5_fold_simple_resnet34_0\n", 724 | "5_fold_simple_resnet34_1\n", 725 | "5_fold_simple_resnet34_2\n", 726 | "5_fold_simple_resnet34_3\n", 727 | "5_fold_simple_resnet34_4\n", 728 | "\n" 729 | ] 730 | }, 731 | { 732 | "data": { 733 | "application/vnd.jupyter.widget-view+json": { 734 | "model_id": "a904b594ed7a4238bedc9f73164937ee", 735 | "version_major": 2, 736 | "version_minor": 0 737 | }, 738 | "text/plain": [ 739 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 740 | ] 741 | }, 742 | "metadata": {}, 743 | "output_type": "display_data" 744 | }, 745 | { 746 | "name": "stdout", 747 | "output_type": "stream", 748 | "text": [ 749 | "5_fold_simple_resnet34_0\n", 750 | "5_fold_simple_resnet34_1\n", 751 | "5_fold_simple_resnet34_2\n", 752 | "5_fold_simple_resnet34_3\n", 753 | "5_fold_simple_resnet34_4\n", 754 | "\n" 755 | ] 756 | } 757 | ], 758 | "source": [ 759 | "preds = np.zeros(shape = (18000,sz,sz))\n", 760 | "for o in [True,False]:\n", 761 | " md.test_dl.dataset = TestFilesDataset(tst_x,tst_x,tfms[1],flip=o,path=PATH)\n", 762 | " md.test_dl.dataset = DepthDataset(md.test_dl.dataset,dpth_dict)\n", 763 | " \n", 764 | " for i in tqdm_notebook(range(kf)):\n", 765 | " pa = f'{kf}_fold_{model}_{i}'\n", 766 | " print(pa)\n", 767 | " learn.load(pa)\n", 768 | " pred = learn.predict(is_test=True)\n", 769 | " pred = to_np(torch.sigmoid(torch.from_numpy(pred))) \n", 770 | " for im_idx,im in enumerate(pred):\n", 771 | " preds[im_idx] += np.fliplr(im) if o else im\n", 772 | " del pred" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": 15, 778 | "metadata": { 779 | "hidden": true 780 | }, 781 | "outputs": [ 782 | { 783 | "data": { 784 | "text/plain": [ 785 | "" 786 | ] 787 | }, 788 | "execution_count": 15, 789 | "metadata": {}, 790 | "output_type": "execute_result" 791 | }, 792 | { 793 | "data": { 794 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADY9JREFUeJzt3V2sXWWdx/HvbwoyEZlIeUspnQFNJxm8sJITJGFinJAR6E3xwglcaGNI6gUkmjgXVS/k0pmMmpDMkNRIrBMHhvgSesGMYmNi5kKgGCxvA1Rk5NiG+hYkY4JQ/3Nx1hk2ffbh7J6z19lrt99PsrPWfvaz9v53ted3nmetvVZTVUjSqD+ZdQGShsdgkNQwGCQ1DAZJDYNBUsNgkNToLRiS3JDkmSRHkuzt63MkTV/6+B5Dkk3As8DfAovAI8AtVfXU1D9M0tT1NWK4GjhSVc9X1R+Ae4FdPX2WpCk7q6f33Qq8OPJ8EXj/Sp0v3LypLt92dk+lnLmePfz2WZegAXmF3/6qqi6apG9fwZAxbW+asyTZA+wB+POtZ/Hwd7f1VMqZ6/pLd8y6BA3I9+ub/zNp376mEovA6E/6ZcDR0Q5Vta+qFqpq4aILNvVUxpnLUNB69BUMjwDbk1yR5G3AzcCBnj5L0pT1MpWoqteT3A58F9gE3F1VT/bxWZKmr69jDFTVA8ADfb2/VuY0QuvlNx8lNQyG04yjBU2DwSCpYTBIahgMpxGnEZoWg0FSw2CQ1DAYJDUMBkkNg0FSw2CQ1DAYJDUMBkkNg0FSw2CQ1DAYJDUMBkkNg0FSw2A4TXhlpabJYJDUMBhOA44WNG0Gw5wzFNQHg0FSw2CYY44W1BeDQVLDYJDUMBjmlNMI9clgkNQwGCQ1DAZJDYNBUsNgkNQwGCQ1DIY55KlK9e2s9Wyc5AXgFeAE8HpVLSTZDPw7cDnwAvB3VfXb9ZWpZYaCNsI0Rgx/U1U7qmqhe74XOFhV24GD3XNNgaGgjdLHVGIXsL9b3w/c1MNnSOrReoOhgO8leTTJnq7tkqo6BtAtLx63YZI9SQ4lOfTLX59YZxmSpmldxxiAa6vqaJKLgQeT/PekG1bVPmAfwMJ7/7TWWcdpz2mENtK6RgxVdbRbHge+A1wNvJRkC0C3PL7eIiVtrDUHQ5Jzk5y3vA58CHgCOADs7rrtBu5fb5GSNtZ6phKXAN9Jsvw+/1ZV/5nkEeC+JLcCPwc+sv4yJW2kNQdDVT0PvHdM+6+B69ZTlKTZ8puPc8ADj9poBoOkhsEgqWEwSGoYDJIaBoOkhsEgqWEwSGoYDJIaBsPA+eUmzYLBIKlhMAyYowXNisEgqWEwSGoYDAPlNEKzZDBIahgMA+RoQbNmMEhqGAySGgaDpIbBIKlhMAyMBx41BAaDpIbBMCCOFjQUBoOkhsEgqWEwDITTCA2JwSCpYTBIaqz5f7vWdDiF0BA5YpDUMBhmyNGChspgmBFDQUO2ajAkuTvJ8SRPjLRtTvJgkue65flde5LcmeRIksNJruqzeEn9mGTE8DXghpPa9gIHq2o7cLB7DnAjsL177AHumk6ZpxdHCxq6VYOhqn4I/Oak5l3A/m59P3DTSPvXa8mPgHcm2TKtYiVtjLUeY7ikqo4BdMuLu/atwIsj/Ra7NklzZNoHHzOmrcZ2TPYkOZTk0C9/fWLKZQyX0wjNg7UGw0vLU4RuebxrXwS2jfS7DDg67g2qal9VLVTVwkUXbFpjGfPFUNC8WGswHAB2d+u7gftH2j/WnZ24Bnh5ecohaX6s+pXoJPcAHwQuTLIIfB74AnBfkluBnwMf6bo/AOwEjgC/Bz7eQ81zx5GC5s2qwVBVt6zw0nVj+hZw23qLmtT1l+7gu0cf26iPWxNDQfPIbz5Kasz91ZXjfiMPZRThaEHzyhGDpMZcjhhW+028/PqsRg6OFDTvBhkM0/rBmoeDk9IQDW4qMe3ftv72lk7dYEYMff4Ab+TUwiDS6WAQI4ZnD7991iVIGjGIYNgofY9KHC3odDGYqcQ8Mgh0ujqjRgySJuOI4RQ5StCZwBGDpMYZN2JY66lLRwo6k5yxIwZ/0KWVnbHBAG+cYlwOCcNCWnLGTSVWYjhIbzijRwySxjMYJDUMBkkNg0FSw2CQ1DAYJDUMBkkNg0FSw2CQ1DAYJDUMBkkNg0FSw2CQ1DAYJDUMBkkNg0FSw2CQ1Fg1GJLcneR4kidG2u5I8oskj3WPnSOvfSbJkSTPJLm+r8Il9WeSEcPXgBvGtH+5qnZ0jwcAklwJ3Ay8p9vmX5JsmlaxkjbGqsFQVT8EfjPh++0C7q2qV6vqZ8AR4Op11CdpBtZzjOH2JIe7qcb5XdtW4MWRPotdWyPJniSHkhx6jVfXUYakaVtrMNwFvBvYARwDvti1Z0zfGvcGVbWvqhaqauFszlljGZL6sKZgqKqXqupEVf0R+ApvTBcWgW0jXS8Djq6vREkbbU3BkGTLyNMPA8tnLA4ANyc5J8kVwHbg4fWVKGmjrfofziS5B/ggcGGSReDzwAeT7GBpmvAC8AmAqnoyyX3AU8DrwG1VdaKf0iX1JVVjDwFsqD/L5np/rpt1GdJp7fv1zUeramGSvn7zUVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVLDYJDUMBgkNQwGSQ2DQVJj1WBIsi3JD5I8neTJJJ/s2jcneTDJc93y/K49Se5MciTJ4SRX9f2HkDRdk4wYXgc+XVV/BVwD3JbkSmAvcLCqtgMHu+cANwLbu8ce4K6pVy2pV6sGQ1Udq6ofd+uvAE8DW4FdwP6u237gpm59F/D1WvIj4J1Jtky9ckm9OaVjDEkuB94HPARcUlXHYCk8gIu7bluBF0c2W+zaJM2JiYMhyTuAbwGfqqrfvVXXMW015v32JDmU5NBrvDppGZI2wETBkORslkLhG1X17a75peUpQrc83rUvAttGNr8MOHrye1bVvqpaqKqFszlnrfVL6sEkZyUCfBV4uqq+NPLSAWB3t74buH+k/WPd2YlrgJeXpxyS5sNZE/S5Fvgo8HiSx7q2zwJfAO5Lcivwc+Aj3WsPADuBI8DvgY9PtWJJvVs1GKrqvxh/3ADgujH9C7htnXVJmiG/+SipYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIahgMkhoGg6SGwSCpYTBIaqwaDEm2JflBkqeTPJnkk137HUl+keSx7rFzZJvPJDmS5Jkk1/f5B5A0fWdN0Od14NNV9eMk5wGPJnmwe+3LVfVPo52TXAncDLwHuBT4fpK/rKoT0yxcUn9WHTFU1bGq+nG3/grwNLD1LTbZBdxbVa9W1c+AI8DV0yhW0sY4pWMMSS4H3gc81DXdnuRwkruTnN+1bQVeHNlskTFBkmRPkkNJDr3Gq6dcuKT+TBwMSd4BfAv4VFX9DrgLeDewAzgGfHG565jNq2mo2ldVC1W1cDbnnHLhkvozUTAkOZulUPhGVX0boKpeqqoTVfVH4Cu8MV1YBLaNbH4ZcHR6JUvq2yRnJQJ8FXi6qr400r5lpNuHgSe69QPAzUnOSXIFsB14eHolS+rbJGclrgU+Cjye5LGu7bPALUl2sDRNeAH4BEBVPZnkPuApls5o3OYZCWm+pKqZ/m98Eckvgf8FfjXrWiZwIfNRJ8xPrdY5feNq/YuqumiSjQcRDABJDlXVwqzrWM281AnzU6t1Tt96a/Ur0ZIaBoOkxpCCYd+sC5jQvNQJ81OrdU7fumodzDEGScMxpBGDpIGYeTAkuaG7PPtIkr2zrudkSV5I8nh3afmhrm1zkgeTPNctz1/tfXqo6+4kx5M8MdI2tq4subPbx4eTXDWAWgd32f5b3GJgUPt1Q26FUFUzewCbgJ8C7wLeBvwEuHKWNY2p8QXgwpPa/hHY263vBf5hBnV9ALgKeGK1uoCdwH+wdB3LNcBDA6j1DuDvx/S9svt3cA5wRffvY9MG1bkFuKpbPw94tqtnUPv1Leqc2j6d9YjhauBIVT1fVX8A7mXpsu2h2wXs79b3AzdtdAFV9UPgNyc1r1TXLuDrteRHwDtP+kp7r1aodSUzu2y/Vr7FwKD261vUuZJT3qezDoaJLtGesQK+l+TRJHu6tkuq6hgs/SUBF8+sujdbqa6h7uc1X7bft5NuMTDY/TrNWyGMmnUwTHSJ9oxdW1VXATcCtyX5wKwLWoMh7ud1XbbfpzG3GFix65i2Dat12rdCGDXrYBj8JdpVdbRbHge+w9IQ7KXlIWO3PD67Ct9kpboGt59roJftj7vFAAPcr33fCmHWwfAIsD3JFUnextK9Ig/MuKb/l+Tc7j6XJDkX+BBLl5cfAHZ33XYD98+mwsZKdR0APtYdRb8GeHl5aDwrQ7xsf6VbDDCw/bpSnVPdpxtxFHWVI6w7WTqq+lPgc7Ou56Ta3sXS0dyfAE8u1wdcABwEnuuWm2dQ2z0sDRdfY+k3wq0r1cXSUPKfu338OLAwgFr/tavlcPcPd8tI/891tT4D3LiBdf41S0Psw8Bj3WPn0PbrW9Q5tX3qNx8lNWY9lZA0QAaDpIbBIKlhMEhqGAySGgaDpIbBIKlhMEhq/B+1vZtDcyHTvQAAAABJRU5ErkJggg==\n", 795 | "text/plain": [ 796 | "
" 797 | ] 798 | }, 799 | "metadata": {}, 800 | "output_type": "display_data" 801 | } 802 | ], 803 | "source": [ 804 | "plt.imshow(((preds[16]/10)>0.5).astype(np.uint8))" 805 | ] 806 | }, 807 | { 808 | "cell_type": "code", 809 | "execution_count": 16, 810 | "metadata": { 811 | "hidden": true 812 | }, 813 | "outputs": [], 814 | "source": [ 815 | "p = [cv2.resize(o/10,dsize=(101,101)) for o in preds]\n", 816 | "p = [(o>0.5).astype(np.uint8) for o in p]" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": 17, 822 | "metadata": { 823 | "hidden": true 824 | }, 825 | "outputs": [ 826 | { 827 | "data": { 828 | "application/vnd.jupyter.widget-view+json": { 829 | "model_id": "c76d2f1f04484c3f92fb502dde60a920", 830 | "version_major": 2, 831 | "version_minor": 0 832 | }, 833 | "text/plain": [ 834 | "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" 835 | ] 836 | }, 837 | "metadata": {}, 838 | "output_type": "display_data" 839 | }, 840 | { 841 | "name": "stdout", 842 | "output_type": "stream", 843 | "text": [ 844 | "\n" 845 | ] 846 | } 847 | ], 848 | "source": [ 849 | "pred_dict = {id_[11:-4]:RLenc(p[i]) for i,id_ in tqdm_notebook(enumerate(tst_x))}\n", 850 | "sub = pd.DataFrame.from_dict(pred_dict,orient='index')\n", 851 | "sub.index.names = ['id']\n", 852 | "sub.columns = ['rle_mask']\n", 853 | "sub.to_csv('simple_k_fold_flipped.csv')" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "metadata": { 860 | "hidden": true 861 | }, 862 | "outputs": [], 863 | "source": [] 864 | } 865 | ], 866 | "metadata": { 867 | "kernelspec": { 868 | "display_name": "Python [default]", 869 | "language": "python", 870 | "name": "python3" 871 | }, 872 | "language_info": { 873 | "codemirror_mode": { 874 | "name": "ipython", 875 | "version": 3 876 | }, 877 | "file_extension": ".py", 878 | "mimetype": "text/x-python", 879 | "name": "python", 880 | "nbconvert_exporter": "python", 881 | "pygments_lexer": "ipython3", 882 | "version": "3.6.4" 883 | } 884 | }, 885 | "nbformat": 4, 886 | "nbformat_minor": 2 887 | } 888 | -------------------------------------------------------------------------------- /lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | try: 13 | from itertools import ifilterfalse 14 | except ImportError: # py3k 15 | from itertools import filterfalse 16 | 17 | 18 | def lovasz_grad(gt_sorted): 19 | """ 20 | Computes gradient of the Lovasz extension w.r.t sorted errors 21 | See Alg. 1 in paper 22 | """ 23 | p = len(gt_sorted) 24 | gts = gt_sorted.sum() 25 | intersection = gts - gt_sorted.float().cumsum(0) 26 | union = gts + (1 - gt_sorted).float().cumsum(0) 27 | jaccard = 1. - intersection / union 28 | if p > 1: # cover 1-pixel case 29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 30 | return jaccard 31 | 32 | 33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 34 | """ 35 | IoU for foreground class 36 | binary: 1 foreground, 0 background 37 | """ 38 | if not per_image: 39 | preds, labels = (preds,), (labels,) 40 | ious = [] 41 | for pred, label in zip(preds, labels): 42 | intersection = ((label == 1) & (pred == 1)).sum() 43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 44 | if not union: 45 | iou = EMPTY 46 | else: 47 | iou = float(intersection) / union 48 | ious.append(iou) 49 | iou = mean(ious) # mean accross images if per_image 50 | return 100 * iou 51 | 52 | 53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 54 | """ 55 | Array of IoU for each (non ignored) class 56 | """ 57 | if not per_image: 58 | preds, labels = (preds,), (labels,) 59 | ious = [] 60 | for pred, label in zip(preds, labels): 61 | iou = [] 62 | for i in range(C): 63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 64 | intersection = ((label == i) & (pred == i)).sum() 65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 66 | if not union: 67 | iou.append(EMPTY) 68 | else: 69 | iou.append(float(intersection) / union) 70 | ious.append(iou) 71 | ious = map(mean, zip(*ious)) # mean accross images if per_image 72 | return 100 * np.array(ious) 73 | 74 | 75 | # --------------------------- BINARY LOSSES --------------------------- 76 | 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.elu(errors_sorted) +1, Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) 157 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 158 | only_present: average only on classes present in ground truth 159 | per_image: compute the loss per image instead of per batch 160 | ignore: void class labels 161 | """ 162 | if per_image: 163 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present) 164 | for prob, lab in zip(probas, labels)) 165 | else: 166 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present) 167 | return loss 168 | 169 | 170 | def lovasz_softmax_flat(probas, labels, only_present=False): 171 | """ 172 | Multi-class Lovasz-Softmax loss 173 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 174 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 175 | only_present: average only on classes present in ground truth 176 | """ 177 | C = probas.size(1) 178 | losses = [] 179 | for c in range(C): 180 | fg = (labels == c).float() # foreground for class c 181 | if only_present and fg.sum() == 0: 182 | continue 183 | errors = (Variable(fg) - probas[:, c]).abs() 184 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 185 | perm = perm.data 186 | fg_sorted = fg[perm] 187 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 188 | return mean(losses) 189 | 190 | 191 | def flatten_probas(probas, labels, ignore=None): 192 | """ 193 | Flattens predictions in the batch 194 | """ 195 | B, C, H, W = probas.size() 196 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 197 | labels = labels.view(-1) 198 | if ignore is None: 199 | return probas, labels 200 | valid = (labels != ignore) 201 | vprobas = probas[valid.nonzero().squeeze()] 202 | vlabels = labels[valid] 203 | return vprobas, vlabels 204 | 205 | def xloss(logits, labels, ignore=None): 206 | """ 207 | Cross entropy loss 208 | """ 209 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 210 | 211 | 212 | # --------------------------- HELPER FUNCTIONS --------------------------- 213 | 214 | def mean(l, ignore_nan=False, empty=0): 215 | """ 216 | nanmean compatible with generators. 217 | """ 218 | l = iter(l) 219 | if ignore_nan: 220 | l = ifilterfalse(np.isnan, l) 221 | try: 222 | n = 1 223 | acc = next(l) 224 | except StopIteration: 225 | if empty == 'raise': 226 | raise ValueError('Empty mean') 227 | return empty 228 | for n, v in enumerate(l, 2): 229 | acc += v 230 | if n == 1: 231 | return acc 232 | return acc / n 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pycocotools import mask as cocomask 2 | from fastai.conv_learner import * 3 | 4 | def get_segmentations(labeled): 5 | nr_true = int(labeled.max()) 6 | segmentations = [] 7 | for i in range(1, nr_true + 1): 8 | msk = labeled == i 9 | segmentation = rle_from_binary(msk.astype('uint8')) 10 | segmentation['counts'] = segmentation['counts'].decode("UTF-8") 11 | segmentations.append(segmentation) 12 | return segmentations 13 | 14 | def compute_precision_at(ious, threshold): 15 | mx1 = np.max(ious, axis=0) 16 | mx2 = np.max(ious, axis=1) 17 | tp = np.sum(mx2 >= threshold) 18 | fp = np.sum(mx2 < threshold) 19 | fn = np.sum(mx1 < threshold) 20 | return float(tp) / (tp + fp + fn) 21 | 22 | def compute_ious(gt, predictions): 23 | gt_ = get_segmentations(gt) 24 | predictions_ = get_segmentations(predictions) 25 | 26 | if len(gt_) == 0 and len(predictions_) == 0: 27 | return np.ones((1, 1)) 28 | elif len(gt_) != 0 and len(predictions_) == 0: 29 | return np.zeros((1, 1)) 30 | else: 31 | iscrowd = [0 for _ in predictions_] 32 | ious = cocomask.iou(gt_, predictions_, iscrowd) 33 | if not np.array(ious).size: 34 | ious = np.zeros((1, 1)) 35 | return ious 36 | 37 | def compute_eval_metric(gt, predictions): 38 | thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 39 | ious = compute_ious(gt, predictions) 40 | precisions = [compute_precision_at(ious, th) for th in thresholds] 41 | return sum(precisions) / len(precisions) 42 | 43 | def intersection_over_union_thresholds(y_true, y_pred): 44 | iouts = [] 45 | for y_t, y_p in list(zip(y_true, y_pred)): 46 | iouts.append(compute_eval_metric(y_t, y_p)) 47 | return np.mean(iouts) 48 | 49 | def rle_from_binary(prediction): 50 | prediction = np.asfortranarray(prediction) 51 | return cocomask.encode(prediction) 52 | 53 | def intersection_over_union(y_true, y_pred): 54 | ious = [] 55 | for y_t, y_p in list(zip(y_true, y_pred)): 56 | iou = compute_ious(y_t, y_p) 57 | iou_mean = 1.0 * np.sum(iou) / len(iou) 58 | ious.append(iou_mean) 59 | return np.mean(ious) 60 | 61 | def my_eval(pred,targ): 62 | pred = to_np(torch.sigmoid(pred)) 63 | targ = to_np(targ) 64 | losses = [] 65 | for i in range(targ.shape[0]): 66 | losses.append(compute_eval_metric(targ[i],((pred[i]>0.5).astype(np.uint8)))) 67 | return np.mean(losses) 68 | 69 | def RLenc(img, order='F', format=True): 70 | """ 71 | img is binary mask image, shape (r,c) 72 | order is down-then-right, i.e. Fortran 73 | format determines if the order needs to be preformatted (according to submission rules) or not 74 | 75 | returns run length as an array or string (if format is True) 76 | """ 77 | bytes = img.reshape(img.shape[0] * img.shape[1], order=order) 78 | runs = [] ## list of run lengths 79 | r = 0 ## the current run length 80 | pos = 1 ## count starts from 1 per WK 81 | for c in bytes: 82 | if (c == 0): 83 | if r != 0: 84 | runs.append((pos, r)) 85 | pos += r 86 | r = 0 87 | pos += 1 88 | else: 89 | r += 1 90 | 91 | # if last run is unsaved (i.e. data ends with 1) 92 | if r != 0: 93 | runs.append((pos, r)) 94 | pos += r 95 | r = 0 96 | 97 | if format: 98 | z = '' 99 | 100 | for rr in runs: 101 | z += '{} {} '.format(rr[0], rr[1]) 102 | return z[:-1] 103 | else: 104 | return runs 105 | 106 | 107 | 108 | --------------------------------------------------------------------------------