├── Kaggle 21st place solution for TGS Salt Identification Challenge.ipynb ├── README.md ├── bam.py ├── lovasz_losses.py ├── model.py ├── networks.py └── utils.py /Kaggle 21st place solution for TGS Salt Identification Challenge.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "scrolled": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "%matplotlib inline\n", 19 | "%reload_ext autoreload\n", 20 | "%autoreload 2\n", 21 | "from fastai.conv_learner import *\n", 22 | "from fastai.dataset import *\n", 23 | "from networks import GCN,SEModule,Refine\n", 24 | "from fastai.models.senet import *\n", 25 | "from skimage.transform import resize\n", 26 | "import json\n", 27 | "from sklearn.model_selection import train_test_split, StratifiedKFold , KFold\n", 28 | "from sklearn.metrics import jaccard_similarity_score\n", 29 | "from networks import *\n", 30 | "from pycocotools import mask as cocomask\n", 31 | "from utils import *\n", 32 | "from lovasz_losses import lovasz_hinge\n", 33 | "from bam import *\n", 34 | "print(torch.__version__)\n", 35 | "torch.cuda.is_available()\n", 36 | "torch.backends.cudnn.benchmark=True" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## Paths" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "PATH = Path('data/tgs/')\n", 53 | "TRN_MASKS = 'trn_masks'\n", 54 | "TRN_IMG = 'trn_images'\n", 55 | "TRN_MSK = 'trn_masks'\n", 56 | "TST_IMG = 'tst_images'\n", 57 | "trn = pd.read_csv(PATH/'train.csv')\n", 58 | "dpth = pd.read_csv(PATH/'depths.csv')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def show_img(im, figsize=None, ax=None, alpha=None):\n", 68 | " if not ax: fig,ax = plt.subplots(figsize=figsize)\n", 69 | " ax.imshow(im, alpha=alpha)\n", 70 | " ax.set_axis_off()\n", 71 | " return ax" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## Datasets" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "class DepthDataset(Dataset):\n", 88 | " def __init__(self,ds,dpth_dict):\n", 89 | " self.dpth = dpth_dict\n", 90 | " self.ds = ds\n", 91 | " \n", 92 | " def __getitem__(self,i):\n", 93 | " val = self.ds[i]\n", 94 | " return val[0],self.dpth[self.ds.fnames[i].split('/')[1][:-4]],val[1]\n", 95 | " \n", 96 | "class MatchedFilesDataset(FilesDataset):\n", 97 | " def __init__(self, fnames, y, transform, path):\n", 98 | " self.y=y\n", 99 | " assert(len(fnames)==len(y))\n", 100 | " super().__init__(fnames, transform, path)\n", 101 | " \n", 102 | " def get_x(self, i): \n", 103 | " return open_image(os.path.join(self.path, self.fnames[i]))\n", 104 | " \n", 105 | " def get_y(self, i):\n", 106 | " return open_image(os.path.join(str(self.path), str(self.y[i])))\n", 107 | "\n", 108 | " def get_c(self): return 0\n", 109 | " \n", 110 | "class TestFilesDataset(FilesDataset):\n", 111 | " def __init__(self, fnames, y, transform,flip, path):\n", 112 | " self.y=y\n", 113 | " self.flip = flip\n", 114 | " super().__init__(fnames, transform, path)\n", 115 | " \n", 116 | " def get_x(self, i): \n", 117 | " im = open_image(os.path.join(self.path, self.fnames[i]))\n", 118 | " return np.fliplr(im) if self.flip else im\n", 119 | " \n", 120 | " def get_y(self, i):\n", 121 | " im = open_image(os.path.join(str(self.path), str(self.y[i])))\n", 122 | " return np.fliplr(im) if self.flip else im\n", 123 | " def get_c(self): return 0" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "## Creating K-Fold" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "x_names = np.array([f'{TRN_IMG}/{o.name}' for o in (PATH/TRN_MASKS).iterdir()])\n", 140 | "y_names = np.array([f'{TRN_MASKS}/{o.name}' for o in (PATH/TRN_MASKS).iterdir()])\n", 141 | "tst_x = np.array([f'{TST_IMG}/{o.name}' for o in (PATH/TST_IMG).iterdir()])\n", 142 | "f_name = [o.split('/')[-1] for o in x_names]\n", 143 | "\n", 144 | "c = dpth.set_index('id')\n", 145 | "dpth_dict = c['z'].to_dict()\n", 146 | "\n", 147 | "kf = 10\n", 148 | "kfold = KFold(n_splits=kf, shuffle=True, random_state=42)\n", 149 | "\n", 150 | "train_folds = []\n", 151 | "val_folds = []\n", 152 | "for idxs in kfold.split(f_name):\n", 153 | " train_folds.append([f_name[idx] for idx in idxs[0]])\n", 154 | " val_folds.append([f_name[idx] for idx in idxs[1]])\n", 155 | " " 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "## Unet Model" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 7, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "class UnetWithAttention(nn.Module):\n", 172 | " def __init__(self):\n", 173 | " super().__init__()\n", 174 | " self.rn = ResNetWithBAM()\n", 175 | " fs = 16\n", 176 | " self.up1 = UnetBlock(512,256,fs)\n", 177 | " self.up2 = UnetBlock(fs,128,fs)\n", 178 | " self.up3 = UnetBlock(fs,64,fs)\n", 179 | " self.up4 = nn.ConvTranspose2d(fs, fs, 2, stride=2)\n", 180 | "\n", 181 | " self.img_class = nn.Sequential(nn.AdaptiveAvgPool2d(1),\n", 182 | " Flatten(),\n", 183 | " nn.Dropout(0.3),\n", 184 | " nn.Linear(512,256),nn.ReLU(inplace=True),nn.BatchNorm1d(256),\n", 185 | " nn.Dropout(0.3),\n", 186 | " nn.Linear(256,1),nn.Sigmoid()\n", 187 | " )\n", 188 | " \n", 189 | " self.logit = nn.Sequential(nn.Conv2d(69,69,kernel_size=3,padding=1),nn.ReLU(inplace=True),\n", 190 | " nn.Conv2d(69,1,kernel_size=1,padding=0))\n", 191 | " \n", 192 | " self.ds1,self.ds2,self.ds3,self.ds4,self.ds5 = [conv_block(fs,1) for _ in range(5)]\n", 193 | " \n", 194 | " def forward(self,img,depth):\n", 195 | " e0,e1,e2,e3,e4 = self.rn(img)\n", 196 | " img_sz = img.size(2) \n", 197 | " d1 = self.up1(e4, e3) \n", 198 | " d2 = self.up2(d1, e2) \n", 199 | " d3 = self.up3(d2, e1) \n", 200 | " d4 = self.up4(d3) \n", 201 | "\n", 202 | " #Creating hyper column features\n", 203 | " hyp_column = torch.cat([create_interpolate(o,img_sz) for o in [d1,d2,d3,d4]],1)\n", 204 | " \n", 205 | " #Creating features for deep supervision\n", 206 | " ds1,ds2,ds3,ds4 = self.ds1(d1),self.ds2(d2),self.ds3(d3),self.ds4(d4)\n", 207 | " \n", 208 | " ds = torch.cat([create_interpolate(o,img_sz) for o in [ds1,ds2,ds3,ds4]],1)\n", 209 | " \n", 210 | " #Image classifier\n", 211 | " img_class = self.img_class(e4)\n", 212 | " \n", 213 | " img_class_up = create_interpolate(img_class.view(img_class.size(0),-1,1,1),img_sz,'nearest',None)\n", 214 | " \n", 215 | " #Fuse Deep supervision features\n", 216 | " ds = torch.cat([hyp_column,ds,img_class_up],1)\n", 217 | " \n", 218 | " x = self.logit(ds)\n", 219 | " \n", 220 | " return x[:,0],(img_class,*[o[:,0] for o in [ds1,ds2,ds3,ds4]])\n", 221 | "\n", 222 | "\n", 223 | "class UnetModel():\n", 224 | " def __init__(self,model,lr_cut,name='unet'):\n", 225 | " self.model,self.name = model,name\n", 226 | " self.lr_cut = lr_cut\n", 227 | "\n", 228 | " def get_layer_groups(self, precompute):\n", 229 | " lgs = list(split_by_idxs(children(self.model.rn), [2]))\n", 230 | " return lgs + [children(self.model)[1:]]\n", 231 | " \n", 232 | " " 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "def get_tgs_model():\n", 242 | " f = resnet34\n", 243 | " cut,lr_cut = model_meta[f]\n", 244 | " m = to_gpu(UnetWithAttention())\n", 245 | " models = UnetModel(m,lr_cut)\n", 246 | " learn = ConvLearner(md, models)\n", 247 | " return learn\n", 248 | "\n", 249 | "def get_base(f,cut):\n", 250 | " layers = cut_model(f(True), cut)\n", 251 | " return nn.Sequential(*layers) \n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "## Loss function " 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 9, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "def change_tensor_size(targ,sz):\n", 268 | " if targ.size(1) == sz:\n", 269 | " return targ\n", 270 | " targ_np = np.array([cv2.resize(o,dsize=(sz,sz)) for o in to_np(targ)])\n", 271 | " return torch.tensor(targ_np,dtype=torch.float32,device=torch.device(\"cuda\"))\n", 272 | "\n", 273 | "def multi_lovasz_loss(logits,target):\n", 274 | " logit,cl_logit,ds1,ds2,ds3,ds4 = (logits[0],*logits[1][0]) if isinstance(logits[1],list) else (logits[0],*logits[1])\n", 275 | "\n", 276 | " cl_targets = (Flatten()(target).sum(1) != 0).type(torch.cuda.FloatTensor).view(cl_logit.size(0),-1)\n", 277 | " non_empty_imgs = cl_targets.view(cl_logit.size(0),1,1)\n", 278 | " cl = F.binary_cross_entropy(cl_logit,cl_targets)\n", 279 | " rf_loss = lovasz_hinge(logit,target)\n", 280 | " \n", 281 | " #Handling deep supervised features\n", 282 | " for o in [ds1,ds2,ds3,ds4]:\n", 283 | " targ_rs = change_tensor_size(target,o.size(1))\n", 284 | " o = o * non_empty_imgs\n", 285 | " rf_loss += lovasz_hinge(o,targ_rs)\n", 286 | " \n", 287 | " return 0.05*cl+ rf_loss\n", 288 | "\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "## Training loop" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "scrolled": true 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "model = 'submission_model'\n", 307 | "bst_acc=[]\n", 308 | "use_clr_min=20\n", 309 | "use_clr_div=10\n", 310 | "aug_tfms = [\n", 311 | " RandomRotate(4, tfm_y=TfmType.CLASS),\n", 312 | " RandomFlip(tfm_y=TfmType.CLASS),\n", 313 | " RandomLighting(0.1, 0, tfm_y=TfmType.CLASS),\n", 314 | " RandomBlur([3,5,7]),\n", 315 | " RandomZoom(0.1,tfm_y=TfmType.CLASS)\n", 316 | " ]\n", 317 | "\n", 318 | "szs = [(224,16)]\n", 319 | "for sz,bs in szs:\n", 320 | " print([sz,bs])\n", 321 | " for i in range(kf) :\n", 322 | " print(f'fold_id{i}')\n", 323 | " \n", 324 | " trn_x = np.array([f'trn_images/{o}' for o in train_folds[i]])\n", 325 | " trn_y = np.array([f'trn_masks/{o}' for o in train_folds[i]])\n", 326 | " val_x = [f'trn_images/{o}' for o in val_folds[i]]\n", 327 | " val_y = [f'trn_masks/{o}' for o in val_folds[i]]\n", 328 | " \n", 329 | " tfms = tfms_from_model(resnet34, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 330 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=tst_x,path=PATH)\n", 331 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 332 | " denorm = md.trn_ds.denorm\n", 333 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 334 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 335 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 336 | " learn = get_tgs_model() \n", 337 | " learn.metrics=[my_eval]\n", 338 | " pa = f'{kf}_fold_{model}_{i}'\n", 339 | " print(pa)\n", 340 | " learn.unfreeze() \n", 341 | " learn.crit = multi_lovasz_loss \n", 342 | " \n", 343 | " learn.fit(1e-2,n_cycle=1,wds=0.0001,cycle_len=100,use_clr=(10,8),best_save_name=pa)\n", 344 | " \n", 345 | " learn.load(pa)\n", 346 | " #Calcuating mean iou score\n", 347 | " v_targ = md.val_ds.ds[:][1]\n", 348 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 349 | " v_pred = learn.predict()\n", 350 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 351 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 352 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 353 | " print(bst_acc[-1])" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": {}, 359 | "source": [ 360 | "## Submission - TTA" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 14, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "application/vnd.jupyter.widget-view+json": { 371 | "model_id": "e2bba17f719a422ba82f0b91235f6d13", 372 | "version_major": 2, 373 | "version_minor": 0 374 | }, 375 | "text/plain": [ 376 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 377 | ] 378 | }, 379 | "metadata": {}, 380 | "output_type": "display_data" 381 | }, 382 | { 383 | "name": "stdout", 384 | "output_type": "stream", 385 | "text": [ 386 | "10_fold_resnet_ds_0ft\n", 387 | "10_fold_resnet_ds_1ft\n", 388 | "10_fold_resnet_ds_2ft\n", 389 | "10_fold_resnet_ds_3ft\n", 390 | "10_fold_resnet_ds_4ft\n", 391 | "10_fold_resnet_ds_5ft\n", 392 | "10_fold_resnet_ds_6ft\n", 393 | "10_fold_resnet_ds_7ft\n", 394 | "10_fold_resnet_ds_8ft\n", 395 | "10_fold_resnet_ds_9ft\n", 396 | "\n" 397 | ] 398 | }, 399 | { 400 | "data": { 401 | "application/vnd.jupyter.widget-view+json": { 402 | "model_id": "c6f4a17057bc4b47911111b2c8c3b941", 403 | "version_major": 2, 404 | "version_minor": 0 405 | }, 406 | "text/plain": [ 407 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 408 | ] 409 | }, 410 | "metadata": {}, 411 | "output_type": "display_data" 412 | }, 413 | { 414 | "name": "stdout", 415 | "output_type": "stream", 416 | "text": [ 417 | "10_fold_resnet_ds_0ft\n", 418 | "10_fold_resnet_ds_1ft\n", 419 | "10_fold_resnet_ds_2ft\n", 420 | "10_fold_resnet_ds_3ft\n", 421 | "10_fold_resnet_ds_4ft\n", 422 | "10_fold_resnet_ds_5ft\n", 423 | "10_fold_resnet_ds_6ft\n", 424 | "10_fold_resnet_ds_7ft\n", 425 | "10_fold_resnet_ds_8ft\n", 426 | "10_fold_resnet_ds_9ft\n", 427 | "\n" 428 | ] 429 | } 430 | ], 431 | "source": [ 432 | "preds = np.zeros(shape = (18000,sz,sz))\n", 433 | "for o in [True,False]:\n", 434 | " md.test_dl.dataset = TestFilesDataset(tst_x,tst_x,tfms[1],flip=o,path=PATH)\n", 435 | " md.test_dl.dataset = DepthDataset(md.test_dl.dataset,dpth_dict)\n", 436 | " \n", 437 | " for i in tqdm_notebook(range(kf)):\n", 438 | " pa = f'{kf}_fold_{model}_{i}'\n", 439 | " print(pa)\n", 440 | " learn.load(pa)\n", 441 | " pred = learn.predict(is_test=True)\n", 442 | " pred = to_np(torch.sigmoid(torch.from_numpy(pred))) \n", 443 | " for im_idx,im in enumerate(pred):\n", 444 | " preds[im_idx] += np.fliplr(im) if o else im\n", 445 | " del pred" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 15, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/plain": [ 456 | "" 457 | ] 458 | }, 459 | "execution_count": 15, 460 | "metadata": {}, 461 | "output_type": "execute_result" 462 | }, 463 | { 464 | "data": { 465 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEC1JREFUeJzt3X2spGV5x/HvrwtifAssb8FlKS9ZTbGpK26QhEpsqfKSxoUmWkijW0u6mkCiiU0KmrSkf1krmpi2mCUQoaEgFRH+wCoSozEpyIvrAq7AgijLbnYVDZBiUJarf8xzytxnz+EczsycmTn7/SQn88w9z8xck9n5cd/PDM+VqkKSZvzeuAuQNFkMBUkNQ0FSw1CQ1DAUJDUMBUmNkYVCkrOTPJxkR5JLR/U8koYro/idQpJVwCPAe4GdwD3AhVX146E/maShGtVM4VRgR1U9XlW/BW4ENo7ouSQN0UEjetw1wJN913cC75pv5yNWr6rj1x48olIObI9se924S9CEeI5f/7Kqjlxov1GFQuYYa9YpSTYDmwGOW3MQP/jm2hGVcmA7683rx12CJsS366s/W8x+o1o+7AT6P+XHArv6d6iqLVW1oao2HHn4qhGVcWAzELQUowqFe4B1SU5I8hrgAuC2ET2XpCEayfKhql5McgnwTWAVcE1VPTSK55I0XKM6pkBV3Q7cPqrHlzQa/qJxhfJ4gpbKUJDUMBQkNQyFFcilgwZhKEhqGAorjLMEDcpQkNQwFCQ1DIUVxKWDhsFQkNQwFCQ1DIUVwqWDhsVQkNQwFFYAZwkaJkNBUmPJoZBkbZLvJNme5KEkH+/GL0/yVJKt3d+5wytX0qgNcpKVF4FPVtX9Sd4I3Jfkju62L1TV5wYvTwtx6aBhW3IoVNVuYHe3/VyS7fRO7S5pig3lmEKS44F3AHd3Q5ck2ZbkmiSHDeM5tD9nCRqFgUMhyRuAm4FPVNWzwJXAScB6ejOJK+a53+Yk9ya59xdP7xu0DElDMlAoJDmYXiBcX1VfA6iqPVW1r6peAq6i10JuP/Z9kCbTIN8+BLga2F5Vn+8bP6Zvt/OBB5denqTlNsi3D6cDHwIeSLK1G/sUcGGS9fTaxD0BfHSgCjUnjydoVAb59uH7zN0z0l4P0hTzF41TyFmCRslQkNQwFCQ1DIUp49JBo2YoSGoYClPEWYKWg6EgqWEoSGoYCpIahsKU8HiClouhIKlhKEhqGApTwKWDlpOhIKlhKEhqGAoTzqWDltsgZ14CIMkTwHPAPuDFqtqQZDXwFeB4emdf+mBV/XrQ55I0esOaKfxJVa2vqg3d9UuBO6tqHXBnd13SFBjV8mEjcG23fS1w3oieR9KQDSMUCvhWkvuSbO7Gju46SM10kjpq9p3s+7AwjydoHAY+pgCcXlW7khwF3JHkJ4u5U1VtAbYAbHj7a2sIdUgagoFnClW1q7vcC9xCr/nLnpn+D93l3kGfR9LyGLRD1Ou7jtMkeT3wPnrNX24DNnW7bQJuHeR5JC2fQZcPRwO39JpFcRDwn1X130nuAW5KchHwc+ADAz7PAcfjCRqXgUKhqh4H3j7H+NPAmYM8tqTx8BeNkhqGwgRy6aBxMhQkNQwFSQ1DQVLDUJgwHk/QuBkKkhqGgqSGoSCpYShMEI8naBIYCpIahoKkhqEgqWEoTAiPJ2hSGAqSGks+n0KSt9Lr7TDjROAfgEOBvwV+0Y1/qqpuX3KFkpbVkkOhqh4G1gMkWQU8Re8cjR8BvlBVnxtKhZKW1bCWD2cCj1XVz4b0eAcUjydokgwrFC4Abui7fkmSbUmuSXLYkJ5D0jIYOBSSvAZ4P/Bf3dCVwEn0lha7gSvmuZ/NYKQJNIyZwjnA/VW1B6Cq9lTVvqp6CbiKXh+I/VTVlqraUFUbjjx81RDKmE4uHTRphhEKF9K3dJhpAtM5n14fCElTYqBTvCd5HfBe4KN9w59Nsp5ej8knZt2mPs4SNIkG7fvwPHD4rLEPDVSRpLHyF41j4ixBk8pQkNQwFCQ1DIUxcOmgSWYoSGoYCsvMWYImnaGwjAwETQNDQVLDUFgmzhI0LQyFZWAgaJqsqFA4683rJ+4DOGn1SAtZEaEwOwz8IEpLN9D/EDVur/Thn7ntm7u2Llc5cz6/NG2mMhQm/QM36fVJr2Tqlg+T/oGb9PqkhSwqFLoTsO5N8mDf2OokdyR5tLs8rBtPki8m2dGdvPWUYRS61IOIy/khNRC0Eix2pvBl4OxZY5cCd1bVOuDO7jr0ztm4rvvbTO9ErgOZhg/bNNQoLcaiQqGqvgf8atbwRuDabvta4Ly+8euq5y7g0FnnbVy0YX3FOOoPrIGglWSQYwpHV9VugO7yqG58DfBk3347u7FXZVo+aNNSp7RYo/j2IXOM1X47JZvpLS84bs3LZfghk8ZrkFDYk+SYqtrdLQ/2duM7gbV9+x0L7Jp956raAmwBeFNW13JM8Yf9mwUDTCvRIMuH24BN3fYm4Na+8Q9330KcBjwzs8xYSQwErVSLmikkuQF4D3BEkp3APwKfAW5KchHwc+AD3e63A+cCO4Dn6XWhXlEMBK1kiwqFqrpwnpvOnGPfAi4epKhRGcYSwkDQSjd1v2gcJwNBB4IDLhSW+sE2EHSgOOBCYSkMBB1IDIUFGAg60ByQobDYD7qBoAPRARkKsPAJWgwEHaim8iQrw+IHX9rfATtTkDQ3Q0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUmNBUNhnkYw/5LkJ12zl1uSHNqNH5/kN0m2dn9fGmXxkoZvMTOFL7N/I5g7gD+sqj8CHgEu67vtsapa3/19bDhlSlouC4bCXI1gqupbVfVid/UuemdslrQCDOOYwt8A3+i7fkKSHyb5bpJ3z3enJJuT3Jvk3t/xwhDKkDQMA/1fkkk+DbwIXN8N7QaOq6qnk7wT+HqSt1XVs7PvO7vvwyB1SBqeJc8UkmwC/hz4q+4MzlTVC1X1dLd9H/AY8JZhFCppeSwpFJKcDfw98P6qer5v/Mgkq7rtE+l1nn58GIVKWh4LLh/maQRzGXAIcEcSgLu6bxrOAP4pyYvAPuBjVTW7W7WkCbZgKMzTCObqefa9Gbh50KIkjY+/aJTUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSY2l9n24PMlTff0dzu277bIkO5I8nOSsURUuaTSW2vcB4At9/R1uB0hyMnAB8LbuPv8+c3o2SdNhSX0fXsFG4MbuBK4/BXYApw5Qn6RlNsgxhUu6tnHXJDmsG1sDPNm3z85ubD/2fZAm01JD4UrgJGA9vV4PV3TjmWPfOXs6VNWWqtpQVRsO5pAlliFp2JYUClW1p6r2VdVLwFW8vETYCazt2/VYYNdgJUpaTkvt+3BM39XzgZlvJm4DLkhySJIT6PV9+MFgJUpaTkvt+/CeJOvpLQ2eAD4KUFUPJbkJ+DG9dnIXV9W+0ZQuaRTSdXwbqzdldb0rZ467DGlF+3Z99b6q2rDQfv6iUVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNZba9+ErfT0fnkiytRs/Pslv+m770iiLlzR8C555iV7fh38FrpsZqKq/nNlOcgXwTN/+j1XV+mEVKGl5LRgKVfW9JMfPdVuSAB8E/nS4ZUkal0GPKbwb2FNVj/aNnZDkh0m+m+TdAz6+pGW2mOXDK7kQuKHv+m7guKp6Osk7ga8neVtVPTv7jkk2A5sBXsvrBixD0rAseaaQ5CDgL4CvzIx17eKe7rbvAx4D3jLX/W0GI02mQZYPfwb8pKp2zgwkOXKmoWySE+n1fXh8sBIlLafFfCV5A/A/wFuT7ExyUXfTBbRLB4AzgG1JfgR8FfhYVS22Oa2kCbCYbx8unGf8r+cYuxm4efCyJI2Lv2iU1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUmNxZxkZW2S7yTZnuShJB/vxlcnuSPJo93lYd14knwxyY4k25KcMuoXIWl4FjNTeBH4ZFX9AXAacHGSk4FLgTurah1wZ3cd4Bx6p2FbR+/ErFcOvWpJI7NgKFTV7qq6v9t+DtgOrAE2Atd2u10LnNdtbwSuq567gEOTHDP0yiWNxKs6ptA1hXkHcDdwdFXthl5wAEd1u60Bnuy7285uTNIUWHQoJHkDvfMvfmKuPg79u84xVnM83uYk9ya593e8sNgyJI3YokIhycH0AuH6qvpaN7xnZlnQXe7txncCa/vufiywa/Zj2vdBmkyL+fYhwNXA9qr6fN9NtwGbuu1NwK194x/uvoU4DXhmZpkhafItpm3c6cCHgAdmWs4DnwI+A9zU9YH4OfCB7rbbgXOBHcDzwEeGWrGkkVpM34fvM/dxAoAz59i/gIsHrEvSmPiLRkkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSI72zp425iOQXwP8Cvxx3LQM4gumuH6b/NUx7/TDa1/D7VXXkQjtNRCgAJLm3qjaMu46lmvb6Yfpfw7TXD5PxGlw+SGoYCpIakxQKW8ZdwICmvX6Y/tcw7fXDBLyGiTmmIGkyTNJMQdIEGHsoJDk7ycNJdiS5dNz1LFaSJ5I8kGRrknu7sdVJ7kjyaHd52Ljr7JfkmiR7kzzYNzZnzV0v0C9278u2JKeMr/L/r3Wu+i9P8lT3PmxNcm7fbZd19T+c5KzxVP2yJGuTfCfJ9iQPJfl4Nz5Z70FVje0PWAU8BpwIvAb4EXDyOGt6FbU/ARwxa+yzwKXd9qXAP4+7zln1nQGcAjy4UM30+oF+g17LwNOAuye0/suBv5tj35O7f0+HACd0/85Wjbn+Y4BTuu03Ao90dU7UezDumcKpwI6qeryqfgvcCGwcc02D2Ahc221fC5w3xlr2U1XfA341a3i+mjcC11XPXcChSY5ZnkrnNk/989kI3FhVL1TVT+k1PD51ZMUtQlXtrqr7u+3ngO3AGibsPRh3KKwBnuy7vrMbmwYFfCvJfUk2d2NHV9Vu6P0DAI4aW3WLN1/N0/TeXNJNr6/pW7JNdP1JjgfeAdzNhL0H4w6FubpZT8vXIadX1SnAOcDFSc4Yd0FDNi3vzZXAScB6YDdwRTc+sfUneQNwM/CJqnr2lXadY2zkr2HcobATWNt3/Vhg15hqeVWqald3uRe4hd7UdM/M9K673Du+Chdtvpqn4r2pqj1Vta+qXgKu4uUlwkTWn+RgeoFwfVV9rRueqPdg3KFwD7AuyQlJXgNcANw25poWlOT1Sd44sw28D3iQXu2but02AbeOp8JXZb6abwM+3B0BPw14ZmaKO0lmrbHPp/c+QK/+C5IckuQEYB3wg+Wur1+SAFcD26vq8303TdZ7MM6jsX1HWB+hd3T40+OuZ5E1n0jvyPaPgIdm6gYOB+4EHu0uV4+71ll130Bviv07ev8Vumi+mulNXf+te18eADZMaP3/0dW3jd6H6Ji+/T/d1f8wcM4E1P/H9Kb/24Ct3d+5k/Ye+ItGSY1xLx8kTRhDQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNf4P0M2A4iyQALwAAAAASUVORK5CYII=\n", 466 | "text/plain": [ 467 | "
" 468 | ] 469 | }, 470 | "metadata": {}, 471 | "output_type": "display_data" 472 | } 473 | ], 474 | "source": [ 475 | "#Quick look on if the submission masks looks good.\n", 476 | "plt.imshow(((preds[16]/kf*2)>0.5).astype(np.uint8))" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 16, 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "p = [cv2.resize(o/kf*2,dsize=(101,101)) for o in preds]\n", 486 | "p = [(o>0.5).astype(np.uint8) for o in p]" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 17, 492 | "metadata": {}, 493 | "outputs": [ 494 | { 495 | "data": { 496 | "application/vnd.jupyter.widget-view+json": { 497 | "model_id": "02ab4f5f0dd945eca5ff83514beeec81", 498 | "version_major": 2, 499 | "version_minor": 0 500 | }, 501 | "text/plain": [ 502 | "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" 503 | ] 504 | }, 505 | "metadata": {}, 506 | "output_type": "display_data" 507 | }, 508 | { 509 | "name": "stdout", 510 | "output_type": "stream", 511 | "text": [ 512 | "\n" 513 | ] 514 | } 515 | ], 516 | "source": [ 517 | "pred_dict = {id_[11:-4]:RLenc(p[i]) for i,id_ in tqdm_notebook(enumerate(tst_x))}\n", 518 | "sub = pd.DataFrame.from_dict(pred_dict,orient='index')\n", 519 | "sub.index.names = ['id']\n", 520 | "sub.columns = ['rle_mask']\n", 521 | "sub.to_csv('simple_k_fold_flipped.csv')" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [] 530 | } 531 | ], 532 | "metadata": { 533 | "kernelspec": { 534 | "display_name": "Python [default]", 535 | "language": "python", 536 | "name": "python3" 537 | }, 538 | "language_info": { 539 | "codemirror_mode": { 540 | "name": "ipython", 541 | "version": 3 542 | }, 543 | "file_extension": ".py", 544 | "mimetype": "text/x-python", 545 | "name": "python", 546 | "nbconvert_exporter": "python", 547 | "pygments_lexer": "ipython3", 548 | "version": "3.6.4" 549 | } 550 | }, 551 | "nbformat": 4, 552 | "nbformat_minor": 2 553 | } 554 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TGS-SaltNet 2 | Kaggle | 21st place solution for TGS Salt Identification Challenge 3 | 4 | ## General 5 | 6 | I recently participated in a Kaggle competition [TGS Salt Identification Challenge](https://www.kaggle.com/c/tgs-salt-identification-challenge) 7 | and reached the 21st place. This repository contains the final code which resulted in the best model. The code demonstrates usage of different important techniques using [fast.ai](http://www.fast.ai/) and [PyTorch](https://pytorch.org/). 8 | 1. Use ResNet model as an encoder for UNet. 9 | 2. Add intermediate layers like [BAM](http://bmvc2018.org/contents/papers/0092.pdf),[Squeeze & Excitation](https://arxiv.org/abs/1803.02579) blocks in a ResNet34 model which can be easily replicated for other network architectures. 10 | 3. Show how to add [Deep supervision](https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65933) to the network, and calculate loss and combine loss at different scale. 11 | 12 | ## Main software used 13 | 14 | 1. fastai - 0.7 15 | 2. pytorch - 0.4 16 | 3. python - 3.6 17 | 18 | ## Hardware required 19 | 20 | The code was tested with TitanX GPU/1080ti. 21 | 22 | ## Thanks 23 | 24 | A special thanks to Heng for his generous contributions to different ideas in the competition, for a long list of amazing Kaglle community members, Jeremy and Fast.ai community for the amazing and flexible fastai framework. 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /bam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, x): 8 | return x.view(x.size(0), -1) 9 | class ChannelGate(nn.Module): 10 | def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): 11 | super(ChannelGate, self).__init__() 12 | #self.gate_activation = gate_activation 13 | self.gate_c = nn.Sequential() 14 | self.gate_c.add_module( 'flatten', Flatten() ) 15 | gate_channels = [gate_channel] 16 | gate_channels += [gate_channel // reduction_ratio] * num_layers 17 | gate_channels += [gate_channel] 18 | for i in range( len(gate_channels) - 2 ): 19 | self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) ) 20 | self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) ) 21 | self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) 22 | self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) 23 | def forward(self, in_tensor): 24 | avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) 25 | return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) 26 | 27 | class SpatialGate(nn.Module): 28 | def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): 29 | super(SpatialGate, self).__init__() 30 | self.gate_s = nn.Sequential() 31 | self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1)) 32 | self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) ) 33 | self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() ) 34 | for i in range( dilation_conv_num ): 35 | self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \ 36 | padding=dilation_val, dilation=dilation_val) ) 37 | self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) ) 38 | self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() ) 39 | self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) ) 40 | def forward(self, in_tensor): 41 | return self.gate_s( in_tensor ).expand_as(in_tensor) 42 | class BAM(nn.Module): 43 | def __init__(self, gate_channel): 44 | super(BAM, self).__init__() 45 | self.channel_att = ChannelGate(gate_channel) 46 | self.spatial_att = SpatialGate(gate_channel) 47 | def forward(self,in_tensor): 48 | att = 1 + torch.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) ) 49 | return att * in_tensor 50 | -------------------------------------------------------------------------------- /lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | try: 13 | from itertools import ifilterfalse 14 | except ImportError: # py3k 15 | from itertools import filterfalse 16 | 17 | 18 | def lovasz_grad(gt_sorted): 19 | """ 20 | Computes gradient of the Lovasz extension w.r.t sorted errors 21 | See Alg. 1 in paper 22 | """ 23 | p = len(gt_sorted) 24 | gts = gt_sorted.sum() 25 | intersection = gts - gt_sorted.float().cumsum(0) 26 | union = gts + (1 - gt_sorted).float().cumsum(0) 27 | jaccard = 1. - intersection / union 28 | if p > 1: # cover 1-pixel case 29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 30 | return jaccard 31 | 32 | 33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 34 | """ 35 | IoU for foreground class 36 | binary: 1 foreground, 0 background 37 | """ 38 | if not per_image: 39 | preds, labels = (preds,), (labels,) 40 | ious = [] 41 | for pred, label in zip(preds, labels): 42 | intersection = ((label == 1) & (pred == 1)).sum() 43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 44 | if not union: 45 | iou = EMPTY 46 | else: 47 | iou = float(intersection) / union 48 | ious.append(iou) 49 | iou = mean(ious) # mean accross images if per_image 50 | return 100 * iou 51 | 52 | 53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 54 | """ 55 | Array of IoU for each (non ignored) class 56 | """ 57 | if not per_image: 58 | preds, labels = (preds,), (labels,) 59 | ious = [] 60 | for pred, label in zip(preds, labels): 61 | iou = [] 62 | for i in range(C): 63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 64 | intersection = ((label == i) & (pred == i)).sum() 65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 66 | if not union: 67 | iou.append(EMPTY) 68 | else: 69 | iou.append(float(intersection) / union) 70 | ious.append(iou) 71 | ious = map(mean, zip(*ious)) # mean accross images if per_image 72 | return 100 * np.array(ious) 73 | 74 | 75 | # --------------------------- BINARY LOSSES --------------------------- 76 | 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.elu(errors_sorted) +1, Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) 157 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 158 | only_present: average only on classes present in ground truth 159 | per_image: compute the loss per image instead of per batch 160 | ignore: void class labels 161 | """ 162 | if per_image: 163 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present) 164 | for prob, lab in zip(probas, labels)) 165 | else: 166 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present) 167 | return loss 168 | 169 | 170 | def lovasz_softmax_flat(probas, labels, only_present=False): 171 | """ 172 | Multi-class Lovasz-Softmax loss 173 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 174 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 175 | only_present: average only on classes present in ground truth 176 | """ 177 | C = probas.size(1) 178 | losses = [] 179 | for c in range(C): 180 | fg = (labels == c).float() # foreground for class c 181 | if only_present and fg.sum() == 0: 182 | continue 183 | errors = (Variable(fg) - probas[:, c]).abs() 184 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 185 | perm = perm.data 186 | fg_sorted = fg[perm] 187 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 188 | return mean(losses) 189 | 190 | 191 | def flatten_probas(probas, labels, ignore=None): 192 | """ 193 | Flattens predictions in the batch 194 | """ 195 | B, C, H, W = probas.size() 196 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 197 | labels = labels.view(-1) 198 | if ignore is None: 199 | return probas, labels 200 | valid = (labels != ignore) 201 | vprobas = probas[valid.nonzero().squeeze()] 202 | vlabels = labels[valid] 203 | return vprobas, vlabels 204 | 205 | def xloss(logits, labels, ignore=None): 206 | """ 207 | Cross entropy loss 208 | """ 209 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 210 | 211 | 212 | # --------------------------- HELPER FUNCTIONS --------------------------- 213 | 214 | def mean(l, ignore_nan=False, empty=0): 215 | """ 216 | nanmean compatible with generators. 217 | """ 218 | l = iter(l) 219 | if ignore_nan: 220 | l = ifilterfalse(np.isnan, l) 221 | try: 222 | n = 1 223 | acc = next(l) 224 | except StopIteration: 225 | if empty == 'raise': 226 | raise ValueError('Empty mean') 227 | return empty 228 | for n, v in enumerate(l, 2): 229 | acc += v 230 | if n == 1: 231 | return acc 232 | return acc / n 233 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from .imports import * 2 | from .torch_imports import * 3 | from .core import * 4 | from .layer_optimizer import * 5 | from .swa import * 6 | from .fp16 import * 7 | 8 | IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4') 9 | 10 | def cut_model(m, cut): 11 | return list(m.children())[:cut] if cut else [m] 12 | 13 | def predict_to_bcolz(m, gen, arr, workers=4): 14 | arr.trim(len(arr)) 15 | lock=threading.Lock() 16 | m.eval() 17 | for x,*_ in tqdm(gen): 18 | y = to_np(m(VV(x)).data) 19 | with lock: 20 | arr.append(y) 21 | arr.flush() 22 | 23 | def num_features(m): 24 | c=children(m) 25 | if len(c)==0: return None 26 | for l in reversed(c): 27 | if hasattr(l, 'num_features'): return l.num_features 28 | res = num_features(l) 29 | if res is not None: return res 30 | 31 | def torch_item(x): return x.item() if hasattr(x,'item') else x[0] 32 | 33 | class Stepper(): 34 | def __init__(self, m, opt, crit, clip=0, reg_fn=None, fp16=False, loss_scale=1): 35 | self.m,self.opt,self.crit,self.clip,self.reg_fn = m,opt,crit,clip,reg_fn 36 | self.fp16 = fp16 37 | self.reset(True) 38 | if self.fp16: self.fp32_params = copy_model_to_fp32(m, opt) 39 | self.loss_scale = loss_scale 40 | 41 | def reset(self, train=True): 42 | if train: apply_leaf(self.m, set_train_mode) 43 | else: self.m.eval() 44 | if hasattr(self.m, 'reset'): 45 | self.m.reset() 46 | if self.fp16: self.fp32_params = copy_model_to_fp32(self.m, self.opt) 47 | 48 | def step(self, xs, y, epoch): 49 | xtra = [] 50 | output = self.m(*xs) 51 | if isinstance(output,tuple): output,*xtra = output 52 | if self.fp16: self.m.zero_grad() 53 | else: self.opt.zero_grad() 54 | loss = raw_loss = self.crit((output,xtra), y) 55 | if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale 56 | if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss) 57 | loss.backward() 58 | if self.fp16: update_fp32_grads(self.fp32_params, self.m) 59 | if self.loss_scale != 1: 60 | for param in self.fp32_params: param.grad.data.div_(self.loss_scale) 61 | if self.clip: # Gradient clipping 62 | if IS_TORCH_04: nn.utils.clip_grad_norm_(trainable_params_(self.m), self.clip) 63 | else: nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip) 64 | if 'wd' in self.opt.param_groups and self.opt.param_groups['wd'] != 0: 65 | #Weight decay out of the loss. After the gradient computation but before the step. 66 | for group in self.opt.param_groups: 67 | lr, wd = group['lr'], group['wd'] 68 | for p in group['params']: 69 | if p.grad is not None: p.data = p.data.add(-wd * lr, p.data) 70 | self.opt.step() 71 | if self.fp16: 72 | copy_fp32_to_model(self.m, self.fp32_params) 73 | torch.cuda.synchronize() 74 | return torch_item(raw_loss.data) 75 | 76 | def evaluate(self, xs, y): 77 | preds = self.m(*xs) 78 | #if isinstance(preds,tuple): preds=preds[0] 79 | return preds[0], self.crit(preds, y) 80 | 81 | def set_train_mode(m): 82 | if (hasattr(m, 'running_mean') and (getattr(m,'bn_freeze',False) 83 | or not getattr(m,'trainable',False))): m.eval() 84 | elif (getattr(m,'drop_freeze',False) and hasattr(m, 'p') 85 | and ('drop' in type(m).__name__.lower())): m.eval() 86 | else: m.train() 87 | 88 | def fit(model, data, n_epochs, opt, crit, metrics=None, callbacks=None, stepper=Stepper, 89 | swa_model=None, swa_start=None, swa_eval_freq=None, **kwargs): 90 | """ Fits a model 91 | 92 | Arguments: 93 | model (model): any pytorch module 94 | net = to_gpu(net) 95 | data (ModelData): see ModelData class and subclasses (can be a list) 96 | opts: an optimizer. Example: optim.Adam. 97 | If n_epochs is a list, it needs to be the layer_optimizer to get the optimizer as it changes. 98 | n_epochs(int or list): number of epochs (or list of number of epochs) 99 | crit: loss function to optimize. Example: F.cross_entropy 100 | """ 101 | 102 | all_val = kwargs.pop('all_val') if 'all_val' in kwargs else False 103 | get_ep_vals = kwargs.pop('get_ep_vals') if 'get_ep_vals' in kwargs else False 104 | metrics = metrics or [] 105 | callbacks = callbacks or [] 106 | avg_mom=0.98 107 | batch_num,avg_loss=0,0. 108 | for cb in callbacks: cb.on_train_begin() 109 | names = ["epoch", "trn_loss", "val_loss"] + [f.__name__ for f in metrics] 110 | if swa_model is not None: 111 | swa_names = ['swa_loss'] + [f'swa_{f.__name__}' for f in metrics] 112 | names += swa_names 113 | # will use this to call evaluate later 114 | swa_stepper = stepper(swa_model, None, crit, **kwargs) 115 | 116 | layout = "{!s:10} " * len(names) 117 | if not isinstance(n_epochs, Iterable): n_epochs=[n_epochs] 118 | if not isinstance(data, Iterable): data = [data] 119 | if len(data) == 1: data = data * len(n_epochs) 120 | for cb in callbacks: cb.on_phase_begin() 121 | model_stepper = stepper(model, opt.opt if hasattr(opt,'opt') else opt, crit, **kwargs) 122 | ep_vals = collections.OrderedDict() 123 | tot_epochs = int(np.ceil(np.array(n_epochs).sum())) 124 | cnt_phases = np.array([ep * len(dat.trn_dl) for (ep,dat) in zip(n_epochs,data)]).cumsum() 125 | phase = 0 126 | for epoch in tnrange(tot_epochs, desc='Epoch'): 127 | if phase >= len(n_epochs): break #Sometimes cumulated errors make this append. 128 | model_stepper.reset(True) 129 | cur_data = data[phase] 130 | if hasattr(cur_data, 'trn_sampler'): cur_data.trn_sampler.set_epoch(epoch) 131 | if hasattr(cur_data, 'val_sampler'): cur_data.val_sampler.set_epoch(epoch) 132 | num_batch = len(cur_data.trn_dl) 133 | t = tqdm(iter(cur_data.trn_dl), leave=False, total=num_batch) 134 | if all_val: val_iter = IterBatch(cur_data.val_dl) 135 | 136 | for (*x,y) in t: 137 | batch_num += 1 138 | for cb in callbacks: cb.on_batch_begin() 139 | loss = model_stepper.step(V(x),V(y), epoch) 140 | avg_loss = avg_loss * avg_mom + loss * (1-avg_mom) 141 | debias_loss = avg_loss / (1 - avg_mom**batch_num) 142 | t.set_postfix(loss=debias_loss) 143 | stop=False 144 | los = debias_loss if not all_val else [debias_loss] + validate_next(model_stepper,metrics, val_iter) 145 | for cb in callbacks: stop = stop or cb.on_batch_end(los) 146 | if stop: return 147 | if batch_num >= cnt_phases[phase]: 148 | for cb in callbacks: cb.on_phase_end() 149 | phase += 1 150 | if phase >= len(n_epochs): 151 | t.close() 152 | break 153 | for cb in callbacks: cb.on_phase_begin() 154 | if isinstance(opt, LayerOptimizer): model_stepper.opt = opt.opt 155 | if cur_data != data[phase]: 156 | t.close() 157 | break 158 | 159 | if not all_val: 160 | vals = validate(model_stepper, cur_data.val_dl, metrics) 161 | stop=False 162 | for cb in callbacks: stop = stop or cb.on_epoch_end(vals) 163 | if swa_model is not None: 164 | if (epoch + 1) >= swa_start and ((epoch + 1 - swa_start) % swa_eval_freq == 0 or epoch == tot_epochs - 1): 165 | fix_batchnorm(swa_model, cur_data.trn_dl) 166 | swa_vals = validate(swa_stepper, cur_data.val_dl, metrics) 167 | vals += swa_vals 168 | 169 | if epoch == 0: print(layout.format(*names)) 170 | print_stats(epoch, [debias_loss] + vals) 171 | ep_vals = append_stats(ep_vals, epoch, [debias_loss] + vals) 172 | if stop: break 173 | for cb in callbacks: cb.on_train_end() 174 | if get_ep_vals: return vals, ep_vals 175 | else: return vals 176 | 177 | def append_stats(ep_vals, epoch, values, decimals=6): 178 | ep_vals[epoch]=list(np.round(values, decimals)) 179 | return ep_vals 180 | 181 | def print_stats(epoch, values, decimals=6): 182 | layout = "{!s:^10}" + " {!s:10}" * len(values) 183 | values = [epoch] + list(np.round(values, decimals)) 184 | print(layout.format(*values)) 185 | 186 | class IterBatch(): 187 | def __init__(self, dl): 188 | self.idx = 0 189 | self.dl = dl 190 | self.iter = iter(dl) 191 | 192 | def __iter__(self): return self 193 | 194 | def next(self): 195 | res = next(self.iter) 196 | self.idx += 1 197 | if self.idx == len(self.dl): 198 | self.iter = iter(self.dl) 199 | self.idx=0 200 | return res 201 | 202 | def validate_next(stepper, metrics, val_iter): 203 | """Computes the loss on the next minibatch of the validation set.""" 204 | stepper.reset(False) 205 | with no_grad_context(): 206 | (*x,y) = val_iter.next() 207 | preds,l = stepper.evaluate(VV(x), VV(y)) 208 | res = [delistify(to_np(l))] 209 | res += [f(preds.data,y) for f in metrics] 210 | stepper.reset(True) 211 | return res 212 | 213 | def validate(stepper, dl, metrics): 214 | batch_cnts,loss,res = [],[],[] 215 | stepper.reset(False) 216 | with no_grad_context(): 217 | for (*x,y) in iter(dl): 218 | preds, l = stepper.evaluate(VV(x), VV(y)) 219 | if isinstance(x,list): batch_cnts.append(len(x[0])) 220 | else: batch_cnts.append(len(x)) 221 | loss.append(to_np(l)) 222 | res.append([f(preds.data, y) for f in metrics]) 223 | return [np.average(loss, 0, weights=batch_cnts)] + list(np.average(np.stack(res), 0, weights=batch_cnts)) 224 | 225 | def get_prediction(x): 226 | if is_listy(x): x=x[0] 227 | return x.data 228 | 229 | def predict(m, dl): 230 | preda,_ = predict_with_targs_(m, dl) 231 | return np.concatenate(preda) 232 | 233 | def predict_batch(m, x): 234 | m.eval() 235 | if hasattr(m, 'reset'): m.reset() 236 | return m(VV(x)) 237 | 238 | def predict_with_targs_(m, dl): 239 | m.eval() 240 | if hasattr(m, 'reset'): m.reset() 241 | res = [] 242 | for *x,y in iter(dl): res.append([get_prediction(to_np(m(*VV(x)))),to_np(y)]) 243 | return zip(*res) 244 | 245 | def predict_with_targs(m, dl): 246 | preda,targa = predict_with_targs_(m, dl) 247 | return np.concatenate(preda), np.concatenate(targa) 248 | 249 | # From https://github.com/ncullen93/torchsample 250 | def model_summary(m, input_size): 251 | def register_hook(module): 252 | def hook(module, input, output): 253 | class_name = str(module.__class__).split('.')[-1].split("'")[0] 254 | module_idx = len(summary) 255 | 256 | m_key = '%s-%i' % (class_name, module_idx+1) 257 | summary[m_key] = OrderedDict() 258 | summary[m_key]['input_shape'] = list(input[0].size()) 259 | summary[m_key]['input_shape'][0] = -1 260 | if is_listy(output): 261 | summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output] 262 | else: 263 | summary[m_key]['output_shape'] = list(output.size()) 264 | summary[m_key]['output_shape'][0] = -1 265 | 266 | params = 0 267 | if hasattr(module, 'weight'): 268 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 269 | summary[m_key]['trainable'] = module.weight.requires_grad 270 | if hasattr(module, 'bias') and module.bias is not None: 271 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 272 | summary[m_key]['nb_params'] = params 273 | 274 | if (not isinstance(module, nn.Sequential) and 275 | not isinstance(module, nn.ModuleList) and 276 | not (module == m)): 277 | hooks.append(module.register_forward_hook(hook)) 278 | 279 | summary = OrderedDict() 280 | hooks = [] 281 | m.apply(register_hook) 282 | 283 | if is_listy(input_size[0]): 284 | x = [to_gpu(Variable(torch.rand(3,*in_size))) for in_size in input_size] 285 | else: x = [to_gpu(Variable(torch.rand(3,*input_size)))] 286 | m(*x) 287 | 288 | for h in hooks: h.remove() 289 | return summary 290 | 291 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_learner import * 2 | from fastai.dataset import * 3 | from bam import * 4 | 5 | class SEModule(nn.Module): 6 | def __init__(self, ch, re=16): 7 | super(SEModule, self).__init__() 8 | self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1), 9 | nn.Conv2d(ch,ch//re,1), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(ch//re,ch,1), 12 | nn.Sigmoid() 13 | ) 14 | def forward(self, x): 15 | return x * self.se(x) 16 | 17 | class SCSEModule(nn.Module): 18 | def __init__(self, ch, re=16): 19 | super().__init__() 20 | self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1), 21 | nn.Conv2d(ch,ch//re,1), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(ch//re,ch,1), 24 | nn.Sigmoid() 25 | ) 26 | self.sSE = nn.Sequential(nn.Conv2d(ch,ch,1), 27 | nn.Sigmoid()) 28 | 29 | def forward(self, x): 30 | return x * self.cSE(x) + x * self.sSE(x) 31 | 32 | class GCN(nn.Module): 33 | def __init__(self,inp,out,ks=7): 34 | super().__init__() 35 | self.conv_l = nn.Sequential(nn.Conv2d(inp,out,(ks,1),padding=(ks//2,0)), 36 | nn.Conv2d(out,out,(1,ks),padding=(0,ks//2)) 37 | ) 38 | self.conv_r = nn.Sequential(nn.Conv2d(inp,out,(1,ks),padding=(0,ks//2)), 39 | nn.Conv2d(out,out,(ks,1),padding=(ks//2,0)) 40 | ) 41 | def forward(self,x): 42 | return self.conv_l(x) + self.conv_r(x) 43 | 44 | class Refine(nn.Module): 45 | def __init__(self,inp): 46 | super().__init__() 47 | self.conv1 = nn.Sequential(nn.BatchNorm2d(inp), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(inp,inp,3,padding=1), 50 | nn.BatchNorm2d(inp), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(inp,inp,3,padding=1)) 53 | def forward(self,x): 54 | return x + self.conv1(x) 55 | 56 | def conv_block(ni,nf,ks = 3): 57 | model = nn.Sequential(nn.Conv2d(ni,ni,kernel_size=ks,padding=ks//2),nn.ReLU(inplace=True),nn.BatchNorm2d(ni), 58 | nn.Conv2d(ni,nf,kernel_size=ks,padding=ks//2)) 59 | return model 60 | 61 | def create_interpolate(x,img_size,mode='bilinear',ac=False): 62 | sf = img_size//x.size(2) 63 | return x if sf == 1 else F.interpolate(x,scale_factor=sf,align_corners=ac,mode=mode) 64 | 65 | def unet_conv(ni,nf): 66 | model = nn.Sequential(nn.ReLU(),nn.BatchNorm2d(ni), 67 | nn.Conv2d(ni,ni,3,padding=1), 68 | nn.ReLU(),nn.BatchNorm2d(ni), 69 | nn.Conv2d(ni,nf,3,padding=1)) 70 | return model 71 | 72 | class ResNetWithBAM(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | encoder = torchvision.models.resnet34(True) 76 | encoder.conv1.stride = (1,1) 77 | self.input_adjust = nn.Sequential(encoder.conv1, 78 | encoder.bn1, 79 | encoder.relu) 80 | self.pool = encoder.maxpool 81 | self.conv1 = encoder.layer1 82 | self.conv2 = encoder.layer2 83 | self.conv3 = encoder.layer3 84 | self.conv4 = encoder.layer4 85 | self.bam1,self.bam2,self.bam3,self.bam4 = BAM(64),BAM(128),BAM(256),BAM(512) 86 | 87 | def forward(self,x): 88 | inp =self.input_adjust(x) 89 | e0 = self.pool(inp) 90 | e1 = self.bam1(self.conv1(e0)) 91 | e2 = self.bam2(self.conv2(e1)) 92 | e3 = self.bam3(self.conv3(e2)) 93 | e4 = self.bam4(self.conv4(e3)) 94 | return e0,e1,e2,e3,e4 95 | 96 | class UnetBlock(nn.Module): 97 | def __init__(self, up_in, x_in, n_out): 98 | super().__init__() 99 | up_out = x_out = n_out//2 100 | self.x_conv = nn.Conv2d(x_in, x_out, 1) 101 | self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2) 102 | self.bn = nn.BatchNorm2d(n_out) 103 | self.bam = BAM(n_out) 104 | 105 | def forward(self, up_p, x_p): 106 | up_p = self.tr_conv(up_p) 107 | x_p = self.x_conv(x_p) 108 | cat_p = torch.cat([up_p,x_p], dim=1) 109 | return self.bam(self.bn(F.relu(cat_p))) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pycocotools import mask as cocomask 2 | from fastai.conv_learner import * 3 | 4 | def get_segmentations(labeled): 5 | nr_true = int(labeled.max()) 6 | segmentations = [] 7 | for i in range(1, nr_true + 1): 8 | msk = labeled == i 9 | segmentation = rle_from_binary(msk.astype('uint8')) 10 | segmentation['counts'] = segmentation['counts'].decode("UTF-8") 11 | segmentations.append(segmentation) 12 | return segmentations 13 | 14 | def compute_precision_at(ious, threshold): 15 | mx1 = np.max(ious, axis=0) 16 | mx2 = np.max(ious, axis=1) 17 | tp = np.sum(mx2 >= threshold) 18 | fp = np.sum(mx2 < threshold) 19 | fn = np.sum(mx1 < threshold) 20 | return float(tp) / (tp + fp + fn) 21 | 22 | def compute_ious(gt, predictions): 23 | gt_ = get_segmentations(gt) 24 | predictions_ = get_segmentations(predictions) 25 | 26 | if len(gt_) == 0 and len(predictions_) == 0: 27 | return np.ones((1, 1)) 28 | elif len(gt_) != 0 and len(predictions_) == 0: 29 | return np.zeros((1, 1)) 30 | else: 31 | iscrowd = [0 for _ in predictions_] 32 | ious = cocomask.iou(gt_, predictions_, iscrowd) 33 | if not np.array(ious).size: 34 | ious = np.zeros((1, 1)) 35 | return ious 36 | 37 | def compute_eval_metric(gt, predictions): 38 | thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 39 | ious = compute_ious(gt, predictions) 40 | precisions = [compute_precision_at(ious, th) for th in thresholds] 41 | return sum(precisions) / len(precisions) 42 | 43 | def intersection_over_union_thresholds(y_true, y_pred): 44 | iouts = [] 45 | for y_t, y_p in list(zip(y_true, y_pred)): 46 | iouts.append(compute_eval_metric(y_t, y_p)) 47 | return np.mean(iouts) 48 | 49 | def rle_from_binary(prediction): 50 | prediction = np.asfortranarray(prediction) 51 | return cocomask.encode(prediction) 52 | 53 | def intersection_over_union(y_true, y_pred): 54 | ious = [] 55 | for y_t, y_p in list(zip(y_true, y_pred)): 56 | iou = compute_ious(y_t, y_p) 57 | iou_mean = 1.0 * np.sum(iou) / len(iou) 58 | ious.append(iou_mean) 59 | return np.mean(ious) 60 | 61 | def my_eval(pred,targ): 62 | pred = to_np(torch.sigmoid(pred)) 63 | targ = to_np(targ) 64 | losses = [] 65 | for i in range(targ.shape[0]): 66 | losses.append(compute_eval_metric(targ[i],((pred[i]>0.5).astype(np.uint8)))) 67 | return np.mean(losses) 68 | 69 | def RLenc(img, order='F', format=True): 70 | """ 71 | img is binary mask image, shape (r,c) 72 | order is down-then-right, i.e. Fortran 73 | format determines if the order needs to be preformatted (according to submission rules) or not 74 | 75 | returns run length as an array or string (if format is True) 76 | """ 77 | bytes = img.reshape(img.shape[0] * img.shape[1], order=order) 78 | runs = [] ## list of run lengths 79 | r = 0 ## the current run length 80 | pos = 1 ## count starts from 1 per WK 81 | for c in bytes: 82 | if (c == 0): 83 | if r != 0: 84 | runs.append((pos, r)) 85 | pos += r 86 | r = 0 87 | pos += 1 88 | else: 89 | r += 1 90 | 91 | # if last run is unsaved (i.e. data ends with 1) 92 | if r != 0: 93 | runs.append((pos, r)) 94 | pos += r 95 | r = 0 96 | 97 | if format: 98 | z = '' 99 | 100 | for rr in runs: 101 | z += '{} {} '.format(rr[0], rr[1]) 102 | return z[:-1] 103 | else: 104 | return runs 105 | 106 | 107 | 108 | --------------------------------------------------------------------------------