├── README.md ├── guide.md └── TGS_Salt_resnext50_unet_5Fold_scSE.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ## [[TGS Salt Identification Challenge](https://www.kaggle.com/c/tgs-salt-identification-challenge)] Bronze Medal Soultion 2 | 3 | First of all, Thanks [Fastai library](https://github.com/fastai/fastai/tree/0.7.0/fastai) which save me a lot of time to build the model. Bronze medal place isn't a very good score, However, this [notebook](https://github.com/alexshuang/TGS_Salt/blob/master/TGS_Salt_resnext50_unet_5Fold_scSE.ipynb) can be directly run on Google colab. You can use it as backbone to experiment with the ideas published by other top contestants. The most important thing is that you don't need to have your own GPU for this competition, you'd better have a faster GPU, though:). 4 | 5 | For more information about source code, please read my chinese blog: [here](https://github.com/alexshuang/Kaggle_TGS_Salt_Identification_Challenge/blob/master/guide.md) 6 | 7 | --- 8 | 9 | ### Network architecture 10 | 11 | * U-net with a pre-trained resnet34/resnext50, resnext50 is better. 12 | 13 | ### Training regime 14 | 15 | * Adam optimizer with weight decay. SGD is better than Adam, but it also much lower than Adam. 16 | * 5-folds cross validation. 17 | * use [albumentations](https://github.com/albu/albumentations) library to do data augmentation, random corp 50%~100% area region, scale to 128x128. Scale to 192x192 or 256x256 is recommended, but you need to change to a faster GPU. 18 | * stage1: train by SGD + BCE, stage2: train by Adam + Lovasz Hinge Loss, until validation loss is no longer reduced. 19 | 20 | ### Techniques that helped 21 | 22 | * Adding Concurrent Spatial and Channel Squeeze & Excitation blocks to the U-net decoder **(+0.032 Private LB)**: [https://arxiv.org/pdf/1803.02579.pdf.](https://arxiv.org/pdf/1803.02579.pdf.) I used the [implementation](https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178) generously provided by [Bilal](https://www.kaggle.com/bkkaggle). 23 | * Depth-statified n-fold training (+0.01 LB). For me, 5-fold scored is good. 24 | * TTA: flip and non-flip (+0.02 LB) 25 | 26 | ### Techniques that not helped 27 | * Hypercolumns: [https://arxiv.org/pdf/1411.5752.pdf.](https://arxiv.org/pdf/1411.5752.pdf.) 28 | * Many folks([Heng](https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/64645#380301), et al) reported anywhere from a public LB +0.01 to +0.005 improvement after concatenating hypercolumns to their decoder's output. My attempt led to a lower score, and because I found it to be computationally expensive on my Paperspace P5000 machine, I didn't bother seeing if I could troubleshoot. 29 | 30 | -------------------------------------------------------------------------------- /guide.md: -------------------------------------------------------------------------------- 1 | ## Segmentation 2 | 3 | ![Figure 1 ](https://upload-images.jianshu.io/upload_images/13575947-e7d482d05e62cb96.jpg?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 4 | Figure1来自CamVid database,专为目标识别(Object Dection)和图像分割(Image Segmentation)提供训练数据的网站。从图中可以看出,segmentation将图像中不同个体用不同颜色来标记,这里不同的颜色就代表不同的分类,例如红色就是分类1,蓝色就是分类2,可以看出,它就是像素级的图像识别(Image Identification)。 5 | 6 | 除了自动驾驶之外,图像分割还广泛应用于医学诊断、卫星影像定位、图片合成等领域,本文就以当前[kaggle](http://www.kaggle.com)上最热门的segmentation竞赛--[TGS Salt Identification Challenge](https://www.kaggle.com/c/tgs-salt-identification-challenge)为例来讲解如何应用Unet来解决真实世界的图像分割问题。github: [here](https://github.com/alexshuang/TGS_Salt)。 7 | 8 | TGS公司通过地震波反射技术绘制出下图所示的3D地质图像,并标记出图像中的盐矿区域,参赛者需要训练用于从岩层中分离盐矿的机器学习模型。 9 | ![Figure 2: Images & marks](https://upload-images.jianshu.io/upload_images/13575947-ffe482168fdeff0f.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 10 | 11 | Figure 2是trainset中5组image和mark图片,每组的左图为原始地质图像,右图为该图像的分类,称为mark,黑色区域代表一般岩层,白色区域就是盐的分布。segmentation要做的就是训练一个image-to-image的模型,通过对原始图像的学习,生成其对应的mask2,mask则作为target,通过最小化mask和mask2的差距来识别哪些是盐。 12 | 13 | ### Dataset 14 | 15 | 生成dataset的第一步是根据run length数据创建对应的mark图片,因为TGS的trainset里面已经提供了mark图片(mark图片和对应image图片同名),所以我们就不需要额外再创建。 16 | 17 | 但要知道的是,并非所有的segmentation dataset都会提供marks,你需要根据数据run length来为images创建相应的marks,run length是如下图rle_mask所示的数据,数据间以空格分隔,两两为一组,每组的第一个数代表flatten后的image vector的起始下标,后一个数代表它所占据的长度,占据区域会填充该目标对应的分类号,如0、1、2...,通过rle_decode()可以将run length转化为mark。 18 | ![image.png](https://upload-images.jianshu.io/upload_images/13575947-376d4fda8d49f757.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 19 | 20 | ``` 21 | def rle_decode(mask_rle, shape=(101, 101)): 22 | s = mask_rle.split() 23 | starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] 24 | starts -= 1 25 | ends = starts + lengths 26 | img = np.zeros(shape[0]*shape[1], dtype=np.uint8) 27 | for lo, hi in zip(starts, ends): 28 | img[lo:hi] = 1 29 | return img.mean() 30 | ``` 31 | 32 | 从Figure 2可以看到,地质图像都是低分辨画质,只有101x101大小,不仅不利于神经网络的卷积计算,也不利于图像识别,所以我们接下来一般会将其resize为128x128大小。 33 | ``` 34 | def resize_img(fn, outp_dir, sz): 35 | Image.open(fn).resize((sz, sz)).save(outp_dir/fn.name) 36 | ``` 37 | Data augmentation是创建dataset的核心,和object dection一样,segmentation一般不会做random crop,我在这个项目中采用水平、垂直翻转和微调光暗的方法来做augmentation。 38 | ``` 39 | aug_tfms = [ 40 | RandomFlip(tfm_y=TfmType.CLASS), 41 | RandomDihedral(tfm_y=TfmType.CLASS), 42 | # RandomRotate(4, tfm_y=TfmType.CLASS), 43 | RandomLighting(0.07, 0.07, tfm_y=TfmType.CLASS) 44 | ] 45 | ``` 46 | 47 | ### Unet 48 | [paper](https://arxiv.org/abs/1505.04597) 49 | Unet虽然是2015年诞生的模型,但它依旧是当前segmentation项目中应用最广的模型,kaggle上LB排名靠前的选手很多都是使用该模型。 50 | ![image.png](https://upload-images.jianshu.io/upload_images/13575947-2ce6b66cc5d3df89.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 51 | Unet的左侧是convolution layers,右侧则是upsamping layers,convolutions layers中每个pooling layer前一刻的activation值会concatenate到对应的upsamping层的activation值中。 52 | 53 | 因为Unet左侧部分和resnet、vgg、inception等模型一样,都是通过卷积层来提取图像特征,所以Unet可以采用resnet/vgg/inception+upsampling的形式来实现,这样做好处是可以利用pretrained的成熟模型来加速Unet的训练,要知道transfer training的效果是非常显著的,我在这个项目中采用的就是resnet34+upsampling的架构。 54 | ``` 55 | class SaveFeatures(): 56 | features=None 57 | def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn) 58 | def hook_fn(self, module, input, output): self.features = output 59 | def remove(self): self.hook.remove() 60 | 61 | 62 | class UnetBlock(nn.Module): 63 | def __init__(self, up_in, down_in, n_out, dp=False, ps=0.25): 64 | super().__init__() 65 | up_out = down_out = n_out // 2 66 | self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, 2, bias=False) 67 | self.conv = nn.Conv2d(down_in, down_out, 1, bias=False) 68 | self.bn = nn.BatchNorm2d(n_out) 69 | self.dp = dp 70 | if dp: self.dropout = nn.Dropout(ps, inplace=True) 71 | 72 | def forward(self, up_x, down_x): 73 | x1 = self.tr_conv(up_x) 74 | x2 = self.conv(down_x) 75 | x = torch.cat([x1, x2], dim=1) 76 | x = self.bn(F.relu(x)) 77 | return self.dropout(x) if self.dp else x 78 | 79 | 80 | class Unet34(nn.Module): 81 | def __init__(self, rn, drop_i=False, ps_i=None, drop_up=False, ps=None): 82 | super().__init__() 83 | self.rn = rn 84 | self.sfs = [SaveFeatures(rn[i]) for i in [2, 4, 5, 6]] 85 | self.drop_i = drop_i 86 | if drop_i: 87 | self.dropout = nn.Dropout(ps_i, inplace=True) 88 | if ps_i is None: ps_i = 0.1 89 | if ps is not None: assert len(ps) == 4 90 | if ps is None: ps = [0.1] * 4 91 | self.up1 = UnetBlock(512, 256, 256, drop_up, ps[0]) 92 | self.up2 = UnetBlock(256, 128, 256, drop_up, ps[1]) 93 | self.up3 = UnetBlock(256, 64, 256, drop_up, ps[2]) 94 | self.up4 = UnetBlock(256, 64, 256, drop_up, ps[3]) 95 | self.up5 = nn.ConvTranspose2d(256, 1, 2, 2) 96 | 97 | def forward(self, x): 98 | x = F.relu(self.rn(x)) 99 | x = self.dropout(x) if self.drop_i else x 100 | x = self.up1(x, self.sfs[3].features) 101 | x = self.up2(x, self.sfs[2].features) 102 | x = self.up3(x, self.sfs[1].features) 103 | x = self.up4(x, self.sfs[0].features) 104 | x = self.up5(x) 105 | return x[:, 0] 106 | 107 | def close(self): 108 | for o in self.sfs: o.remove() 109 | ``` 110 | 通过注册nn.register_forward_hook() ,将指定resnet34指定层(2, 4, 5, 6)的activation值保存起来,在upsampling的过程中将它们concatnate到相应的upsampling layer中。upsampling layer中使用ConvTranspose2d()来做deconvolution,ConvTranspose2d()的工作机制和conv2d()正好相反,用于增加feature map的grid size,对deconvolution的计算不是很熟悉的朋友请自行阅读[convolution arithmetic tutorial](http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html)。 111 | 112 | ### Loss 113 | 114 | 前文也提到,segmentation本质上是像素级的图像识别,该项目只有两个分类: 盐和岩,和猫vs狗一样是binary classification问题,用binary cross entropy即可,即nn.BCEWithLogitsLoss()。除了BCE,我还尝试了[focal loss](https://arxiv.org/abs/1708.02002),准确率提升了0.013。 115 | ![Figure 3](https://upload-images.jianshu.io/upload_images/13575947-ff420a35c83e81e8.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 116 | 117 | 从Figure 3中数学公式可以看出,focal loss就是scale版的cross entropy,$-(1 - p_t)^\gamma$是scale值,这里的scale不是常数而是可学习的weights。focal loss的公式虽然很简单,但在object dection中,focal loss的表现远胜于BCE,其背后的逻辑是:通过scale放大/缩小模型的输出结果,将原本模糊不清的判断确定化。Figure 3,当gamma == 0时,focal loss就相当于corss entropy(CE),如蓝色曲线所示,即使probability达到0.6,loss值还会>= 0.5,就好像是说:“我判断输出不是分类B的概率是60%,但我依旧不能确定它一定不是分类B”。当gamma == 2时,同样是probability达到0.6,loss值接近于0,就好像是说:“我判断输出不是分类B的概率是60%,我认为它一定不是分类B”,这就是scale的威力。 118 | ``` 119 | #https://github.com/marvis/pytorch-yolo2/blob/master/FocalLoss.py 120 | #https://github.com/unsky/focal-loss 121 | class FocalLoss2d(nn.Module): 122 | def __init__(self, gamma=2, size_average=True): 123 | super(FocalLoss2d, self).__init__() 124 | self.gamma = gamma 125 | self.size_average = size_average 126 | 127 | def forward(self, logit, target, class_weight=None, type='softmax'): 128 | target = target.view(-1, 1).long() 129 | if type=='sigmoid': 130 | if class_weight is None: 131 | class_weight = [1]*2 #[0.5, 0.5] 132 | prob = F.sigmoid(logit) 133 | prob = prob.view(-1, 1) 134 | prob = torch.cat((1-prob, prob), 1) 135 | select = torch.FloatTensor(len(prob), 2).zero_().cuda() 136 | select.scatter_(1, target, 1.) 137 | elif type=='softmax': 138 | B,C,H,W = logit.size() 139 | if class_weight is None: 140 | class_weight =[1]*C #[1/C]*C 141 | logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C) 142 | prob = F.softmax(logit,1) 143 | select = torch.FloatTensor(len(prob), C).zero_().cuda() 144 | select.scatter_(1, target, 1.) 145 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1) 146 | class_weight = torch.gather(class_weight, 0, target) 147 | prob = (prob*select).sum(1).view(-1,1) 148 | prob = torch.clamp(prob,1e-8,1-1e-8) 149 | batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log() 150 | if self.size_average: 151 | loss = batch_loss.mean() 152 | else: 153 | loss = batch_loss 154 | return loss 155 | ``` 156 | 157 | ### Metric 158 | 159 | ![image.png](https://upload-images.jianshu.io/upload_images/13575947-23578a2dac8e4716.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 160 | 161 | 项目采用取超过probability超过Thresholds:[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]的IoU均值作为metric。 162 | ``` 163 | iou_thresholds = np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]) 164 | 165 | def iou(img_true, img_pred): 166 | img_pred = (img_pred > 0).float() 167 | i = (img_true * img_pred).sum() 168 | u = (img_true + img_pred).sum() 169 | return i / u if u != 0 else u 170 | 171 | def iou_metric(imgs_pred, imgs_true): 172 | num_images = len(imgs_true) 173 | scores = np.zeros(num_images) 174 | for i in range(num_images): 175 | if imgs_true[i].sum() == imgs_pred[i].sum() == 0: 176 | scores[i] = 1 177 | else: 178 | scores[i] = (iou_thresholds <= iou(imgs_true[i], imgs_pred[i])).mean() 179 | return scores.mean() 180 | ``` 181 | 182 | ### Training 183 | 184 | Unet模型训练大致分两步: 185 | - 通过[LR Test](https://arxiv.org/pdf/1506.01186.pdf)找出合适的学习率区间。 186 | - [Cycle Learning Rate (CLR)](https://arxiv.org/pdf/1506.01186.pdf)的方法来训练模型,直至过拟合。 187 | ``` 188 | wd = 4e-4 189 | arch = resnet34 190 | ps_i = 0.05 191 | ps = np.array([0.1, 0.1, 0.1, 0.1]) * 1 192 | m_base = get_base_model(arch, cut, True) 193 | m = to_gpu(Unet34(m_base, drop_i=True, drop_up=True, ps=ps, ps_i=ps_i)) 194 | models = UnetModel(m) 195 | learn = ConvLearner(md, models) 196 | learn.opt_fn = optim.Adam 197 | learn.crit = nn.BCEWithLogitsLoss() 198 | learn.metrics = [accuracy_thresh(0.5), miou] 199 | ``` 200 | 当模型训练到无法通过变化学习率来减少loss值,val loss收敛且有过拟合的可能时,我停止了模型的训练。 201 | ![image.png](https://upload-images.jianshu.io/upload_images/13575947-0afe62a185e242e5.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 202 | ![image.png](https://upload-images.jianshu.io/upload_images/13575947-9a693854572800f8.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) 203 | 从结果来看模型需要增加正则化来对抗过拟合,Dropout在Unet的实际应用中并没有起到好的效果,所以需要从data augmentation和weight decay下功夫。 204 | 205 | ### Run length encoder 206 | 207 | 和rle_decode()相反,在将输出提交到kaggle之前,需要通过rle_encode()根据mask生成相应的run length。当然前提是通过downsample()将mask resize回101x101大小。 208 | ``` 209 | def downsample(img, shape): 210 | if shape == img.shape: return img 211 | return resize(img, shape, mode='constant', preserve_range=True) 212 | 213 | def rle_encode(im): 214 | pixels = im.flatten() 215 | pixels = np.concatenate([[0], pixels, [0]]) 216 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 217 | runs[1::2] -= runs[::2] 218 | return ' '.join(str(x) for x in runs) 219 | ``` 220 | 221 | ### TTA(Test Time Augmentation) 222 | 我们可以通过对testset做data augmentation来提高在kaggle上的score。在segmentation项目中应用TTA时要特别注意的是,augmented images会带来augmented outputs,在对这些outputs求均值之前需要先根据相应的transform规则来转化outputs,例如,image1和水平翻转后的image2经模型分别生成mark1和mark2,在计算mark的均值之前需要先将mark2做水平翻转。 223 | 224 | --- 225 | 226 | ## 小结 227 | 到此,Unet模型的构建、训练的几个要点:dataset、model、loss和metric等都已经基本讲清了。这篇博文是我在比赛初期写下的,和我最终使用的模型稍有不同,例如新模型增加了5-folds cross validation、scSE network等, 有时间我会再写篇博文介绍排名靠前的参赛者的方案以及相关技术。 228 | 229 | 230 | --- 231 | -------------------------------------------------------------------------------- /TGS_Salt_resnext50_unet_5Fold_scSE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 440 10 | }, 11 | "colab_type": "code", 12 | "id": "0mTKWCsxXo4r", 13 | "outputId": "87668956-8fa4-4bb0-9247-015a74a51caa" 14 | }, 15 | "outputs": [], 16 | "source": [ 17 | "!apt-get install -y -qq software-properties-common python-software-properties module-init-tools\n", 18 | "!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null\n", 19 | "!apt-get update -qq 2>&1 > /dev/null\n", 20 | "!curl --header \"Host: launchpadlibrarian.net\" --header \"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.62 Safari/537.36\" --header \"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\" --header \"Accept-Language: zh,zh-CN;q=0.9,zh-TW;q=0.8,en-US;q=0.7,en;q=0.6\" \"https://launchpadlibrarian.net/386846978/google-drive-ocamlfuse_0.7.0-0ubuntu1_amd64.deb\" -o \"google-drive-ocamlfuse_0.7.0-0ubuntu1_amd64.deb\" -L\n", 21 | "!apt-get -y install -qq fuse\n", 22 | "!dpkg -i google-drive-ocamlfuse_0.7.0-0ubuntu1_amd64.deb\n", 23 | "\n", 24 | "from googleapiclient.discovery import build\n", 25 | "import io, os\n", 26 | "from googleapiclient.http import MediaIoBaseDownload\n", 27 | "from google.colab import auth\n", 28 | "\n", 29 | "auth.authenticate_user()\n", 30 | "\n", 31 | "drive_service = build('drive', 'v3')\n", 32 | "results = drive_service.files().list(\n", 33 | " q=\"name = 'kaggle.json'\", fields=\"files(id)\").execute()\n", 34 | "kaggle_api_key = results.get('files', [])\n", 35 | "\n", 36 | "filename = \"/content/.kaggle/kaggle.json\"\n", 37 | "os.makedirs(os.path.dirname(filename), exist_ok=True)\n", 38 | "\n", 39 | "request = drive_service.files().get_media(fileId=kaggle_api_key[0]['id'])\n", 40 | "fh = io.FileIO(filename, 'wb')\n", 41 | "downloader = MediaIoBaseDownload(fh, request)\n", 42 | "done = False\n", 43 | "while done is False:\n", 44 | " status, done = downloader.next_chunk()\n", 45 | " print(\"Download %d%%.\" % int(status.progress() * 100))\n", 46 | "os.chmod(filename, 600)\n", 47 | "!mkdir ~/.kaggle\n", 48 | "!cp /content/.kaggle/kaggle.json ~/.kaggle/kaggle.json\n", 49 | "\n", 50 | "from oauth2client.client import GoogleCredentials\n", 51 | "creds = GoogleCredentials.get_application_default()\n", 52 | "import getpass\n", 53 | "!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL\n", 54 | "vcode = getpass.getpass()\n", 55 | "!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 0, 61 | "metadata": { 62 | "colab": {}, 63 | "colab_type": "code", 64 | "id": "K102BSsYDGMy" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "!mkdir -p drive\n", 69 | "!google-drive-ocamlfuse -o nonempty drive\n", 70 | "!ls drive > /dev/null" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "colab_type": "text", 77 | "id": "teccmM5G-YOk" 78 | }, 79 | "source": [ 80 | "### Setup" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "colab": { 88 | "base_uri": "https://localhost:8080/", 89 | "height": 134 90 | }, 91 | "colab_type": "code", 92 | "id": "EaDX0j2Kt4W1", 93 | "outputId": "0b410fa0-4198-4f63-ac96-ab67db380594" 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "!mkdir src -p && cd src && git clone https://github.com/fastai/fastai.git" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "colab": { 105 | "base_uri": "https://localhost:8080/", 106 | "height": 54 107 | }, 108 | "colab_type": "code", 109 | "id": "698sgxtBuXZO", 110 | "outputId": "78e4c0d3-0361-43c7-90f4-7412461d2592" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "!pip3 install -q bcolz graphviz sklearn_pandas isoweek pandas_summary ipywidgets torch torchvision torchtext" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 0, 120 | "metadata": { 121 | "colab": {}, 122 | "colab_type": "code", 123 | "id": "D6T70YIYyuKp" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "!git config --global user.email 'nikshuang@163.com'\n", 128 | "!git config --global user.name 'Alex Huang'\n", 129 | "!pip install -q kaggle\n", 130 | "!pip install -q -U git+https://github.com/albu/albumentations" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": { 137 | "colab": { 138 | "base_uri": "https://localhost:8080/", 139 | "height": 118 140 | }, 141 | "colab_type": "code", 142 | "id": "kv9wLQUwc5Vd", 143 | "outputId": "89b20b5f-d89f-4759-eb4a-6efc2e41e4dc" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "!git clone https://github.com/svishnu88/TGS-SaltIdentification-Open-Solution-fastai.git\n", 148 | "!cp TGS-SaltIdentification-Open-Solution-fastai/*.py .\n", 149 | "!pip install -q Cython --upgrade\n", 150 | "!pip install -q pycocotools --upgrade" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 0, 156 | "metadata": { 157 | "colab": { 158 | "base_uri": "https://localhost:8080/", 159 | "height": 50 160 | }, 161 | "colab_type": "code", 162 | "id": "ib2O97lzt4cG", 163 | "outputId": "5304877b-0461-4af1-c124-51a5eccf8b1c" 164 | }, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "0.4.1\n" 171 | ] 172 | }, 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "(True, True)" 177 | ] 178 | }, 179 | "execution_count": 8, 180 | "metadata": { 181 | "tags": [] 182 | }, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "import sys\n", 188 | "sys.path.append(\"/content/src/fastai/old\") # on windows use \\'s instead\n", 189 | "\n", 190 | "from fastai.conv_learner import *\n", 191 | "from fastai.dataset import *\n", 192 | "from skimage.transform import resize\n", 193 | "import gc\n", 194 | "from sklearn.model_selection import train_test_split, StratifiedKFold , KFold\n", 195 | "from sklearn.metrics import jaccard_similarity_score\n", 196 | "from pycocotools import mask as cocomask\n", 197 | "from utils import my_eval,intersection_over_union_thresholds,RLenc\n", 198 | "from lovasz_losses import *\n", 199 | "print(torch.__version__)\n", 200 | "torch.backends.cudnn.benchmark=True\n", 201 | "torch.cuda.is_available(), torch.backends.cudnn.enabled" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 0, 207 | "metadata": { 208 | "colab": {}, 209 | "colab_type": "code", 210 | "id": "zkqghZ5gOF1L" 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "PATH = Path('data/salt')\n", 215 | "TRN_PATH = PATH/'train'\n", 216 | "TEST_PATH = PATH/'test'\n", 217 | "DISK_PATH = 'drive/tgs_salt'\n", 218 | "os.makedirs(DISK_PATH, exist_ok=True)\n", 219 | "os.makedirs(PATH, exist_ok=True)\n", 220 | "os.makedirs(TRN_PATH, exist_ok=True)\n", 221 | "os.makedirs(TEST_PATH, exist_ok=True)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 0, 227 | "metadata": { 228 | "colab": {}, 229 | "colab_type": "code", 230 | "id": "W8sZ5U5xp4kO" 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "from albumentations import (\n", 235 | " Resize,\n", 236 | " HorizontalFlip,\n", 237 | " VerticalFlip,\n", 238 | " Compose,\n", 239 | " RandomRotate90,\n", 240 | " RandomSizedCrop,\n", 241 | " OneOf,\n", 242 | " RandomSizedCrop,\n", 243 | " Rotate,\n", 244 | " RandomContrast,\n", 245 | " RandomGamma,\n", 246 | " ElasticTransform,\n", 247 | " GridDistortion, \n", 248 | " OpticalDistortion,\n", 249 | " RandomBrightness\n", 250 | ") " 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": { 257 | "colab": { 258 | "base_uri": "https://localhost:8080/", 259 | "height": 235 260 | }, 261 | "colab_type": "code", 262 | "id": "9CJQ66L9dELB", 263 | "outputId": "4ead758a-b897-4118-983e-19e2d2e43923" 264 | }, 265 | "outputs": [], 266 | "source": [ 267 | "!wget http://files.fast.ai/models/weights.tgz\n", 268 | "!tar xf weights.tgz -C /content/src/fastai/old/fastai/\n", 269 | "!ls /content/src/fastai/old/fastai/weights" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": { 276 | "colab": { 277 | "base_uri": "https://localhost:8080/", 278 | "height": 269 279 | }, 280 | "colab_type": "code", 281 | "id": "Q5YLtNMJzSyS", 282 | "outputId": "a5930376-e517-4a8a-cea3-4215152ab2c1" 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "!kaggle competitions download -c tgs-salt-identification-challenge -p {PATH}" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 0, 292 | "metadata": { 293 | "colab": {}, 294 | "colab_type": "code", 295 | "id": "E5j5p85o0DD6" 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "!cd {TRN_PATH} && unzip -q /content/{PATH}/'train.zip'\n", 300 | "!cd {TEST_PATH} && unzip -q /content/{PATH}/'test.zip'" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 0, 306 | "metadata": { 307 | "colab": { 308 | "base_uri": "https://localhost:8080/", 309 | "height": 50 310 | }, 311 | "colab_type": "code", 312 | "id": "Pmbd4oZ10mXH", 313 | "outputId": "8fc58192-70fc-45e0-962e-082360d4f164" 314 | }, 315 | "outputs": [ 316 | { 317 | "name": "stdout", 318 | "output_type": "stream", 319 | "text": [ 320 | "images\tmasks\n", 321 | "images\n" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "!ls {TRN_PATH}\n", 327 | "!ls {TEST_PATH}" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": { 333 | "colab_type": "text", 334 | "id": "Dz4H7ksT0tMo" 335 | }, 336 | "source": [ 337 | "### EDA" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 0, 343 | "metadata": { 344 | "colab": {}, 345 | "colab_type": "code", 346 | "id": "WpM9V6Pmzc3o" 347 | }, 348 | "outputs": [], 349 | "source": [ 350 | "depth = pd.read_csv(PATH/'depths.csv')\n", 351 | "train = pd.read_csv(PATH/'train.csv')\n", 352 | "submission = pd.read_csv(PATH/'sample_submission.csv')" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 0, 358 | "metadata": { 359 | "colab": {}, 360 | "colab_type": "code", 361 | "id": "U25DaYEF1L4U" 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "submission.head()" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 0, 371 | "metadata": { 372 | "colab": {}, 373 | "colab_type": "code", 374 | "id": "ZUNgd4s50-dd" 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "depth.head()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 0, 384 | "metadata": { 385 | "colab": {}, 386 | "colab_type": "code", 387 | "id": "NjEG1A6F5V9k" 388 | }, 389 | "outputs": [], 390 | "source": [ 391 | "depth.set_index('id').z.to_dict()" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 0, 397 | "metadata": { 398 | "colab": {}, 399 | "colab_type": "code", 400 | "id": "w4TItUxR5ND0" 401 | }, 402 | "outputs": [], 403 | "source": [ 404 | "depth.z.to_dict()" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 0, 410 | "metadata": { 411 | "colab": {}, 412 | "colab_type": "code", 413 | "id": "Y6-sKto01EPu" 414 | }, 415 | "outputs": [], 416 | "source": [ 417 | "train.head()" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 0, 423 | "metadata": { 424 | "colab": {}, 425 | "colab_type": "code", 426 | "id": "2AgjhosP1G2s" 427 | }, 428 | "outputs": [], 429 | "source": [ 430 | "!ls {TRN_PATH}/images | head -10" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 0, 436 | "metadata": { 437 | "colab": {}, 438 | "colab_type": "code", 439 | "id": "fFXbh_0F1Awx" 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "IMG_ID = '008a50a2ec'" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 0, 449 | "metadata": { 450 | "colab": {}, 451 | "colab_type": "code", 452 | "id": "3PLLX_oV1i9Z" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "Image.open(TRN_IMG_PATH/f'{IMG_ID}.png').resize((32, 32))" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 0, 462 | "metadata": { 463 | "colab": {}, 464 | "colab_type": "code", 465 | "id": "pjCMbgZk2Eu4" 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "Image.open(TRN_MASK_PATH/f'{IMG_ID}.png').resize((32, 32))" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 0, 475 | "metadata": { 476 | "colab": {}, 477 | "colab_type": "code", 478 | "id": "B-yOgm0x0QbQ" 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "rn34 = torchvision.models.resnet34()" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 0, 488 | "metadata": { 489 | "colab": {}, 490 | "colab_type": "code", 491 | "id": "GRmr-npg3JpT" 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "rn34" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 0, 501 | "metadata": { 502 | "colab": {}, 503 | "colab_type": "code", 504 | "id": "5SmuAbQh0Qfz" 505 | }, 506 | "outputs": [], 507 | "source": [ 508 | "rn34.layer2" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 0, 514 | "metadata": { 515 | "colab": {}, 516 | "colab_type": "code", 517 | "id": "eLYRkoraNbfi" 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "m_base = get_base_model(arch, cut, True)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": 0, 527 | "metadata": { 528 | "colab": {}, 529 | "colab_type": "code", 530 | "id": "QlUi3xT5jlxA" 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "m_base[4]" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 0, 540 | "metadata": { 541 | "colab": {}, 542 | "colab_type": "code", 543 | "id": "teVt8T7aNTnE" 544 | }, 545 | "outputs": [], 546 | "source": [ 547 | "pretrained_state_dict = torch.load('/root/.torch/models/resnet34-333f7ec4.pth')" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 0, 553 | "metadata": { 554 | "colab": {}, 555 | "colab_type": "code", 556 | "id": "qDY9tpPZNT0R" 557 | }, 558 | "outputs": [], 559 | "source": [ 560 | "pretrained_state_dict.keys()" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 0, 566 | "metadata": { 567 | "colab": {}, 568 | "colab_type": "code", 569 | "id": "prL_3f_JN4F2" 570 | }, 571 | "outputs": [], 572 | "source": [ 573 | "arch()" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": { 579 | "colab_type": "text", 580 | "id": "fkL9G_kG07dN" 581 | }, 582 | "source": [ 583 | "### Dataset" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 0, 589 | "metadata": { 590 | "colab": {}, 591 | "colab_type": "code", 592 | "id": "eG_MV_FOOfh9" 593 | }, 594 | "outputs": [], 595 | "source": [ 596 | "TRN_DN = 'train/images'\n", 597 | "MASK_DN = 'train/masks'\n", 598 | "TEST_DN = 'test/images'\n", 599 | "TRN_PATH = PATH/TRN_DN\n", 600 | "MASK_PATH = PATH/MASK_DN" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 0, 606 | "metadata": { 607 | "colab": {}, 608 | "colab_type": "code", 609 | "id": "5Sj0FM7X_5gC" 610 | }, 611 | "outputs": [], 612 | "source": [ 613 | "today = datetime.datetime.now().strftime('%Y%m%d')\n", 614 | "\n", 615 | "def save_model_weights(path=DISK_PATH):\n", 616 | " targ_dir = f'{path}/{today}/models'\n", 617 | " os.makedirs(targ_dir, exist_ok=True)\n", 618 | " for f in list(PATH.glob('**/*.h5')):\n", 619 | " shutil.copy('/'.join(f.parts), targ_dir)\n", 620 | " \n", 621 | "def load_model_weights(path=PATH, disk_path=DISK_PATH, subdir=None):\n", 622 | " if subdir == None: subdir = today\n", 623 | " os.makedirs(f'{path}/models', exist_ok=True)\n", 624 | " src_dir = f'{disk_path}/{subdir}/models'\n", 625 | " targ_dir = f'{path}/models'\n", 626 | " for f in os.listdir(src_dir):\n", 627 | " shutil.copy(os.path.join(src_dir, f), targ_dir)\n", 628 | " \n", 629 | "def save_dataset(path=DISK_PATH):\n", 630 | " targ_dir = f'{path}/{today}'\n", 631 | " os.makedirs(targ_dir, exist_ok=True)\n", 632 | " shutil.copy('test_x.pkl', targ_dir)\n", 633 | " shutil.copy('train_folds.pkl', targ_dir)\n", 634 | " shutil.copy('val_folds.pkl', targ_dir)\n", 635 | " \n", 636 | "def load_dataset(path=PATH, disk_path=DISK_PATH, time_stamp=None):\n", 637 | " if time_stamp == None: time_stamp = today\n", 638 | " src_dir = f'{disk_path}/{time_stamp}'\n", 639 | "# assert os.path.exists(src_dir) == True\n", 640 | " shutil.copy(f'{src_dir}/test_x.pkl', '.')\n", 641 | " shutil.copy(f'{src_dir}/train_folds.pkl', '.')\n", 642 | " shutil.copy(f'{src_dir}/val_folds.pkl', '.')" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 0, 648 | "metadata": { 649 | "colab": {}, 650 | "colab_type": "code", 651 | "id": "kVd-dF4ZO8s7" 652 | }, 653 | "outputs": [], 654 | "source": [ 655 | "class DepthDataset(Dataset):\n", 656 | " def __init__(self,ds,dpth_dict):\n", 657 | " self.dpth = dpth_dict\n", 658 | " self.ds = ds\n", 659 | " \n", 660 | " def __getitem__(self,i):\n", 661 | " val = self.ds[i]\n", 662 | " return val[0],self.dpth[self.ds.fnames[i].split('/')[-1][:-4]],val[1]\n", 663 | " \n", 664 | "class DepthDatasetV2(Dataset):\n", 665 | " def __init__(self,ds,dpth_dict):\n", 666 | " self.dpth = dpth_dict\n", 667 | " self.ds = ds\n", 668 | " \n", 669 | " def __getitem__(self,i):\n", 670 | " val = self.ds[i]\n", 671 | " return val[0],self.dpth[self.ds.fnames[i].name[:-4]],val[1]\n", 672 | " \n", 673 | "class MatchedFilesDataset(FilesDataset):\n", 674 | " def __init__(self, fnames, y, transform, path):\n", 675 | " self.y=y\n", 676 | " assert(len(fnames)==len(y))\n", 677 | " super().__init__(fnames, transform, path)\n", 678 | " \n", 679 | " def get_x(self, i): \n", 680 | " return open_image(os.path.join(self.path, self.fnames[i]))\n", 681 | " \n", 682 | " def get_y(self, i):\n", 683 | " return open_image(os.path.join(str(self.path), str(self.y[i])))\n", 684 | "\n", 685 | " def get_c(self): return 0\n", 686 | " \n", 687 | "class TestFilesDataset(FilesDataset):\n", 688 | " def __init__(self, fnames, y, transform,flip, path):\n", 689 | " self.y=y\n", 690 | " self.flip = flip\n", 691 | " super().__init__(fnames, transform, path)\n", 692 | " \n", 693 | " def get_x(self, i): \n", 694 | " im = open_image(os.path.join(self.path, self.fnames[i]))\n", 695 | " return np.fliplr(im) if self.flip else im\n", 696 | " \n", 697 | " def get_y(self, i):\n", 698 | " im = open_image(os.path.join(str(self.path), str(self.y[i])))\n", 699 | " return np.fliplr(im) if self.flip else im\n", 700 | " def get_c(self): return 0" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 0, 706 | "metadata": { 707 | "colab": {}, 708 | "colab_type": "code", 709 | "id": "SYwzDb6CPqEw" 710 | }, 711 | "outputs": [], 712 | "source": [ 713 | "# trn = pd.read_csv(PATH/'train.csv')\n", 714 | "dpth = pd.read_csv(PATH/'depths.csv')\n", 715 | "c = dpth.set_index('id')\n", 716 | "dpth_dict = c['z'].to_dict()" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 0, 722 | "metadata": { 723 | "colab": {}, 724 | "colab_type": "code", 725 | "id": "4mlccVfddfRm" 726 | }, 727 | "outputs": [], 728 | "source": [ 729 | "kf = 5\n", 730 | "# kf = 10" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": 0, 736 | "metadata": { 737 | "colab": {}, 738 | "colab_type": "code", 739 | "id": "1qSG--oWPJNS" 740 | }, 741 | "outputs": [], 742 | "source": [ 743 | "x_names = np.array([Path('/'.join([TRN_DN, f])) for f in list(os.listdir(TRN_PATH))])\n", 744 | "y_names = np.array([Path('/'.join([MASK_DN, f])) for f in list(os.listdir(TRN_PATH))])\n", 745 | "test_x = np.array([Path('/'.join([TEST_DN, f.name])) for f in list((PATH/TEST_DN).iterdir())])\n", 746 | "f_name = [o.name for o in x_names]\n", 747 | "\n", 748 | "c = dpth.set_index('id')\n", 749 | "dpth_dict = c['z'].to_dict()\n", 750 | "\n", 751 | "kfold = KFold(n_splits=kf, shuffle=True, random_state=42)\n", 752 | "\n", 753 | "train_folds = []\n", 754 | "val_folds = []\n", 755 | "for idxs in kfold.split(f_name):\n", 756 | " train_folds.append([f_name[idx] for idx in idxs[0]])\n", 757 | " val_folds.append([f_name[idx] for idx in idxs[1]])\n", 758 | "\n", 759 | "with open('train_folds.pkl', 'wb') as fp:\n", 760 | " pickle.dump(train_folds, fp)\n", 761 | "with open('val_folds.pkl', 'wb') as fp:\n", 762 | " pickle.dump(val_folds, fp)\n", 763 | "with open('test_x.pkl', 'wb') as fp:\n", 764 | " pickle.dump(test_x, fp)\n", 765 | "\n", 766 | "save_dataset()" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": 0, 772 | "metadata": { 773 | "colab": {}, 774 | "colab_type": "code", 775 | "id": "XgpkySKgYiMw" 776 | }, 777 | "outputs": [], 778 | "source": [ 779 | "load_dataset(time_stamp='20181008') # 5 folds\n", 780 | "# load_dataset(time_stamp='20181016') # 10 folds" 781 | ] 782 | }, 783 | { 784 | "cell_type": "code", 785 | "execution_count": 0, 786 | "metadata": { 787 | "colab": {}, 788 | "colab_type": "code", 789 | "id": "E9EHO3w5SZVD" 790 | }, 791 | "outputs": [], 792 | "source": [ 793 | "train_folds = pickle.load(open('train_folds.pkl',mode='rb'))\n", 794 | "val_folds = pickle.load(open('val_folds.pkl',mode='rb'))\n", 795 | "test_x = pickle.load(open('test_x.pkl',mode='rb'))" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 0, 801 | "metadata": { 802 | "colab": {}, 803 | "colab_type": "code", 804 | "id": "ys7cmvFiSDy3" 805 | }, 806 | "outputs": [], 807 | "source": [ 808 | "def open_image(fn):\n", 809 | " \"\"\" Opens an image using OpenCV given the file path.\n", 810 | "\n", 811 | " Arguments:\n", 812 | " fn: the file path of the image\n", 813 | "\n", 814 | " Returns:\n", 815 | " The image in RGB format as numpy array of floats normalized to range between 0.0 - 1.0\n", 816 | " \"\"\"\n", 817 | " flags = cv2.IMREAD_UNCHANGED+cv2.IMREAD_ANYDEPTH+cv2.IMREAD_ANYCOLOR\n", 818 | " if not os.path.exists(fn) and not str(fn).startswith(\"http\"):\n", 819 | " raise OSError('No such file or directory: {}'.format(fn))\n", 820 | " elif os.path.isdir(fn) and not str(fn).startswith(\"http\"):\n", 821 | " raise OSError('Is a directory: {}'.format(fn))\n", 822 | "# elif isdicom(fn):\n", 823 | "# slice = pydicom.read_file(fn)\n", 824 | "# if slice.PhotometricInterpretation.startswith('MONOCHROME'):\n", 825 | "# # Make a fake RGB image\n", 826 | "# im = np.stack([slice.pixel_array]*3,-1)\n", 827 | "# return im / ((1 << slice.BitsStored)-1)\n", 828 | "# else:\n", 829 | "# # No support for RGB yet, as it involves various color spaces.\n", 830 | "# # It shouldn't be too difficult to add though, if needed.\n", 831 | "# raise OSError('Unsupported DICOM image with PhotometricInterpretation=={}'.format(slice.PhotometricInterpretation))\n", 832 | " else:\n", 833 | " #res = np.array(Image.open(fn), dtype=np.float32)/255\n", 834 | " #if len(res.shape)==2: res = np.repeat(res[...,None],3,2)\n", 835 | " #return res\n", 836 | " try:\n", 837 | " if str(fn).startswith(\"http\"):\n", 838 | " req = urllib.urlopen(str(fn))\n", 839 | " image = np.asarray(bytearray(req.read()), dtype=\"uint8\")\n", 840 | " im = cv2.imdecode(image, flags).astype(np.float32)/255\n", 841 | " else:\n", 842 | " im = cv2.imread(str(fn), flags).astype(np.float32)/255\n", 843 | " if im is None: raise OSError(f'File not recognized by opencv: {fn}')\n", 844 | " return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)\n", 845 | " except Exception as e:\n", 846 | " raise OSError('Error handling image at: {}'.format(fn)) from e" 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": { 852 | "colab_type": "text", 853 | "id": "0vkz5LH-NF6i" 854 | }, 855 | "source": [ 856 | "### Model" 857 | ] 858 | }, 859 | { 860 | "cell_type": "code", 861 | "execution_count": 0, 862 | "metadata": { 863 | "colab": {}, 864 | "colab_type": "code", 865 | "id": "3Kl7WMiC9aCK" 866 | }, 867 | "outputs": [], 868 | "source": [ 869 | "class SCSEModule(nn.Module):\n", 870 | " def __init__(self, ch, re=16):\n", 871 | " super().__init__()\n", 872 | " self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1),\n", 873 | " nn.Conv2d(ch,ch//re,1),\n", 874 | " nn.ReLU(inplace=True),\n", 875 | " nn.Conv2d(ch//re,ch,1),\n", 876 | " nn.Sigmoid())\n", 877 | " self.sSE = nn.Sequential(nn.Conv2d(ch,ch,1),\n", 878 | " nn.Sigmoid())\n", 879 | "\n", 880 | " def forward(self, x):\n", 881 | " return x * self.cSE(x) + x * self.sSE(x)\n", 882 | "\n", 883 | "\n", 884 | "class SaveFeatures():\n", 885 | " features=None\n", 886 | " def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)\n", 887 | " def hook_fn(self, module, input, output): self.features = output\n", 888 | " def remove(self): self.hook.remove()\n", 889 | "\n", 890 | "\n", 891 | "class UnetBlock(nn.Module):\n", 892 | " def __init__(self, up_in, x_in, n_out):\n", 893 | " super().__init__()\n", 894 | " up_out = x_out = n_out//2\n", 895 | " self.x_conv = nn.Conv2d(x_in, x_out, 1)\n", 896 | " self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)\n", 897 | " self.x_scse = SCSEModule(x_in)\n", 898 | " self.up_scse = SCSEModule(up_in)\n", 899 | " self.bn = nn.BatchNorm2d(n_out)\n", 900 | " \n", 901 | " def forward(self, up_p, x_p):\n", 902 | " up_p = self.tr_conv(self.up_scse(up_p))\n", 903 | " x_p = self.x_conv(self.x_scse(x_p))\n", 904 | " cat_p = torch.cat([up_p,x_p], dim=1)\n", 905 | " return self.bn(F.relu(cat_p))\n", 906 | "\n", 907 | "\n", 908 | "class HyperColumn(nn.Module):\n", 909 | " def __init__(self, nf):\n", 910 | " super().__init__()\n", 911 | " self.conv1 = nn.Conv2d(nf, nf // 16, 3, 1, 1)\n", 912 | " self.conv2 = nn.Conv2d(nf // 16, 1, 1)\n", 913 | " self.bn = nn.BatchNorm2d(nf // 16)\n", 914 | " self.relu = nn.ReLU(inplace=True)\n", 915 | " \n", 916 | " def forward(self, x):\n", 917 | " return self.conv2(self.bn(self.relu(self.conv1(x))))\n", 918 | "\n", 919 | "\n", 920 | "class Unet34(nn.Module):\n", 921 | " def __init__(self, rn):\n", 922 | " super().__init__()\n", 923 | " self.rn = rn\n", 924 | " self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]\n", 925 | " self.up1 = UnetBlock(512,256,128)\n", 926 | " self.up2 = UnetBlock(128,128,128)\n", 927 | " self.up3 = UnetBlock(128,64,128)\n", 928 | " self.up4 = UnetBlock(128,64,128)\n", 929 | "# self.up5 = nn.ConvTranspose2d(128, 1, 2, stride=2)\n", 930 | " self.up5 = nn.ConvTranspose2d(128, 16, 2, stride=2)\n", 931 | " self.dc = nn.Conv2d(128, 16, 1)\n", 932 | " self.drop = nn.Dropout2d(p=0.5, inplace=True)\n", 933 | " self.logit = nn.Sequential(\n", 934 | " nn.ReLU(inplace=True),\n", 935 | " nn.Conv2d(80, 1, 1)\n", 936 | " )\n", 937 | " \n", 938 | " def forward(self,img,depth):\n", 939 | " x = F.relu(self.rn(img))\n", 940 | " x1 = self.up1(x, self.sfs[3].features)\n", 941 | " x2 = self.up2(x1, self.sfs[2].features)\n", 942 | " x3 = self.up3(x2, self.sfs[1].features)\n", 943 | " x4 = self.up4(x3, self.sfs[0].features)\n", 944 | "# x = self.up5(x4)\n", 945 | " x5 = self.up5(x4)\n", 946 | " x = torch.cat((x5,\n", 947 | " F.interpolate(self.dc(x4),scale_factor=2,align_corners=False,mode='bilinear'),\n", 948 | " F.interpolate(self.dc(x3),scale_factor=4,align_corners=False,mode='bilinear'),\n", 949 | " F.interpolate(self.dc(x2),scale_factor=8,align_corners=False,mode='bilinear'),\n", 950 | " F.interpolate(self.dc(x1),scale_factor=16,align_corners=False,mode='bilinear')),1)\n", 951 | " x = self.logit(self.drop(x))\n", 952 | " return x[:,0]\n", 953 | " \n", 954 | " def close(self):\n", 955 | " for sf in self.sfs: sf.remove()\n", 956 | "\n", 957 | "\n", 958 | "class UnetModel():\n", 959 | " def __init__(self,model,lr_cut,name='unet'):\n", 960 | " self.model,self.name = model,name\n", 961 | " self.lr_cut = lr_cut\n", 962 | "\n", 963 | " def get_layer_groups(self, precompute):\n", 964 | " lgs = list(split_by_idxs(children(self.model.rn), [self.lr_cut]))\n", 965 | " return lgs + [children(self.model)[1:]]" 966 | ] 967 | }, 968 | { 969 | "cell_type": "code", 970 | "execution_count": 0, 971 | "metadata": { 972 | "colab": {}, 973 | "colab_type": "code", 974 | "id": "8Me7dRQVzdKB" 975 | }, 976 | "outputs": [], 977 | "source": [ 978 | "def dechannels(nf):\n", 979 | " return nn.Sequential(\n", 980 | " nn.Conv2d(nf, 32, 1),\n", 981 | " nn.ReLU(inplace=True)\n", 982 | " )\n", 983 | " \n", 984 | "\n", 985 | "class Unet50(nn.Module):\n", 986 | " def __init__(self, rn):\n", 987 | " super().__init__()\n", 988 | " self.rn = rn\n", 989 | " self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]\n", 990 | " self.up1 = UnetBlock(2048, 1024, 512)\n", 991 | " self.up2 = UnetBlock(512, 512, 256)\n", 992 | " self.up3 = UnetBlock(256, 256, 128)\n", 993 | " self.up4 = UnetBlock(128, 64, 128)\n", 994 | " self.up5 = nn.ConvTranspose2d(128, 1, 2, stride=2)\n", 995 | "# self.up5 = nn.ConvTranspose2d(128, 128, 2, stride=2)\n", 996 | "# self.d1 = dechannels(512)\n", 997 | "# self.d2 = dechannels(256)\n", 998 | "# self.d3 = dechannels(128)\n", 999 | "# self.d4 = dechannels(128)\n", 1000 | "# self.d5 = dechannels(128)\n", 1001 | "# self.drop = nn.Dropout2d(p=0.5, inplace=True)\n", 1002 | "# self.logit = nn.Conv2d(160, 1, 1)\n", 1003 | "# self.logit = nn.Sequential(\n", 1004 | "# nn.Conv2d(1152, 128, 3, 1, 1),\n", 1005 | "# nn.ReLU(inplace=True),\n", 1006 | "# nn.Conv2d(128, 128, 1)\n", 1007 | "# )\n", 1008 | " \n", 1009 | " def forward(self,img,depth):\n", 1010 | " x = F.relu(self.rn(img))\n", 1011 | " x1 = self.up1(x, self.sfs[3].features)\n", 1012 | " x2 = self.up2(x1, self.sfs[2].features)\n", 1013 | " x3 = self.up3(x2, self.sfs[1].features)\n", 1014 | " x4 = self.up4(x3, self.sfs[0].features)\n", 1015 | " x = self.up5(x4)\n", 1016 | "# x5 = self.up5(x4)\n", 1017 | "# # x = torch.cat((x5,\n", 1018 | "# # F.interpolate(x4,scale_factor=2,align_corners=False,mode='bilinear'),\n", 1019 | "# # F.interpolate(x3,scale_factor=4,align_corners=False,mode='bilinear'),\n", 1020 | "# # F.interpolate(x2,scale_factor=8,align_corners=False,mode='bilinear'),\n", 1021 | "# # F.interpolate(x1,scale_factor=16,align_corners=False,mode='bilinear')),1)\n", 1022 | "# x = torch.cat((self.d5(x5),\n", 1023 | "# F.interpolate(self.d4(x4),scale_factor=2,align_corners=False,mode='bilinear'),\n", 1024 | "# F.interpolate(self.d3(x3),scale_factor=4,align_corners=False,mode='bilinear'),\n", 1025 | "# F.interpolate(self.d2(x2),scale_factor=8,align_corners=False,mode='bilinear'),\n", 1026 | "# F.interpolate(self.d1(x1),scale_factor=16,align_corners=False,mode='bilinear')),1)\n", 1027 | "# x = self.logit(self.drop(x))\n", 1028 | " return x[:,0]\n", 1029 | " \n", 1030 | " def close(self):\n", 1031 | " for sf in self.sfs: sf.remove()" 1032 | ] 1033 | }, 1034 | { 1035 | "cell_type": "code", 1036 | "execution_count": 0, 1037 | "metadata": { 1038 | "colab": {}, 1039 | "colab_type": "code", 1040 | "id": "6KHDtXQ4WESz" 1041 | }, 1042 | "outputs": [], 1043 | "source": [ 1044 | "def get_tgs_model():\n", 1045 | "# f = resnet34\n", 1046 | " f = resnext50\n", 1047 | " cut,lr_cut = model_meta[f]\n", 1048 | " m_base = get_base(f,cut)\n", 1049 | "# m = to_gpu(Unet34(m_base))\n", 1050 | " m = to_gpu(Unet50(m_base))\n", 1051 | " models = UnetModel(m,lr_cut)\n", 1052 | " learn = ConvLearner(md, models)\n", 1053 | " return learn\n", 1054 | "\n", 1055 | "def get_base(f,cut):\n", 1056 | " layers = cut_model(f(True), cut)\n", 1057 | " return nn.Sequential(*layers)" 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "code", 1062 | "execution_count": 0, 1063 | "metadata": { 1064 | "colab": { 1065 | "base_uri": "https://localhost:8080/", 1066 | "height": 34 1067 | }, 1068 | "colab_type": "code", 1069 | "id": "lWcwr6hJFFHU", 1070 | "outputId": "ddca2693-ffc7-4b8b-a8e9-967e76e8de35" 1071 | }, 1072 | "outputs": [ 1073 | { 1074 | "data": { 1075 | "text/plain": [ 1076 | "800" 1077 | ] 1078 | }, 1079 | "execution_count": 27, 1080 | "metadata": { 1081 | "tags": [] 1082 | }, 1083 | "output_type": "execute_result" 1084 | } 1085 | ], 1086 | "source": [ 1087 | "gc.collect()" 1088 | ] 1089 | }, 1090 | { 1091 | "cell_type": "markdown", 1092 | "metadata": { 1093 | "colab_type": "text", 1094 | "id": "PIwVeVA9xv04" 1095 | }, 1096 | "source": [ 1097 | "### Training" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": 0, 1103 | "metadata": { 1104 | "colab": {}, 1105 | "colab_type": "code", 1106 | "id": "lZcBehwhkNb7" 1107 | }, 1108 | "outputs": [], 1109 | "source": [ 1110 | "load_model_weights(subdir='tags/score_8.41')\n", 1111 | "# load_model_weights(subdir='20181016')\n", 1112 | "# load_model_weights(subdir='20181018')" 1113 | ] 1114 | }, 1115 | { 1116 | "cell_type": "code", 1117 | "execution_count": 0, 1118 | "metadata": { 1119 | "colab": {}, 1120 | "colab_type": "code", 1121 | "id": "srDUaJW0fhTn" 1122 | }, 1123 | "outputs": [], 1124 | "source": [ 1125 | "def lovasz_hinge_flat(logits, labels):\n", 1126 | " if len(labels) == 0:\n", 1127 | " # only void pixels, the gradients should be 0\n", 1128 | " return logits.sum() * 0.\n", 1129 | " signs = 2. * labels.float() - 1.\n", 1130 | " errors = (1. - logits * Variable(signs.data))\n", 1131 | " errors_sorted, perm = torch.sort(errors, dim=0, descending=True)\n", 1132 | " perm = perm.data\n", 1133 | " gt_sorted = labels[perm]\n", 1134 | " grad = lovasz_grad(gt_sorted)\n", 1135 | " loss = torch.dot(F.elu(errors_sorted) +1, Variable(grad.data))\n", 1136 | " return loss\n", 1137 | "\n", 1138 | " \n", 1139 | "def lovasz_hinge(logits, labels, per_image=True, ignore=None):\n", 1140 | " if per_image:\n", 1141 | " loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))\n", 1142 | " for log, lab in zip(logits, labels))\n", 1143 | " else:\n", 1144 | " loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))\n", 1145 | " return loss" 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "execution_count": 0, 1151 | "metadata": { 1152 | "colab": {}, 1153 | "colab_type": "code", 1154 | "id": "haFc4sz-PIk3" 1155 | }, 1156 | "outputs": [], 1157 | "source": [ 1158 | "sz = 128\n", 1159 | "# sz = 192\n", 1160 | "\n", 1161 | "aug = Compose([\n", 1162 | " OneOf([RandomSizedCrop(min_max_height=(50, 101), height=sz, width=sz, p=0.5),\n", 1163 | " Resize(sz, sz, p=0.5)], p=1),\n", 1164 | " OneOf([VerticalFlip(p=0.5), RandomRotate90(p=0.5), HorizontalFlip(p=0.5)], p=0.5),\n", 1165 | "# OneOf([\n", 1166 | "# ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),\n", 1167 | "# GridDistortion(p=0.5),\n", 1168 | "# OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5) \n", 1169 | "# ], p=0.8),\n", 1170 | "# RandomContrast(limit=0.1, p=0.5),\n", 1171 | " RandomBrightness(limit=0.1, p=0.5),\n", 1172 | "# RandomGamma(p=0.5)\n", 1173 | "])" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": 0, 1179 | "metadata": { 1180 | "colab": {}, 1181 | "colab_type": "code", 1182 | "id": "AZB93ZfeO1Kz" 1183 | }, 1184 | "outputs": [], 1185 | "source": [ 1186 | "class Transforms():\n", 1187 | " def __init__(self, sz, tfms, normalizer, denorm, crop_type=CropType.CENTER,\n", 1188 | " tfm_y=TfmType.NO, sz_y=None, is_trn=True):\n", 1189 | " if sz_y is None: sz_y = sz\n", 1190 | " self.sz,self.denorm,self.norm,self.sz_y = sz,denorm,normalizer,sz_y\n", 1191 | " crop_tfm = crop_fn_lu[crop_type](sz, tfm_y, sz_y)\n", 1192 | " self.tfms = tfms\n", 1193 | " self.tfms.append(crop_tfm)\n", 1194 | " if normalizer is not None: self.tfms.append(normalizer)\n", 1195 | " self.tfms.append(ChannelOrder(tfm_y))\n", 1196 | " self.channel_order = ChannelOrder(tfm_y)\n", 1197 | " self.trn_tfms = aug if is_trn == True else None\n", 1198 | "\n", 1199 | "# def __call__(self, im, y=None): return compose(im, y, self.trn_tfms, self.tfms, self.norm)\n", 1200 | " def __call__(self, im, y=None): return self.compose(im, y)\n", 1201 | " def __repr__(self): return str(self.trn_tfms) if self.trn_tfms else str(self.tfms)\n", 1202 | " \n", 1203 | " def compose(self, im, y):\n", 1204 | " if self.trn_tfms:\n", 1205 | " augmented = self.trn_tfms(image=im, mask=y)\n", 1206 | " im, y = self.norm(augmented['image'], augmented['mask'])\n", 1207 | " return self.channel_order(im, y)\n", 1208 | " else:\n", 1209 | " for fn in self.tfms:\n", 1210 | " im, y =fn(im, y)\n", 1211 | " return im if y is None else (im, y)\n", 1212 | "\n", 1213 | "\n", 1214 | "def image_gen(normalizer, denorm, sz, tfms=None, max_zoom=None, pad=0, crop_type=None,\n", 1215 | " tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, scale=None, is_trn=True):\n", 1216 | " if tfm_y is None: tfm_y=TfmType.NO\n", 1217 | " if tfms is None: tfms=[]\n", 1218 | " elif not isinstance(tfms, collections.Iterable): tfms=[tfms]\n", 1219 | " if sz_y is None: sz_y = sz\n", 1220 | " if scale is None:\n", 1221 | " scale = [RandomScale(sz, max_zoom, tfm_y=tfm_y, sz_y=sz_y) if max_zoom is not None\n", 1222 | " else Scale(sz, tfm_y, sz_y=sz_y)]\n", 1223 | " elif not is_listy(scale): scale = [scale]\n", 1224 | " if pad: scale.append(AddPadding(pad, mode=pad_mode))\n", 1225 | " if crop_type!=CropType.GOOGLENET: tfms=scale+tfms\n", 1226 | " return Transforms(sz, tfms, normalizer, denorm, crop_type,\n", 1227 | " tfm_y=tfm_y, sz_y=sz_y, is_trn=is_trn)\n", 1228 | " \n", 1229 | " \n", 1230 | "def tfms_from_stats(stats, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,\n", 1231 | " tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):\n", 1232 | " if aug_tfms is None: aug_tfms=[]\n", 1233 | " tfm_norm = Normalize(*stats, tfm_y=tfm_y if norm_y else TfmType.NO) if stats is not None else None\n", 1234 | " tfm_denorm = Denormalize(*stats) if stats is not None else None\n", 1235 | " val_crop = CropType.CENTER if crop_type in (CropType.RANDOM,CropType.GOOGLENET) else crop_type\n", 1236 | " val_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=val_crop,\n", 1237 | " tfm_y=tfm_y, sz_y=sz_y, scale=scale, is_trn=False)\n", 1238 | " trn_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=crop_type,\n", 1239 | " tfm_y=tfm_y, sz_y=sz_y, tfms=aug_tfms, max_zoom=max_zoom, pad_mode=pad_mode, scale=scale, is_trn=True)\n", 1240 | " return trn_tfm, val_tfm\n", 1241 | " \n", 1242 | " \n", 1243 | "def tfms_from_model(f_model, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,\n", 1244 | " tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):\n", 1245 | " stats = inception_stats if f_model in inception_models else imagenet_stats\n", 1246 | " return tfms_from_stats(stats, sz, aug_tfms, max_zoom=max_zoom, pad=pad, crop_type=crop_type,\n", 1247 | " tfm_y=tfm_y, sz_y=sz_y, pad_mode=pad_mode, norm_y=norm_y, scale=scale)" 1248 | ] 1249 | }, 1250 | { 1251 | "cell_type": "markdown", 1252 | "metadata": { 1253 | "colab_type": "text", 1254 | "id": "b53Co8tFUCZh" 1255 | }, 1256 | "source": [ 1257 | "### Training: resnext50 + scSE + hypercolumns" 1258 | ] 1259 | }, 1260 | { 1261 | "cell_type": "code", 1262 | "execution_count": 0, 1263 | "metadata": { 1264 | "colab": { 1265 | "base_uri": "https://localhost:8080/", 1266 | "height": 1442 1267 | }, 1268 | "colab_type": "code", 1269 | "id": "vaZO9Kz44xYu", 1270 | "outputId": "98d1552c-a4a1-41b4-85b9-cfac455fb289" 1271 | }, 1272 | "outputs": [ 1273 | { 1274 | "name": "stdout", 1275 | "output_type": "stream", 1276 | "text": [ 1277 | "[128, 64]\n", 1278 | "fold_id0\n", 1279 | "5_fold_resnext50_128_scse_hcol_0\n" 1280 | ] 1281 | }, 1282 | { 1283 | "data": { 1284 | "application/vnd.jupyter.widget-view+json": { 1285 | "model_id": "b582b36f65e04eb380f8bfa3d7a05a74", 1286 | "version_major": 2, 1287 | "version_minor": 0 1288 | }, 1289 | "text/plain": [ 1290 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" 1291 | ] 1292 | }, 1293 | "metadata": { 1294 | "tags": [] 1295 | }, 1296 | "output_type": "display_data" 1297 | }, 1298 | { 1299 | "name": "stdout", 1300 | "output_type": "stream", 1301 | "text": [ 1302 | "epoch trn_loss val_loss my_eval \n", 1303 | " 0 0.296024 0.229948 0.6635 \n", 1304 | " 1 0.226501 0.167776 0.7175 \n", 1305 | " 2 0.197844 0.174305 0.70425 \n", 1306 | " 3 0.178601 0.157858 0.734 \n", 1307 | " 4 0.159284 0.141243 0.758375 \n", 1308 | " 5 0.156893 0.144844 0.7555 \n", 1309 | " 6 0.139704 0.14005 0.76825 \n", 1310 | " 7 0.129483 0.135873 0.782125 \n", 1311 | " 8 0.117466 0.133637 0.77575 \n", 1312 | " 9 0.107742 0.148499 0.7775 \n", 1313 | "\n", 1314 | "0.7821250000000001\n", 1315 | "fold_id1\n", 1316 | "5_fold_resnext50_128_scse_hcol_1\n" 1317 | ] 1318 | }, 1319 | { 1320 | "data": { 1321 | "application/vnd.jupyter.widget-view+json": { 1322 | "model_id": "e95165aba22348e1a9e374735a41960f", 1323 | "version_major": 2, 1324 | "version_minor": 0 1325 | }, 1326 | "text/plain": [ 1327 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" 1328 | ] 1329 | }, 1330 | "metadata": { 1331 | "tags": [] 1332 | }, 1333 | "output_type": "display_data" 1334 | }, 1335 | { 1336 | "name": "stdout", 1337 | "output_type": "stream", 1338 | "text": [ 1339 | "epoch trn_loss val_loss my_eval \n", 1340 | " 0 0.298795 0.330592 0.577375 \n", 1341 | " 1 0.232322 0.154958 0.684125 \n", 1342 | " 2 0.202833 0.152708 0.708 \n", 1343 | " 3 0.182856 0.165284 0.688875 \n", 1344 | " 4 0.160326 0.142061 0.73875 \n", 1345 | " 5 0.149245 0.133409 0.737375 \n", 1346 | " 6 0.142726 0.129582 0.747375 \n", 1347 | " 7 0.136257 0.128909 0.754375 \n", 1348 | " 8 0.121961 0.14172 0.746 \n", 1349 | " 9 0.111444 0.138969 0.755375 \n", 1350 | "\n", 1351 | "0.7553750000000001\n", 1352 | "fold_id2\n", 1353 | "5_fold_resnext50_128_scse_hcol_2\n" 1354 | ] 1355 | }, 1356 | { 1357 | "data": { 1358 | "application/vnd.jupyter.widget-view+json": { 1359 | "model_id": "64c3fc40887d48a6923fa8811195da65", 1360 | "version_major": 2, 1361 | "version_minor": 0 1362 | }, 1363 | "text/plain": [ 1364 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" 1365 | ] 1366 | }, 1367 | "metadata": { 1368 | "tags": [] 1369 | }, 1370 | "output_type": "display_data" 1371 | }, 1372 | { 1373 | "name": "stdout", 1374 | "output_type": "stream", 1375 | "text": [ 1376 | "epoch trn_loss val_loss my_eval \n", 1377 | " 0 0.302976 0.230204 0.6545 \n", 1378 | " 1 0.236865 0.176783 0.70275 \n", 1379 | " 2 0.200898 0.158097 0.706625 \n", 1380 | " 3 0.182162 0.147802 0.752 \n", 1381 | " 4 0.1738 0.146175 0.73025 \n", 1382 | " 5 0.157935 0.135695 0.758875 \n", 1383 | " 6 0.148242 0.143257 0.76125 \n", 1384 | " 7 0.135907 0.130015 0.768625 \n", 1385 | " 8 0.125547 0.12942 0.761 \n", 1386 | " 9 0.114384 0.135982 0.76275 \n", 1387 | "\n", 1388 | "0.7686250000000001\n", 1389 | "fold_id3\n", 1390 | "5_fold_resnext50_128_scse_hcol_3\n" 1391 | ] 1392 | }, 1393 | { 1394 | "data": { 1395 | "application/vnd.jupyter.widget-view+json": { 1396 | "model_id": "3a2960908675441aa81622bf7f1373c0", 1397 | "version_major": 2, 1398 | "version_minor": 0 1399 | }, 1400 | "text/plain": [ 1401 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" 1402 | ] 1403 | }, 1404 | "metadata": { 1405 | "tags": [] 1406 | }, 1407 | "output_type": "display_data" 1408 | }, 1409 | { 1410 | "name": "stdout", 1411 | "output_type": "stream", 1412 | "text": [ 1413 | "epoch trn_loss val_loss my_eval \n", 1414 | " 0 0.298731 0.179404 0.649875 \n", 1415 | " 1 0.236641 0.150297 0.6965 \n", 1416 | " 2 0.208887 0.130067 0.725375 \n", 1417 | " 3 0.187371 0.116859 0.7545 \n", 1418 | " 4 0.170393 0.127697 0.749 \n", 1419 | " 5 0.160323 0.117837 0.75725 \n", 1420 | " 6 0.142183 0.105238 0.78125 \n", 1421 | " 7 0.135057 0.109958 0.776875 \n", 1422 | " 8 0.120477 0.110117 0.794 \n", 1423 | " 9 0.113762 0.101085 0.7865 \n", 1424 | "\n", 1425 | "0.794\n", 1426 | "fold_id4\n", 1427 | "5_fold_resnext50_128_scse_hcol_4\n" 1428 | ] 1429 | }, 1430 | { 1431 | "data": { 1432 | "application/vnd.jupyter.widget-view+json": { 1433 | "model_id": "3c65c3a78192409ea3cee73d3ad370ba", 1434 | "version_major": 2, 1435 | "version_minor": 0 1436 | }, 1437 | "text/plain": [ 1438 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" 1439 | ] 1440 | }, 1441 | "metadata": { 1442 | "tags": [] 1443 | }, 1444 | "output_type": "display_data" 1445 | }, 1446 | { 1447 | "name": "stdout", 1448 | "output_type": "stream", 1449 | "text": [ 1450 | "epoch trn_loss val_loss my_eval \n", 1451 | " 0 0.302227 0.204683 0.65575 \n", 1452 | " 1 0.229877 0.183582 0.71175 \n", 1453 | " 2 0.199228 0.171225 0.68475 \n", 1454 | " 3 0.192234 0.1548 0.734625 \n", 1455 | " 4 0.167416 0.140409 0.74525 \n", 1456 | " 5 0.149521 0.126696 0.743125 \n", 1457 | " 6 0.136639 0.126258 0.774875 \n", 1458 | " 7 0.125201 0.119398 0.774875 \n", 1459 | " 8 0.120503 0.138388 0.762875 \n", 1460 | " 9 0.111816 0.117212 0.782875 \n", 1461 | "\n", 1462 | "0.782875\n" 1463 | ] 1464 | } 1465 | ], 1466 | "source": [ 1467 | "model = 'resnext50_128_scse_hcol'\n", 1468 | "arch = resnext50\n", 1469 | "bst_acc=[]\n", 1470 | "use_clr_min=20\n", 1471 | "use_clr_div=10\n", 1472 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 1473 | "\n", 1474 | "szs = [(128,64)]\n", 1475 | "for sz,bs in szs:\n", 1476 | " print([sz,bs])\n", 1477 | " for i in range(kf) :\n", 1478 | " print(f'fold_id{i}')\n", 1479 | " \n", 1480 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 1481 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 1482 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 1483 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 1484 | " \n", 1485 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 1486 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 1487 | " md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)\n", 1488 | " denorm = md.trn_ds.denorm\n", 1489 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 1490 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 1491 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 1492 | " learn = get_tgs_model() \n", 1493 | " learn.opt_fn = optim.Adam\n", 1494 | " learn.metrics=[my_eval]\n", 1495 | " pa = f'{kf}_fold_{model}_{i}'\n", 1496 | " print(pa)\n", 1497 | "\n", 1498 | " lr=1e-2\n", 1499 | " wd=1e-7\n", 1500 | " lrs = np.array([lr/100,lr/10,lr])\n", 1501 | "\n", 1502 | " learn.unfreeze() \n", 1503 | " learn.crit = nn.BCEWithLogitsLoss()\n", 1504 | " if os.path.exists(pa):\n", 1505 | " learn.load(pa)\n", 1506 | " learn.fit(lrs/2,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 1507 | " \n", 1508 | " learn.load(pa) \n", 1509 | " #Calcuating mean iou score\n", 1510 | " v_targ = md.val_ds.ds[:][1]\n", 1511 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 1512 | " v_pred = learn.predict()\n", 1513 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 1514 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 1515 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 1516 | " print(bst_acc[-1])\n", 1517 | " \n", 1518 | " del learn\n", 1519 | " del md\n", 1520 | " gc.collect()" 1521 | ] 1522 | }, 1523 | { 1524 | "cell_type": "code", 1525 | "execution_count": 0, 1526 | "metadata": { 1527 | "colab": { 1528 | "base_uri": "https://localhost:8080/", 1529 | "height": 4346 1530 | }, 1531 | "colab_type": "code", 1532 | "id": "ROi0HiEA4xeT", 1533 | "outputId": "3edaabab-5f31-4435-ee92-a9ad8d121d57" 1534 | }, 1535 | "outputs": [ 1536 | { 1537 | "name": "stdout", 1538 | "output_type": "stream", 1539 | "text": [ 1540 | "[128, 64]\n", 1541 | "fold_id0\n", 1542 | "5_fold_resnext50_128_scse_hcol_0\n" 1543 | ] 1544 | }, 1545 | { 1546 | "data": { 1547 | "application/vnd.jupyter.widget-view+json": { 1548 | "model_id": "cc2e39db82e44742bc11cb86df4f4ac8", 1549 | "version_major": 2, 1550 | "version_minor": 0 1551 | }, 1552 | "text/plain": [ 1553 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1554 | ] 1555 | }, 1556 | "metadata": { 1557 | "tags": [] 1558 | }, 1559 | "output_type": "display_data" 1560 | }, 1561 | { 1562 | "name": "stdout", 1563 | "output_type": "stream", 1564 | "text": [ 1565 | "epoch trn_loss val_loss my_eval \n", 1566 | " 0 0.977655 0.9829 0.76025 \n", 1567 | " 1 0.984492 1.024706 0.763625 \n", 1568 | " 2 0.94598 0.968942 0.768125 \n", 1569 | " 3 0.91825 0.867097 0.782625 \n", 1570 | " 4 0.842798 0.835175 0.7835 \n", 1571 | " 5 0.811193 0.819348 0.79875 \n", 1572 | " 6 0.784356 0.763691 0.8005 \n", 1573 | " 7 0.763975 0.773445 0.800375 \n", 1574 | " 8 0.707037 0.741929 0.79875 \n", 1575 | " 9 0.676857 0.763426 0.80925 \n", 1576 | " 10 0.686296 0.84298 0.7955 \n", 1577 | " 11 0.778831 0.926706 0.788625 \n", 1578 | " 12 0.823366 0.929922 0.78175 \n", 1579 | " 13 0.846463 0.840977 0.768375 \n", 1580 | " 14 0.80643 0.766691 0.79525 \n", 1581 | " 15 0.785104 0.772868 0.793375 \n", 1582 | " 16 0.771523 0.761224 0.803625 \n", 1583 | " 17 0.73849 0.711444 0.80725 \n", 1584 | " 18 0.698676 0.748108 0.807125 \n", 1585 | " 19 0.664383 0.732889 0.804625 \n", 1586 | "\n" 1587 | ] 1588 | }, 1589 | { 1590 | "data": { 1591 | "application/vnd.jupyter.widget-view+json": { 1592 | "model_id": "38873a4d51c9454a903188c1d9af5995", 1593 | "version_major": 2, 1594 | "version_minor": 0 1595 | }, 1596 | "text/plain": [ 1597 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1598 | ] 1599 | }, 1600 | "metadata": { 1601 | "tags": [] 1602 | }, 1603 | "output_type": "display_data" 1604 | }, 1605 | { 1606 | "name": "stdout", 1607 | "output_type": "stream", 1608 | "text": [ 1609 | "epoch trn_loss val_loss my_eval \n", 1610 | " 0 0.642692 0.774102 0.79425 \n", 1611 | " 1 0.670723 0.791728 0.797125 \n", 1612 | " 2 0.701122 0.799743 0.794125 \n", 1613 | " 3 0.695085 0.809558 0.8015 \n", 1614 | " 4 0.709821 0.699042 0.801 \n", 1615 | " 5 0.686103 0.869248 0.78125 \n", 1616 | " 6 0.668086 0.807663 0.79025 \n", 1617 | " 7 0.708722 0.761571 0.786875 \n", 1618 | " 8 0.695549 0.760038 0.809125 \n", 1619 | " 9 0.710776 0.817635 0.794625 \n", 1620 | " 10 0.667159 0.799424 0.79725 \n", 1621 | " 11 0.674622 0.791136 0.797375 \n", 1622 | " 12 0.653431 0.732944 0.809125 \n", 1623 | " 13 0.618604 0.733997 0.803 \n", 1624 | " 14 0.614818 0.781419 0.8115 \n", 1625 | " 15 0.592244 0.753463 0.8135 \n", 1626 | " 16 0.593055 0.76313 0.818875 \n", 1627 | " 17 0.591503 0.800476 0.791875 \n", 1628 | " 18 0.592983 0.749201 0.816125 \n", 1629 | " 19 0.586945 0.764278 0.814375 \n", 1630 | "\n", 1631 | "0.818875\n", 1632 | "fold_id1\n", 1633 | "5_fold_resnext50_128_scse_hcol_1\n" 1634 | ] 1635 | }, 1636 | { 1637 | "data": { 1638 | "application/vnd.jupyter.widget-view+json": { 1639 | "model_id": "3d9c315bb12748429c63e7d655c9407d", 1640 | "version_major": 2, 1641 | "version_minor": 0 1642 | }, 1643 | "text/plain": [ 1644 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1645 | ] 1646 | }, 1647 | "metadata": { 1648 | "tags": [] 1649 | }, 1650 | "output_type": "display_data" 1651 | }, 1652 | { 1653 | "name": "stdout", 1654 | "output_type": "stream", 1655 | "text": [ 1656 | "epoch trn_loss val_loss my_eval \n", 1657 | " 0 0.871871 1.046633 0.740375 \n", 1658 | " 1 0.953284 1.179364 0.718875 \n", 1659 | " 2 0.927234 0.987735 0.73175 \n", 1660 | " 3 0.891905 0.919984 0.759875 \n", 1661 | " 4 0.824233 0.970662 0.76025 \n", 1662 | " 5 0.779335 0.885995 0.766125 \n", 1663 | " 6 0.750891 0.903214 0.774125 \n", 1664 | " 7 0.726601 0.843466 0.783625 \n", 1665 | " 8 0.684857 0.833532 0.789125 \n", 1666 | " 9 0.668015 0.820718 0.791125 \n", 1667 | " 10 0.638963 0.86629 0.77425 \n", 1668 | " 11 0.764584 1.057496 0.69975 \n", 1669 | " 12 0.819989 1.005252 0.747375 \n", 1670 | " 13 0.824506 0.927726 0.76375 \n", 1671 | " 14 0.803146 0.867914 0.77575 \n", 1672 | " 15 0.80264 0.916781 0.7715 \n", 1673 | " 16 0.771997 0.827193 0.786625 \n", 1674 | " 17 0.727303 0.858733 0.788625 \n", 1675 | " 18 0.686257 0.846492 0.779625 \n", 1676 | " 19 0.661692 0.800186 0.79375 \n", 1677 | "\n" 1678 | ] 1679 | }, 1680 | { 1681 | "data": { 1682 | "application/vnd.jupyter.widget-view+json": { 1683 | "model_id": "55f8389c6cd8486ebb04a737a18b32c2", 1684 | "version_major": 2, 1685 | "version_minor": 0 1686 | }, 1687 | "text/plain": [ 1688 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1689 | ] 1690 | }, 1691 | "metadata": { 1692 | "tags": [] 1693 | }, 1694 | "output_type": "display_data" 1695 | }, 1696 | { 1697 | "name": "stdout", 1698 | "output_type": "stream", 1699 | "text": [ 1700 | "epoch trn_loss val_loss my_eval \n", 1701 | " 0 0.680326 0.828174 0.7895 \n", 1702 | " 1 0.68143 0.84101 0.775875 \n", 1703 | " 2 0.716512 0.921755 0.780375 \n", 1704 | " 3 0.716995 0.865955 0.780875 \n", 1705 | " 4 0.733966 0.841387 0.754875 \n", 1706 | " 5 0.726975 0.876049 0.77775 \n", 1707 | " 6 0.702523 0.901705 0.77425 \n", 1708 | " 7 0.672928 0.878522 0.7805 \n", 1709 | " 8 0.668419 0.859559 0.7695 \n", 1710 | " 9 0.645093 0.857982 0.788 \n", 1711 | " 10 0.646567 0.948428 0.7555 \n", 1712 | " 11 0.659159 0.892936 0.783125 \n", 1713 | " 12 0.626061 0.893832 0.787125 \n", 1714 | " 13 0.633051 0.822176 0.787 \n", 1715 | " 14 0.60531 0.902848 0.788625 \n", 1716 | " 15 0.598807 0.832705 0.7875 \n", 1717 | " 16 0.571184 0.847312 0.787625 \n", 1718 | " 17 0.563408 0.871148 0.795625 \n", 1719 | " 18 0.571028 0.850332 0.791875 \n", 1720 | " 19 0.538525 0.868956 0.798375 \n", 1721 | "\n", 1722 | "0.7983750000000001\n", 1723 | "fold_id2\n", 1724 | "5_fold_resnext50_128_scse_hcol_2\n" 1725 | ] 1726 | }, 1727 | { 1728 | "data": { 1729 | "application/vnd.jupyter.widget-view+json": { 1730 | "model_id": "a4c02576e160455984579f98cf21250c", 1731 | "version_major": 2, 1732 | "version_minor": 0 1733 | }, 1734 | "text/plain": [ 1735 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1736 | ] 1737 | }, 1738 | "metadata": { 1739 | "tags": [] 1740 | }, 1741 | "output_type": "display_data" 1742 | }, 1743 | { 1744 | "name": "stdout", 1745 | "output_type": "stream", 1746 | "text": [ 1747 | "epoch trn_loss val_loss my_eval \n", 1748 | " 0 0.930581 1.076116 0.750875 \n", 1749 | " 1 0.977241 0.991857 0.748 \n", 1750 | " 2 0.963887 0.864013 0.77875 \n", 1751 | " 3 0.921879 0.839139 0.7825 \n", 1752 | " 4 0.891747 0.80363 0.787875 \n", 1753 | " 5 0.842367 0.788842 0.7915 \n", 1754 | " 6 0.809913 0.766017 0.80325 \n", 1755 | " 7 0.745459 0.74085 0.81025 \n", 1756 | " 8 0.717501 0.751182 0.80775 \n", 1757 | " 9 0.672494 0.721534 0.81775 \n", 1758 | " 10 0.687882 0.777983 0.802125 \n", 1759 | " 11 0.835816 0.843362 0.768375 \n", 1760 | " 12 0.861769 0.882111 0.774125 \n", 1761 | " 13 0.916062 0.933928 0.7635 \n", 1762 | " 14 0.860758 0.787877 0.783875 \n", 1763 | " 15 0.817398 0.785797 0.791625 \n", 1764 | " 16 0.799109 0.764803 0.792125 \n", 1765 | " 17 0.763636 0.776863 0.801875 \n", 1766 | " 18 0.734554 0.735275 0.809875 \n", 1767 | " 19 0.69847 0.73502 0.808125 \n", 1768 | "\n" 1769 | ] 1770 | }, 1771 | { 1772 | "data": { 1773 | "application/vnd.jupyter.widget-view+json": { 1774 | "model_id": "794f42c7ffbb49a4a48d5599479b1550", 1775 | "version_major": 2, 1776 | "version_minor": 0 1777 | }, 1778 | "text/plain": [ 1779 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1780 | ] 1781 | }, 1782 | "metadata": { 1783 | "tags": [] 1784 | }, 1785 | "output_type": "display_data" 1786 | }, 1787 | { 1788 | "name": "stdout", 1789 | "output_type": "stream", 1790 | "text": [ 1791 | "epoch trn_loss val_loss my_eval \n", 1792 | " 0 0.671766 0.759295 0.791375 \n", 1793 | " 1 0.731133 0.811743 0.789875 \n", 1794 | " 2 0.763409 0.791452 0.789625 \n", 1795 | " 3 0.774678 0.789636 0.792 \n", 1796 | " 4 0.758107 0.77808 0.772 \n", 1797 | " 5 0.735919 0.742705 0.796625 \n", 1798 | " 6 0.720245 0.771582 0.782125 \n", 1799 | " 7 0.710452 0.743036 0.7825 \n", 1800 | " 8 0.710112 0.777343 0.790875 \n", 1801 | " 9 0.736112 0.743986 0.7875 \n", 1802 | " 10 0.70377 0.773617 0.7935 \n", 1803 | " 11 0.66605 0.770097 0.803 \n", 1804 | " 12 0.662434 0.712938 0.808625 \n", 1805 | " 13 0.652016 0.743833 0.807625 \n", 1806 | " 14 0.649992 0.737581 0.811125 \n", 1807 | " 15 0.638489 0.707261 0.81575 \n", 1808 | " 16 0.596832 0.710357 0.818 \n", 1809 | " 17 0.595883 0.704559 0.804625 \n", 1810 | " 18 0.572071 0.705111 0.82025 \n", 1811 | " 19 0.563588 0.69612 0.82325 \n", 1812 | "\n", 1813 | "0.8232499999999999\n", 1814 | "fold_id3\n", 1815 | "5_fold_resnext50_128_scse_hcol_3\n" 1816 | ] 1817 | }, 1818 | { 1819 | "data": { 1820 | "application/vnd.jupyter.widget-view+json": { 1821 | "model_id": "f4961bde0db542448bf37775dbc087d5", 1822 | "version_major": 2, 1823 | "version_minor": 0 1824 | }, 1825 | "text/plain": [ 1826 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1827 | ] 1828 | }, 1829 | "metadata": { 1830 | "tags": [] 1831 | }, 1832 | "output_type": "display_data" 1833 | }, 1834 | { 1835 | "name": "stdout", 1836 | "output_type": "stream", 1837 | "text": [ 1838 | "epoch trn_loss val_loss my_eval \n", 1839 | " 0 0.910043 0.931082 0.774375 \n", 1840 | " 1 0.962166 0.841506 0.7835 \n", 1841 | " 2 0.966172 0.884748 0.76575 \n", 1842 | " 3 0.919205 0.797827 0.792 \n", 1843 | " 4 0.887789 0.838587 0.782125 \n", 1844 | " 5 0.833859 0.758443 0.79875 \n", 1845 | " 6 0.787326 0.748243 0.7995 \n", 1846 | " 7 0.768268 0.734664 0.80275 \n", 1847 | " 8 0.721092 0.701125 0.81575 \n", 1848 | " 9 0.678771 0.707262 0.817625 \n", 1849 | " 10 0.686963 0.761993 0.807 \n", 1850 | " 11 0.797779 0.845824 0.78325 \n", 1851 | " 12 0.836203 0.781706 0.79575 \n", 1852 | " 13 0.807488 0.871408 0.7955 \n", 1853 | " 14 0.786293 0.773449 0.801375 \n", 1854 | " 15 0.779008 0.795548 0.800875 \n", 1855 | " 16 0.756668 0.742518 0.80525 \n", 1856 | " 17 0.705088 0.703899 0.81325 \n", 1857 | " 18 0.677209 0.693749 0.819125 \n", 1858 | " 19 0.639317 0.679222 0.82 \n", 1859 | "\n" 1860 | ] 1861 | }, 1862 | { 1863 | "data": { 1864 | "application/vnd.jupyter.widget-view+json": { 1865 | "model_id": "46e4957616e84cd499a4f53b689cc927", 1866 | "version_major": 2, 1867 | "version_minor": 0 1868 | }, 1869 | "text/plain": [ 1870 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1871 | ] 1872 | }, 1873 | "metadata": { 1874 | "tags": [] 1875 | }, 1876 | "output_type": "display_data" 1877 | }, 1878 | { 1879 | "name": "stdout", 1880 | "output_type": "stream", 1881 | "text": [ 1882 | "epoch trn_loss val_loss my_eval \n", 1883 | " 0 0.653376 0.719673 0.812375 \n", 1884 | " 1 0.701078 0.751476 0.80475 \n", 1885 | " 2 0.728884 0.854194 0.79975 \n", 1886 | " 3 0.752242 0.724502 0.8095 \n", 1887 | " 4 0.704703 0.742548 0.815 \n", 1888 | " 5 0.716587 0.734281 0.811875 \n", 1889 | " 6 0.683265 0.743123 0.806875 \n", 1890 | " 7 0.678836 0.815633 0.793625 \n", 1891 | " 8 0.689473 0.700409 0.81575 \n", 1892 | " 9 0.665838 0.669478 0.82075 \n", 1893 | " 10 0.633243 0.730008 0.816125 \n", 1894 | " 11 0.655839 0.701937 0.816125 \n", 1895 | " 12 0.632317 0.649175 0.82675 \n", 1896 | " 13 0.598759 0.674421 0.83225 \n", 1897 | " 14 0.596253 0.666588 0.829875 \n", 1898 | " 15 0.586747 0.670523 0.827125 \n", 1899 | " 16 0.567371 0.686732 0.824375 \n", 1900 | " 17 0.56206 0.673798 0.8225 \n", 1901 | " 18 0.545179 0.67912 0.82925 \n", 1902 | " 19 0.526347 0.698485 0.827 \n", 1903 | "\n", 1904 | "0.8322499999999999\n", 1905 | "fold_id4\n", 1906 | "5_fold_resnext50_128_scse_hcol_4\n" 1907 | ] 1908 | }, 1909 | { 1910 | "data": { 1911 | "application/vnd.jupyter.widget-view+json": { 1912 | "model_id": "161724e6ee244e4c9dc4fc87b9a7c5f4", 1913 | "version_major": 2, 1914 | "version_minor": 0 1915 | }, 1916 | "text/plain": [ 1917 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1918 | ] 1919 | }, 1920 | "metadata": { 1921 | "tags": [] 1922 | }, 1923 | "output_type": "display_data" 1924 | }, 1925 | { 1926 | "name": "stdout", 1927 | "output_type": "stream", 1928 | "text": [ 1929 | "epoch trn_loss val_loss my_eval \n", 1930 | " 0 0.92871 0.959824 0.784125 \n", 1931 | " 1 0.949663 0.872785 0.776375 \n", 1932 | " 2 0.909849 0.817117 0.790375 \n", 1933 | " 3 0.896153 0.804488 0.794125 \n", 1934 | " 4 0.83174 0.835692 0.79425 \n", 1935 | " 5 0.814119 0.780774 0.797125 \n", 1936 | " 6 0.755571 0.759111 0.801875 \n", 1937 | " 7 0.729774 0.78065 0.808 \n", 1938 | " 8 0.674973 0.741045 0.813375 \n", 1939 | " 9 0.660971 0.72706 0.8115 \n", 1940 | " 10 0.689892 0.819716 0.786375 \n", 1941 | " 11 0.789879 0.855918 0.789375 \n", 1942 | " 12 0.783312 0.796657 0.79 \n", 1943 | " 13 0.781245 0.913932 0.78375 \n", 1944 | " 14 0.787869 0.830769 0.782625 \n", 1945 | " 15 0.799666 0.75475 0.805375 \n", 1946 | " 16 0.750443 0.712465 0.81 \n", 1947 | " 17 0.708587 0.7156 0.812875 \n", 1948 | " 18 0.680404 0.758386 0.805625 \n", 1949 | " 19 0.647 0.703068 0.813125 \n", 1950 | "\n" 1951 | ] 1952 | }, 1953 | { 1954 | "data": { 1955 | "application/vnd.jupyter.widget-view+json": { 1956 | "model_id": "882d2ce3e11448a7897138b5879d9d77", 1957 | "version_major": 2, 1958 | "version_minor": 0 1959 | }, 1960 | "text/plain": [ 1961 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 1962 | ] 1963 | }, 1964 | "metadata": { 1965 | "tags": [] 1966 | }, 1967 | "output_type": "display_data" 1968 | }, 1969 | { 1970 | "name": "stdout", 1971 | "output_type": "stream", 1972 | "text": [ 1973 | "epoch trn_loss val_loss my_eval \n", 1974 | " 0 0.625973 0.726331 0.820375 \n", 1975 | " 1 0.670526 0.805685 0.80375 \n", 1976 | " 2 0.701503 0.781464 0.788875 \n", 1977 | " 3 0.69349 0.791473 0.801625 \n", 1978 | " 4 0.682512 0.745718 0.81325 \n", 1979 | " 5 0.69381 0.70612 0.806625 \n", 1980 | " 6 0.665137 0.696765 0.816375 \n", 1981 | " 7 0.641031 0.729714 0.8195 \n", 1982 | " 8 0.632783 0.706463 0.80375 \n", 1983 | " 9 0.634154 0.687877 0.81325 \n", 1984 | " 10 0.626251 0.690043 0.809375 \n", 1985 | " 11 0.591347 0.709688 0.815125 \n", 1986 | " 12 0.575432 0.702248 0.816625 \n", 1987 | " 13 0.587349 0.694433 0.81625 \n", 1988 | " 14 0.573425 0.694701 0.820375 \n", 1989 | " 15 0.584016 0.66213 0.8265 \n", 1990 | " 16 0.547527 0.703967 0.824625 \n", 1991 | " 17 0.548586 0.674334 0.82825 \n", 1992 | " 18 0.534016 0.680121 0.829125 \n", 1993 | " 19 0.507791 0.67429 0.823375 \n", 1994 | "\n", 1995 | "0.8291249999999999\n" 1996 | ] 1997 | } 1998 | ], 1999 | "source": [ 2000 | "model = 'resnext50_128_scse_hcol'\n", 2001 | "arch = resnext50\n", 2002 | "bst_acc=[]\n", 2003 | "use_clr_min=20\n", 2004 | "use_clr_div=10\n", 2005 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 2006 | "\n", 2007 | "szs = [(128,64)]\n", 2008 | "for sz,bs in szs:\n", 2009 | " print([sz,bs])\n", 2010 | " for i in range(kf) :\n", 2011 | " print(f'fold_id{i}')\n", 2012 | " \n", 2013 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 2014 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 2015 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 2016 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 2017 | " \n", 2018 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 2019 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 2020 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 2021 | " denorm = md.trn_ds.denorm\n", 2022 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 2023 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 2024 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 2025 | " learn = get_tgs_model() \n", 2026 | " learn.opt_fn = optim.Adam\n", 2027 | " learn.metrics=[my_eval]\n", 2028 | " pa = f'{kf}_fold_{model}_{i}'\n", 2029 | " print(pa)\n", 2030 | "\n", 2031 | " lr=1e-2\n", 2032 | " wd=1e-7\n", 2033 | " lrs = np.array([lr/100,lr/10,lr])\n", 2034 | "\n", 2035 | " learn.unfreeze() \n", 2036 | " learn.crit = lovasz_hinge\n", 2037 | " learn.load(pa)\n", 2038 | " learn.fit(lrs/2,2, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 2039 | " learn.fit(lrs/3,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n", 2040 | " \n", 2041 | " learn.load(pa) \n", 2042 | " #Calcuating mean iou score\n", 2043 | " v_targ = md.val_ds.ds[:][1]\n", 2044 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 2045 | " v_pred = learn.predict()\n", 2046 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 2047 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 2048 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 2049 | " print(bst_acc[-1])\n", 2050 | " \n", 2051 | " del learn\n", 2052 | " del md\n", 2053 | " gc.collect()\n", 2054 | "\n", 2055 | "save_model_weights()" 2056 | ] 2057 | }, 2058 | { 2059 | "cell_type": "code", 2060 | "execution_count": 0, 2061 | "metadata": { 2062 | "colab": { 2063 | "base_uri": "https://localhost:8080/", 2064 | "height": 3898 2065 | }, 2066 | "colab_type": "code", 2067 | "id": "mFTMgQkX6VtI", 2068 | "outputId": "9c990f72-dbe6-4ece-96f6-97e445e91f25" 2069 | }, 2070 | "outputs": [ 2071 | { 2072 | "name": "stdout", 2073 | "output_type": "stream", 2074 | "text": [ 2075 | "[128, 64]\n", 2076 | "fold_id0\n", 2077 | "5_fold_resnext50_128_scse_hcol_0\n" 2078 | ] 2079 | }, 2080 | { 2081 | "data": { 2082 | "application/vnd.jupyter.widget-view+json": { 2083 | "model_id": "2c64641c0ef945a182d60ab806fcfdc1", 2084 | "version_major": 2, 2085 | "version_minor": 0 2086 | }, 2087 | "text/plain": [ 2088 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 2089 | ] 2090 | }, 2091 | "metadata": { 2092 | "tags": [] 2093 | }, 2094 | "output_type": "display_data" 2095 | }, 2096 | { 2097 | "name": "stdout", 2098 | "output_type": "stream", 2099 | "text": [ 2100 | "epoch trn_loss val_loss my_eval \n", 2101 | " 0 0.580299 0.837661 0.80625 \n", 2102 | " 1 0.628745 0.803514 0.798375 \n", 2103 | " 2 0.619702 0.796008 0.80575 \n", 2104 | " 3 0.628366 0.750468 0.81175 \n", 2105 | " 4 0.623308 0.748861 0.802 \n", 2106 | " 5 0.650031 0.795647 0.803 \n", 2107 | " 6 0.646449 0.763365 0.81125 \n", 2108 | " 7 0.636074 0.781185 0.811375 \n", 2109 | " 8 0.623231 0.742726 0.809875 \n", 2110 | " 9 0.602579 0.753169 0.813375 \n", 2111 | " 10 0.585192 0.776042 0.817375 \n", 2112 | " 11 0.576949 0.790247 0.8175 \n", 2113 | " 12 0.562295 0.796752 0.813875 \n", 2114 | " 13 0.56244 0.795689 0.788875 \n", 2115 | " 14 0.575644 0.806945 0.8145 \n", 2116 | " 15 0.58317 0.765245 0.81525 \n", 2117 | " 16 0.560825 0.782946 0.805875 \n", 2118 | " 17 0.562333 0.797728 0.79975 \n", 2119 | " 18 0.562816 0.761856 0.816375 \n", 2120 | " 19 0.553854 0.775206 0.8135 \n", 2121 | " 20 0.546581 0.714756 0.82125 \n", 2122 | " 21 0.541287 0.770479 0.8095 \n", 2123 | " 22 0.518532 0.805129 0.813625 \n", 2124 | " 23 0.521006 0.732549 0.806625 \n", 2125 | " 24 0.495842 0.853394 0.811375 \n", 2126 | " 25 0.493223 0.838361 0.81125 \n", 2127 | " 26 0.482846 0.831322 0.814375 \n", 2128 | " 27 0.483749 0.788495 0.820125 \n", 2129 | " 28 0.496672 0.792871 0.82125 \n", 2130 | " 29 0.497796 0.762518 0.823875 \n", 2131 | " 30 0.49775 0.79859 0.81975 \n", 2132 | " 31 0.491748 0.779934 0.822375 \n", 2133 | " 32 0.475641 0.790661 0.816625 \n", 2134 | " 33 0.4794 0.78795 0.82175 \n", 2135 | " 34 0.475329 0.764116 0.81825 \n", 2136 | " 35 0.455563 0.822148 0.815 \n", 2137 | " 36 0.451707 0.795527 0.821375 \n", 2138 | " 37 0.444007 0.772276 0.8155 \n", 2139 | " 38 0.435822 0.779683 0.817375 \n", 2140 | " 39 0.445946 0.780952 0.814 \n", 2141 | "\n", 2142 | "0.823875\n", 2143 | "fold_id1\n", 2144 | "5_fold_resnext50_128_scse_hcol_1\n" 2145 | ] 2146 | }, 2147 | { 2148 | "data": { 2149 | "application/vnd.jupyter.widget-view+json": { 2150 | "model_id": "c2faf773816a409e9d909ded6b7f2313", 2151 | "version_major": 2, 2152 | "version_minor": 0 2153 | }, 2154 | "text/plain": [ 2155 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 2156 | ] 2157 | }, 2158 | "metadata": { 2159 | "tags": [] 2160 | }, 2161 | "output_type": "display_data" 2162 | }, 2163 | { 2164 | "name": "stdout", 2165 | "output_type": "stream", 2166 | "text": [ 2167 | "epoch trn_loss val_loss my_eval \n", 2168 | " 0 0.565228 0.871326 0.785875 \n", 2169 | " 1 0.601723 0.882529 0.79425 \n", 2170 | " 2 0.642919 0.897614 0.78075 \n", 2171 | " 3 0.637186 0.878214 0.777875 \n", 2172 | " 4 0.609784 0.895229 0.79075 \n", 2173 | " 5 0.609371 0.926091 0.7865 \n", 2174 | " 6 0.601144 0.836737 0.788125 \n", 2175 | " 7 0.630633 0.871732 0.7835 \n", 2176 | " 8 0.611182 0.890666 0.7895 \n", 2177 | " 9 0.592919 0.831647 0.80075 \n", 2178 | " 10 0.567077 0.936684 0.786875 \n", 2179 | " 11 0.553323 0.966901 0.78875 \n", 2180 | " 12 0.559024 0.944055 0.77 \n", 2181 | " 13 0.590418 0.850013 0.796 \n", 2182 | " 14 0.573935 0.856166 0.7945 \n", 2183 | " 15 0.567616 0.860534 0.79075 \n", 2184 | " 16 0.569173 0.960924 0.781 \n", 2185 | " 17 0.558296 0.893143 0.791 \n", 2186 | " 18 0.542129 0.934519 0.794125 \n", 2187 | " 19 0.547132 0.855531 0.80075 \n", 2188 | " 20 0.53702 0.854191 0.796875 \n", 2189 | " 21 0.538296 0.912877 0.787625 \n", 2190 | " 22 0.526649 0.917397 0.800375 \n", 2191 | " 23 0.502509 0.896211 0.7885 \n", 2192 | " 24 0.521145 0.898083 0.796375 \n", 2193 | " 25 0.491927 0.923525 0.802625 \n", 2194 | " 26 0.501542 0.864817 0.794 \n", 2195 | " 27 0.494032 0.884823 0.796375 \n", 2196 | " 28 0.490451 0.899924 0.795875 \n", 2197 | " 29 0.481898 0.880166 0.784 \n", 2198 | " 30 0.472968 0.895765 0.797375 \n", 2199 | " 31 0.462385 0.878445 0.793625 \n", 2200 | " 32 0.455645 0.896859 0.795 \n", 2201 | " 33 0.469805 0.894238 0.7995 \n", 2202 | " 34 0.455159 0.913979 0.801375 \n", 2203 | " 35 0.457995 0.898335 0.8015 \n", 2204 | " 36 0.453348 0.900837 0.804875 \n", 2205 | " 37 0.446832 0.902242 0.793 \n", 2206 | " 38 0.453402 0.907094 0.801375 \n", 2207 | " 39 0.434806 0.921633 0.80875 \n", 2208 | "\n", 2209 | "0.80875\n", 2210 | "fold_id2\n", 2211 | "5_fold_resnext50_128_scse_hcol_2\n" 2212 | ] 2213 | }, 2214 | { 2215 | "data": { 2216 | "application/vnd.jupyter.widget-view+json": { 2217 | "model_id": "2914de385caf4b7d879001cf1ca135af", 2218 | "version_major": 2, 2219 | "version_minor": 0 2220 | }, 2221 | "text/plain": [ 2222 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 2223 | ] 2224 | }, 2225 | "metadata": { 2226 | "tags": [] 2227 | }, 2228 | "output_type": "display_data" 2229 | }, 2230 | { 2231 | "name": "stdout", 2232 | "output_type": "stream", 2233 | "text": [ 2234 | "epoch trn_loss val_loss my_eval \n", 2235 | " 0 0.577075 0.786943 0.80825 \n", 2236 | " 1 0.609053 0.768417 0.804375 \n", 2237 | " 2 0.656391 0.73859 0.812125 \n", 2238 | " 3 0.652841 0.757531 0.797 \n", 2239 | " 4 0.655581 0.736351 0.809875 \n", 2240 | " 5 0.623059 0.69868 0.812375 \n", 2241 | " 6 0.627167 0.77681 0.7915 \n", 2242 | " 7 0.623288 0.750636 0.802125 \n", 2243 | " 8 0.620909 0.76361 0.807125 \n", 2244 | " 9 0.628333 0.755546 0.793625 \n", 2245 | " 10 0.63122 0.744164 0.79475 \n", 2246 | " 11 0.63828 0.720994 0.8095 \n", 2247 | " 12 0.607952 0.712268 0.804125 \n", 2248 | " 13 0.588277 0.729939 0.80725 \n", 2249 | " 14 0.585946 0.705454 0.8255 \n", 2250 | " 15 0.58386 0.691177 0.81425 \n", 2251 | " 16 0.561116 0.713686 0.812625 \n", 2252 | " 17 0.553486 0.771892 0.818875 \n", 2253 | " 18 0.553116 0.707916 0.817875 \n", 2254 | " 19 0.565598 0.728426 0.815625 \n", 2255 | " 20 0.554347 0.766951 0.819625 \n", 2256 | " 21 0.536056 0.723741 0.820125 \n", 2257 | " 22 0.536135 0.724242 0.821625 \n", 2258 | " 23 0.554331 0.767052 0.8075 \n", 2259 | " 24 0.550846 0.718899 0.82025 \n", 2260 | " 25 0.539729 0.703877 0.81175 \n", 2261 | " 26 0.517849 0.694621 0.820125 \n", 2262 | " 27 0.504886 0.708674 0.821625 \n", 2263 | " 28 0.502828 0.726042 0.82175 \n", 2264 | " 29 0.498927 0.701111 0.81025 \n", 2265 | " 30 0.495155 0.715683 0.8185 \n", 2266 | " 31 0.507423 0.730976 0.82175 \n", 2267 | " 32 0.51031 0.709125 0.814125 \n", 2268 | " 33 0.504614 0.7094 0.81775 \n", 2269 | " 34 0.500099 0.694097 0.825375 \n", 2270 | " 35 0.482972 0.689517 0.824625 \n", 2271 | " 36 0.484635 0.679306 0.82225 \n", 2272 | " 37 0.481951 0.681534 0.828 \n", 2273 | " 38 0.475027 0.686403 0.825625 \n", 2274 | " 39 0.483564 0.695852 0.82175 \n", 2275 | "\n", 2276 | "0.828\n", 2277 | "fold_id3\n", 2278 | "5_fold_resnext50_128_scse_hcol_3\n" 2279 | ] 2280 | }, 2281 | { 2282 | "data": { 2283 | "application/vnd.jupyter.widget-view+json": { 2284 | "model_id": "019ac79c6d2e4e56ac325d5d3dd9a97a", 2285 | "version_major": 2, 2286 | "version_minor": 0 2287 | }, 2288 | "text/plain": [ 2289 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 2290 | ] 2291 | }, 2292 | "metadata": { 2293 | "tags": [] 2294 | }, 2295 | "output_type": "display_data" 2296 | }, 2297 | { 2298 | "name": "stdout", 2299 | "output_type": "stream", 2300 | "text": [ 2301 | "epoch trn_loss val_loss my_eval \n", 2302 | " 0 0.575313 0.695703 0.823625 \n", 2303 | " 1 0.624102 0.780891 0.813625 \n", 2304 | " 2 0.648184 0.697116 0.81825 \n", 2305 | " 3 0.651478 0.729888 0.80325 \n", 2306 | " 4 0.633732 0.689429 0.816875 \n", 2307 | " 5 0.6391 0.707463 0.821875 \n", 2308 | " 6 0.608808 0.734975 0.821 \n", 2309 | " 7 0.583592 0.727842 0.8085 \n", 2310 | " 8 0.607585 0.754664 0.806125 \n", 2311 | " 9 0.602211 0.745517 0.82925 \n", 2312 | " 10 0.58275 0.690118 0.826 \n", 2313 | " 11 0.577038 0.690472 0.816875 \n", 2314 | " 12 0.585021 0.703931 0.820875 \n", 2315 | " 13 0.573149 0.737506 0.822875 \n", 2316 | " 14 0.593045 0.701834 0.823125 \n", 2317 | " 15 0.576374 0.722945 0.8195 \n", 2318 | " 16 0.578671 0.718012 0.824125 \n", 2319 | " 17 0.576396 0.737476 0.825 \n", 2320 | " 18 0.562349 0.740334 0.82075 \n", 2321 | " 19 0.562596 0.703988 0.83375 \n", 2322 | " 20 0.545302 0.742709 0.824875 \n", 2323 | " 21 0.536242 0.730253 0.821375 \n", 2324 | " 22 0.538357 0.726912 0.834375 \n", 2325 | " 23 0.524991 0.705785 0.82975 \n", 2326 | " 24 0.539796 0.718887 0.821625 \n", 2327 | " 25 0.528741 0.74735 0.825125 \n", 2328 | " 26 0.52838 0.697965 0.833875 \n", 2329 | " 27 0.515824 0.80159 0.82975 \n", 2330 | " 28 0.519962 0.729874 0.813625 \n", 2331 | " 29 0.506576 0.729921 0.8205 \n", 2332 | " 30 0.488176 0.70798 0.823 \n", 2333 | " 31 0.487228 0.747718 0.824625 \n", 2334 | " 32 0.477002 0.720598 0.829625 \n", 2335 | " 33 0.468923 0.748624 0.833625 \n", 2336 | " 34 0.487961 0.757478 0.833125 \n", 2337 | " 35 0.470121 0.742304 0.813875 \n", 2338 | " 36 0.48037 0.745087 0.831875 \n", 2339 | " 37 0.476722 0.772283 0.8305 \n", 2340 | " 38 0.456916 0.735131 0.8195 \n", 2341 | " 39 0.461469 0.748882 0.8215 \n", 2342 | "\n", 2343 | "0.834375\n", 2344 | "fold_id4\n", 2345 | "5_fold_resnext50_128_scse_hcol_4\n" 2346 | ] 2347 | }, 2348 | { 2349 | "data": { 2350 | "application/vnd.jupyter.widget-view+json": { 2351 | "model_id": "d6269952b5694a12a836f74645dd9bbe", 2352 | "version_major": 2, 2353 | "version_minor": 0 2354 | }, 2355 | "text/plain": [ 2356 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 2357 | ] 2358 | }, 2359 | "metadata": { 2360 | "tags": [] 2361 | }, 2362 | "output_type": "display_data" 2363 | }, 2364 | { 2365 | "name": "stdout", 2366 | "output_type": "stream", 2367 | "text": [ 2368 | "epoch trn_loss val_loss my_eval \n", 2369 | " 0 0.546006 0.717454 0.8235 \n", 2370 | " 1 0.563938 0.751986 0.812375 \n", 2371 | " 2 0.598719 0.75113 0.815 \n", 2372 | " 3 0.588652 0.798206 0.822 \n", 2373 | " 4 0.599703 0.805366 0.8095 \n", 2374 | " 5 0.591598 0.734947 0.82075 \n", 2375 | " 6 0.596523 0.733508 0.811375 \n", 2376 | " 7 0.597403 0.742025 0.816125 \n", 2377 | " 8 0.588319 0.729897 0.80425 \n", 2378 | " 9 0.57466 0.743337 0.808375 \n", 2379 | " 10 0.590704 0.786838 0.8205 \n", 2380 | " 11 0.557619 0.687391 0.81425 \n", 2381 | " 12 0.539545 0.704461 0.82875 \n", 2382 | " 13 0.543599 0.810272 0.820625 \n", 2383 | " 14 0.549854 0.70434 0.819 \n", 2384 | " 15 0.541127 0.726442 0.820625 \n", 2385 | " 16 0.547006 0.710707 0.805625 \n", 2386 | " 17 0.556314 0.704044 0.825 \n", 2387 | " 18 0.533182 0.687741 0.81375 \n", 2388 | " 19 0.511729 0.7148 0.834 \n", 2389 | " 20 0.509799 0.690751 0.8375 \n", 2390 | " 21 0.496054 0.709561 0.83325 \n", 2391 | " 22 0.499565 0.745653 0.8235 \n", 2392 | " 23 0.506426 0.723755 0.824375 \n", 2393 | " 24 0.487424 0.73846 0.828125 \n", 2394 | " 25 0.485281 0.725011 0.822 \n", 2395 | " 26 0.466737 0.791306 0.827125 \n", 2396 | " 27 0.46954 0.759082 0.827625 \n", 2397 | " 28 0.479826 0.683398 0.8315 \n", 2398 | " 29 0.463535 0.724489 0.831 \n", 2399 | " 30 0.460917 0.712501 0.831 \n", 2400 | " 31 0.441873 0.723946 0.829 \n", 2401 | " 32 0.435788 0.724188 0.82975 \n", 2402 | " 33 0.438317 0.718923 0.821375 \n", 2403 | " 34 0.442253 0.724379 0.82475 \n", 2404 | " 35 0.425489 0.734597 0.824125 \n", 2405 | " 36 0.44013 0.728877 0.820125 \n", 2406 | " 37 0.443834 0.744366 0.828 \n", 2407 | " 38 0.426803 0.729525 0.827875 \n", 2408 | " 39 0.435825 0.735564 0.822 \n", 2409 | "\n", 2410 | "0.8375\n" 2411 | ] 2412 | } 2413 | ], 2414 | "source": [ 2415 | "model = 'resnext50_128_scse_hcol'\n", 2416 | "arch = resnext50\n", 2417 | "bst_acc=[]\n", 2418 | "use_clr_min=20\n", 2419 | "use_clr_div=10\n", 2420 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 2421 | "\n", 2422 | "szs = [(128,64)]\n", 2423 | "for sz,bs in szs:\n", 2424 | " print([sz,bs])\n", 2425 | " for i in range(kf) :\n", 2426 | " print(f'fold_id{i}')\n", 2427 | " \n", 2428 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 2429 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 2430 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 2431 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 2432 | " \n", 2433 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 2434 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 2435 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 2436 | " denorm = md.trn_ds.denorm\n", 2437 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 2438 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 2439 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 2440 | " learn = get_tgs_model() \n", 2441 | " learn.opt_fn = optim.Adam\n", 2442 | " learn.metrics=[my_eval]\n", 2443 | " pa = f'{kf}_fold_{model}_{i}'\n", 2444 | " print(pa)\n", 2445 | "\n", 2446 | " lr=1e-2\n", 2447 | " wd=1e-7\n", 2448 | " lrs = np.array([lr/100,lr/10,lr])\n", 2449 | "\n", 2450 | " learn.unfreeze() \n", 2451 | " learn.crit = lovasz_hinge\n", 2452 | " learn.load(pa)\n", 2453 | "# learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n", 2454 | " learn.fit(lrs/4,1, wds=wd, cycle_len=40,use_clr=(20,32),best_save_name=pa)\n", 2455 | " \n", 2456 | " learn.load(pa) \n", 2457 | " #Calcuating mean iou score\n", 2458 | " v_targ = md.val_ds.ds[:][1]\n", 2459 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 2460 | " v_pred = learn.predict()\n", 2461 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 2462 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 2463 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 2464 | " print(bst_acc[-1])\n", 2465 | " \n", 2466 | " del learn\n", 2467 | " del md\n", 2468 | " gc.collect()\n", 2469 | "\n", 2470 | "save_model_weights()" 2471 | ] 2472 | }, 2473 | { 2474 | "cell_type": "code", 2475 | "execution_count": 0, 2476 | "metadata": { 2477 | "colab": { 2478 | "base_uri": "https://localhost:8080/", 2479 | "height": 1996 2480 | }, 2481 | "colab_type": "code", 2482 | "id": "H7YZnN_HiiuR", 2483 | "outputId": "ee40d675-4739-4c54-f4dd-a0a55119aad3" 2484 | }, 2485 | "outputs": [ 2486 | { 2487 | "name": "stdout", 2488 | "output_type": "stream", 2489 | "text": [ 2490 | "[128, 64]\n", 2491 | "fold_id0\n", 2492 | "5_fold_resnext50_128_scse_hcol_0\n" 2493 | ] 2494 | }, 2495 | { 2496 | "data": { 2497 | "application/vnd.jupyter.widget-view+json": { 2498 | "model_id": "56e79297e25847f0a15aff07aa418769", 2499 | "version_major": 2, 2500 | "version_minor": 0 2501 | }, 2502 | "text/plain": [ 2503 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 2504 | ] 2505 | }, 2506 | "metadata": { 2507 | "tags": [] 2508 | }, 2509 | "output_type": "display_data" 2510 | }, 2511 | { 2512 | "name": "stdout", 2513 | "output_type": "stream", 2514 | "text": [ 2515 | "epoch trn_loss val_loss my_eval \n", 2516 | " 0 0.464316 0.82814 0.816375 \n", 2517 | " 1 0.503966 0.973039 0.808125 \n", 2518 | " 2 0.524512 0.875324 0.802125 \n", 2519 | " 3 0.552585 0.802356 0.8155 \n", 2520 | " 4 0.524045 0.898251 0.8165 \n", 2521 | " 5 0.519004 0.884085 0.815 \n", 2522 | " 6 0.512701 0.761209 0.816625 \n", 2523 | " 7 0.520394 0.781913 0.80575 \n", 2524 | " 8 0.501048 0.748156 0.815875 \n", 2525 | " 9 0.504658 0.800892 0.808875 \n", 2526 | " 10 0.491141 0.764409 0.813375 \n", 2527 | " 11 0.480288 0.845715 0.823 \n", 2528 | " 12 0.482994 0.864056 0.817125 \n", 2529 | " 13 0.474893 0.844127 0.818 \n", 2530 | " 14 0.475831 0.82016 0.822375 \n", 2531 | " 15 0.462766 0.80165 0.816 \n", 2532 | " 16 0.459729 0.801812 0.822625 \n", 2533 | " 17 0.452857 0.825043 0.821125 \n", 2534 | " 18 0.432667 0.798102 0.819125 \n", 2535 | " 19 0.448399 0.799831 0.82425 \n", 2536 | "\n", 2537 | "0.8242499999999999\n", 2538 | "fold_id1\n", 2539 | "5_fold_resnext50_128_scse_hcol_1\n" 2540 | ] 2541 | }, 2542 | { 2543 | "data": { 2544 | "application/vnd.jupyter.widget-view+json": { 2545 | "model_id": "8aea5543371447dbafac57bda9d368dd", 2546 | "version_major": 2, 2547 | "version_minor": 0 2548 | }, 2549 | "text/plain": [ 2550 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 2551 | ] 2552 | }, 2553 | "metadata": { 2554 | "tags": [] 2555 | }, 2556 | "output_type": "display_data" 2557 | }, 2558 | { 2559 | "name": "stdout", 2560 | "output_type": "stream", 2561 | "text": [ 2562 | "epoch trn_loss val_loss my_eval \n", 2563 | " 0 0.460871 0.923814 0.807875 \n", 2564 | " 1 0.463133 0.902192 0.804 \n", 2565 | " 2 0.487417 0.895001 0.79675 \n", 2566 | " 3 0.49072 0.918243 0.7825 \n", 2567 | " 4 0.51323 0.852118 0.80025 \n", 2568 | " 5 0.496024 0.872492 0.789875 \n", 2569 | " 6 0.506847 0.856714 0.802125 \n", 2570 | " 7 0.506169 0.917013 0.785 \n", 2571 | " 8 0.480601 0.872236 0.7935 \n", 2572 | " 9 0.478031 0.859447 0.79525 \n", 2573 | " 10 0.465379 0.893305 0.79625 \n", 2574 | " 11 0.455413 0.920751 0.802125 \n", 2575 | " 12 0.450254 0.921361 0.797125 \n", 2576 | " 13 0.437148 0.914842 0.802 \n", 2577 | " 14 0.445062 0.855681 0.797125 \n", 2578 | " 15 0.438345 0.867074 0.79825 \n", 2579 | " 16 0.445464 0.852298 0.80075 \n", 2580 | " 17 0.439749 0.874568 0.79525 \n", 2581 | " 18 0.434257 0.861378 0.809375 \n", 2582 | " 19 0.41873 0.868665 0.806125 \n", 2583 | "\n", 2584 | "0.809375\n", 2585 | "fold_id2\n", 2586 | "5_fold_resnext50_128_scse_hcol_2\n" 2587 | ] 2588 | }, 2589 | { 2590 | "data": { 2591 | "application/vnd.jupyter.widget-view+json": { 2592 | "model_id": "758e5881ec094853afe6f93c9588adcd", 2593 | "version_major": 2, 2594 | "version_minor": 0 2595 | }, 2596 | "text/plain": [ 2597 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 2598 | ] 2599 | }, 2600 | "metadata": { 2601 | "tags": [] 2602 | }, 2603 | "output_type": "display_data" 2604 | }, 2605 | { 2606 | "name": "stdout", 2607 | "output_type": "stream", 2608 | "text": [ 2609 | "epoch trn_loss val_loss my_eval \n", 2610 | " 0 0.46378 0.7404 0.81825 \n", 2611 | " 1 0.523088 0.696791 0.81175 \n", 2612 | " 2 0.561582 0.741509 0.806625 \n", 2613 | " 3 0.559633 0.743049 0.79175 \n", 2614 | " 70%|███████ | 35/50 [00:57<00:22, 1.49s/it, loss=0.537]" 2615 | ] 2616 | }, 2617 | { 2618 | "ename": "KeyboardInterrupt", 2619 | "evalue": "ignored", 2620 | "output_type": "error", 2621 | "traceback": [ 2622 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 2623 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 2624 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_len\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0muse_clr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbest_save_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2625 | "\u001b[0;32m/content/src/fastai/old/fastai/learner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, lrs, n_cycle, wds, **kwargs)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0mlayer_opt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_gen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwarm_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2626 | "\u001b[0;32m/content/src/fastai/old/fastai/learner.py\u001b[0m in \u001b[0;36mfit_gen\u001b[0;34m(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, use_clr_beta, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, use_swa, swa_start, swa_eval_freq, **kwargs)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreg_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp16\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 248\u001b[0m \u001b[0mswa_model\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mswa_model\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0muse_swa\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mswa_start\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mswa_start\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 249\u001b[0;31m swa_eval_freq=swa_eval_freq, **kwargs)\n\u001b[0m\u001b[1;32m 250\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2627 | "\u001b[0;32m/content/src/fastai/old/fastai/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, data, n_epochs, opt, crit, metrics, callbacks, stepper, swa_model, swa_start, swa_eval_freq, visualize, **kwargs)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0mbatch_num\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_stepper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mavg_mom\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mavg_mom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mdebias_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mavg_mom\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mbatch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2628 | "\u001b[0;32m/content/src/fastai/old/fastai/model.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, xs, y, epoch)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_scale\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_scale\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxtra\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mupdate_fp32_grads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp32_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_scale\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2629 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \"\"\"\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2630 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 88\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 89\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 2631 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 2632 | ] 2633 | } 2634 | ], 2635 | "source": [ 2636 | "model = 'resnext50_128_scse_hcol'\n", 2637 | "arch = resnext50\n", 2638 | "bst_acc=[]\n", 2639 | "use_clr_min=20\n", 2640 | "use_clr_div=10\n", 2641 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 2642 | "\n", 2643 | "szs = [(128,64)]\n", 2644 | "for sz,bs in szs:\n", 2645 | " print([sz,bs])\n", 2646 | " for i in range(kf) :\n", 2647 | " print(f'fold_id{i}')\n", 2648 | " \n", 2649 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 2650 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 2651 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 2652 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 2653 | " \n", 2654 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 2655 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 2656 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 2657 | " denorm = md.trn_ds.denorm\n", 2658 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 2659 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 2660 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 2661 | " learn = get_tgs_model() \n", 2662 | " learn.opt_fn = optim.Adam\n", 2663 | " learn.metrics=[my_eval]\n", 2664 | " pa = f'{kf}_fold_{model}_{i}'\n", 2665 | " print(pa)\n", 2666 | "\n", 2667 | " lr=1e-2\n", 2668 | " wd=1e-7\n", 2669 | " lrs = np.array([lr/100,lr/10,lr])\n", 2670 | "\n", 2671 | " learn.unfreeze() \n", 2672 | " learn.crit = lovasz_hinge\n", 2673 | " learn.load(pa)\n", 2674 | "# learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n", 2675 | " learn.fit(lrs/5,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n", 2676 | " \n", 2677 | " learn.load(pa) \n", 2678 | " #Calcuating mean iou score\n", 2679 | " v_targ = md.val_ds.ds[:][1]\n", 2680 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 2681 | " v_pred = learn.predict()\n", 2682 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 2683 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 2684 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 2685 | " print(bst_acc[-1])\n", 2686 | " \n", 2687 | " del learn\n", 2688 | " del md\n", 2689 | " gc.collect()\n", 2690 | "\n", 2691 | "# save_model_weights()" 2692 | ] 2693 | }, 2694 | { 2695 | "cell_type": "code", 2696 | "execution_count": 0, 2697 | "metadata": { 2698 | "colab": {}, 2699 | "colab_type": "code", 2700 | "id": "R2qbSFsNYJqO" 2701 | }, 2702 | "outputs": [], 2703 | "source": [ 2704 | "save_model_weights()" 2705 | ] 2706 | }, 2707 | { 2708 | "cell_type": "code", 2709 | "execution_count": 0, 2710 | "metadata": { 2711 | "colab": { 2712 | "base_uri": "https://localhost:8080/", 2713 | "height": 34 2714 | }, 2715 | "colab_type": "code", 2716 | "id": "TagiWKJ16V2z", 2717 | "outputId": "4c898c92-63b4-4bb0-af97-306f6113ac9a" 2718 | }, 2719 | "outputs": [ 2720 | { 2721 | "data": { 2722 | "text/plain": [ 2723 | "([0.823875, 0.80875, 0.828, 0.834375, 0.8375], 0.8265)" 2724 | ] 2725 | }, 2726 | "execution_count": 31, 2727 | "metadata": { 2728 | "tags": [] 2729 | }, 2730 | "output_type": "execute_result" 2731 | } 2732 | ], 2733 | "source": [ 2734 | "bst_acc,np.mean(bst_acc)#With 128" 2735 | ] 2736 | }, 2737 | { 2738 | "cell_type": "code", 2739 | "execution_count": 0, 2740 | "metadata": { 2741 | "colab": {}, 2742 | "colab_type": "code", 2743 | "id": "FswF_Yw14xk2" 2744 | }, 2745 | "outputs": [], 2746 | "source": [ 2747 | "gc.collect()" 2748 | ] 2749 | }, 2750 | { 2751 | "cell_type": "markdown", 2752 | "metadata": { 2753 | "colab_type": "text", 2754 | "id": "YxJSW5-QTk8b" 2755 | }, 2756 | "source": [ 2757 | "### Training: resnext50 + scSE" 2758 | ] 2759 | }, 2760 | { 2761 | "cell_type": "code", 2762 | "execution_count": 0, 2763 | "metadata": { 2764 | "colab": {}, 2765 | "colab_type": "code", 2766 | "id": "jXY7FrEcrLHj" 2767 | }, 2768 | "outputs": [], 2769 | "source": [ 2770 | "model = 'resnext50_128_scse'\n", 2771 | "arch = resnext50\n", 2772 | "bst_acc=[]\n", 2773 | "use_clr_min=20\n", 2774 | "use_clr_div=10\n", 2775 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 2776 | "\n", 2777 | "szs = [(128,64)]\n", 2778 | "for sz,bs in szs:\n", 2779 | " print([sz,bs])\n", 2780 | " for i in range(kf) :\n", 2781 | " print(f'fold_id{i}')\n", 2782 | " \n", 2783 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 2784 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 2785 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 2786 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 2787 | " \n", 2788 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 2789 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 2790 | " md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)\n", 2791 | " denorm = md.trn_ds.denorm\n", 2792 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 2793 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 2794 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 2795 | " learn = get_tgs_model() \n", 2796 | " learn.opt_fn = optim.Adam\n", 2797 | " learn.metrics=[my_eval]\n", 2798 | " pa = f'{kf}_fold_{model}_{i}'\n", 2799 | " print(pa)\n", 2800 | "\n", 2801 | " lr=1e-2\n", 2802 | " wd=1e-7\n", 2803 | " lrs = np.array([lr/100,lr/10,lr])\n", 2804 | "\n", 2805 | " learn.unfreeze() \n", 2806 | " learn.crit = nn.BCEWithLogitsLoss()\n", 2807 | " if os.path.exists(pa):\n", 2808 | " learn.load(pa)\n", 2809 | " learn.fit(lrs/2,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 2810 | " \n", 2811 | " learn.load(pa) \n", 2812 | " #Calcuating mean iou score\n", 2813 | " v_targ = md.val_ds.ds[:][1]\n", 2814 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 2815 | " v_pred = learn.predict()\n", 2816 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 2817 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 2818 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 2819 | " print(bst_acc[-1])" 2820 | ] 2821 | }, 2822 | { 2823 | "cell_type": "code", 2824 | "execution_count": 0, 2825 | "metadata": { 2826 | "colab": {}, 2827 | "colab_type": "code", 2828 | "id": "6mFbdBCuWbgc" 2829 | }, 2830 | "outputs": [], 2831 | "source": [ 2832 | "model = 'resnext50_128_scse'\n", 2833 | "arch = resnet34\n", 2834 | "bst_acc=[]\n", 2835 | "use_clr_min=20\n", 2836 | "use_clr_div=10\n", 2837 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 2838 | "\n", 2839 | "szs = [(128,64)]\n", 2840 | "for sz,bs in szs:\n", 2841 | " print([sz,bs])\n", 2842 | " for i in range(kf) :\n", 2843 | " print(f'fold_id{i}')\n", 2844 | " \n", 2845 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 2846 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 2847 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 2848 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 2849 | " \n", 2850 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 2851 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 2852 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 2853 | " denorm = md.trn_ds.denorm\n", 2854 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 2855 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 2856 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 2857 | " learn = get_tgs_model() \n", 2858 | " learn.opt_fn = optim.Adam\n", 2859 | " learn.metrics=[my_eval]\n", 2860 | " pa = f'{kf}_fold_{model}_{i}'\n", 2861 | " print(pa)\n", 2862 | "\n", 2863 | " lr=1e-2\n", 2864 | " wd=1e-7\n", 2865 | " lrs = np.array([lr/100,lr/10,lr])\n", 2866 | "\n", 2867 | " learn.unfreeze() \n", 2868 | " learn.crit = lovasz_hinge\n", 2869 | " learn.load(pa)\n", 2870 | "# learn.fit(lrs/3,2, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 2871 | " learn.fit(lrs/3,3, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 2872 | " \n", 2873 | " learn.load(pa) \n", 2874 | " #Calcuating mean iou score\n", 2875 | " v_targ = md.val_ds.ds[:][1]\n", 2876 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 2877 | " v_pred = learn.predict()\n", 2878 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 2879 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 2880 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 2881 | " print(bst_acc[-1])" 2882 | ] 2883 | }, 2884 | { 2885 | "cell_type": "code", 2886 | "execution_count": 0, 2887 | "metadata": { 2888 | "colab": {}, 2889 | "colab_type": "code", 2890 | "id": "UpWadtxBDMYv" 2891 | }, 2892 | "outputs": [], 2893 | "source": [ 2894 | "model = 'resnext50_128_scse'\n", 2895 | "arch = resnext50\n", 2896 | "bst_acc=[]\n", 2897 | "use_clr_min=20\n", 2898 | "use_clr_div=10\n", 2899 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 2900 | "\n", 2901 | "szs = [(128,64)]\n", 2902 | "for sz,bs in szs:\n", 2903 | " print([sz,bs])\n", 2904 | " for i in range(kf) :\n", 2905 | " print(f'fold_id{i}')\n", 2906 | " \n", 2907 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 2908 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 2909 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 2910 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 2911 | " \n", 2912 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 2913 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 2914 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 2915 | " denorm = md.trn_ds.denorm\n", 2916 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 2917 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 2918 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 2919 | " learn = get_tgs_model() \n", 2920 | " learn.opt_fn = optim.Adam\n", 2921 | " learn.metrics=[my_eval]\n", 2922 | " pa = f'{kf}_fold_{model}_{i}'\n", 2923 | " print(pa)\n", 2924 | "\n", 2925 | " lr=1e-2\n", 2926 | " wd=1e-7\n", 2927 | " lrs = np.array([lr/100,lr/10,lr])\n", 2928 | "\n", 2929 | " learn.unfreeze() \n", 2930 | " learn.crit = lovasz_hinge\n", 2931 | " learn.load(pa)\n", 2932 | "# learn.fit(lrs/3,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 2933 | " learn.fit(lrs/3,1, wds=wd, cycle_len=20,use_clr=(20,10),best_save_name=pa)\n", 2934 | " \n", 2935 | " learn.load(pa) \n", 2936 | " #Calcuating mean iou score\n", 2937 | " v_targ = md.val_ds.ds[:][1]\n", 2938 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 2939 | " v_pred = learn.predict()\n", 2940 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 2941 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 2942 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 2943 | " print(bst_acc[-1])" 2944 | ] 2945 | }, 2946 | { 2947 | "cell_type": "code", 2948 | "execution_count": 0, 2949 | "metadata": { 2950 | "colab": { 2951 | "base_uri": "https://localhost:8080/", 2952 | "height": 5830 2953 | }, 2954 | "colab_type": "code", 2955 | "id": "qrPUrDI6pieX", 2956 | "outputId": "8c65fb9e-b0e7-422a-f6ab-0e82cce4b87b" 2957 | }, 2958 | "outputs": [ 2959 | { 2960 | "name": "stdout", 2961 | "output_type": "stream", 2962 | "text": [ 2963 | "[128, 64]\n", 2964 | "fold_id0\n", 2965 | "5_fold_resnext50_128_scse_0\n" 2966 | ] 2967 | }, 2968 | { 2969 | "data": { 2970 | "application/vnd.jupyter.widget-view+json": { 2971 | "model_id": "9a222f50d2b1423db51268c7a6edcbcc", 2972 | "version_major": 2, 2973 | "version_minor": 0 2974 | }, 2975 | "text/plain": [ 2976 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 2977 | ] 2978 | }, 2979 | "metadata": { 2980 | "tags": [] 2981 | }, 2982 | "output_type": "display_data" 2983 | }, 2984 | { 2985 | "name": "stdout", 2986 | "output_type": "stream", 2987 | "text": [ 2988 | "epoch trn_loss val_loss my_eval \n", 2989 | " 0 0.533707 0.76256 0.8175 \n", 2990 | " 1 0.594565 0.858352 0.798875 \n", 2991 | " 2 0.601546 0.788414 0.812125 \n", 2992 | " 3 0.602312 0.763777 0.816625 \n", 2993 | " 4 0.597194 0.756633 0.816125 \n", 2994 | " 5 0.56907 0.767038 0.81325 \n", 2995 | " 6 0.564023 0.746384 0.82175 \n", 2996 | " 7 0.579532 0.831695 0.81125 \n", 2997 | " 8 0.551373 0.760372 0.810875 \n", 2998 | " 9 0.54197 0.845872 0.809625 \n", 2999 | " 10 0.539457 0.726901 0.813625 \n", 3000 | " 11 0.528602 0.7233 0.81875 \n", 3001 | " 12 0.520925 0.772535 0.822875 \n", 3002 | " 13 0.518783 0.764549 0.8235 \n", 3003 | " 14 0.513134 0.735043 0.830125 \n", 3004 | " 15 0.510657 0.727242 0.826625 \n", 3005 | " 16 0.493573 0.704724 0.82725 \n", 3006 | " 17 0.481923 0.73212 0.829125 \n", 3007 | " 18 0.471987 0.698919 0.828 \n", 3008 | " 19 0.463161 0.711147 0.82725 \n", 3009 | "\n" 3010 | ] 3011 | }, 3012 | { 3013 | "data": { 3014 | "application/vnd.jupyter.widget-view+json": { 3015 | "model_id": "076489da5da74916b79331f943e5b625", 3016 | "version_major": 2, 3017 | "version_minor": 0 3018 | }, 3019 | "text/plain": [ 3020 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 3021 | ] 3022 | }, 3023 | "metadata": { 3024 | "tags": [] 3025 | }, 3026 | "output_type": "display_data" 3027 | }, 3028 | { 3029 | "name": "stdout", 3030 | "output_type": "stream", 3031 | "text": [ 3032 | "epoch trn_loss val_loss my_eval \n", 3033 | " 0 0.479967 0.827555 0.821125 \n", 3034 | " 1 0.506115 0.7792 0.807 \n", 3035 | " 2 0.521105 0.704616 0.823375 \n", 3036 | " 3 0.531661 0.759747 0.813625 \n", 3037 | " 4 0.527358 0.76719 0.8165 \n", 3038 | " 5 0.599692 0.78576 0.8095 \n", 3039 | " 6 0.579218 0.723072 0.814375 \n", 3040 | " 7 0.553373 0.735349 0.81025 \n", 3041 | " 8 0.522601 0.74302 0.82075 \n", 3042 | " 9 0.51634 0.725892 0.8255 \n", 3043 | " 10 0.507244 0.736419 0.810875 \n", 3044 | " 11 0.495299 0.747552 0.8235 \n", 3045 | " 12 0.517736 0.709832 0.8205 \n", 3046 | " 13 0.52503 0.702308 0.82 \n", 3047 | " 14 0.517949 0.763225 0.822875 \n", 3048 | " 15 0.481496 0.760829 0.825875 \n", 3049 | " 16 0.467406 0.75723 0.818 \n", 3050 | " 17 0.48397 0.719035 0.821125 \n", 3051 | " 18 0.480835 0.74827 0.81975 \n", 3052 | " 19 0.464327 0.758697 0.822625 \n", 3053 | " 20 0.460972 0.714269 0.8275 \n", 3054 | " 21 0.451309 0.763296 0.824375 \n", 3055 | " 22 0.443526 0.740101 0.82525 \n", 3056 | " 23 0.432715 0.74261 0.829875 \n", 3057 | " 24 0.459113 0.808962 0.8255 \n", 3058 | " 25 0.449143 0.793144 0.827 \n", 3059 | " 26 0.460752 0.774263 0.824 \n", 3060 | " 27 0.448464 0.768298 0.823625 \n", 3061 | " 28 0.429408 0.746062 0.826875 \n", 3062 | " 29 0.418083 0.767724 0.830375 \n", 3063 | " 30 0.420722 0.716629 0.82625 \n", 3064 | " 31 0.426644 0.73742 0.829625 \n", 3065 | " 32 0.425052 0.748321 0.82975 \n", 3066 | " 33 0.419258 0.759544 0.829125 \n", 3067 | " 34 0.411915 0.743406 0.82825 \n", 3068 | " 35 0.4213 0.755378 0.82675 \n", 3069 | " 36 0.404763 0.741769 0.831125 \n", 3070 | " 37 0.398306 0.767223 0.825625 \n", 3071 | " 38 0.396879 0.733644 0.825625 \n", 3072 | " 39 0.413759 0.745653 0.826 \n", 3073 | "\n", 3074 | "0.831125\n", 3075 | "fold_id1\n", 3076 | "5_fold_resnext50_128_scse_1\n" 3077 | ] 3078 | }, 3079 | { 3080 | "data": { 3081 | "application/vnd.jupyter.widget-view+json": { 3082 | "model_id": "756be4b59566488e876b7880d74f2849", 3083 | "version_major": 2, 3084 | "version_minor": 0 3085 | }, 3086 | "text/plain": [ 3087 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 3088 | ] 3089 | }, 3090 | "metadata": { 3091 | "tags": [] 3092 | }, 3093 | "output_type": "display_data" 3094 | }, 3095 | { 3096 | "name": "stdout", 3097 | "output_type": "stream", 3098 | "text": [ 3099 | "epoch trn_loss val_loss my_eval \n", 3100 | " 0 0.467102 0.833521 0.792125 \n", 3101 | " 1 0.514805 0.866801 0.770625 \n", 3102 | " 2 0.512858 0.888006 0.78125 \n", 3103 | " 3 0.523391 0.872697 0.785625 \n", 3104 | " 4 0.541148 0.89274 0.7915 \n", 3105 | " 5 0.534021 0.817045 0.794 \n", 3106 | " 6 0.511217 0.827669 0.80725 \n", 3107 | " 7 0.504867 0.870196 0.796375 \n", 3108 | " 8 0.482823 0.813677 0.80225 \n", 3109 | " 9 0.470734 0.859631 0.79925 \n", 3110 | " 10 0.50188 0.805922 0.790625 \n", 3111 | " 11 0.50007 0.865748 0.790125 \n", 3112 | " 12 0.490715 0.808949 0.79925 \n", 3113 | " 13 0.467356 0.822104 0.80175 \n", 3114 | " 14 0.459628 0.823243 0.8005 \n", 3115 | " 15 0.442727 0.81513 0.789125 \n", 3116 | " 16 0.433787 0.801445 0.803375 \n", 3117 | " 17 0.428531 0.821544 0.79725 \n", 3118 | " 18 0.427835 0.775848 0.80675 \n", 3119 | " 19 0.41361 0.812282 0.805 \n", 3120 | "\n" 3121 | ] 3122 | }, 3123 | { 3124 | "data": { 3125 | "application/vnd.jupyter.widget-view+json": { 3126 | "model_id": "bc7d0f393faf4c2ea88f2dee53b90a55", 3127 | "version_major": 2, 3128 | "version_minor": 0 3129 | }, 3130 | "text/plain": [ 3131 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 3132 | ] 3133 | }, 3134 | "metadata": { 3135 | "tags": [] 3136 | }, 3137 | "output_type": "display_data" 3138 | }, 3139 | { 3140 | "name": "stdout", 3141 | "output_type": "stream", 3142 | "text": [ 3143 | "epoch trn_loss val_loss my_eval \n", 3144 | " 0 0.443954 0.821174 0.795625 \n", 3145 | " 1 0.473285 0.845159 0.792 \n", 3146 | " 2 0.465905 0.985209 0.785 \n", 3147 | " 3 0.456353 0.974172 0.786875 \n", 3148 | " 4 0.467155 0.920761 0.79475 \n", 3149 | " 5 0.470211 0.900441 0.79725 \n", 3150 | " 6 0.455301 0.96447 0.792 \n", 3151 | " 7 0.47351 0.90089 0.78875 \n", 3152 | " 8 0.463646 0.856061 0.794875 \n", 3153 | " 9 0.443074 0.860522 0.792 \n", 3154 | " 10 0.429617 0.824053 0.798375 \n", 3155 | " 11 0.445916 0.875289 0.796375 \n", 3156 | " 12 0.453059 0.872382 0.798125 \n", 3157 | " 13 0.447237 0.872192 0.801125 \n", 3158 | " 14 0.44465 0.867648 0.800875 \n", 3159 | " 15 0.447324 0.889119 0.795 \n", 3160 | " 16 0.449389 0.843531 0.800875 \n", 3161 | " 17 0.438213 0.823923 0.8075 \n", 3162 | " 18 0.434376 0.847579 0.7975 \n", 3163 | " 19 0.443935 0.843054 0.801625 \n", 3164 | " 20 0.422225 0.849103 0.80025 \n", 3165 | " 21 0.414883 0.866637 0.7965 \n", 3166 | " 22 0.415798 0.935687 0.791625 \n", 3167 | " 23 0.401541 0.8712 0.800375 \n", 3168 | " 24 0.408011 0.873514 0.793375 \n", 3169 | " 25 0.38908 0.84675 0.798875 \n", 3170 | " 26 0.392715 0.857604 0.79125 \n", 3171 | " 27 0.376516 0.896894 0.792125 \n", 3172 | " 28 0.369762 0.973169 0.786875 \n", 3173 | " 29 0.385661 0.912236 0.798875 \n", 3174 | " 30 0.38374 0.864333 0.801625 \n", 3175 | " 31 0.36842 0.901064 0.806 \n", 3176 | " 32 0.378856 0.892412 0.800125 \n", 3177 | " 33 0.377577 0.88206 0.801 \n", 3178 | " 34 0.35965 0.856498 0.79675 \n", 3179 | " 35 0.365454 0.8875 0.798875 \n", 3180 | " 36 0.371453 0.905397 0.80125 \n", 3181 | " 37 0.365299 0.876433 0.7985 \n", 3182 | " 38 0.358871 0.897993 0.799625 \n", 3183 | " 39 0.360378 0.904567 0.797375 \n", 3184 | "\n", 3185 | "0.8075\n", 3186 | "fold_id2\n", 3187 | "5_fold_resnext50_128_scse_2\n" 3188 | ] 3189 | }, 3190 | { 3191 | "data": { 3192 | "application/vnd.jupyter.widget-view+json": { 3193 | "model_id": "a25f7884982344d68ae68729bf31f33e", 3194 | "version_major": 2, 3195 | "version_minor": 0 3196 | }, 3197 | "text/plain": [ 3198 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 3199 | ] 3200 | }, 3201 | "metadata": { 3202 | "tags": [] 3203 | }, 3204 | "output_type": "display_data" 3205 | }, 3206 | { 3207 | "name": "stdout", 3208 | "output_type": "stream", 3209 | "text": [ 3210 | "epoch trn_loss val_loss my_eval \n", 3211 | " 0 0.485156 0.756766 0.82075 \n", 3212 | " 1 0.517438 0.832651 0.814375 \n", 3213 | " 2 0.518396 0.735615 0.815 \n", 3214 | " 3 0.522302 0.761968 0.806625 \n", 3215 | " 4 0.540298 0.807242 0.81675 \n", 3216 | " 5 0.536279 0.714934 0.81575 \n", 3217 | " 6 0.510553 0.791743 0.797375 \n", 3218 | " 7 0.511985 0.739001 0.817625 \n", 3219 | " 8 0.50142 0.754067 0.81825 \n", 3220 | " 9 0.496233 0.749199 0.818 \n", 3221 | " 10 0.479415 0.739346 0.821375 \n", 3222 | " 11 0.4808 0.692525 0.826 \n", 3223 | " 12 0.473838 0.731722 0.823 \n", 3224 | " 13 0.467366 0.737328 0.82275 \n", 3225 | " 14 0.476735 0.719575 0.81975 \n", 3226 | " 15 0.448495 0.726288 0.820375 \n", 3227 | " 16 0.434496 0.749208 0.82025 \n", 3228 | " 17 0.426012 0.765359 0.825 \n", 3229 | " 18 0.435108 0.743209 0.8205 \n", 3230 | " 19 0.433077 0.765477 0.82475 \n", 3231 | "\n" 3232 | ] 3233 | }, 3234 | { 3235 | "data": { 3236 | "application/vnd.jupyter.widget-view+json": { 3237 | "model_id": "354e58e006fe4d679dab82ffe007f623", 3238 | "version_major": 2, 3239 | "version_minor": 0 3240 | }, 3241 | "text/plain": [ 3242 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 3243 | ] 3244 | }, 3245 | "metadata": { 3246 | "tags": [] 3247 | }, 3248 | "output_type": "display_data" 3249 | }, 3250 | { 3251 | "name": "stdout", 3252 | "output_type": "stream", 3253 | "text": [ 3254 | "epoch trn_loss val_loss my_eval \n", 3255 | " 0 0.443196 0.750643 0.826125 \n", 3256 | " 1 0.463696 0.795019 0.81775 \n", 3257 | " 2 0.475888 0.753367 0.819375 \n", 3258 | " 3 0.478167 0.75828 0.814875 \n", 3259 | " 4 0.471563 0.778656 0.813625 \n", 3260 | " 5 0.483763 0.742092 0.81475 \n", 3261 | " 6 0.479101 0.748993 0.81875 \n", 3262 | " 7 0.476277 0.780733 0.810875 \n", 3263 | " 8 0.498361 0.69888 0.82975 \n", 3264 | " 9 0.486142 0.739697 0.81925 \n", 3265 | " 10 0.457164 0.801588 0.825375 \n", 3266 | " 11 0.455916 0.745798 0.815875 \n", 3267 | " 12 0.457613 0.76608 0.818625 \n", 3268 | " 13 0.45103 0.773046 0.8205 \n", 3269 | " 14 0.447433 0.764471 0.813 \n", 3270 | " 15 0.436356 0.770631 0.82025 \n", 3271 | " 16 0.440452 0.798791 0.824875 \n", 3272 | " 17 0.43147 0.753223 0.825125 \n", 3273 | " 18 0.432604 0.828235 0.8245 \n", 3274 | " 19 0.435804 0.800897 0.81925 \n", 3275 | " 20 0.416318 0.782468 0.82325 \n", 3276 | " 21 0.425427 0.760889 0.82525 \n", 3277 | " 22 0.414975 0.779994 0.817375 \n", 3278 | " 23 0.41498 0.785758 0.82025 \n", 3279 | " 24 0.423222 0.784227 0.81475 \n", 3280 | " 25 0.417059 0.753324 0.823875 \n", 3281 | " 26 0.406137 0.769448 0.825125 \n", 3282 | " 27 0.398413 0.820555 0.824125 \n", 3283 | " 28 0.393823 0.790955 0.82025 \n", 3284 | " 29 0.404996 0.793106 0.827375 \n", 3285 | " 30 0.388835 0.761253 0.827 \n", 3286 | " 31 0.38426 0.748922 0.829625 \n", 3287 | " 32 0.380705 0.752186 0.829125 \n", 3288 | " 33 0.378552 0.722271 0.827 \n", 3289 | " 34 0.378998 0.761381 0.83075 \n", 3290 | " 35 0.380338 0.760387 0.828625 \n", 3291 | " 36 0.375267 0.79428 0.825375 \n", 3292 | " 37 0.37371 0.777205 0.830375 \n", 3293 | " 38 0.374267 0.77222 0.8245 \n", 3294 | " 39 0.372311 0.783963 0.824625 \n", 3295 | "\n", 3296 | "0.83075\n", 3297 | "fold_id3\n", 3298 | "5_fold_resnext50_128_scse_3\n" 3299 | ] 3300 | }, 3301 | { 3302 | "data": { 3303 | "application/vnd.jupyter.widget-view+json": { 3304 | "model_id": "f568474ab13a402f984119d9c0490b8e", 3305 | "version_major": 2, 3306 | "version_minor": 0 3307 | }, 3308 | "text/plain": [ 3309 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 3310 | ] 3311 | }, 3312 | "metadata": { 3313 | "tags": [] 3314 | }, 3315 | "output_type": "display_data" 3316 | }, 3317 | { 3318 | "name": "stdout", 3319 | "output_type": "stream", 3320 | "text": [ 3321 | "epoch trn_loss val_loss my_eval \n", 3322 | " 0 0.496033 0.709309 0.81325 \n", 3323 | " 1 0.521143 0.718663 0.82275 \n", 3324 | " 2 0.54441 0.835224 0.81875 \n", 3325 | " 3 0.549383 0.731907 0.821625 \n", 3326 | " 4 0.539594 0.64432 0.82425 \n", 3327 | " 5 0.534994 0.674456 0.82175 \n", 3328 | " 6 0.529376 0.71022 0.82975 \n", 3329 | " 7 0.502961 0.68506 0.824875 \n", 3330 | " 8 0.498022 0.686693 0.830625 \n", 3331 | " 9 0.500132 0.725364 0.822125 \n", 3332 | " 10 0.50281 0.675001 0.831875 \n", 3333 | " 11 0.512494 0.684257 0.836875 \n", 3334 | " 12 0.484919 0.696663 0.83475 \n", 3335 | " 13 0.469144 0.709017 0.837875 \n", 3336 | " 14 0.469601 0.659997 0.834375 \n", 3337 | " 15 0.46119 0.692171 0.834125 \n", 3338 | " 16 0.453164 0.69287 0.83025 \n", 3339 | " 17 0.443224 0.70075 0.834625 \n", 3340 | " 18 0.44307 0.694192 0.83 \n", 3341 | " 19 0.428427 0.728258 0.83575 \n", 3342 | "\n" 3343 | ] 3344 | }, 3345 | { 3346 | "data": { 3347 | "application/vnd.jupyter.widget-view+json": { 3348 | "model_id": "5eaf0a7d85f94e4aac0efe9e6d732595", 3349 | "version_major": 2, 3350 | "version_minor": 0 3351 | }, 3352 | "text/plain": [ 3353 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 3354 | ] 3355 | }, 3356 | "metadata": { 3357 | "tags": [] 3358 | }, 3359 | "output_type": "display_data" 3360 | }, 3361 | { 3362 | "name": "stdout", 3363 | "output_type": "stream", 3364 | "text": [ 3365 | "epoch trn_loss val_loss my_eval \n", 3366 | " 0 0.422124 0.726504 0.8295 \n", 3367 | " 1 0.462113 0.731205 0.8245 \n", 3368 | " 2 0.486406 0.728322 0.829125 \n", 3369 | " 3 0.487896 0.765298 0.828875 \n", 3370 | " 4 0.484002 0.75814 0.8245 \n", 3371 | " 5 0.491141 0.697839 0.827125 \n", 3372 | " 6 0.477768 0.709401 0.822625 \n", 3373 | " 7 0.483948 0.766132 0.82175 \n", 3374 | " 8 0.490673 0.695626 0.830375 \n", 3375 | " 9 0.472208 0.698568 0.81975 \n", 3376 | " 10 0.466838 0.67741 0.827375 \n", 3377 | " 11 0.463749 0.797842 0.81275 \n", 3378 | " 12 0.448588 0.679358 0.830125 \n", 3379 | " 13 0.469459 0.723771 0.828375 \n", 3380 | " 14 0.452238 0.672009 0.830875 \n", 3381 | " 15 0.447767 0.722987 0.830875 \n", 3382 | " 16 0.43008 0.666037 0.831 \n", 3383 | " 17 0.430931 0.738625 0.827625 \n", 3384 | " 18 0.427639 0.6799 0.83275 \n", 3385 | " 19 0.430004 0.645604 0.82625 \n", 3386 | " 20 0.440587 0.710382 0.825625 \n", 3387 | " 21 0.427687 0.758092 0.8225 \n", 3388 | " 22 0.424524 0.714262 0.830125 \n", 3389 | " 23 0.415856 0.734838 0.829375 \n", 3390 | " 24 0.420022 0.709158 0.834375 \n", 3391 | " 25 0.40909 0.71966 0.838125 \n", 3392 | " 26 0.392587 0.719539 0.836625 \n", 3393 | " 27 0.393285 0.732876 0.82375 \n", 3394 | " 28 0.389524 0.695115 0.83525 \n", 3395 | " 29 0.391245 0.734108 0.829375 \n", 3396 | " 30 0.384423 0.73317 0.833625 \n", 3397 | " 31 0.373081 0.731024 0.830875 \n", 3398 | " 32 0.389755 0.754126 0.83025 \n", 3399 | " 33 0.38323 0.752325 0.83575 \n", 3400 | " 34 0.375477 0.733968 0.83 \n", 3401 | " 35 0.371067 0.743891 0.83375 \n", 3402 | " 36 0.382864 0.747606 0.833875 \n", 3403 | " 37 0.370416 0.745533 0.83125 \n", 3404 | " 38 0.362293 0.751041 0.834 \n", 3405 | " 39 0.365368 0.740053 0.833625 \n", 3406 | "\n", 3407 | "0.8381250000000001\n", 3408 | "fold_id4\n", 3409 | "5_fold_resnext50_128_scse_4\n" 3410 | ] 3411 | }, 3412 | { 3413 | "data": { 3414 | "application/vnd.jupyter.widget-view+json": { 3415 | "model_id": "f6ae02badebd42bcbf21aaabae7cd6f4", 3416 | "version_major": 2, 3417 | "version_minor": 0 3418 | }, 3419 | "text/plain": [ 3420 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 3421 | ] 3422 | }, 3423 | "metadata": { 3424 | "tags": [] 3425 | }, 3426 | "output_type": "display_data" 3427 | }, 3428 | { 3429 | "name": "stdout", 3430 | "output_type": "stream", 3431 | "text": [ 3432 | "epoch trn_loss val_loss my_eval \n", 3433 | " 0 0.47414 0.722221 0.827 \n", 3434 | " 1 0.491114 0.712578 0.82425 \n", 3435 | " 2 0.51879 0.73661 0.815875 \n", 3436 | " 3 0.540288 0.692429 0.826375 \n", 3437 | " 4 0.534986 0.700633 0.827 \n", 3438 | " 5 0.518888 0.720056 0.826625 \n", 3439 | " 6 0.510556 0.819519 0.814 \n", 3440 | " 7 0.498047 0.740013 0.821125 \n", 3441 | " 8 0.467538 0.726275 0.821875 \n", 3442 | " 9 0.465198 0.740374 0.82125 \n", 3443 | " 10 0.454122 0.729831 0.822125 \n", 3444 | " 11 0.465724 0.69371 0.827625 \n", 3445 | " 12 0.459633 0.711312 0.82275 \n", 3446 | " 13 0.44714 0.709925 0.8275 \n", 3447 | " 14 0.450566 0.755889 0.81475 \n", 3448 | " 15 0.433707 0.689775 0.836125 \n", 3449 | " 16 0.419188 0.687338 0.8345 \n", 3450 | " 17 0.419773 0.675787 0.834 \n", 3451 | " 18 0.422193 0.681697 0.83225 \n", 3452 | " 19 0.414564 0.676884 0.833125 \n", 3453 | "\n" 3454 | ] 3455 | }, 3456 | { 3457 | "data": { 3458 | "application/vnd.jupyter.widget-view+json": { 3459 | "model_id": "95d90582589d4934be7383e615f1dbd0", 3460 | "version_major": 2, 3461 | "version_minor": 0 3462 | }, 3463 | "text/plain": [ 3464 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))" 3465 | ] 3466 | }, 3467 | "metadata": { 3468 | "tags": [] 3469 | }, 3470 | "output_type": "display_data" 3471 | }, 3472 | { 3473 | "name": "stdout", 3474 | "output_type": "stream", 3475 | "text": [ 3476 | "epoch trn_loss val_loss my_eval \n", 3477 | " 0 0.416673 0.743005 0.82225 \n", 3478 | " 1 0.451756 0.796281 0.822 \n", 3479 | " 2 0.47234 0.772019 0.824125 \n", 3480 | " 3 0.479158 0.759053 0.807875 \n", 3481 | " 4 0.46425 0.769957 0.815375 \n", 3482 | " 5 0.464162 0.772572 0.815875 \n", 3483 | " 6 0.460326 0.737922 0.820625 \n", 3484 | " 7 0.468473 0.750393 0.818375 \n", 3485 | " 8 0.458967 0.704921 0.831375 \n", 3486 | " 9 0.439696 0.740977 0.832625 \n", 3487 | " 10 0.445156 0.723011 0.824875 \n", 3488 | " 11 0.434864 0.725554 0.833625 \n", 3489 | " 12 0.422101 0.704498 0.828375 \n", 3490 | " 13 0.429801 0.699662 0.82975 \n", 3491 | " 14 0.432414 0.763216 0.81775 \n", 3492 | " 15 0.423432 0.72516 0.822875 \n", 3493 | " 16 0.425951 0.732884 0.832625 \n", 3494 | " 17 0.41282 0.699522 0.830125 \n", 3495 | " 18 0.40507 0.741123 0.8335 \n", 3496 | " 19 0.406175 0.729666 0.8355 \n", 3497 | " 20 0.423782 0.71638 0.82775 \n", 3498 | " 21 0.407595 0.724523 0.82375 \n", 3499 | " 22 0.400728 0.743059 0.82825 \n", 3500 | " 23 0.397718 0.712238 0.828375 \n", 3501 | " 24 0.400653 0.771733 0.81775 \n", 3502 | " 25 0.406241 0.706003 0.83425 \n", 3503 | " 26 0.384568 0.726671 0.830625 \n", 3504 | " 27 0.395504 0.695435 0.83125 \n", 3505 | " 28 0.392318 0.724178 0.831125 \n", 3506 | " 29 0.381853 0.684155 0.83 \n", 3507 | " 30 0.375725 0.688363 0.83775 \n", 3508 | " 31 0.385391 0.703829 0.836875 \n", 3509 | " 32 0.362482 0.684262 0.84275 \n", 3510 | " 33 0.354302 0.701925 0.83675 \n", 3511 | " 34 0.365955 0.715952 0.83425 \n", 3512 | " 35 0.367506 0.688223 0.8365 \n", 3513 | " 36 0.361231 0.706321 0.839 \n", 3514 | " 37 0.35691 0.693329 0.835625 \n", 3515 | " 38 0.363043 0.709944 0.832125 \n", 3516 | " 39 0.350134 0.699815 0.834625 \n", 3517 | "\n", 3518 | "0.8427500000000001\n" 3519 | ] 3520 | } 3521 | ], 3522 | "source": [ 3523 | "model = 'resnext50_128_scse'\n", 3524 | "arch = resnext50\n", 3525 | "bst_acc=[]\n", 3526 | "use_clr_min=20\n", 3527 | "use_clr_div=10\n", 3528 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 3529 | "\n", 3530 | "szs = [(128,64)]\n", 3531 | "for sz,bs in szs:\n", 3532 | " print([sz,bs])\n", 3533 | " for i in range(kf) :\n", 3534 | " print(f'fold_id{i}')\n", 3535 | " \n", 3536 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 3537 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 3538 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 3539 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 3540 | " \n", 3541 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 3542 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 3543 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 3544 | " denorm = md.trn_ds.denorm\n", 3545 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 3546 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 3547 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 3548 | " learn = get_tgs_model() \n", 3549 | " learn.opt_fn = optim.Adam\n", 3550 | " learn.metrics=[my_eval]\n", 3551 | " pa = f'{kf}_fold_{model}_{i}'\n", 3552 | " print(pa)\n", 3553 | "\n", 3554 | " lr=1e-2\n", 3555 | " wd=1e-7\n", 3556 | " lrs = np.array([lr/100,lr/10,lr])\n", 3557 | "\n", 3558 | " learn.unfreeze() \n", 3559 | " learn.crit = lovasz_hinge\n", 3560 | " learn.load(pa)\n", 3561 | "# learn.fit(lrs/3,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n", 3562 | " learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n", 3563 | " learn.fit(lrs/5,1, wds=wd, cycle_len=40,use_clr=(20,32),best_save_name=pa)\n", 3564 | " \n", 3565 | " learn.load(pa) \n", 3566 | " #Calcuating mean iou score\n", 3567 | " v_targ = md.val_ds.ds[:][1]\n", 3568 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 3569 | " v_pred = learn.predict()\n", 3570 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 3571 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 3572 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 3573 | " print(bst_acc[-1])\n", 3574 | "\n", 3575 | "save_model_weights()" 3576 | ] 3577 | }, 3578 | { 3579 | "cell_type": "code", 3580 | "execution_count": 0, 3581 | "metadata": { 3582 | "colab": { 3583 | "base_uri": "https://localhost:8080/", 3584 | "height": 2218 3585 | }, 3586 | "colab_type": "code", 3587 | "id": "ga4JYphvCPyd", 3588 | "outputId": "b589a96a-cf8d-40d5-b11f-a32c260d5aa5" 3589 | }, 3590 | "outputs": [ 3591 | { 3592 | "name": "stdout", 3593 | "output_type": "stream", 3594 | "text": [ 3595 | "[128, 64]\n", 3596 | "fold_id0\n", 3597 | "5_fold_resnext50_128_scse_0\n" 3598 | ] 3599 | }, 3600 | { 3601 | "data": { 3602 | "application/vnd.jupyter.widget-view+json": { 3603 | "model_id": "3f19b564db994e6687c8b43ab8f917ed", 3604 | "version_major": 2, 3605 | "version_minor": 0 3606 | }, 3607 | "text/plain": [ 3608 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3609 | ] 3610 | }, 3611 | "metadata": { 3612 | "tags": [] 3613 | }, 3614 | "output_type": "display_data" 3615 | }, 3616 | { 3617 | "name": "stdout", 3618 | "output_type": "stream", 3619 | "text": [ 3620 | "epoch trn_loss val_loss my_eval \n", 3621 | " 0 0.352572 0.727342 0.824 \n", 3622 | " 1 0.369447 0.676443 0.83 \n", 3623 | " 2 0.387457 0.791503 0.825 \n", 3624 | " 3 0.351119 0.766687 0.827625 \n", 3625 | " 4 0.36598 0.731143 0.83275 \n", 3626 | " 5 0.364853 0.723793 0.8195 \n", 3627 | " 6 0.381143 0.841033 0.829125 \n", 3628 | " 7 0.364594 0.680827 0.83075 \n", 3629 | " 8 0.349329 0.828338 0.817375 \n", 3630 | " 9 0.35967 0.746246 0.834 \n", 3631 | " 10 0.355637 0.749717 0.83125 \n", 3632 | " 11 0.353103 0.816089 0.83225 \n", 3633 | " 12 0.33925 0.821926 0.83525 \n", 3634 | " 13 0.347948 0.751935 0.83625 \n", 3635 | " 14 0.325283 0.842732 0.833625 \n", 3636 | " 15 0.332459 0.769492 0.83075 \n", 3637 | " 16 0.329128 0.773248 0.83475 \n", 3638 | " 17 0.326993 0.837361 0.837 \n", 3639 | " 18 0.33641 0.736122 0.83775 \n", 3640 | " 19 0.32058 0.791872 0.8365 \n", 3641 | "\n", 3642 | "0.8377500000000001\n", 3643 | "fold_id1\n", 3644 | "5_fold_resnext50_128_scse_1\n" 3645 | ] 3646 | }, 3647 | { 3648 | "data": { 3649 | "application/vnd.jupyter.widget-view+json": { 3650 | "model_id": "68bf61a771034be3aa0b6105e3a7b93f", 3651 | "version_major": 2, 3652 | "version_minor": 0 3653 | }, 3654 | "text/plain": [ 3655 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3656 | ] 3657 | }, 3658 | "metadata": { 3659 | "tags": [] 3660 | }, 3661 | "output_type": "display_data" 3662 | }, 3663 | { 3664 | "name": "stdout", 3665 | "output_type": "stream", 3666 | "text": [ 3667 | "epoch trn_loss val_loss my_eval \n", 3668 | " 0 0.329473 0.895153 0.791 \n", 3669 | " 1 0.345151 0.857987 0.802125 \n", 3670 | " 2 0.346842 0.957962 0.787625 \n", 3671 | " 3 0.359232 0.914885 0.78725 \n", 3672 | " 4 0.35904 0.945004 0.79475 \n", 3673 | " 5 0.364206 0.90916 0.79025 \n", 3674 | " 6 0.352084 0.907089 0.803125 \n", 3675 | " 7 0.340385 0.884313 0.79975 \n", 3676 | " 8 0.331111 0.935045 0.80525 \n", 3677 | " 9 0.327103 0.912742 0.80725 \n", 3678 | " 10 0.334112 0.903979 0.799875 \n", 3679 | " 11 0.339094 0.997254 0.79625 \n", 3680 | " 12 0.336058 0.903867 0.802 \n", 3681 | " 13 0.330904 0.945125 0.79575 \n", 3682 | " 14 0.328544 0.888129 0.798625 \n", 3683 | " 15 0.322872 0.931454 0.7975 \n", 3684 | " 16 0.320054 0.917544 0.803375 \n", 3685 | " 17 0.318902 0.909964 0.8055 \n", 3686 | " 18 0.318903 0.864422 0.80375 \n", 3687 | " 19 0.319137 0.889656 0.803875 \n", 3688 | "\n", 3689 | "0.8072499999999999\n", 3690 | "fold_id2\n", 3691 | "5_fold_resnext50_128_scse_2\n" 3692 | ] 3693 | }, 3694 | { 3695 | "data": { 3696 | "application/vnd.jupyter.widget-view+json": { 3697 | "model_id": "01b404522f664b47ad0a31f6acfa7b1a", 3698 | "version_major": 2, 3699 | "version_minor": 0 3700 | }, 3701 | "text/plain": [ 3702 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3703 | ] 3704 | }, 3705 | "metadata": { 3706 | "tags": [] 3707 | }, 3708 | "output_type": "display_data" 3709 | }, 3710 | { 3711 | "name": "stdout", 3712 | "output_type": "stream", 3713 | "text": [ 3714 | "epoch trn_loss val_loss my_eval \n", 3715 | " 0 0.334564 0.907237 0.827 \n", 3716 | " 1 0.356609 0.77755 0.821125 \n", 3717 | " 2 0.359052 0.832501 0.83225 \n", 3718 | " 3 0.365679 0.80743 0.821 \n", 3719 | " 4 0.361891 0.790572 0.814625 \n", 3720 | " 5 0.343242 0.863644 0.823625 \n", 3721 | " 6 0.339013 0.821969 0.828 \n", 3722 | " 7 0.334319 0.807286 0.818625 \n", 3723 | " 8 0.327691 0.864193 0.828125 \n", 3724 | " 9 0.328807 0.806045 0.81325 \n", 3725 | " 10 0.34561 0.859754 0.82925 \n", 3726 | " 11 0.334322 0.871514 0.828 \n", 3727 | " 12 0.320915 0.854894 0.823 \n", 3728 | " 13 0.30964 0.863667 0.8175 \n", 3729 | " 14 0.3198 0.776783 0.81925 \n", 3730 | " 15 0.312285 0.829952 0.83175 \n", 3731 | " 16 0.315182 0.793502 0.8215 \n", 3732 | " 17 0.304829 0.833349 0.82575 \n", 3733 | " 18 0.305454 0.814216 0.833375 \n", 3734 | " 19 0.309406 0.807178 0.82525 \n", 3735 | "\n", 3736 | "0.8333750000000001\n", 3737 | "fold_id3\n", 3738 | "5_fold_resnext50_128_scse_3\n" 3739 | ] 3740 | }, 3741 | { 3742 | "data": { 3743 | "application/vnd.jupyter.widget-view+json": { 3744 | "model_id": "7e9d4a962e6644e7a36ccbb7aa2958ef", 3745 | "version_major": 2, 3746 | "version_minor": 0 3747 | }, 3748 | "text/plain": [ 3749 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3750 | ] 3751 | }, 3752 | "metadata": { 3753 | "tags": [] 3754 | }, 3755 | "output_type": "display_data" 3756 | }, 3757 | { 3758 | "name": "stdout", 3759 | "output_type": "stream", 3760 | "text": [ 3761 | "epoch trn_loss val_loss my_eval \n", 3762 | " 0 0.303919 0.798046 0.829125 \n", 3763 | " 1 0.330936 0.707061 0.83625 \n", 3764 | " 2 0.344996 0.750156 0.829125 \n", 3765 | " 3 0.344157 0.924753 0.833125 \n", 3766 | " 4 0.326378 0.760323 0.833875 \n", 3767 | " 5 0.328699 0.870832 0.825625 \n", 3768 | " 6 0.331057 0.820294 0.83625 \n", 3769 | " 7 0.34331 0.716414 0.838625 \n", 3770 | " 8 0.335372 0.924066 0.83275 \n", 3771 | " 9 0.323994 0.739209 0.83925 \n", 3772 | " 10 0.334121 0.761376 0.834 \n", 3773 | " 11 0.322592 0.740011 0.83975 \n", 3774 | " 12 0.330691 0.819215 0.837125 \n", 3775 | " 13 0.313002 0.78467 0.842875 \n", 3776 | " 14 0.318578 0.795953 0.83725 \n", 3777 | " 15 0.316326 0.765305 0.842625 \n", 3778 | " 16 0.310763 0.723363 0.84525 \n", 3779 | " 17 0.296692 0.825693 0.84075 \n", 3780 | " 18 0.29445 0.823472 0.842625 \n", 3781 | " 19 0.297867 0.751366 0.83925 \n", 3782 | "\n", 3783 | "0.8452500000000001\n", 3784 | "fold_id4\n", 3785 | "5_fold_resnext50_128_scse_4\n" 3786 | ] 3787 | }, 3788 | { 3789 | "data": { 3790 | "application/vnd.jupyter.widget-view+json": { 3791 | "model_id": "7a025c00a503477aa9cfcba1f1ab3cd2", 3792 | "version_major": 2, 3793 | "version_minor": 0 3794 | }, 3795 | "text/plain": [ 3796 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3797 | ] 3798 | }, 3799 | "metadata": { 3800 | "tags": [] 3801 | }, 3802 | "output_type": "display_data" 3803 | }, 3804 | { 3805 | "name": "stdout", 3806 | "output_type": "stream", 3807 | "text": [ 3808 | "epoch trn_loss val_loss my_eval \n", 3809 | " 0 0.32856 0.730436 0.834625 \n", 3810 | " 1 0.335845 0.899728 0.837125 \n", 3811 | " 2 0.348538 0.913843 0.8295 \n", 3812 | " 3 0.33758 0.769679 0.8235 \n", 3813 | " 4 0.346654 0.802165 0.835625 \n", 3814 | " 5 0.339364 0.746379 0.83825 \n", 3815 | " 6 0.343552 0.756106 0.828125 \n", 3816 | " 7 0.331363 0.792459 0.8335 \n", 3817 | " 8 0.31865 0.735218 0.836125 \n", 3818 | " 9 0.325964 0.810198 0.839 \n", 3819 | " 10 0.311348 0.795061 0.841875 \n", 3820 | " 11 0.321848 0.720627 0.83475 \n", 3821 | " 12 0.311088 0.884234 0.8375 \n", 3822 | " 13 0.304974 0.749801 0.836 \n", 3823 | " 14 0.305366 0.71613 0.83375 \n", 3824 | " 15 0.314715 0.693329 0.84125 \n", 3825 | " 16 0.318737 0.787573 0.83675 \n", 3826 | " 17 0.304981 0.747158 0.83425 \n", 3827 | " 18 0.302457 0.767533 0.834 \n", 3828 | " 19 0.299126 0.759485 0.833625 \n", 3829 | "\n", 3830 | "0.841875\n" 3831 | ] 3832 | } 3833 | ], 3834 | "source": [ 3835 | "model = 'resnext50_128_scse'\n", 3836 | "arch = resnext50\n", 3837 | "bst_acc=[]\n", 3838 | "use_clr_min=20\n", 3839 | "use_clr_div=10\n", 3840 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 3841 | "\n", 3842 | "szs = [(128,64)]\n", 3843 | "for sz,bs in szs:\n", 3844 | " print([sz,bs])\n", 3845 | " for i in range(kf) :\n", 3846 | " print(f'fold_id{i}')\n", 3847 | " \n", 3848 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 3849 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 3850 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 3851 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 3852 | " \n", 3853 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 3854 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 3855 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 3856 | " denorm = md.trn_ds.denorm\n", 3857 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 3858 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 3859 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 3860 | " learn = get_tgs_model() \n", 3861 | " learn.opt_fn = optim.Adam\n", 3862 | " learn.metrics=[my_eval]\n", 3863 | " pa = f'{kf}_fold_{model}_{i}'\n", 3864 | " print(pa)\n", 3865 | "\n", 3866 | " lr=1e-2\n", 3867 | " wd=1e-7\n", 3868 | " lrs = np.array([lr/150,lr/20,lr])\n", 3869 | "\n", 3870 | " learn.unfreeze()\n", 3871 | " learn.bn_freeze(True)\n", 3872 | " learn.crit = lovasz_hinge\n", 3873 | " learn.load(pa)\n", 3874 | " learn.fit(lrs/10,1, wds=wd, cycle_len=20,use_clr=(20,20),best_save_name=pa)\n", 3875 | " \n", 3876 | " learn.load(pa) \n", 3877 | " #Calcuating mean iou score\n", 3878 | " v_targ = md.val_ds.ds[:][1]\n", 3879 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 3880 | " v_pred = learn.predict()\n", 3881 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 3882 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 3883 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 3884 | " print(bst_acc[-1])\n", 3885 | "\n", 3886 | " if i < kf - 1:\n", 3887 | " del learn\n", 3888 | " del md\n", 3889 | " gc.collect()\n", 3890 | " \n", 3891 | "save_model_weights()" 3892 | ] 3893 | }, 3894 | { 3895 | "cell_type": "code", 3896 | "execution_count": 0, 3897 | "metadata": { 3898 | "colab": { 3899 | "base_uri": "https://localhost:8080/", 3900 | "height": 1730 3901 | }, 3902 | "colab_type": "code", 3903 | "id": "ITXgqK7jjdZL", 3904 | "outputId": "0b712303-2fb6-4477-9985-1c67849c6563" 3905 | }, 3906 | "outputs": [ 3907 | { 3908 | "name": "stdout", 3909 | "output_type": "stream", 3910 | "text": [ 3911 | "[128, 64]\n", 3912 | "fold_id0\n", 3913 | "5_fold_resnext50_128_scse_0\n" 3914 | ] 3915 | }, 3916 | { 3917 | "data": { 3918 | "application/vnd.jupyter.widget-view+json": { 3919 | "model_id": "80b17c7b21dd4f91b3af5c1ad0da18a9", 3920 | "version_major": 2, 3921 | "version_minor": 0 3922 | }, 3923 | "text/plain": [ 3924 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3925 | ] 3926 | }, 3927 | "metadata": { 3928 | "tags": [] 3929 | }, 3930 | "output_type": "display_data" 3931 | }, 3932 | { 3933 | "name": "stdout", 3934 | "output_type": "stream", 3935 | "text": [ 3936 | "epoch trn_loss val_loss my_eval \n", 3937 | " 0 0.352473 0.759049 0.829625 \n", 3938 | " 1 0.363571 0.880663 0.826 \n", 3939 | " 2 0.35344 0.807012 0.824875 \n", 3940 | " 3 0.344199 0.811331 0.82425 \n", 3941 | " 4 0.341503 0.780056 0.826375 \n", 3942 | " 5 0.349747 0.841056 0.820375 \n", 3943 | " 6 0.350928 0.883612 0.823375 \n", 3944 | " 7 0.346245 0.738591 0.819875 \n", 3945 | " 8 0.339464 0.790437 0.826375 \n", 3946 | " 9 0.339998 0.844214 0.826875 \n", 3947 | " 10 0.347131 0.780476 0.83 \n", 3948 | " 11 0.341404 0.798782 0.829 \n", 3949 | " 12 0.339125 0.789289 0.829875 \n", 3950 | " 13 0.334993 0.868688 0.827625 \n", 3951 | " 14 0.328118 0.880147 0.83275 \n", 3952 | " 15 0.306373 0.860532 0.831625 \n", 3953 | " 16 0.304504 0.829747 0.830875 \n", 3954 | " 17 0.300264 0.837326 0.832625 \n", 3955 | " 18 0.317846 0.774235 0.836375 \n", 3956 | " 19 0.310489 0.787925 0.833875 \n", 3957 | "\n", 3958 | "0.8363749999999999\n", 3959 | "fold_id1\n", 3960 | "5_fold_resnext50_128_scse_1\n" 3961 | ] 3962 | }, 3963 | { 3964 | "data": { 3965 | "application/vnd.jupyter.widget-view+json": { 3966 | "model_id": "40209053452647b2858ae4b6ad13be2a", 3967 | "version_major": 2, 3968 | "version_minor": 0 3969 | }, 3970 | "text/plain": [ 3971 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 3972 | ] 3973 | }, 3974 | "metadata": { 3975 | "tags": [] 3976 | }, 3977 | "output_type": "display_data" 3978 | }, 3979 | { 3980 | "name": "stdout", 3981 | "output_type": "stream", 3982 | "text": [ 3983 | "epoch trn_loss val_loss my_eval \n", 3984 | " 0 0.327861 0.839905 0.7995 \n", 3985 | " 1 0.332155 0.878513 0.80625 \n", 3986 | " 2 0.327904 0.842675 0.78325 \n", 3987 | " 3 0.323279 1.045407 0.800125 \n", 3988 | " 4 0.34445 0.838635 0.794375 \n", 3989 | " 5 0.336768 0.906733 0.8 \n", 3990 | " 6 0.325074 0.921411 0.804625 \n", 3991 | " 7 0.331317 0.867638 0.802125 \n", 3992 | " 8 0.336123 0.883447 0.796875 \n", 3993 | " 9 0.331685 0.851472 0.80425 \n", 3994 | " 10 0.315916 0.857878 0.790625 \n", 3995 | " 11 0.317591 1.001778 0.802125 \n", 3996 | " 12 0.322167 0.868535 0.79975 \n", 3997 | " 13 0.324085 0.90506 0.800625 \n", 3998 | " 14 0.309266 0.923913 0.795375 \n", 3999 | " 15 0.311752 0.973415 0.79975 \n", 4000 | " 16 0.300871 0.953171 0.8045 \n", 4001 | " 17 0.290202 0.962913 0.800125 \n", 4002 | " 18 0.297091 0.981357 0.80025 \n", 4003 | " 19 0.299129 0.941654 0.802625 \n", 4004 | "\n", 4005 | "0.80625\n", 4006 | "fold_id2\n", 4007 | "5_fold_resnext50_128_scse_2\n" 4008 | ] 4009 | }, 4010 | { 4011 | "data": { 4012 | "application/vnd.jupyter.widget-view+json": { 4013 | "model_id": "91f55feb2d0541cd800bc6b78ff08643", 4014 | "version_major": 2, 4015 | "version_minor": 0 4016 | }, 4017 | "text/plain": [ 4018 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 4019 | ] 4020 | }, 4021 | "metadata": { 4022 | "tags": [] 4023 | }, 4024 | "output_type": "display_data" 4025 | }, 4026 | { 4027 | "name": "stdout", 4028 | "output_type": "stream", 4029 | "text": [ 4030 | "epoch trn_loss val_loss my_eval \n", 4031 | " 0 0.337081 0.823228 0.83475 \n", 4032 | " 1 0.344549 0.913375 0.831625 \n", 4033 | " 2 0.329596 0.828047 0.831125 \n", 4034 | " 3 0.325009 0.777783 0.8285 \n", 4035 | " 4 0.33951 0.821626 0.808375 \n", 4036 | " 5 0.338317 0.940252 0.820875 \n", 4037 | " 6 0.330713 0.791693 0.83 \n", 4038 | " 7 0.332787 0.776885 0.82775 \n", 4039 | " 8 0.332073 0.806915 0.82675 \n", 4040 | " 9 0.329484 0.799763 0.817375 \n", 4041 | " 10 0.329214 0.797196 0.82475 \n", 4042 | " 11 0.31536 0.82193 0.832625 \n", 4043 | " 12 0.321162 0.843901 0.827375 \n", 4044 | " 13 0.310198 0.90338 0.828375 \n", 4045 | " 14 0.292332 0.859527 0.8295 \n", 4046 | " 15 0.291901 0.843471 0.8235 \n", 4047 | " 16 0.296143 0.85044 0.819875 \n", 4048 | " 17 0.302732 0.847719 0.827 \n", 4049 | " 18 0.303203 0.819016 0.827875 \n", 4050 | " 19 0.303959 0.815134 0.8255 \n", 4051 | "\n", 4052 | "0.8347500000000001\n", 4053 | "fold_id3\n", 4054 | "5_fold_resnext50_128_scse_3\n" 4055 | ] 4056 | }, 4057 | { 4058 | "data": { 4059 | "application/vnd.jupyter.widget-view+json": { 4060 | "model_id": "28a3862558874de39a8aed9c8a9bf792", 4061 | "version_major": 2, 4062 | "version_minor": 0 4063 | }, 4064 | "text/plain": [ 4065 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…" 4066 | ] 4067 | }, 4068 | "metadata": { 4069 | "tags": [] 4070 | }, 4071 | "output_type": "display_data" 4072 | }, 4073 | { 4074 | "name": "stdout", 4075 | "output_type": "stream", 4076 | "text": [ 4077 | "epoch trn_loss val_loss my_eval \n", 4078 | " 0 0.316421 0.866826 0.83525 \n", 4079 | " 1 0.337529 0.992894 0.828 \n", 4080 | " 2 0.355061 0.758108 0.83175 \n", 4081 | " 3 0.341312 0.793649 0.82925 \n", 4082 | " 4 0.3637 0.741804 0.83575 \n", 4083 | " 5 0.330458 0.743388 0.839375 \n", 4084 | " 6 0.334183 0.73421 0.841375 \n", 4085 | " 7 0.30901 0.724088 0.83675 \n", 4086 | " 8 0.322369 0.710755 0.830625 \n", 4087 | " 9 0.312626 0.907613 0.841 \n", 4088 | " 10 0.314347 0.75424 0.838875 \n", 4089 | " 11 0.287641 0.762639 0.836125 \n", 4090 | " 12 0.291472 0.738226 0.844875 \n", 4091 | " 13 0.294669 0.741112 0.841375 \n", 4092 | " 14 0.294856 0.747115 0.842375 \n", 4093 | " 15 0.295397 0.75156 0.843125 \n", 4094 | " 16 0.287345 0.807438 0.841875 \n", 4095 | " 17 0.295449 0.754707 0.843625 \n", 4096 | " 82%|████████▏ | 41/50 [00:54<00:10, 1.22s/it, loss=0.29] " 4097 | ] 4098 | } 4099 | ], 4100 | "source": [ 4101 | "model = 'resnext50_128_scse'\n", 4102 | "arch = resnext50\n", 4103 | "bst_acc=[]\n", 4104 | "use_clr_min=20\n", 4105 | "use_clr_div=10\n", 4106 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 4107 | "\n", 4108 | "szs = [(128,64)]\n", 4109 | "for sz,bs in szs:\n", 4110 | " print([sz,bs])\n", 4111 | " for i in range(kf) :\n", 4112 | " print(f'fold_id{i}')\n", 4113 | " \n", 4114 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n", 4115 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n", 4116 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n", 4117 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n", 4118 | " \n", 4119 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 4120 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 4121 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 4122 | " denorm = md.trn_ds.denorm\n", 4123 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n", 4124 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n", 4125 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n", 4126 | " learn = get_tgs_model() \n", 4127 | " learn.opt_fn = optim.Adam\n", 4128 | " learn.metrics=[my_eval]\n", 4129 | " pa = f'{kf}_fold_{model}_{i}'\n", 4130 | " print(pa)\n", 4131 | "\n", 4132 | " lr=1e-2\n", 4133 | " wd=1e-7\n", 4134 | " lrs = np.array([lr/150,lr/20,lr])\n", 4135 | "\n", 4136 | " learn.unfreeze()\n", 4137 | " learn.bn_freeze(True)\n", 4138 | " learn.crit = lovasz_hinge\n", 4139 | " learn.load(pa)\n", 4140 | " learn.fit(lrs/10,1, wds=wd, cycle_len=20,use_clr=(20,20),best_save_name=pa)\n", 4141 | " \n", 4142 | " learn.load(pa) \n", 4143 | " #Calcuating mean iou score\n", 4144 | " v_targ = md.val_ds.ds[:][1]\n", 4145 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n", 4146 | " v_pred = learn.predict()\n", 4147 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n", 4148 | " p = ((v_pred)>0.5).astype(np.uint8)\n", 4149 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n", 4150 | " print(bst_acc[-1])\n", 4151 | "\n", 4152 | " if i < kf - 1:\n", 4153 | " del learn\n", 4154 | " del md\n", 4155 | " gc.collect()\n", 4156 | " \n", 4157 | "save_model_weights()" 4158 | ] 4159 | }, 4160 | { 4161 | "cell_type": "code", 4162 | "execution_count": 0, 4163 | "metadata": { 4164 | "colab": {}, 4165 | "colab_type": "code", 4166 | "id": "lTpda6ZVbF0b" 4167 | }, 4168 | "outputs": [], 4169 | "source": [ 4170 | "save_model_weights()" 4171 | ] 4172 | }, 4173 | { 4174 | "cell_type": "code", 4175 | "execution_count": 0, 4176 | "metadata": { 4177 | "colab": {}, 4178 | "colab_type": "code", 4179 | "id": "xP7Kn5wYs3NG" 4180 | }, 4181 | "outputs": [], 4182 | "source": [ 4183 | "load_model_weights(time_stamp='20181010')" 4184 | ] 4185 | }, 4186 | { 4187 | "cell_type": "code", 4188 | "execution_count": 0, 4189 | "metadata": { 4190 | "colab": { 4191 | "base_uri": "https://localhost:8080/", 4192 | "height": 50 4193 | }, 4194 | "colab_type": "code", 4195 | "id": "lz43bF8JWcKC", 4196 | "outputId": "3efd88a3-094c-448b-b5f2-5b6ae26e2a4e" 4197 | }, 4198 | "outputs": [ 4199 | { 4200 | "data": { 4201 | "text/plain": [ 4202 | "([0.835, 0.8065000000000001, 0.8315, 0.8447499999999999, 0.8440000000000001],\n", 4203 | " 0.8323500000000001)" 4204 | ] 4205 | }, 4206 | "execution_count": 74, 4207 | "metadata": { 4208 | "tags": [] 4209 | }, 4210 | "output_type": "execute_result" 4211 | } 4212 | ], 4213 | "source": [ 4214 | "bst_acc,np.mean(bst_acc)#With 128" 4215 | ] 4216 | }, 4217 | { 4218 | "cell_type": "markdown", 4219 | "metadata": { 4220 | "colab_type": "text", 4221 | "id": "rF8h3iawT6PI" 4222 | }, 4223 | "source": [ 4224 | "### Submit" 4225 | ] 4226 | }, 4227 | { 4228 | "cell_type": "code", 4229 | "execution_count": 0, 4230 | "metadata": { 4231 | "colab": {}, 4232 | "colab_type": "code", 4233 | "id": "LqbkAvBMMKR6" 4234 | }, 4235 | "outputs": [], 4236 | "source": [ 4237 | "# sz = 128\n", 4238 | "# bs = 64\n", 4239 | "# model = 'resnext50_128_scse'\n", 4240 | "\n", 4241 | "# # trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[0]])\n", 4242 | "# # trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[0]])\n", 4243 | "# # val_x = [f'{TRN_DN}/{o}' for o in val_folds[0]]\n", 4244 | "# # val_y = [f'{MASK_DN}/{o}' for o in val_folds[0]]\n", 4245 | " \n", 4246 | "# aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n", 4247 | "# tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 4248 | "# datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 4249 | "# md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 4250 | "# # tfms = tfms_from_model(resnet34, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", 4251 | "# # datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n", 4252 | "# # md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", 4253 | "\n", 4254 | "# learn = get_tgs_model() \n", 4255 | "# learn.opt_fn = optim.Adam\n", 4256 | "# learn.crit = lovasz_hinge\n", 4257 | "# learn.metrics=[my_eval]" 4258 | ] 4259 | }, 4260 | { 4261 | "cell_type": "code", 4262 | "execution_count": 0, 4263 | "metadata": { 4264 | "colab": { 4265 | "base_uri": "https://localhost:8080/", 4266 | "height": 252 4267 | }, 4268 | "colab_type": "code", 4269 | "id": "ms-zAHP2WcGb", 4270 | "outputId": "e1af7781-86a4-4154-82cf-4d6eb7e55897" 4271 | }, 4272 | "outputs": [ 4273 | { 4274 | "data": { 4275 | "application/vnd.jupyter.widget-view+json": { 4276 | "model_id": "ef25c63e5456401eb8edde4089530451", 4277 | "version_major": 2, 4278 | "version_minor": 0 4279 | }, 4280 | "text/plain": [ 4281 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 4282 | ] 4283 | }, 4284 | "metadata": { 4285 | "tags": [] 4286 | }, 4287 | "output_type": "display_data" 4288 | }, 4289 | { 4290 | "name": "stdout", 4291 | "output_type": "stream", 4292 | "text": [ 4293 | "5_fold_resnext50_128_scse_0\n", 4294 | "5_fold_resnext50_128_scse_1\n", 4295 | "5_fold_resnext50_128_scse_2\n", 4296 | "5_fold_resnext50_128_scse_3\n", 4297 | "5_fold_resnext50_128_scse_4\n", 4298 | "\n" 4299 | ] 4300 | }, 4301 | { 4302 | "data": { 4303 | "application/vnd.jupyter.widget-view+json": { 4304 | "model_id": "5c581f68713c40a090b48c2b02908785", 4305 | "version_major": 2, 4306 | "version_minor": 0 4307 | }, 4308 | "text/plain": [ 4309 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 4310 | ] 4311 | }, 4312 | "metadata": { 4313 | "tags": [] 4314 | }, 4315 | "output_type": "display_data" 4316 | }, 4317 | { 4318 | "name": "stdout", 4319 | "output_type": "stream", 4320 | "text": [ 4321 | "5_fold_resnext50_128_scse_0\n", 4322 | "5_fold_resnext50_128_scse_1\n", 4323 | "5_fold_resnext50_128_scse_2\n", 4324 | "5_fold_resnext50_128_scse_3\n", 4325 | "5_fold_resnext50_128_scse_4\n", 4326 | "\n" 4327 | ] 4328 | } 4329 | ], 4330 | "source": [ 4331 | "preds = np.zeros(shape = (18000,sz,sz))\n", 4332 | "for o in [True,False]:\n", 4333 | " md.test_dl.dataset = TestFilesDataset(test_x,test_x,tfms[1],flip=o,path=PATH)\n", 4334 | " md.test_dl.dataset = DepthDatasetV2(md.test_dl.dataset,dpth_dict)\n", 4335 | " \n", 4336 | " for i in tqdm_notebook(range(kf)):\n", 4337 | " pa = f'{kf}_fold_{model}_{i}'\n", 4338 | " print(pa)\n", 4339 | " learn.load(pa)\n", 4340 | " pred = learn.predict(is_test=True)\n", 4341 | " pred = to_np(torch.sigmoid(torch.from_numpy(pred))) \n", 4342 | " for im_idx,im in enumerate(pred):\n", 4343 | " preds[im_idx] += np.fliplr(im) if o else im\n", 4344 | " del pred" 4345 | ] 4346 | }, 4347 | { 4348 | "cell_type": "code", 4349 | "execution_count": 0, 4350 | "metadata": { 4351 | "colab": {}, 4352 | "colab_type": "code", 4353 | "id": "ZHKkmKrqWb7W" 4354 | }, 4355 | "outputs": [], 4356 | "source": [ 4357 | "# plt.imshow(((preds[16]/10)>0.5).astype(np.uint8))" 4358 | ] 4359 | }, 4360 | { 4361 | "cell_type": "code", 4362 | "execution_count": 0, 4363 | "metadata": { 4364 | "colab": {}, 4365 | "colab_type": "code", 4366 | "id": "pe59_TSGWb4K" 4367 | }, 4368 | "outputs": [], 4369 | "source": [ 4370 | "p = [cv2.resize(o/10,dsize=(101,101)) for o in preds]\n", 4371 | "p = [(o>0.5).astype(np.uint8) for o in p]" 4372 | ] 4373 | }, 4374 | { 4375 | "cell_type": "code", 4376 | "execution_count": 0, 4377 | "metadata": { 4378 | "colab": { 4379 | "base_uri": "https://localhost:8080/", 4380 | "height": 34 4381 | }, 4382 | "colab_type": "code", 4383 | "id": "qu8Gin-NTMef", 4384 | "outputId": "acaf9176-75b0-4b53-d2f5-ece8890b9862" 4385 | }, 4386 | "outputs": [ 4387 | { 4388 | "data": { 4389 | "text/plain": [ 4390 | "(101, 101)" 4391 | ] 4392 | }, 4393 | "execution_count": 67, 4394 | "metadata": { 4395 | "tags": [] 4396 | }, 4397 | "output_type": "execute_result" 4398 | } 4399 | ], 4400 | "source": [ 4401 | "p[0].shape" 4402 | ] 4403 | }, 4404 | { 4405 | "cell_type": "code", 4406 | "execution_count": 0, 4407 | "metadata": { 4408 | "colab": { 4409 | "base_uri": "https://localhost:8080/", 4410 | "height": 50 4411 | }, 4412 | "colab_type": "code", 4413 | "id": "7ZRdo1lAWb08", 4414 | "outputId": "615f7553-793c-4f59-d6a1-22ccdedbbf6d" 4415 | }, 4416 | "outputs": [ 4417 | { 4418 | "data": { 4419 | "application/vnd.jupyter.widget-view+json": { 4420 | "model_id": "6a28697675af4cbc90c8dcd4afba37d2", 4421 | "version_major": 2, 4422 | "version_minor": 0 4423 | }, 4424 | "text/plain": [ 4425 | "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" 4426 | ] 4427 | }, 4428 | "metadata": { 4429 | "tags": [] 4430 | }, 4431 | "output_type": "display_data" 4432 | }, 4433 | { 4434 | "name": "stdout", 4435 | "output_type": "stream", 4436 | "text": [ 4437 | "\n" 4438 | ] 4439 | } 4440 | ], 4441 | "source": [ 4442 | "pred_dict = {id_.name[:-4]:RLenc(p[i]) for i,id_ in tqdm_notebook(enumerate(test_x))}\n", 4443 | "sub = pd.DataFrame.from_dict(pred_dict,orient='index')\n", 4444 | "sub.index.names = ['id']\n", 4445 | "sub.columns = ['rle_mask']\n", 4446 | "sub.to_csv('submission.csv')" 4447 | ] 4448 | }, 4449 | { 4450 | "cell_type": "code", 4451 | "execution_count": 0, 4452 | "metadata": { 4453 | "colab": { 4454 | "base_uri": "https://localhost:8080/", 4455 | "height": 34 4456 | }, 4457 | "colab_type": "code", 4458 | "id": "et8aAjFVWbxh", 4459 | "outputId": "3af2bb8f-dcdc-4187-a48e-9dabfaebf9f4" 4460 | }, 4461 | "outputs": [ 4462 | { 4463 | "name": "stdout", 4464 | "output_type": "stream", 4465 | "text": [ 4466 | "Successfully submitted to TGS Salt Identification Challenge" 4467 | ] 4468 | } 4469 | ], 4470 | "source": [ 4471 | "!kaggle competitions submit -c tgs-salt-identification-challenge -f submission.csv -m \"\"" 4472 | ] 4473 | }, 4474 | { 4475 | "cell_type": "code", 4476 | "execution_count": null, 4477 | "metadata": { 4478 | "colab": { 4479 | "base_uri": "https://localhost:8080/", 4480 | "height": 386 4481 | }, 4482 | "colab_type": "code", 4483 | "id": "38KczbM-WbuM", 4484 | "outputId": "0a1ab856-ec21-45ad-803e-fa3dc06ec9dc" 4485 | }, 4486 | "outputs": [], 4487 | "source": [ 4488 | "!kaggle competitions submissions -c tgs-salt-identification-challenge" 4489 | ] 4490 | }, 4491 | { 4492 | "cell_type": "code", 4493 | "execution_count": 0, 4494 | "metadata": { 4495 | "colab": {}, 4496 | "colab_type": "code", 4497 | "id": "7rbnFVohWbqd" 4498 | }, 4499 | "outputs": [], 4500 | "source": [ 4501 | "!fusermount -u drive" 4502 | ] 4503 | }, 4504 | { 4505 | "cell_type": "code", 4506 | "execution_count": 0, 4507 | "metadata": { 4508 | "colab": {}, 4509 | "colab_type": "code", 4510 | "id": "emlLQBDaj6fk" 4511 | }, 4512 | "outputs": [], 4513 | "source": [] 4514 | } 4515 | ], 4516 | "metadata": { 4517 | "accelerator": "GPU", 4518 | "colab": { 4519 | "collapsed_sections": [ 4520 | "Dz4H7ksT0tMo" 4521 | ], 4522 | "name": "TGS_Salt_Unet50_5-Fold_scSE_score_8.42.ipynb", 4523 | "provenance": [], 4524 | "version": "0.3.2" 4525 | }, 4526 | "kernelspec": { 4527 | "display_name": "Python 3", 4528 | "language": "python", 4529 | "name": "python3" 4530 | }, 4531 | "language_info": { 4532 | "codemirror_mode": { 4533 | "name": "ipython", 4534 | "version": 3 4535 | }, 4536 | "file_extension": ".py", 4537 | "mimetype": "text/x-python", 4538 | "name": "python", 4539 | "nbconvert_exporter": "python", 4540 | "pygments_lexer": "ipython3", 4541 | "version": "3.5.4" 4542 | } 4543 | }, 4544 | "nbformat": 4, 4545 | "nbformat_minor": 1 4546 | } 4547 | --------------------------------------------------------------------------------