├── README.md
├── guide.md
└── TGS_Salt_resnext50_unet_5Fold_scSE.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | ## [[TGS Salt Identification Challenge](https://www.kaggle.com/c/tgs-salt-identification-challenge)] Bronze Medal Soultion
2 |
3 | First of all, Thanks [Fastai library](https://github.com/fastai/fastai/tree/0.7.0/fastai) which save me a lot of time to build the model. Bronze medal place isn't a very good score, However, this [notebook](https://github.com/alexshuang/TGS_Salt/blob/master/TGS_Salt_resnext50_unet_5Fold_scSE.ipynb) can be directly run on Google colab. You can use it as backbone to experiment with the ideas published by other top contestants. The most important thing is that you don't need to have your own GPU for this competition, you'd better have a faster GPU, though:).
4 |
5 | For more information about source code, please read my chinese blog: [here](https://github.com/alexshuang/Kaggle_TGS_Salt_Identification_Challenge/blob/master/guide.md)
6 |
7 | ---
8 |
9 | ### Network architecture
10 |
11 | * U-net with a pre-trained resnet34/resnext50, resnext50 is better.
12 |
13 | ### Training regime
14 |
15 | * Adam optimizer with weight decay. SGD is better than Adam, but it also much lower than Adam.
16 | * 5-folds cross validation.
17 | * use [albumentations](https://github.com/albu/albumentations) library to do data augmentation, random corp 50%~100% area region, scale to 128x128. Scale to 192x192 or 256x256 is recommended, but you need to change to a faster GPU.
18 | * stage1: train by SGD + BCE, stage2: train by Adam + Lovasz Hinge Loss, until validation loss is no longer reduced.
19 |
20 | ### Techniques that helped
21 |
22 | * Adding Concurrent Spatial and Channel Squeeze & Excitation blocks to the U-net decoder **(+0.032 Private LB)**: [https://arxiv.org/pdf/1803.02579.pdf.](https://arxiv.org/pdf/1803.02579.pdf.) I used the [implementation](https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178) generously provided by [Bilal](https://www.kaggle.com/bkkaggle).
23 | * Depth-statified n-fold training (+0.01 LB). For me, 5-fold scored is good.
24 | * TTA: flip and non-flip (+0.02 LB)
25 |
26 | ### Techniques that not helped
27 | * Hypercolumns: [https://arxiv.org/pdf/1411.5752.pdf.](https://arxiv.org/pdf/1411.5752.pdf.)
28 | * Many folks([Heng](https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/64645#380301), et al) reported anywhere from a public LB +0.01 to +0.005 improvement after concatenating hypercolumns to their decoder's output. My attempt led to a lower score, and because I found it to be computationally expensive on my Paperspace P5000 machine, I didn't bother seeing if I could troubleshoot.
29 |
30 |
--------------------------------------------------------------------------------
/guide.md:
--------------------------------------------------------------------------------
1 | ## Segmentation
2 |
3 | 
4 | Figure1来自CamVid database,专为目标识别(Object Dection)和图像分割(Image Segmentation)提供训练数据的网站。从图中可以看出,segmentation将图像中不同个体用不同颜色来标记,这里不同的颜色就代表不同的分类,例如红色就是分类1,蓝色就是分类2,可以看出,它就是像素级的图像识别(Image Identification)。
5 |
6 | 除了自动驾驶之外,图像分割还广泛应用于医学诊断、卫星影像定位、图片合成等领域,本文就以当前[kaggle](http://www.kaggle.com)上最热门的segmentation竞赛--[TGS Salt Identification Challenge](https://www.kaggle.com/c/tgs-salt-identification-challenge)为例来讲解如何应用Unet来解决真实世界的图像分割问题。github: [here](https://github.com/alexshuang/TGS_Salt)。
7 |
8 | TGS公司通过地震波反射技术绘制出下图所示的3D地质图像,并标记出图像中的盐矿区域,参赛者需要训练用于从岩层中分离盐矿的机器学习模型。
9 | 
10 |
11 | Figure 2是trainset中5组image和mark图片,每组的左图为原始地质图像,右图为该图像的分类,称为mark,黑色区域代表一般岩层,白色区域就是盐的分布。segmentation要做的就是训练一个image-to-image的模型,通过对原始图像的学习,生成其对应的mask2,mask则作为target,通过最小化mask和mask2的差距来识别哪些是盐。
12 |
13 | ### Dataset
14 |
15 | 生成dataset的第一步是根据run length数据创建对应的mark图片,因为TGS的trainset里面已经提供了mark图片(mark图片和对应image图片同名),所以我们就不需要额外再创建。
16 |
17 | 但要知道的是,并非所有的segmentation dataset都会提供marks,你需要根据数据run length来为images创建相应的marks,run length是如下图rle_mask所示的数据,数据间以空格分隔,两两为一组,每组的第一个数代表flatten后的image vector的起始下标,后一个数代表它所占据的长度,占据区域会填充该目标对应的分类号,如0、1、2...,通过rle_decode()可以将run length转化为mark。
18 | 
19 |
20 | ```
21 | def rle_decode(mask_rle, shape=(101, 101)):
22 | s = mask_rle.split()
23 | starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
24 | starts -= 1
25 | ends = starts + lengths
26 | img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
27 | for lo, hi in zip(starts, ends):
28 | img[lo:hi] = 1
29 | return img.mean()
30 | ```
31 |
32 | 从Figure 2可以看到,地质图像都是低分辨画质,只有101x101大小,不仅不利于神经网络的卷积计算,也不利于图像识别,所以我们接下来一般会将其resize为128x128大小。
33 | ```
34 | def resize_img(fn, outp_dir, sz):
35 | Image.open(fn).resize((sz, sz)).save(outp_dir/fn.name)
36 | ```
37 | Data augmentation是创建dataset的核心,和object dection一样,segmentation一般不会做random crop,我在这个项目中采用水平、垂直翻转和微调光暗的方法来做augmentation。
38 | ```
39 | aug_tfms = [
40 | RandomFlip(tfm_y=TfmType.CLASS),
41 | RandomDihedral(tfm_y=TfmType.CLASS),
42 | # RandomRotate(4, tfm_y=TfmType.CLASS),
43 | RandomLighting(0.07, 0.07, tfm_y=TfmType.CLASS)
44 | ]
45 | ```
46 |
47 | ### Unet
48 | [paper](https://arxiv.org/abs/1505.04597)
49 | Unet虽然是2015年诞生的模型,但它依旧是当前segmentation项目中应用最广的模型,kaggle上LB排名靠前的选手很多都是使用该模型。
50 | 
51 | Unet的左侧是convolution layers,右侧则是upsamping layers,convolutions layers中每个pooling layer前一刻的activation值会concatenate到对应的upsamping层的activation值中。
52 |
53 | 因为Unet左侧部分和resnet、vgg、inception等模型一样,都是通过卷积层来提取图像特征,所以Unet可以采用resnet/vgg/inception+upsampling的形式来实现,这样做好处是可以利用pretrained的成熟模型来加速Unet的训练,要知道transfer training的效果是非常显著的,我在这个项目中采用的就是resnet34+upsampling的架构。
54 | ```
55 | class SaveFeatures():
56 | features=None
57 | def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
58 | def hook_fn(self, module, input, output): self.features = output
59 | def remove(self): self.hook.remove()
60 |
61 |
62 | class UnetBlock(nn.Module):
63 | def __init__(self, up_in, down_in, n_out, dp=False, ps=0.25):
64 | super().__init__()
65 | up_out = down_out = n_out // 2
66 | self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, 2, bias=False)
67 | self.conv = nn.Conv2d(down_in, down_out, 1, bias=False)
68 | self.bn = nn.BatchNorm2d(n_out)
69 | self.dp = dp
70 | if dp: self.dropout = nn.Dropout(ps, inplace=True)
71 |
72 | def forward(self, up_x, down_x):
73 | x1 = self.tr_conv(up_x)
74 | x2 = self.conv(down_x)
75 | x = torch.cat([x1, x2], dim=1)
76 | x = self.bn(F.relu(x))
77 | return self.dropout(x) if self.dp else x
78 |
79 |
80 | class Unet34(nn.Module):
81 | def __init__(self, rn, drop_i=False, ps_i=None, drop_up=False, ps=None):
82 | super().__init__()
83 | self.rn = rn
84 | self.sfs = [SaveFeatures(rn[i]) for i in [2, 4, 5, 6]]
85 | self.drop_i = drop_i
86 | if drop_i:
87 | self.dropout = nn.Dropout(ps_i, inplace=True)
88 | if ps_i is None: ps_i = 0.1
89 | if ps is not None: assert len(ps) == 4
90 | if ps is None: ps = [0.1] * 4
91 | self.up1 = UnetBlock(512, 256, 256, drop_up, ps[0])
92 | self.up2 = UnetBlock(256, 128, 256, drop_up, ps[1])
93 | self.up3 = UnetBlock(256, 64, 256, drop_up, ps[2])
94 | self.up4 = UnetBlock(256, 64, 256, drop_up, ps[3])
95 | self.up5 = nn.ConvTranspose2d(256, 1, 2, 2)
96 |
97 | def forward(self, x):
98 | x = F.relu(self.rn(x))
99 | x = self.dropout(x) if self.drop_i else x
100 | x = self.up1(x, self.sfs[3].features)
101 | x = self.up2(x, self.sfs[2].features)
102 | x = self.up3(x, self.sfs[1].features)
103 | x = self.up4(x, self.sfs[0].features)
104 | x = self.up5(x)
105 | return x[:, 0]
106 |
107 | def close(self):
108 | for o in self.sfs: o.remove()
109 | ```
110 | 通过注册nn.register_forward_hook() ,将指定resnet34指定层(2, 4, 5, 6)的activation值保存起来,在upsampling的过程中将它们concatnate到相应的upsampling layer中。upsampling layer中使用ConvTranspose2d()来做deconvolution,ConvTranspose2d()的工作机制和conv2d()正好相反,用于增加feature map的grid size,对deconvolution的计算不是很熟悉的朋友请自行阅读[convolution arithmetic tutorial](http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html)。
111 |
112 | ### Loss
113 |
114 | 前文也提到,segmentation本质上是像素级的图像识别,该项目只有两个分类: 盐和岩,和猫vs狗一样是binary classification问题,用binary cross entropy即可,即nn.BCEWithLogitsLoss()。除了BCE,我还尝试了[focal loss](https://arxiv.org/abs/1708.02002),准确率提升了0.013。
115 | 
116 |
117 | 从Figure 3中数学公式可以看出,focal loss就是scale版的cross entropy,$-(1 - p_t)^\gamma$是scale值,这里的scale不是常数而是可学习的weights。focal loss的公式虽然很简单,但在object dection中,focal loss的表现远胜于BCE,其背后的逻辑是:通过scale放大/缩小模型的输出结果,将原本模糊不清的判断确定化。Figure 3,当gamma == 0时,focal loss就相当于corss entropy(CE),如蓝色曲线所示,即使probability达到0.6,loss值还会>= 0.5,就好像是说:“我判断输出不是分类B的概率是60%,但我依旧不能确定它一定不是分类B”。当gamma == 2时,同样是probability达到0.6,loss值接近于0,就好像是说:“我判断输出不是分类B的概率是60%,我认为它一定不是分类B”,这就是scale的威力。
118 | ```
119 | #https://github.com/marvis/pytorch-yolo2/blob/master/FocalLoss.py
120 | #https://github.com/unsky/focal-loss
121 | class FocalLoss2d(nn.Module):
122 | def __init__(self, gamma=2, size_average=True):
123 | super(FocalLoss2d, self).__init__()
124 | self.gamma = gamma
125 | self.size_average = size_average
126 |
127 | def forward(self, logit, target, class_weight=None, type='softmax'):
128 | target = target.view(-1, 1).long()
129 | if type=='sigmoid':
130 | if class_weight is None:
131 | class_weight = [1]*2 #[0.5, 0.5]
132 | prob = F.sigmoid(logit)
133 | prob = prob.view(-1, 1)
134 | prob = torch.cat((1-prob, prob), 1)
135 | select = torch.FloatTensor(len(prob), 2).zero_().cuda()
136 | select.scatter_(1, target, 1.)
137 | elif type=='softmax':
138 | B,C,H,W = logit.size()
139 | if class_weight is None:
140 | class_weight =[1]*C #[1/C]*C
141 | logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
142 | prob = F.softmax(logit,1)
143 | select = torch.FloatTensor(len(prob), C).zero_().cuda()
144 | select.scatter_(1, target, 1.)
145 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1)
146 | class_weight = torch.gather(class_weight, 0, target)
147 | prob = (prob*select).sum(1).view(-1,1)
148 | prob = torch.clamp(prob,1e-8,1-1e-8)
149 | batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log()
150 | if self.size_average:
151 | loss = batch_loss.mean()
152 | else:
153 | loss = batch_loss
154 | return loss
155 | ```
156 |
157 | ### Metric
158 |
159 | 
160 |
161 | 项目采用取超过probability超过Thresholds:[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]的IoU均值作为metric。
162 | ```
163 | iou_thresholds = np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
164 |
165 | def iou(img_true, img_pred):
166 | img_pred = (img_pred > 0).float()
167 | i = (img_true * img_pred).sum()
168 | u = (img_true + img_pred).sum()
169 | return i / u if u != 0 else u
170 |
171 | def iou_metric(imgs_pred, imgs_true):
172 | num_images = len(imgs_true)
173 | scores = np.zeros(num_images)
174 | for i in range(num_images):
175 | if imgs_true[i].sum() == imgs_pred[i].sum() == 0:
176 | scores[i] = 1
177 | else:
178 | scores[i] = (iou_thresholds <= iou(imgs_true[i], imgs_pred[i])).mean()
179 | return scores.mean()
180 | ```
181 |
182 | ### Training
183 |
184 | Unet模型训练大致分两步:
185 | - 通过[LR Test](https://arxiv.org/pdf/1506.01186.pdf)找出合适的学习率区间。
186 | - [Cycle Learning Rate (CLR)](https://arxiv.org/pdf/1506.01186.pdf)的方法来训练模型,直至过拟合。
187 | ```
188 | wd = 4e-4
189 | arch = resnet34
190 | ps_i = 0.05
191 | ps = np.array([0.1, 0.1, 0.1, 0.1]) * 1
192 | m_base = get_base_model(arch, cut, True)
193 | m = to_gpu(Unet34(m_base, drop_i=True, drop_up=True, ps=ps, ps_i=ps_i))
194 | models = UnetModel(m)
195 | learn = ConvLearner(md, models)
196 | learn.opt_fn = optim.Adam
197 | learn.crit = nn.BCEWithLogitsLoss()
198 | learn.metrics = [accuracy_thresh(0.5), miou]
199 | ```
200 | 当模型训练到无法通过变化学习率来减少loss值,val loss收敛且有过拟合的可能时,我停止了模型的训练。
201 | 
202 | 
203 | 从结果来看模型需要增加正则化来对抗过拟合,Dropout在Unet的实际应用中并没有起到好的效果,所以需要从data augmentation和weight decay下功夫。
204 |
205 | ### Run length encoder
206 |
207 | 和rle_decode()相反,在将输出提交到kaggle之前,需要通过rle_encode()根据mask生成相应的run length。当然前提是通过downsample()将mask resize回101x101大小。
208 | ```
209 | def downsample(img, shape):
210 | if shape == img.shape: return img
211 | return resize(img, shape, mode='constant', preserve_range=True)
212 |
213 | def rle_encode(im):
214 | pixels = im.flatten()
215 | pixels = np.concatenate([[0], pixels, [0]])
216 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
217 | runs[1::2] -= runs[::2]
218 | return ' '.join(str(x) for x in runs)
219 | ```
220 |
221 | ### TTA(Test Time Augmentation)
222 | 我们可以通过对testset做data augmentation来提高在kaggle上的score。在segmentation项目中应用TTA时要特别注意的是,augmented images会带来augmented outputs,在对这些outputs求均值之前需要先根据相应的transform规则来转化outputs,例如,image1和水平翻转后的image2经模型分别生成mark1和mark2,在计算mark的均值之前需要先将mark2做水平翻转。
223 |
224 | ---
225 |
226 | ## 小结
227 | 到此,Unet模型的构建、训练的几个要点:dataset、model、loss和metric等都已经基本讲清了。这篇博文是我在比赛初期写下的,和我最终使用的模型稍有不同,例如新模型增加了5-folds cross validation、scSE network等, 有时间我会再写篇博文介绍排名靠前的参赛者的方案以及相关技术。
228 |
229 |
230 | ---
231 |
--------------------------------------------------------------------------------
/TGS_Salt_resnext50_unet_5Fold_scSE.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/",
9 | "height": 440
10 | },
11 | "colab_type": "code",
12 | "id": "0mTKWCsxXo4r",
13 | "outputId": "87668956-8fa4-4bb0-9247-015a74a51caa"
14 | },
15 | "outputs": [],
16 | "source": [
17 | "!apt-get install -y -qq software-properties-common python-software-properties module-init-tools\n",
18 | "!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null\n",
19 | "!apt-get update -qq 2>&1 > /dev/null\n",
20 | "!curl --header \"Host: launchpadlibrarian.net\" --header \"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.62 Safari/537.36\" --header \"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\" --header \"Accept-Language: zh,zh-CN;q=0.9,zh-TW;q=0.8,en-US;q=0.7,en;q=0.6\" \"https://launchpadlibrarian.net/386846978/google-drive-ocamlfuse_0.7.0-0ubuntu1_amd64.deb\" -o \"google-drive-ocamlfuse_0.7.0-0ubuntu1_amd64.deb\" -L\n",
21 | "!apt-get -y install -qq fuse\n",
22 | "!dpkg -i google-drive-ocamlfuse_0.7.0-0ubuntu1_amd64.deb\n",
23 | "\n",
24 | "from googleapiclient.discovery import build\n",
25 | "import io, os\n",
26 | "from googleapiclient.http import MediaIoBaseDownload\n",
27 | "from google.colab import auth\n",
28 | "\n",
29 | "auth.authenticate_user()\n",
30 | "\n",
31 | "drive_service = build('drive', 'v3')\n",
32 | "results = drive_service.files().list(\n",
33 | " q=\"name = 'kaggle.json'\", fields=\"files(id)\").execute()\n",
34 | "kaggle_api_key = results.get('files', [])\n",
35 | "\n",
36 | "filename = \"/content/.kaggle/kaggle.json\"\n",
37 | "os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
38 | "\n",
39 | "request = drive_service.files().get_media(fileId=kaggle_api_key[0]['id'])\n",
40 | "fh = io.FileIO(filename, 'wb')\n",
41 | "downloader = MediaIoBaseDownload(fh, request)\n",
42 | "done = False\n",
43 | "while done is False:\n",
44 | " status, done = downloader.next_chunk()\n",
45 | " print(\"Download %d%%.\" % int(status.progress() * 100))\n",
46 | "os.chmod(filename, 600)\n",
47 | "!mkdir ~/.kaggle\n",
48 | "!cp /content/.kaggle/kaggle.json ~/.kaggle/kaggle.json\n",
49 | "\n",
50 | "from oauth2client.client import GoogleCredentials\n",
51 | "creds = GoogleCredentials.get_application_default()\n",
52 | "import getpass\n",
53 | "!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL\n",
54 | "vcode = getpass.getpass()\n",
55 | "!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 0,
61 | "metadata": {
62 | "colab": {},
63 | "colab_type": "code",
64 | "id": "K102BSsYDGMy"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "!mkdir -p drive\n",
69 | "!google-drive-ocamlfuse -o nonempty drive\n",
70 | "!ls drive > /dev/null"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "colab_type": "text",
77 | "id": "teccmM5G-YOk"
78 | },
79 | "source": [
80 | "### Setup"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {
87 | "colab": {
88 | "base_uri": "https://localhost:8080/",
89 | "height": 134
90 | },
91 | "colab_type": "code",
92 | "id": "EaDX0j2Kt4W1",
93 | "outputId": "0b410fa0-4198-4f63-ac96-ab67db380594"
94 | },
95 | "outputs": [],
96 | "source": [
97 | "!mkdir src -p && cd src && git clone https://github.com/fastai/fastai.git"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "colab": {
105 | "base_uri": "https://localhost:8080/",
106 | "height": 54
107 | },
108 | "colab_type": "code",
109 | "id": "698sgxtBuXZO",
110 | "outputId": "78e4c0d3-0361-43c7-90f4-7412461d2592"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "!pip3 install -q bcolz graphviz sklearn_pandas isoweek pandas_summary ipywidgets torch torchvision torchtext"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 0,
120 | "metadata": {
121 | "colab": {},
122 | "colab_type": "code",
123 | "id": "D6T70YIYyuKp"
124 | },
125 | "outputs": [],
126 | "source": [
127 | "!git config --global user.email 'nikshuang@163.com'\n",
128 | "!git config --global user.name 'Alex Huang'\n",
129 | "!pip install -q kaggle\n",
130 | "!pip install -q -U git+https://github.com/albu/albumentations"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {
137 | "colab": {
138 | "base_uri": "https://localhost:8080/",
139 | "height": 118
140 | },
141 | "colab_type": "code",
142 | "id": "kv9wLQUwc5Vd",
143 | "outputId": "89b20b5f-d89f-4759-eb4a-6efc2e41e4dc"
144 | },
145 | "outputs": [],
146 | "source": [
147 | "!git clone https://github.com/svishnu88/TGS-SaltIdentification-Open-Solution-fastai.git\n",
148 | "!cp TGS-SaltIdentification-Open-Solution-fastai/*.py .\n",
149 | "!pip install -q Cython --upgrade\n",
150 | "!pip install -q pycocotools --upgrade"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 0,
156 | "metadata": {
157 | "colab": {
158 | "base_uri": "https://localhost:8080/",
159 | "height": 50
160 | },
161 | "colab_type": "code",
162 | "id": "ib2O97lzt4cG",
163 | "outputId": "5304877b-0461-4af1-c124-51a5eccf8b1c"
164 | },
165 | "outputs": [
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "0.4.1\n"
171 | ]
172 | },
173 | {
174 | "data": {
175 | "text/plain": [
176 | "(True, True)"
177 | ]
178 | },
179 | "execution_count": 8,
180 | "metadata": {
181 | "tags": []
182 | },
183 | "output_type": "execute_result"
184 | }
185 | ],
186 | "source": [
187 | "import sys\n",
188 | "sys.path.append(\"/content/src/fastai/old\") # on windows use \\'s instead\n",
189 | "\n",
190 | "from fastai.conv_learner import *\n",
191 | "from fastai.dataset import *\n",
192 | "from skimage.transform import resize\n",
193 | "import gc\n",
194 | "from sklearn.model_selection import train_test_split, StratifiedKFold , KFold\n",
195 | "from sklearn.metrics import jaccard_similarity_score\n",
196 | "from pycocotools import mask as cocomask\n",
197 | "from utils import my_eval,intersection_over_union_thresholds,RLenc\n",
198 | "from lovasz_losses import *\n",
199 | "print(torch.__version__)\n",
200 | "torch.backends.cudnn.benchmark=True\n",
201 | "torch.cuda.is_available(), torch.backends.cudnn.enabled"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 0,
207 | "metadata": {
208 | "colab": {},
209 | "colab_type": "code",
210 | "id": "zkqghZ5gOF1L"
211 | },
212 | "outputs": [],
213 | "source": [
214 | "PATH = Path('data/salt')\n",
215 | "TRN_PATH = PATH/'train'\n",
216 | "TEST_PATH = PATH/'test'\n",
217 | "DISK_PATH = 'drive/tgs_salt'\n",
218 | "os.makedirs(DISK_PATH, exist_ok=True)\n",
219 | "os.makedirs(PATH, exist_ok=True)\n",
220 | "os.makedirs(TRN_PATH, exist_ok=True)\n",
221 | "os.makedirs(TEST_PATH, exist_ok=True)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 0,
227 | "metadata": {
228 | "colab": {},
229 | "colab_type": "code",
230 | "id": "W8sZ5U5xp4kO"
231 | },
232 | "outputs": [],
233 | "source": [
234 | "from albumentations import (\n",
235 | " Resize,\n",
236 | " HorizontalFlip,\n",
237 | " VerticalFlip,\n",
238 | " Compose,\n",
239 | " RandomRotate90,\n",
240 | " RandomSizedCrop,\n",
241 | " OneOf,\n",
242 | " RandomSizedCrop,\n",
243 | " Rotate,\n",
244 | " RandomContrast,\n",
245 | " RandomGamma,\n",
246 | " ElasticTransform,\n",
247 | " GridDistortion, \n",
248 | " OpticalDistortion,\n",
249 | " RandomBrightness\n",
250 | ") "
251 | ]
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": null,
256 | "metadata": {
257 | "colab": {
258 | "base_uri": "https://localhost:8080/",
259 | "height": 235
260 | },
261 | "colab_type": "code",
262 | "id": "9CJQ66L9dELB",
263 | "outputId": "4ead758a-b897-4118-983e-19e2d2e43923"
264 | },
265 | "outputs": [],
266 | "source": [
267 | "!wget http://files.fast.ai/models/weights.tgz\n",
268 | "!tar xf weights.tgz -C /content/src/fastai/old/fastai/\n",
269 | "!ls /content/src/fastai/old/fastai/weights"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {
276 | "colab": {
277 | "base_uri": "https://localhost:8080/",
278 | "height": 269
279 | },
280 | "colab_type": "code",
281 | "id": "Q5YLtNMJzSyS",
282 | "outputId": "a5930376-e517-4a8a-cea3-4215152ab2c1"
283 | },
284 | "outputs": [],
285 | "source": [
286 | "!kaggle competitions download -c tgs-salt-identification-challenge -p {PATH}"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 0,
292 | "metadata": {
293 | "colab": {},
294 | "colab_type": "code",
295 | "id": "E5j5p85o0DD6"
296 | },
297 | "outputs": [],
298 | "source": [
299 | "!cd {TRN_PATH} && unzip -q /content/{PATH}/'train.zip'\n",
300 | "!cd {TEST_PATH} && unzip -q /content/{PATH}/'test.zip'"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": 0,
306 | "metadata": {
307 | "colab": {
308 | "base_uri": "https://localhost:8080/",
309 | "height": 50
310 | },
311 | "colab_type": "code",
312 | "id": "Pmbd4oZ10mXH",
313 | "outputId": "8fc58192-70fc-45e0-962e-082360d4f164"
314 | },
315 | "outputs": [
316 | {
317 | "name": "stdout",
318 | "output_type": "stream",
319 | "text": [
320 | "images\tmasks\n",
321 | "images\n"
322 | ]
323 | }
324 | ],
325 | "source": [
326 | "!ls {TRN_PATH}\n",
327 | "!ls {TEST_PATH}"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "metadata": {
333 | "colab_type": "text",
334 | "id": "Dz4H7ksT0tMo"
335 | },
336 | "source": [
337 | "### EDA"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 0,
343 | "metadata": {
344 | "colab": {},
345 | "colab_type": "code",
346 | "id": "WpM9V6Pmzc3o"
347 | },
348 | "outputs": [],
349 | "source": [
350 | "depth = pd.read_csv(PATH/'depths.csv')\n",
351 | "train = pd.read_csv(PATH/'train.csv')\n",
352 | "submission = pd.read_csv(PATH/'sample_submission.csv')"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": 0,
358 | "metadata": {
359 | "colab": {},
360 | "colab_type": "code",
361 | "id": "U25DaYEF1L4U"
362 | },
363 | "outputs": [],
364 | "source": [
365 | "submission.head()"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": 0,
371 | "metadata": {
372 | "colab": {},
373 | "colab_type": "code",
374 | "id": "ZUNgd4s50-dd"
375 | },
376 | "outputs": [],
377 | "source": [
378 | "depth.head()"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 0,
384 | "metadata": {
385 | "colab": {},
386 | "colab_type": "code",
387 | "id": "NjEG1A6F5V9k"
388 | },
389 | "outputs": [],
390 | "source": [
391 | "depth.set_index('id').z.to_dict()"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": 0,
397 | "metadata": {
398 | "colab": {},
399 | "colab_type": "code",
400 | "id": "w4TItUxR5ND0"
401 | },
402 | "outputs": [],
403 | "source": [
404 | "depth.z.to_dict()"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 0,
410 | "metadata": {
411 | "colab": {},
412 | "colab_type": "code",
413 | "id": "Y6-sKto01EPu"
414 | },
415 | "outputs": [],
416 | "source": [
417 | "train.head()"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 0,
423 | "metadata": {
424 | "colab": {},
425 | "colab_type": "code",
426 | "id": "2AgjhosP1G2s"
427 | },
428 | "outputs": [],
429 | "source": [
430 | "!ls {TRN_PATH}/images | head -10"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": 0,
436 | "metadata": {
437 | "colab": {},
438 | "colab_type": "code",
439 | "id": "fFXbh_0F1Awx"
440 | },
441 | "outputs": [],
442 | "source": [
443 | "IMG_ID = '008a50a2ec'"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": 0,
449 | "metadata": {
450 | "colab": {},
451 | "colab_type": "code",
452 | "id": "3PLLX_oV1i9Z"
453 | },
454 | "outputs": [],
455 | "source": [
456 | "Image.open(TRN_IMG_PATH/f'{IMG_ID}.png').resize((32, 32))"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 0,
462 | "metadata": {
463 | "colab": {},
464 | "colab_type": "code",
465 | "id": "pjCMbgZk2Eu4"
466 | },
467 | "outputs": [],
468 | "source": [
469 | "Image.open(TRN_MASK_PATH/f'{IMG_ID}.png').resize((32, 32))"
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": 0,
475 | "metadata": {
476 | "colab": {},
477 | "colab_type": "code",
478 | "id": "B-yOgm0x0QbQ"
479 | },
480 | "outputs": [],
481 | "source": [
482 | "rn34 = torchvision.models.resnet34()"
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "execution_count": 0,
488 | "metadata": {
489 | "colab": {},
490 | "colab_type": "code",
491 | "id": "GRmr-npg3JpT"
492 | },
493 | "outputs": [],
494 | "source": [
495 | "rn34"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": 0,
501 | "metadata": {
502 | "colab": {},
503 | "colab_type": "code",
504 | "id": "5SmuAbQh0Qfz"
505 | },
506 | "outputs": [],
507 | "source": [
508 | "rn34.layer2"
509 | ]
510 | },
511 | {
512 | "cell_type": "code",
513 | "execution_count": 0,
514 | "metadata": {
515 | "colab": {},
516 | "colab_type": "code",
517 | "id": "eLYRkoraNbfi"
518 | },
519 | "outputs": [],
520 | "source": [
521 | "m_base = get_base_model(arch, cut, True)"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": 0,
527 | "metadata": {
528 | "colab": {},
529 | "colab_type": "code",
530 | "id": "QlUi3xT5jlxA"
531 | },
532 | "outputs": [],
533 | "source": [
534 | "m_base[4]"
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": 0,
540 | "metadata": {
541 | "colab": {},
542 | "colab_type": "code",
543 | "id": "teVt8T7aNTnE"
544 | },
545 | "outputs": [],
546 | "source": [
547 | "pretrained_state_dict = torch.load('/root/.torch/models/resnet34-333f7ec4.pth')"
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "execution_count": 0,
553 | "metadata": {
554 | "colab": {},
555 | "colab_type": "code",
556 | "id": "qDY9tpPZNT0R"
557 | },
558 | "outputs": [],
559 | "source": [
560 | "pretrained_state_dict.keys()"
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 0,
566 | "metadata": {
567 | "colab": {},
568 | "colab_type": "code",
569 | "id": "prL_3f_JN4F2"
570 | },
571 | "outputs": [],
572 | "source": [
573 | "arch()"
574 | ]
575 | },
576 | {
577 | "cell_type": "markdown",
578 | "metadata": {
579 | "colab_type": "text",
580 | "id": "fkL9G_kG07dN"
581 | },
582 | "source": [
583 | "### Dataset"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": 0,
589 | "metadata": {
590 | "colab": {},
591 | "colab_type": "code",
592 | "id": "eG_MV_FOOfh9"
593 | },
594 | "outputs": [],
595 | "source": [
596 | "TRN_DN = 'train/images'\n",
597 | "MASK_DN = 'train/masks'\n",
598 | "TEST_DN = 'test/images'\n",
599 | "TRN_PATH = PATH/TRN_DN\n",
600 | "MASK_PATH = PATH/MASK_DN"
601 | ]
602 | },
603 | {
604 | "cell_type": "code",
605 | "execution_count": 0,
606 | "metadata": {
607 | "colab": {},
608 | "colab_type": "code",
609 | "id": "5Sj0FM7X_5gC"
610 | },
611 | "outputs": [],
612 | "source": [
613 | "today = datetime.datetime.now().strftime('%Y%m%d')\n",
614 | "\n",
615 | "def save_model_weights(path=DISK_PATH):\n",
616 | " targ_dir = f'{path}/{today}/models'\n",
617 | " os.makedirs(targ_dir, exist_ok=True)\n",
618 | " for f in list(PATH.glob('**/*.h5')):\n",
619 | " shutil.copy('/'.join(f.parts), targ_dir)\n",
620 | " \n",
621 | "def load_model_weights(path=PATH, disk_path=DISK_PATH, subdir=None):\n",
622 | " if subdir == None: subdir = today\n",
623 | " os.makedirs(f'{path}/models', exist_ok=True)\n",
624 | " src_dir = f'{disk_path}/{subdir}/models'\n",
625 | " targ_dir = f'{path}/models'\n",
626 | " for f in os.listdir(src_dir):\n",
627 | " shutil.copy(os.path.join(src_dir, f), targ_dir)\n",
628 | " \n",
629 | "def save_dataset(path=DISK_PATH):\n",
630 | " targ_dir = f'{path}/{today}'\n",
631 | " os.makedirs(targ_dir, exist_ok=True)\n",
632 | " shutil.copy('test_x.pkl', targ_dir)\n",
633 | " shutil.copy('train_folds.pkl', targ_dir)\n",
634 | " shutil.copy('val_folds.pkl', targ_dir)\n",
635 | " \n",
636 | "def load_dataset(path=PATH, disk_path=DISK_PATH, time_stamp=None):\n",
637 | " if time_stamp == None: time_stamp = today\n",
638 | " src_dir = f'{disk_path}/{time_stamp}'\n",
639 | "# assert os.path.exists(src_dir) == True\n",
640 | " shutil.copy(f'{src_dir}/test_x.pkl', '.')\n",
641 | " shutil.copy(f'{src_dir}/train_folds.pkl', '.')\n",
642 | " shutil.copy(f'{src_dir}/val_folds.pkl', '.')"
643 | ]
644 | },
645 | {
646 | "cell_type": "code",
647 | "execution_count": 0,
648 | "metadata": {
649 | "colab": {},
650 | "colab_type": "code",
651 | "id": "kVd-dF4ZO8s7"
652 | },
653 | "outputs": [],
654 | "source": [
655 | "class DepthDataset(Dataset):\n",
656 | " def __init__(self,ds,dpth_dict):\n",
657 | " self.dpth = dpth_dict\n",
658 | " self.ds = ds\n",
659 | " \n",
660 | " def __getitem__(self,i):\n",
661 | " val = self.ds[i]\n",
662 | " return val[0],self.dpth[self.ds.fnames[i].split('/')[-1][:-4]],val[1]\n",
663 | " \n",
664 | "class DepthDatasetV2(Dataset):\n",
665 | " def __init__(self,ds,dpth_dict):\n",
666 | " self.dpth = dpth_dict\n",
667 | " self.ds = ds\n",
668 | " \n",
669 | " def __getitem__(self,i):\n",
670 | " val = self.ds[i]\n",
671 | " return val[0],self.dpth[self.ds.fnames[i].name[:-4]],val[1]\n",
672 | " \n",
673 | "class MatchedFilesDataset(FilesDataset):\n",
674 | " def __init__(self, fnames, y, transform, path):\n",
675 | " self.y=y\n",
676 | " assert(len(fnames)==len(y))\n",
677 | " super().__init__(fnames, transform, path)\n",
678 | " \n",
679 | " def get_x(self, i): \n",
680 | " return open_image(os.path.join(self.path, self.fnames[i]))\n",
681 | " \n",
682 | " def get_y(self, i):\n",
683 | " return open_image(os.path.join(str(self.path), str(self.y[i])))\n",
684 | "\n",
685 | " def get_c(self): return 0\n",
686 | " \n",
687 | "class TestFilesDataset(FilesDataset):\n",
688 | " def __init__(self, fnames, y, transform,flip, path):\n",
689 | " self.y=y\n",
690 | " self.flip = flip\n",
691 | " super().__init__(fnames, transform, path)\n",
692 | " \n",
693 | " def get_x(self, i): \n",
694 | " im = open_image(os.path.join(self.path, self.fnames[i]))\n",
695 | " return np.fliplr(im) if self.flip else im\n",
696 | " \n",
697 | " def get_y(self, i):\n",
698 | " im = open_image(os.path.join(str(self.path), str(self.y[i])))\n",
699 | " return np.fliplr(im) if self.flip else im\n",
700 | " def get_c(self): return 0"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": 0,
706 | "metadata": {
707 | "colab": {},
708 | "colab_type": "code",
709 | "id": "SYwzDb6CPqEw"
710 | },
711 | "outputs": [],
712 | "source": [
713 | "# trn = pd.read_csv(PATH/'train.csv')\n",
714 | "dpth = pd.read_csv(PATH/'depths.csv')\n",
715 | "c = dpth.set_index('id')\n",
716 | "dpth_dict = c['z'].to_dict()"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": 0,
722 | "metadata": {
723 | "colab": {},
724 | "colab_type": "code",
725 | "id": "4mlccVfddfRm"
726 | },
727 | "outputs": [],
728 | "source": [
729 | "kf = 5\n",
730 | "# kf = 10"
731 | ]
732 | },
733 | {
734 | "cell_type": "code",
735 | "execution_count": 0,
736 | "metadata": {
737 | "colab": {},
738 | "colab_type": "code",
739 | "id": "1qSG--oWPJNS"
740 | },
741 | "outputs": [],
742 | "source": [
743 | "x_names = np.array([Path('/'.join([TRN_DN, f])) for f in list(os.listdir(TRN_PATH))])\n",
744 | "y_names = np.array([Path('/'.join([MASK_DN, f])) for f in list(os.listdir(TRN_PATH))])\n",
745 | "test_x = np.array([Path('/'.join([TEST_DN, f.name])) for f in list((PATH/TEST_DN).iterdir())])\n",
746 | "f_name = [o.name for o in x_names]\n",
747 | "\n",
748 | "c = dpth.set_index('id')\n",
749 | "dpth_dict = c['z'].to_dict()\n",
750 | "\n",
751 | "kfold = KFold(n_splits=kf, shuffle=True, random_state=42)\n",
752 | "\n",
753 | "train_folds = []\n",
754 | "val_folds = []\n",
755 | "for idxs in kfold.split(f_name):\n",
756 | " train_folds.append([f_name[idx] for idx in idxs[0]])\n",
757 | " val_folds.append([f_name[idx] for idx in idxs[1]])\n",
758 | "\n",
759 | "with open('train_folds.pkl', 'wb') as fp:\n",
760 | " pickle.dump(train_folds, fp)\n",
761 | "with open('val_folds.pkl', 'wb') as fp:\n",
762 | " pickle.dump(val_folds, fp)\n",
763 | "with open('test_x.pkl', 'wb') as fp:\n",
764 | " pickle.dump(test_x, fp)\n",
765 | "\n",
766 | "save_dataset()"
767 | ]
768 | },
769 | {
770 | "cell_type": "code",
771 | "execution_count": 0,
772 | "metadata": {
773 | "colab": {},
774 | "colab_type": "code",
775 | "id": "XgpkySKgYiMw"
776 | },
777 | "outputs": [],
778 | "source": [
779 | "load_dataset(time_stamp='20181008') # 5 folds\n",
780 | "# load_dataset(time_stamp='20181016') # 10 folds"
781 | ]
782 | },
783 | {
784 | "cell_type": "code",
785 | "execution_count": 0,
786 | "metadata": {
787 | "colab": {},
788 | "colab_type": "code",
789 | "id": "E9EHO3w5SZVD"
790 | },
791 | "outputs": [],
792 | "source": [
793 | "train_folds = pickle.load(open('train_folds.pkl',mode='rb'))\n",
794 | "val_folds = pickle.load(open('val_folds.pkl',mode='rb'))\n",
795 | "test_x = pickle.load(open('test_x.pkl',mode='rb'))"
796 | ]
797 | },
798 | {
799 | "cell_type": "code",
800 | "execution_count": 0,
801 | "metadata": {
802 | "colab": {},
803 | "colab_type": "code",
804 | "id": "ys7cmvFiSDy3"
805 | },
806 | "outputs": [],
807 | "source": [
808 | "def open_image(fn):\n",
809 | " \"\"\" Opens an image using OpenCV given the file path.\n",
810 | "\n",
811 | " Arguments:\n",
812 | " fn: the file path of the image\n",
813 | "\n",
814 | " Returns:\n",
815 | " The image in RGB format as numpy array of floats normalized to range between 0.0 - 1.0\n",
816 | " \"\"\"\n",
817 | " flags = cv2.IMREAD_UNCHANGED+cv2.IMREAD_ANYDEPTH+cv2.IMREAD_ANYCOLOR\n",
818 | " if not os.path.exists(fn) and not str(fn).startswith(\"http\"):\n",
819 | " raise OSError('No such file or directory: {}'.format(fn))\n",
820 | " elif os.path.isdir(fn) and not str(fn).startswith(\"http\"):\n",
821 | " raise OSError('Is a directory: {}'.format(fn))\n",
822 | "# elif isdicom(fn):\n",
823 | "# slice = pydicom.read_file(fn)\n",
824 | "# if slice.PhotometricInterpretation.startswith('MONOCHROME'):\n",
825 | "# # Make a fake RGB image\n",
826 | "# im = np.stack([slice.pixel_array]*3,-1)\n",
827 | "# return im / ((1 << slice.BitsStored)-1)\n",
828 | "# else:\n",
829 | "# # No support for RGB yet, as it involves various color spaces.\n",
830 | "# # It shouldn't be too difficult to add though, if needed.\n",
831 | "# raise OSError('Unsupported DICOM image with PhotometricInterpretation=={}'.format(slice.PhotometricInterpretation))\n",
832 | " else:\n",
833 | " #res = np.array(Image.open(fn), dtype=np.float32)/255\n",
834 | " #if len(res.shape)==2: res = np.repeat(res[...,None],3,2)\n",
835 | " #return res\n",
836 | " try:\n",
837 | " if str(fn).startswith(\"http\"):\n",
838 | " req = urllib.urlopen(str(fn))\n",
839 | " image = np.asarray(bytearray(req.read()), dtype=\"uint8\")\n",
840 | " im = cv2.imdecode(image, flags).astype(np.float32)/255\n",
841 | " else:\n",
842 | " im = cv2.imread(str(fn), flags).astype(np.float32)/255\n",
843 | " if im is None: raise OSError(f'File not recognized by opencv: {fn}')\n",
844 | " return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)\n",
845 | " except Exception as e:\n",
846 | " raise OSError('Error handling image at: {}'.format(fn)) from e"
847 | ]
848 | },
849 | {
850 | "cell_type": "markdown",
851 | "metadata": {
852 | "colab_type": "text",
853 | "id": "0vkz5LH-NF6i"
854 | },
855 | "source": [
856 | "### Model"
857 | ]
858 | },
859 | {
860 | "cell_type": "code",
861 | "execution_count": 0,
862 | "metadata": {
863 | "colab": {},
864 | "colab_type": "code",
865 | "id": "3Kl7WMiC9aCK"
866 | },
867 | "outputs": [],
868 | "source": [
869 | "class SCSEModule(nn.Module):\n",
870 | " def __init__(self, ch, re=16):\n",
871 | " super().__init__()\n",
872 | " self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1),\n",
873 | " nn.Conv2d(ch,ch//re,1),\n",
874 | " nn.ReLU(inplace=True),\n",
875 | " nn.Conv2d(ch//re,ch,1),\n",
876 | " nn.Sigmoid())\n",
877 | " self.sSE = nn.Sequential(nn.Conv2d(ch,ch,1),\n",
878 | " nn.Sigmoid())\n",
879 | "\n",
880 | " def forward(self, x):\n",
881 | " return x * self.cSE(x) + x * self.sSE(x)\n",
882 | "\n",
883 | "\n",
884 | "class SaveFeatures():\n",
885 | " features=None\n",
886 | " def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)\n",
887 | " def hook_fn(self, module, input, output): self.features = output\n",
888 | " def remove(self): self.hook.remove()\n",
889 | "\n",
890 | "\n",
891 | "class UnetBlock(nn.Module):\n",
892 | " def __init__(self, up_in, x_in, n_out):\n",
893 | " super().__init__()\n",
894 | " up_out = x_out = n_out//2\n",
895 | " self.x_conv = nn.Conv2d(x_in, x_out, 1)\n",
896 | " self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)\n",
897 | " self.x_scse = SCSEModule(x_in)\n",
898 | " self.up_scse = SCSEModule(up_in)\n",
899 | " self.bn = nn.BatchNorm2d(n_out)\n",
900 | " \n",
901 | " def forward(self, up_p, x_p):\n",
902 | " up_p = self.tr_conv(self.up_scse(up_p))\n",
903 | " x_p = self.x_conv(self.x_scse(x_p))\n",
904 | " cat_p = torch.cat([up_p,x_p], dim=1)\n",
905 | " return self.bn(F.relu(cat_p))\n",
906 | "\n",
907 | "\n",
908 | "class HyperColumn(nn.Module):\n",
909 | " def __init__(self, nf):\n",
910 | " super().__init__()\n",
911 | " self.conv1 = nn.Conv2d(nf, nf // 16, 3, 1, 1)\n",
912 | " self.conv2 = nn.Conv2d(nf // 16, 1, 1)\n",
913 | " self.bn = nn.BatchNorm2d(nf // 16)\n",
914 | " self.relu = nn.ReLU(inplace=True)\n",
915 | " \n",
916 | " def forward(self, x):\n",
917 | " return self.conv2(self.bn(self.relu(self.conv1(x))))\n",
918 | "\n",
919 | "\n",
920 | "class Unet34(nn.Module):\n",
921 | " def __init__(self, rn):\n",
922 | " super().__init__()\n",
923 | " self.rn = rn\n",
924 | " self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]\n",
925 | " self.up1 = UnetBlock(512,256,128)\n",
926 | " self.up2 = UnetBlock(128,128,128)\n",
927 | " self.up3 = UnetBlock(128,64,128)\n",
928 | " self.up4 = UnetBlock(128,64,128)\n",
929 | "# self.up5 = nn.ConvTranspose2d(128, 1, 2, stride=2)\n",
930 | " self.up5 = nn.ConvTranspose2d(128, 16, 2, stride=2)\n",
931 | " self.dc = nn.Conv2d(128, 16, 1)\n",
932 | " self.drop = nn.Dropout2d(p=0.5, inplace=True)\n",
933 | " self.logit = nn.Sequential(\n",
934 | " nn.ReLU(inplace=True),\n",
935 | " nn.Conv2d(80, 1, 1)\n",
936 | " )\n",
937 | " \n",
938 | " def forward(self,img,depth):\n",
939 | " x = F.relu(self.rn(img))\n",
940 | " x1 = self.up1(x, self.sfs[3].features)\n",
941 | " x2 = self.up2(x1, self.sfs[2].features)\n",
942 | " x3 = self.up3(x2, self.sfs[1].features)\n",
943 | " x4 = self.up4(x3, self.sfs[0].features)\n",
944 | "# x = self.up5(x4)\n",
945 | " x5 = self.up5(x4)\n",
946 | " x = torch.cat((x5,\n",
947 | " F.interpolate(self.dc(x4),scale_factor=2,align_corners=False,mode='bilinear'),\n",
948 | " F.interpolate(self.dc(x3),scale_factor=4,align_corners=False,mode='bilinear'),\n",
949 | " F.interpolate(self.dc(x2),scale_factor=8,align_corners=False,mode='bilinear'),\n",
950 | " F.interpolate(self.dc(x1),scale_factor=16,align_corners=False,mode='bilinear')),1)\n",
951 | " x = self.logit(self.drop(x))\n",
952 | " return x[:,0]\n",
953 | " \n",
954 | " def close(self):\n",
955 | " for sf in self.sfs: sf.remove()\n",
956 | "\n",
957 | "\n",
958 | "class UnetModel():\n",
959 | " def __init__(self,model,lr_cut,name='unet'):\n",
960 | " self.model,self.name = model,name\n",
961 | " self.lr_cut = lr_cut\n",
962 | "\n",
963 | " def get_layer_groups(self, precompute):\n",
964 | " lgs = list(split_by_idxs(children(self.model.rn), [self.lr_cut]))\n",
965 | " return lgs + [children(self.model)[1:]]"
966 | ]
967 | },
968 | {
969 | "cell_type": "code",
970 | "execution_count": 0,
971 | "metadata": {
972 | "colab": {},
973 | "colab_type": "code",
974 | "id": "8Me7dRQVzdKB"
975 | },
976 | "outputs": [],
977 | "source": [
978 | "def dechannels(nf):\n",
979 | " return nn.Sequential(\n",
980 | " nn.Conv2d(nf, 32, 1),\n",
981 | " nn.ReLU(inplace=True)\n",
982 | " )\n",
983 | " \n",
984 | "\n",
985 | "class Unet50(nn.Module):\n",
986 | " def __init__(self, rn):\n",
987 | " super().__init__()\n",
988 | " self.rn = rn\n",
989 | " self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]\n",
990 | " self.up1 = UnetBlock(2048, 1024, 512)\n",
991 | " self.up2 = UnetBlock(512, 512, 256)\n",
992 | " self.up3 = UnetBlock(256, 256, 128)\n",
993 | " self.up4 = UnetBlock(128, 64, 128)\n",
994 | " self.up5 = nn.ConvTranspose2d(128, 1, 2, stride=2)\n",
995 | "# self.up5 = nn.ConvTranspose2d(128, 128, 2, stride=2)\n",
996 | "# self.d1 = dechannels(512)\n",
997 | "# self.d2 = dechannels(256)\n",
998 | "# self.d3 = dechannels(128)\n",
999 | "# self.d4 = dechannels(128)\n",
1000 | "# self.d5 = dechannels(128)\n",
1001 | "# self.drop = nn.Dropout2d(p=0.5, inplace=True)\n",
1002 | "# self.logit = nn.Conv2d(160, 1, 1)\n",
1003 | "# self.logit = nn.Sequential(\n",
1004 | "# nn.Conv2d(1152, 128, 3, 1, 1),\n",
1005 | "# nn.ReLU(inplace=True),\n",
1006 | "# nn.Conv2d(128, 128, 1)\n",
1007 | "# )\n",
1008 | " \n",
1009 | " def forward(self,img,depth):\n",
1010 | " x = F.relu(self.rn(img))\n",
1011 | " x1 = self.up1(x, self.sfs[3].features)\n",
1012 | " x2 = self.up2(x1, self.sfs[2].features)\n",
1013 | " x3 = self.up3(x2, self.sfs[1].features)\n",
1014 | " x4 = self.up4(x3, self.sfs[0].features)\n",
1015 | " x = self.up5(x4)\n",
1016 | "# x5 = self.up5(x4)\n",
1017 | "# # x = torch.cat((x5,\n",
1018 | "# # F.interpolate(x4,scale_factor=2,align_corners=False,mode='bilinear'),\n",
1019 | "# # F.interpolate(x3,scale_factor=4,align_corners=False,mode='bilinear'),\n",
1020 | "# # F.interpolate(x2,scale_factor=8,align_corners=False,mode='bilinear'),\n",
1021 | "# # F.interpolate(x1,scale_factor=16,align_corners=False,mode='bilinear')),1)\n",
1022 | "# x = torch.cat((self.d5(x5),\n",
1023 | "# F.interpolate(self.d4(x4),scale_factor=2,align_corners=False,mode='bilinear'),\n",
1024 | "# F.interpolate(self.d3(x3),scale_factor=4,align_corners=False,mode='bilinear'),\n",
1025 | "# F.interpolate(self.d2(x2),scale_factor=8,align_corners=False,mode='bilinear'),\n",
1026 | "# F.interpolate(self.d1(x1),scale_factor=16,align_corners=False,mode='bilinear')),1)\n",
1027 | "# x = self.logit(self.drop(x))\n",
1028 | " return x[:,0]\n",
1029 | " \n",
1030 | " def close(self):\n",
1031 | " for sf in self.sfs: sf.remove()"
1032 | ]
1033 | },
1034 | {
1035 | "cell_type": "code",
1036 | "execution_count": 0,
1037 | "metadata": {
1038 | "colab": {},
1039 | "colab_type": "code",
1040 | "id": "6KHDtXQ4WESz"
1041 | },
1042 | "outputs": [],
1043 | "source": [
1044 | "def get_tgs_model():\n",
1045 | "# f = resnet34\n",
1046 | " f = resnext50\n",
1047 | " cut,lr_cut = model_meta[f]\n",
1048 | " m_base = get_base(f,cut)\n",
1049 | "# m = to_gpu(Unet34(m_base))\n",
1050 | " m = to_gpu(Unet50(m_base))\n",
1051 | " models = UnetModel(m,lr_cut)\n",
1052 | " learn = ConvLearner(md, models)\n",
1053 | " return learn\n",
1054 | "\n",
1055 | "def get_base(f,cut):\n",
1056 | " layers = cut_model(f(True), cut)\n",
1057 | " return nn.Sequential(*layers)"
1058 | ]
1059 | },
1060 | {
1061 | "cell_type": "code",
1062 | "execution_count": 0,
1063 | "metadata": {
1064 | "colab": {
1065 | "base_uri": "https://localhost:8080/",
1066 | "height": 34
1067 | },
1068 | "colab_type": "code",
1069 | "id": "lWcwr6hJFFHU",
1070 | "outputId": "ddca2693-ffc7-4b8b-a8e9-967e76e8de35"
1071 | },
1072 | "outputs": [
1073 | {
1074 | "data": {
1075 | "text/plain": [
1076 | "800"
1077 | ]
1078 | },
1079 | "execution_count": 27,
1080 | "metadata": {
1081 | "tags": []
1082 | },
1083 | "output_type": "execute_result"
1084 | }
1085 | ],
1086 | "source": [
1087 | "gc.collect()"
1088 | ]
1089 | },
1090 | {
1091 | "cell_type": "markdown",
1092 | "metadata": {
1093 | "colab_type": "text",
1094 | "id": "PIwVeVA9xv04"
1095 | },
1096 | "source": [
1097 | "### Training"
1098 | ]
1099 | },
1100 | {
1101 | "cell_type": "code",
1102 | "execution_count": 0,
1103 | "metadata": {
1104 | "colab": {},
1105 | "colab_type": "code",
1106 | "id": "lZcBehwhkNb7"
1107 | },
1108 | "outputs": [],
1109 | "source": [
1110 | "load_model_weights(subdir='tags/score_8.41')\n",
1111 | "# load_model_weights(subdir='20181016')\n",
1112 | "# load_model_weights(subdir='20181018')"
1113 | ]
1114 | },
1115 | {
1116 | "cell_type": "code",
1117 | "execution_count": 0,
1118 | "metadata": {
1119 | "colab": {},
1120 | "colab_type": "code",
1121 | "id": "srDUaJW0fhTn"
1122 | },
1123 | "outputs": [],
1124 | "source": [
1125 | "def lovasz_hinge_flat(logits, labels):\n",
1126 | " if len(labels) == 0:\n",
1127 | " # only void pixels, the gradients should be 0\n",
1128 | " return logits.sum() * 0.\n",
1129 | " signs = 2. * labels.float() - 1.\n",
1130 | " errors = (1. - logits * Variable(signs.data))\n",
1131 | " errors_sorted, perm = torch.sort(errors, dim=0, descending=True)\n",
1132 | " perm = perm.data\n",
1133 | " gt_sorted = labels[perm]\n",
1134 | " grad = lovasz_grad(gt_sorted)\n",
1135 | " loss = torch.dot(F.elu(errors_sorted) +1, Variable(grad.data))\n",
1136 | " return loss\n",
1137 | "\n",
1138 | " \n",
1139 | "def lovasz_hinge(logits, labels, per_image=True, ignore=None):\n",
1140 | " if per_image:\n",
1141 | " loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))\n",
1142 | " for log, lab in zip(logits, labels))\n",
1143 | " else:\n",
1144 | " loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))\n",
1145 | " return loss"
1146 | ]
1147 | },
1148 | {
1149 | "cell_type": "code",
1150 | "execution_count": 0,
1151 | "metadata": {
1152 | "colab": {},
1153 | "colab_type": "code",
1154 | "id": "haFc4sz-PIk3"
1155 | },
1156 | "outputs": [],
1157 | "source": [
1158 | "sz = 128\n",
1159 | "# sz = 192\n",
1160 | "\n",
1161 | "aug = Compose([\n",
1162 | " OneOf([RandomSizedCrop(min_max_height=(50, 101), height=sz, width=sz, p=0.5),\n",
1163 | " Resize(sz, sz, p=0.5)], p=1),\n",
1164 | " OneOf([VerticalFlip(p=0.5), RandomRotate90(p=0.5), HorizontalFlip(p=0.5)], p=0.5),\n",
1165 | "# OneOf([\n",
1166 | "# ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),\n",
1167 | "# GridDistortion(p=0.5),\n",
1168 | "# OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5) \n",
1169 | "# ], p=0.8),\n",
1170 | "# RandomContrast(limit=0.1, p=0.5),\n",
1171 | " RandomBrightness(limit=0.1, p=0.5),\n",
1172 | "# RandomGamma(p=0.5)\n",
1173 | "])"
1174 | ]
1175 | },
1176 | {
1177 | "cell_type": "code",
1178 | "execution_count": 0,
1179 | "metadata": {
1180 | "colab": {},
1181 | "colab_type": "code",
1182 | "id": "AZB93ZfeO1Kz"
1183 | },
1184 | "outputs": [],
1185 | "source": [
1186 | "class Transforms():\n",
1187 | " def __init__(self, sz, tfms, normalizer, denorm, crop_type=CropType.CENTER,\n",
1188 | " tfm_y=TfmType.NO, sz_y=None, is_trn=True):\n",
1189 | " if sz_y is None: sz_y = sz\n",
1190 | " self.sz,self.denorm,self.norm,self.sz_y = sz,denorm,normalizer,sz_y\n",
1191 | " crop_tfm = crop_fn_lu[crop_type](sz, tfm_y, sz_y)\n",
1192 | " self.tfms = tfms\n",
1193 | " self.tfms.append(crop_tfm)\n",
1194 | " if normalizer is not None: self.tfms.append(normalizer)\n",
1195 | " self.tfms.append(ChannelOrder(tfm_y))\n",
1196 | " self.channel_order = ChannelOrder(tfm_y)\n",
1197 | " self.trn_tfms = aug if is_trn == True else None\n",
1198 | "\n",
1199 | "# def __call__(self, im, y=None): return compose(im, y, self.trn_tfms, self.tfms, self.norm)\n",
1200 | " def __call__(self, im, y=None): return self.compose(im, y)\n",
1201 | " def __repr__(self): return str(self.trn_tfms) if self.trn_tfms else str(self.tfms)\n",
1202 | " \n",
1203 | " def compose(self, im, y):\n",
1204 | " if self.trn_tfms:\n",
1205 | " augmented = self.trn_tfms(image=im, mask=y)\n",
1206 | " im, y = self.norm(augmented['image'], augmented['mask'])\n",
1207 | " return self.channel_order(im, y)\n",
1208 | " else:\n",
1209 | " for fn in self.tfms:\n",
1210 | " im, y =fn(im, y)\n",
1211 | " return im if y is None else (im, y)\n",
1212 | "\n",
1213 | "\n",
1214 | "def image_gen(normalizer, denorm, sz, tfms=None, max_zoom=None, pad=0, crop_type=None,\n",
1215 | " tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, scale=None, is_trn=True):\n",
1216 | " if tfm_y is None: tfm_y=TfmType.NO\n",
1217 | " if tfms is None: tfms=[]\n",
1218 | " elif not isinstance(tfms, collections.Iterable): tfms=[tfms]\n",
1219 | " if sz_y is None: sz_y = sz\n",
1220 | " if scale is None:\n",
1221 | " scale = [RandomScale(sz, max_zoom, tfm_y=tfm_y, sz_y=sz_y) if max_zoom is not None\n",
1222 | " else Scale(sz, tfm_y, sz_y=sz_y)]\n",
1223 | " elif not is_listy(scale): scale = [scale]\n",
1224 | " if pad: scale.append(AddPadding(pad, mode=pad_mode))\n",
1225 | " if crop_type!=CropType.GOOGLENET: tfms=scale+tfms\n",
1226 | " return Transforms(sz, tfms, normalizer, denorm, crop_type,\n",
1227 | " tfm_y=tfm_y, sz_y=sz_y, is_trn=is_trn)\n",
1228 | " \n",
1229 | " \n",
1230 | "def tfms_from_stats(stats, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,\n",
1231 | " tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):\n",
1232 | " if aug_tfms is None: aug_tfms=[]\n",
1233 | " tfm_norm = Normalize(*stats, tfm_y=tfm_y if norm_y else TfmType.NO) if stats is not None else None\n",
1234 | " tfm_denorm = Denormalize(*stats) if stats is not None else None\n",
1235 | " val_crop = CropType.CENTER if crop_type in (CropType.RANDOM,CropType.GOOGLENET) else crop_type\n",
1236 | " val_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=val_crop,\n",
1237 | " tfm_y=tfm_y, sz_y=sz_y, scale=scale, is_trn=False)\n",
1238 | " trn_tfm = image_gen(tfm_norm, tfm_denorm, sz, pad=pad, crop_type=crop_type,\n",
1239 | " tfm_y=tfm_y, sz_y=sz_y, tfms=aug_tfms, max_zoom=max_zoom, pad_mode=pad_mode, scale=scale, is_trn=True)\n",
1240 | " return trn_tfm, val_tfm\n",
1241 | " \n",
1242 | " \n",
1243 | "def tfms_from_model(f_model, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,\n",
1244 | " tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):\n",
1245 | " stats = inception_stats if f_model in inception_models else imagenet_stats\n",
1246 | " return tfms_from_stats(stats, sz, aug_tfms, max_zoom=max_zoom, pad=pad, crop_type=crop_type,\n",
1247 | " tfm_y=tfm_y, sz_y=sz_y, pad_mode=pad_mode, norm_y=norm_y, scale=scale)"
1248 | ]
1249 | },
1250 | {
1251 | "cell_type": "markdown",
1252 | "metadata": {
1253 | "colab_type": "text",
1254 | "id": "b53Co8tFUCZh"
1255 | },
1256 | "source": [
1257 | "### Training: resnext50 + scSE + hypercolumns"
1258 | ]
1259 | },
1260 | {
1261 | "cell_type": "code",
1262 | "execution_count": 0,
1263 | "metadata": {
1264 | "colab": {
1265 | "base_uri": "https://localhost:8080/",
1266 | "height": 1442
1267 | },
1268 | "colab_type": "code",
1269 | "id": "vaZO9Kz44xYu",
1270 | "outputId": "98d1552c-a4a1-41b4-85b9-cfac455fb289"
1271 | },
1272 | "outputs": [
1273 | {
1274 | "name": "stdout",
1275 | "output_type": "stream",
1276 | "text": [
1277 | "[128, 64]\n",
1278 | "fold_id0\n",
1279 | "5_fold_resnext50_128_scse_hcol_0\n"
1280 | ]
1281 | },
1282 | {
1283 | "data": {
1284 | "application/vnd.jupyter.widget-view+json": {
1285 | "model_id": "b582b36f65e04eb380f8bfa3d7a05a74",
1286 | "version_major": 2,
1287 | "version_minor": 0
1288 | },
1289 | "text/plain": [
1290 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))"
1291 | ]
1292 | },
1293 | "metadata": {
1294 | "tags": []
1295 | },
1296 | "output_type": "display_data"
1297 | },
1298 | {
1299 | "name": "stdout",
1300 | "output_type": "stream",
1301 | "text": [
1302 | "epoch trn_loss val_loss my_eval \n",
1303 | " 0 0.296024 0.229948 0.6635 \n",
1304 | " 1 0.226501 0.167776 0.7175 \n",
1305 | " 2 0.197844 0.174305 0.70425 \n",
1306 | " 3 0.178601 0.157858 0.734 \n",
1307 | " 4 0.159284 0.141243 0.758375 \n",
1308 | " 5 0.156893 0.144844 0.7555 \n",
1309 | " 6 0.139704 0.14005 0.76825 \n",
1310 | " 7 0.129483 0.135873 0.782125 \n",
1311 | " 8 0.117466 0.133637 0.77575 \n",
1312 | " 9 0.107742 0.148499 0.7775 \n",
1313 | "\n",
1314 | "0.7821250000000001\n",
1315 | "fold_id1\n",
1316 | "5_fold_resnext50_128_scse_hcol_1\n"
1317 | ]
1318 | },
1319 | {
1320 | "data": {
1321 | "application/vnd.jupyter.widget-view+json": {
1322 | "model_id": "e95165aba22348e1a9e374735a41960f",
1323 | "version_major": 2,
1324 | "version_minor": 0
1325 | },
1326 | "text/plain": [
1327 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))"
1328 | ]
1329 | },
1330 | "metadata": {
1331 | "tags": []
1332 | },
1333 | "output_type": "display_data"
1334 | },
1335 | {
1336 | "name": "stdout",
1337 | "output_type": "stream",
1338 | "text": [
1339 | "epoch trn_loss val_loss my_eval \n",
1340 | " 0 0.298795 0.330592 0.577375 \n",
1341 | " 1 0.232322 0.154958 0.684125 \n",
1342 | " 2 0.202833 0.152708 0.708 \n",
1343 | " 3 0.182856 0.165284 0.688875 \n",
1344 | " 4 0.160326 0.142061 0.73875 \n",
1345 | " 5 0.149245 0.133409 0.737375 \n",
1346 | " 6 0.142726 0.129582 0.747375 \n",
1347 | " 7 0.136257 0.128909 0.754375 \n",
1348 | " 8 0.121961 0.14172 0.746 \n",
1349 | " 9 0.111444 0.138969 0.755375 \n",
1350 | "\n",
1351 | "0.7553750000000001\n",
1352 | "fold_id2\n",
1353 | "5_fold_resnext50_128_scse_hcol_2\n"
1354 | ]
1355 | },
1356 | {
1357 | "data": {
1358 | "application/vnd.jupyter.widget-view+json": {
1359 | "model_id": "64c3fc40887d48a6923fa8811195da65",
1360 | "version_major": 2,
1361 | "version_minor": 0
1362 | },
1363 | "text/plain": [
1364 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))"
1365 | ]
1366 | },
1367 | "metadata": {
1368 | "tags": []
1369 | },
1370 | "output_type": "display_data"
1371 | },
1372 | {
1373 | "name": "stdout",
1374 | "output_type": "stream",
1375 | "text": [
1376 | "epoch trn_loss val_loss my_eval \n",
1377 | " 0 0.302976 0.230204 0.6545 \n",
1378 | " 1 0.236865 0.176783 0.70275 \n",
1379 | " 2 0.200898 0.158097 0.706625 \n",
1380 | " 3 0.182162 0.147802 0.752 \n",
1381 | " 4 0.1738 0.146175 0.73025 \n",
1382 | " 5 0.157935 0.135695 0.758875 \n",
1383 | " 6 0.148242 0.143257 0.76125 \n",
1384 | " 7 0.135907 0.130015 0.768625 \n",
1385 | " 8 0.125547 0.12942 0.761 \n",
1386 | " 9 0.114384 0.135982 0.76275 \n",
1387 | "\n",
1388 | "0.7686250000000001\n",
1389 | "fold_id3\n",
1390 | "5_fold_resnext50_128_scse_hcol_3\n"
1391 | ]
1392 | },
1393 | {
1394 | "data": {
1395 | "application/vnd.jupyter.widget-view+json": {
1396 | "model_id": "3a2960908675441aa81622bf7f1373c0",
1397 | "version_major": 2,
1398 | "version_minor": 0
1399 | },
1400 | "text/plain": [
1401 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))"
1402 | ]
1403 | },
1404 | "metadata": {
1405 | "tags": []
1406 | },
1407 | "output_type": "display_data"
1408 | },
1409 | {
1410 | "name": "stdout",
1411 | "output_type": "stream",
1412 | "text": [
1413 | "epoch trn_loss val_loss my_eval \n",
1414 | " 0 0.298731 0.179404 0.649875 \n",
1415 | " 1 0.236641 0.150297 0.6965 \n",
1416 | " 2 0.208887 0.130067 0.725375 \n",
1417 | " 3 0.187371 0.116859 0.7545 \n",
1418 | " 4 0.170393 0.127697 0.749 \n",
1419 | " 5 0.160323 0.117837 0.75725 \n",
1420 | " 6 0.142183 0.105238 0.78125 \n",
1421 | " 7 0.135057 0.109958 0.776875 \n",
1422 | " 8 0.120477 0.110117 0.794 \n",
1423 | " 9 0.113762 0.101085 0.7865 \n",
1424 | "\n",
1425 | "0.794\n",
1426 | "fold_id4\n",
1427 | "5_fold_resnext50_128_scse_hcol_4\n"
1428 | ]
1429 | },
1430 | {
1431 | "data": {
1432 | "application/vnd.jupyter.widget-view+json": {
1433 | "model_id": "3c65c3a78192409ea3cee73d3ad370ba",
1434 | "version_major": 2,
1435 | "version_minor": 0
1436 | },
1437 | "text/plain": [
1438 | "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))"
1439 | ]
1440 | },
1441 | "metadata": {
1442 | "tags": []
1443 | },
1444 | "output_type": "display_data"
1445 | },
1446 | {
1447 | "name": "stdout",
1448 | "output_type": "stream",
1449 | "text": [
1450 | "epoch trn_loss val_loss my_eval \n",
1451 | " 0 0.302227 0.204683 0.65575 \n",
1452 | " 1 0.229877 0.183582 0.71175 \n",
1453 | " 2 0.199228 0.171225 0.68475 \n",
1454 | " 3 0.192234 0.1548 0.734625 \n",
1455 | " 4 0.167416 0.140409 0.74525 \n",
1456 | " 5 0.149521 0.126696 0.743125 \n",
1457 | " 6 0.136639 0.126258 0.774875 \n",
1458 | " 7 0.125201 0.119398 0.774875 \n",
1459 | " 8 0.120503 0.138388 0.762875 \n",
1460 | " 9 0.111816 0.117212 0.782875 \n",
1461 | "\n",
1462 | "0.782875\n"
1463 | ]
1464 | }
1465 | ],
1466 | "source": [
1467 | "model = 'resnext50_128_scse_hcol'\n",
1468 | "arch = resnext50\n",
1469 | "bst_acc=[]\n",
1470 | "use_clr_min=20\n",
1471 | "use_clr_div=10\n",
1472 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
1473 | "\n",
1474 | "szs = [(128,64)]\n",
1475 | "for sz,bs in szs:\n",
1476 | " print([sz,bs])\n",
1477 | " for i in range(kf) :\n",
1478 | " print(f'fold_id{i}')\n",
1479 | " \n",
1480 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
1481 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
1482 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
1483 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
1484 | " \n",
1485 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
1486 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
1487 | " md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)\n",
1488 | " denorm = md.trn_ds.denorm\n",
1489 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
1490 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
1491 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
1492 | " learn = get_tgs_model() \n",
1493 | " learn.opt_fn = optim.Adam\n",
1494 | " learn.metrics=[my_eval]\n",
1495 | " pa = f'{kf}_fold_{model}_{i}'\n",
1496 | " print(pa)\n",
1497 | "\n",
1498 | " lr=1e-2\n",
1499 | " wd=1e-7\n",
1500 | " lrs = np.array([lr/100,lr/10,lr])\n",
1501 | "\n",
1502 | " learn.unfreeze() \n",
1503 | " learn.crit = nn.BCEWithLogitsLoss()\n",
1504 | " if os.path.exists(pa):\n",
1505 | " learn.load(pa)\n",
1506 | " learn.fit(lrs/2,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
1507 | " \n",
1508 | " learn.load(pa) \n",
1509 | " #Calcuating mean iou score\n",
1510 | " v_targ = md.val_ds.ds[:][1]\n",
1511 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
1512 | " v_pred = learn.predict()\n",
1513 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
1514 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
1515 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
1516 | " print(bst_acc[-1])\n",
1517 | " \n",
1518 | " del learn\n",
1519 | " del md\n",
1520 | " gc.collect()"
1521 | ]
1522 | },
1523 | {
1524 | "cell_type": "code",
1525 | "execution_count": 0,
1526 | "metadata": {
1527 | "colab": {
1528 | "base_uri": "https://localhost:8080/",
1529 | "height": 4346
1530 | },
1531 | "colab_type": "code",
1532 | "id": "ROi0HiEA4xeT",
1533 | "outputId": "3edaabab-5f31-4435-ee92-a9ad8d121d57"
1534 | },
1535 | "outputs": [
1536 | {
1537 | "name": "stdout",
1538 | "output_type": "stream",
1539 | "text": [
1540 | "[128, 64]\n",
1541 | "fold_id0\n",
1542 | "5_fold_resnext50_128_scse_hcol_0\n"
1543 | ]
1544 | },
1545 | {
1546 | "data": {
1547 | "application/vnd.jupyter.widget-view+json": {
1548 | "model_id": "cc2e39db82e44742bc11cb86df4f4ac8",
1549 | "version_major": 2,
1550 | "version_minor": 0
1551 | },
1552 | "text/plain": [
1553 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1554 | ]
1555 | },
1556 | "metadata": {
1557 | "tags": []
1558 | },
1559 | "output_type": "display_data"
1560 | },
1561 | {
1562 | "name": "stdout",
1563 | "output_type": "stream",
1564 | "text": [
1565 | "epoch trn_loss val_loss my_eval \n",
1566 | " 0 0.977655 0.9829 0.76025 \n",
1567 | " 1 0.984492 1.024706 0.763625 \n",
1568 | " 2 0.94598 0.968942 0.768125 \n",
1569 | " 3 0.91825 0.867097 0.782625 \n",
1570 | " 4 0.842798 0.835175 0.7835 \n",
1571 | " 5 0.811193 0.819348 0.79875 \n",
1572 | " 6 0.784356 0.763691 0.8005 \n",
1573 | " 7 0.763975 0.773445 0.800375 \n",
1574 | " 8 0.707037 0.741929 0.79875 \n",
1575 | " 9 0.676857 0.763426 0.80925 \n",
1576 | " 10 0.686296 0.84298 0.7955 \n",
1577 | " 11 0.778831 0.926706 0.788625 \n",
1578 | " 12 0.823366 0.929922 0.78175 \n",
1579 | " 13 0.846463 0.840977 0.768375 \n",
1580 | " 14 0.80643 0.766691 0.79525 \n",
1581 | " 15 0.785104 0.772868 0.793375 \n",
1582 | " 16 0.771523 0.761224 0.803625 \n",
1583 | " 17 0.73849 0.711444 0.80725 \n",
1584 | " 18 0.698676 0.748108 0.807125 \n",
1585 | " 19 0.664383 0.732889 0.804625 \n",
1586 | "\n"
1587 | ]
1588 | },
1589 | {
1590 | "data": {
1591 | "application/vnd.jupyter.widget-view+json": {
1592 | "model_id": "38873a4d51c9454a903188c1d9af5995",
1593 | "version_major": 2,
1594 | "version_minor": 0
1595 | },
1596 | "text/plain": [
1597 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1598 | ]
1599 | },
1600 | "metadata": {
1601 | "tags": []
1602 | },
1603 | "output_type": "display_data"
1604 | },
1605 | {
1606 | "name": "stdout",
1607 | "output_type": "stream",
1608 | "text": [
1609 | "epoch trn_loss val_loss my_eval \n",
1610 | " 0 0.642692 0.774102 0.79425 \n",
1611 | " 1 0.670723 0.791728 0.797125 \n",
1612 | " 2 0.701122 0.799743 0.794125 \n",
1613 | " 3 0.695085 0.809558 0.8015 \n",
1614 | " 4 0.709821 0.699042 0.801 \n",
1615 | " 5 0.686103 0.869248 0.78125 \n",
1616 | " 6 0.668086 0.807663 0.79025 \n",
1617 | " 7 0.708722 0.761571 0.786875 \n",
1618 | " 8 0.695549 0.760038 0.809125 \n",
1619 | " 9 0.710776 0.817635 0.794625 \n",
1620 | " 10 0.667159 0.799424 0.79725 \n",
1621 | " 11 0.674622 0.791136 0.797375 \n",
1622 | " 12 0.653431 0.732944 0.809125 \n",
1623 | " 13 0.618604 0.733997 0.803 \n",
1624 | " 14 0.614818 0.781419 0.8115 \n",
1625 | " 15 0.592244 0.753463 0.8135 \n",
1626 | " 16 0.593055 0.76313 0.818875 \n",
1627 | " 17 0.591503 0.800476 0.791875 \n",
1628 | " 18 0.592983 0.749201 0.816125 \n",
1629 | " 19 0.586945 0.764278 0.814375 \n",
1630 | "\n",
1631 | "0.818875\n",
1632 | "fold_id1\n",
1633 | "5_fold_resnext50_128_scse_hcol_1\n"
1634 | ]
1635 | },
1636 | {
1637 | "data": {
1638 | "application/vnd.jupyter.widget-view+json": {
1639 | "model_id": "3d9c315bb12748429c63e7d655c9407d",
1640 | "version_major": 2,
1641 | "version_minor": 0
1642 | },
1643 | "text/plain": [
1644 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1645 | ]
1646 | },
1647 | "metadata": {
1648 | "tags": []
1649 | },
1650 | "output_type": "display_data"
1651 | },
1652 | {
1653 | "name": "stdout",
1654 | "output_type": "stream",
1655 | "text": [
1656 | "epoch trn_loss val_loss my_eval \n",
1657 | " 0 0.871871 1.046633 0.740375 \n",
1658 | " 1 0.953284 1.179364 0.718875 \n",
1659 | " 2 0.927234 0.987735 0.73175 \n",
1660 | " 3 0.891905 0.919984 0.759875 \n",
1661 | " 4 0.824233 0.970662 0.76025 \n",
1662 | " 5 0.779335 0.885995 0.766125 \n",
1663 | " 6 0.750891 0.903214 0.774125 \n",
1664 | " 7 0.726601 0.843466 0.783625 \n",
1665 | " 8 0.684857 0.833532 0.789125 \n",
1666 | " 9 0.668015 0.820718 0.791125 \n",
1667 | " 10 0.638963 0.86629 0.77425 \n",
1668 | " 11 0.764584 1.057496 0.69975 \n",
1669 | " 12 0.819989 1.005252 0.747375 \n",
1670 | " 13 0.824506 0.927726 0.76375 \n",
1671 | " 14 0.803146 0.867914 0.77575 \n",
1672 | " 15 0.80264 0.916781 0.7715 \n",
1673 | " 16 0.771997 0.827193 0.786625 \n",
1674 | " 17 0.727303 0.858733 0.788625 \n",
1675 | " 18 0.686257 0.846492 0.779625 \n",
1676 | " 19 0.661692 0.800186 0.79375 \n",
1677 | "\n"
1678 | ]
1679 | },
1680 | {
1681 | "data": {
1682 | "application/vnd.jupyter.widget-view+json": {
1683 | "model_id": "55f8389c6cd8486ebb04a737a18b32c2",
1684 | "version_major": 2,
1685 | "version_minor": 0
1686 | },
1687 | "text/plain": [
1688 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1689 | ]
1690 | },
1691 | "metadata": {
1692 | "tags": []
1693 | },
1694 | "output_type": "display_data"
1695 | },
1696 | {
1697 | "name": "stdout",
1698 | "output_type": "stream",
1699 | "text": [
1700 | "epoch trn_loss val_loss my_eval \n",
1701 | " 0 0.680326 0.828174 0.7895 \n",
1702 | " 1 0.68143 0.84101 0.775875 \n",
1703 | " 2 0.716512 0.921755 0.780375 \n",
1704 | " 3 0.716995 0.865955 0.780875 \n",
1705 | " 4 0.733966 0.841387 0.754875 \n",
1706 | " 5 0.726975 0.876049 0.77775 \n",
1707 | " 6 0.702523 0.901705 0.77425 \n",
1708 | " 7 0.672928 0.878522 0.7805 \n",
1709 | " 8 0.668419 0.859559 0.7695 \n",
1710 | " 9 0.645093 0.857982 0.788 \n",
1711 | " 10 0.646567 0.948428 0.7555 \n",
1712 | " 11 0.659159 0.892936 0.783125 \n",
1713 | " 12 0.626061 0.893832 0.787125 \n",
1714 | " 13 0.633051 0.822176 0.787 \n",
1715 | " 14 0.60531 0.902848 0.788625 \n",
1716 | " 15 0.598807 0.832705 0.7875 \n",
1717 | " 16 0.571184 0.847312 0.787625 \n",
1718 | " 17 0.563408 0.871148 0.795625 \n",
1719 | " 18 0.571028 0.850332 0.791875 \n",
1720 | " 19 0.538525 0.868956 0.798375 \n",
1721 | "\n",
1722 | "0.7983750000000001\n",
1723 | "fold_id2\n",
1724 | "5_fold_resnext50_128_scse_hcol_2\n"
1725 | ]
1726 | },
1727 | {
1728 | "data": {
1729 | "application/vnd.jupyter.widget-view+json": {
1730 | "model_id": "a4c02576e160455984579f98cf21250c",
1731 | "version_major": 2,
1732 | "version_minor": 0
1733 | },
1734 | "text/plain": [
1735 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1736 | ]
1737 | },
1738 | "metadata": {
1739 | "tags": []
1740 | },
1741 | "output_type": "display_data"
1742 | },
1743 | {
1744 | "name": "stdout",
1745 | "output_type": "stream",
1746 | "text": [
1747 | "epoch trn_loss val_loss my_eval \n",
1748 | " 0 0.930581 1.076116 0.750875 \n",
1749 | " 1 0.977241 0.991857 0.748 \n",
1750 | " 2 0.963887 0.864013 0.77875 \n",
1751 | " 3 0.921879 0.839139 0.7825 \n",
1752 | " 4 0.891747 0.80363 0.787875 \n",
1753 | " 5 0.842367 0.788842 0.7915 \n",
1754 | " 6 0.809913 0.766017 0.80325 \n",
1755 | " 7 0.745459 0.74085 0.81025 \n",
1756 | " 8 0.717501 0.751182 0.80775 \n",
1757 | " 9 0.672494 0.721534 0.81775 \n",
1758 | " 10 0.687882 0.777983 0.802125 \n",
1759 | " 11 0.835816 0.843362 0.768375 \n",
1760 | " 12 0.861769 0.882111 0.774125 \n",
1761 | " 13 0.916062 0.933928 0.7635 \n",
1762 | " 14 0.860758 0.787877 0.783875 \n",
1763 | " 15 0.817398 0.785797 0.791625 \n",
1764 | " 16 0.799109 0.764803 0.792125 \n",
1765 | " 17 0.763636 0.776863 0.801875 \n",
1766 | " 18 0.734554 0.735275 0.809875 \n",
1767 | " 19 0.69847 0.73502 0.808125 \n",
1768 | "\n"
1769 | ]
1770 | },
1771 | {
1772 | "data": {
1773 | "application/vnd.jupyter.widget-view+json": {
1774 | "model_id": "794f42c7ffbb49a4a48d5599479b1550",
1775 | "version_major": 2,
1776 | "version_minor": 0
1777 | },
1778 | "text/plain": [
1779 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1780 | ]
1781 | },
1782 | "metadata": {
1783 | "tags": []
1784 | },
1785 | "output_type": "display_data"
1786 | },
1787 | {
1788 | "name": "stdout",
1789 | "output_type": "stream",
1790 | "text": [
1791 | "epoch trn_loss val_loss my_eval \n",
1792 | " 0 0.671766 0.759295 0.791375 \n",
1793 | " 1 0.731133 0.811743 0.789875 \n",
1794 | " 2 0.763409 0.791452 0.789625 \n",
1795 | " 3 0.774678 0.789636 0.792 \n",
1796 | " 4 0.758107 0.77808 0.772 \n",
1797 | " 5 0.735919 0.742705 0.796625 \n",
1798 | " 6 0.720245 0.771582 0.782125 \n",
1799 | " 7 0.710452 0.743036 0.7825 \n",
1800 | " 8 0.710112 0.777343 0.790875 \n",
1801 | " 9 0.736112 0.743986 0.7875 \n",
1802 | " 10 0.70377 0.773617 0.7935 \n",
1803 | " 11 0.66605 0.770097 0.803 \n",
1804 | " 12 0.662434 0.712938 0.808625 \n",
1805 | " 13 0.652016 0.743833 0.807625 \n",
1806 | " 14 0.649992 0.737581 0.811125 \n",
1807 | " 15 0.638489 0.707261 0.81575 \n",
1808 | " 16 0.596832 0.710357 0.818 \n",
1809 | " 17 0.595883 0.704559 0.804625 \n",
1810 | " 18 0.572071 0.705111 0.82025 \n",
1811 | " 19 0.563588 0.69612 0.82325 \n",
1812 | "\n",
1813 | "0.8232499999999999\n",
1814 | "fold_id3\n",
1815 | "5_fold_resnext50_128_scse_hcol_3\n"
1816 | ]
1817 | },
1818 | {
1819 | "data": {
1820 | "application/vnd.jupyter.widget-view+json": {
1821 | "model_id": "f4961bde0db542448bf37775dbc087d5",
1822 | "version_major": 2,
1823 | "version_minor": 0
1824 | },
1825 | "text/plain": [
1826 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1827 | ]
1828 | },
1829 | "metadata": {
1830 | "tags": []
1831 | },
1832 | "output_type": "display_data"
1833 | },
1834 | {
1835 | "name": "stdout",
1836 | "output_type": "stream",
1837 | "text": [
1838 | "epoch trn_loss val_loss my_eval \n",
1839 | " 0 0.910043 0.931082 0.774375 \n",
1840 | " 1 0.962166 0.841506 0.7835 \n",
1841 | " 2 0.966172 0.884748 0.76575 \n",
1842 | " 3 0.919205 0.797827 0.792 \n",
1843 | " 4 0.887789 0.838587 0.782125 \n",
1844 | " 5 0.833859 0.758443 0.79875 \n",
1845 | " 6 0.787326 0.748243 0.7995 \n",
1846 | " 7 0.768268 0.734664 0.80275 \n",
1847 | " 8 0.721092 0.701125 0.81575 \n",
1848 | " 9 0.678771 0.707262 0.817625 \n",
1849 | " 10 0.686963 0.761993 0.807 \n",
1850 | " 11 0.797779 0.845824 0.78325 \n",
1851 | " 12 0.836203 0.781706 0.79575 \n",
1852 | " 13 0.807488 0.871408 0.7955 \n",
1853 | " 14 0.786293 0.773449 0.801375 \n",
1854 | " 15 0.779008 0.795548 0.800875 \n",
1855 | " 16 0.756668 0.742518 0.80525 \n",
1856 | " 17 0.705088 0.703899 0.81325 \n",
1857 | " 18 0.677209 0.693749 0.819125 \n",
1858 | " 19 0.639317 0.679222 0.82 \n",
1859 | "\n"
1860 | ]
1861 | },
1862 | {
1863 | "data": {
1864 | "application/vnd.jupyter.widget-view+json": {
1865 | "model_id": "46e4957616e84cd499a4f53b689cc927",
1866 | "version_major": 2,
1867 | "version_minor": 0
1868 | },
1869 | "text/plain": [
1870 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1871 | ]
1872 | },
1873 | "metadata": {
1874 | "tags": []
1875 | },
1876 | "output_type": "display_data"
1877 | },
1878 | {
1879 | "name": "stdout",
1880 | "output_type": "stream",
1881 | "text": [
1882 | "epoch trn_loss val_loss my_eval \n",
1883 | " 0 0.653376 0.719673 0.812375 \n",
1884 | " 1 0.701078 0.751476 0.80475 \n",
1885 | " 2 0.728884 0.854194 0.79975 \n",
1886 | " 3 0.752242 0.724502 0.8095 \n",
1887 | " 4 0.704703 0.742548 0.815 \n",
1888 | " 5 0.716587 0.734281 0.811875 \n",
1889 | " 6 0.683265 0.743123 0.806875 \n",
1890 | " 7 0.678836 0.815633 0.793625 \n",
1891 | " 8 0.689473 0.700409 0.81575 \n",
1892 | " 9 0.665838 0.669478 0.82075 \n",
1893 | " 10 0.633243 0.730008 0.816125 \n",
1894 | " 11 0.655839 0.701937 0.816125 \n",
1895 | " 12 0.632317 0.649175 0.82675 \n",
1896 | " 13 0.598759 0.674421 0.83225 \n",
1897 | " 14 0.596253 0.666588 0.829875 \n",
1898 | " 15 0.586747 0.670523 0.827125 \n",
1899 | " 16 0.567371 0.686732 0.824375 \n",
1900 | " 17 0.56206 0.673798 0.8225 \n",
1901 | " 18 0.545179 0.67912 0.82925 \n",
1902 | " 19 0.526347 0.698485 0.827 \n",
1903 | "\n",
1904 | "0.8322499999999999\n",
1905 | "fold_id4\n",
1906 | "5_fold_resnext50_128_scse_hcol_4\n"
1907 | ]
1908 | },
1909 | {
1910 | "data": {
1911 | "application/vnd.jupyter.widget-view+json": {
1912 | "model_id": "161724e6ee244e4c9dc4fc87b9a7c5f4",
1913 | "version_major": 2,
1914 | "version_minor": 0
1915 | },
1916 | "text/plain": [
1917 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1918 | ]
1919 | },
1920 | "metadata": {
1921 | "tags": []
1922 | },
1923 | "output_type": "display_data"
1924 | },
1925 | {
1926 | "name": "stdout",
1927 | "output_type": "stream",
1928 | "text": [
1929 | "epoch trn_loss val_loss my_eval \n",
1930 | " 0 0.92871 0.959824 0.784125 \n",
1931 | " 1 0.949663 0.872785 0.776375 \n",
1932 | " 2 0.909849 0.817117 0.790375 \n",
1933 | " 3 0.896153 0.804488 0.794125 \n",
1934 | " 4 0.83174 0.835692 0.79425 \n",
1935 | " 5 0.814119 0.780774 0.797125 \n",
1936 | " 6 0.755571 0.759111 0.801875 \n",
1937 | " 7 0.729774 0.78065 0.808 \n",
1938 | " 8 0.674973 0.741045 0.813375 \n",
1939 | " 9 0.660971 0.72706 0.8115 \n",
1940 | " 10 0.689892 0.819716 0.786375 \n",
1941 | " 11 0.789879 0.855918 0.789375 \n",
1942 | " 12 0.783312 0.796657 0.79 \n",
1943 | " 13 0.781245 0.913932 0.78375 \n",
1944 | " 14 0.787869 0.830769 0.782625 \n",
1945 | " 15 0.799666 0.75475 0.805375 \n",
1946 | " 16 0.750443 0.712465 0.81 \n",
1947 | " 17 0.708587 0.7156 0.812875 \n",
1948 | " 18 0.680404 0.758386 0.805625 \n",
1949 | " 19 0.647 0.703068 0.813125 \n",
1950 | "\n"
1951 | ]
1952 | },
1953 | {
1954 | "data": {
1955 | "application/vnd.jupyter.widget-view+json": {
1956 | "model_id": "882d2ce3e11448a7897138b5879d9d77",
1957 | "version_major": 2,
1958 | "version_minor": 0
1959 | },
1960 | "text/plain": [
1961 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
1962 | ]
1963 | },
1964 | "metadata": {
1965 | "tags": []
1966 | },
1967 | "output_type": "display_data"
1968 | },
1969 | {
1970 | "name": "stdout",
1971 | "output_type": "stream",
1972 | "text": [
1973 | "epoch trn_loss val_loss my_eval \n",
1974 | " 0 0.625973 0.726331 0.820375 \n",
1975 | " 1 0.670526 0.805685 0.80375 \n",
1976 | " 2 0.701503 0.781464 0.788875 \n",
1977 | " 3 0.69349 0.791473 0.801625 \n",
1978 | " 4 0.682512 0.745718 0.81325 \n",
1979 | " 5 0.69381 0.70612 0.806625 \n",
1980 | " 6 0.665137 0.696765 0.816375 \n",
1981 | " 7 0.641031 0.729714 0.8195 \n",
1982 | " 8 0.632783 0.706463 0.80375 \n",
1983 | " 9 0.634154 0.687877 0.81325 \n",
1984 | " 10 0.626251 0.690043 0.809375 \n",
1985 | " 11 0.591347 0.709688 0.815125 \n",
1986 | " 12 0.575432 0.702248 0.816625 \n",
1987 | " 13 0.587349 0.694433 0.81625 \n",
1988 | " 14 0.573425 0.694701 0.820375 \n",
1989 | " 15 0.584016 0.66213 0.8265 \n",
1990 | " 16 0.547527 0.703967 0.824625 \n",
1991 | " 17 0.548586 0.674334 0.82825 \n",
1992 | " 18 0.534016 0.680121 0.829125 \n",
1993 | " 19 0.507791 0.67429 0.823375 \n",
1994 | "\n",
1995 | "0.8291249999999999\n"
1996 | ]
1997 | }
1998 | ],
1999 | "source": [
2000 | "model = 'resnext50_128_scse_hcol'\n",
2001 | "arch = resnext50\n",
2002 | "bst_acc=[]\n",
2003 | "use_clr_min=20\n",
2004 | "use_clr_div=10\n",
2005 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
2006 | "\n",
2007 | "szs = [(128,64)]\n",
2008 | "for sz,bs in szs:\n",
2009 | " print([sz,bs])\n",
2010 | " for i in range(kf) :\n",
2011 | " print(f'fold_id{i}')\n",
2012 | " \n",
2013 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
2014 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
2015 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
2016 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
2017 | " \n",
2018 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
2019 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
2020 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
2021 | " denorm = md.trn_ds.denorm\n",
2022 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
2023 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
2024 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
2025 | " learn = get_tgs_model() \n",
2026 | " learn.opt_fn = optim.Adam\n",
2027 | " learn.metrics=[my_eval]\n",
2028 | " pa = f'{kf}_fold_{model}_{i}'\n",
2029 | " print(pa)\n",
2030 | "\n",
2031 | " lr=1e-2\n",
2032 | " wd=1e-7\n",
2033 | " lrs = np.array([lr/100,lr/10,lr])\n",
2034 | "\n",
2035 | " learn.unfreeze() \n",
2036 | " learn.crit = lovasz_hinge\n",
2037 | " learn.load(pa)\n",
2038 | " learn.fit(lrs/2,2, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
2039 | " learn.fit(lrs/3,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n",
2040 | " \n",
2041 | " learn.load(pa) \n",
2042 | " #Calcuating mean iou score\n",
2043 | " v_targ = md.val_ds.ds[:][1]\n",
2044 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
2045 | " v_pred = learn.predict()\n",
2046 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
2047 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
2048 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
2049 | " print(bst_acc[-1])\n",
2050 | " \n",
2051 | " del learn\n",
2052 | " del md\n",
2053 | " gc.collect()\n",
2054 | "\n",
2055 | "save_model_weights()"
2056 | ]
2057 | },
2058 | {
2059 | "cell_type": "code",
2060 | "execution_count": 0,
2061 | "metadata": {
2062 | "colab": {
2063 | "base_uri": "https://localhost:8080/",
2064 | "height": 3898
2065 | },
2066 | "colab_type": "code",
2067 | "id": "mFTMgQkX6VtI",
2068 | "outputId": "9c990f72-dbe6-4ece-96f6-97e445e91f25"
2069 | },
2070 | "outputs": [
2071 | {
2072 | "name": "stdout",
2073 | "output_type": "stream",
2074 | "text": [
2075 | "[128, 64]\n",
2076 | "fold_id0\n",
2077 | "5_fold_resnext50_128_scse_hcol_0\n"
2078 | ]
2079 | },
2080 | {
2081 | "data": {
2082 | "application/vnd.jupyter.widget-view+json": {
2083 | "model_id": "2c64641c0ef945a182d60ab806fcfdc1",
2084 | "version_major": 2,
2085 | "version_minor": 0
2086 | },
2087 | "text/plain": [
2088 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
2089 | ]
2090 | },
2091 | "metadata": {
2092 | "tags": []
2093 | },
2094 | "output_type": "display_data"
2095 | },
2096 | {
2097 | "name": "stdout",
2098 | "output_type": "stream",
2099 | "text": [
2100 | "epoch trn_loss val_loss my_eval \n",
2101 | " 0 0.580299 0.837661 0.80625 \n",
2102 | " 1 0.628745 0.803514 0.798375 \n",
2103 | " 2 0.619702 0.796008 0.80575 \n",
2104 | " 3 0.628366 0.750468 0.81175 \n",
2105 | " 4 0.623308 0.748861 0.802 \n",
2106 | " 5 0.650031 0.795647 0.803 \n",
2107 | " 6 0.646449 0.763365 0.81125 \n",
2108 | " 7 0.636074 0.781185 0.811375 \n",
2109 | " 8 0.623231 0.742726 0.809875 \n",
2110 | " 9 0.602579 0.753169 0.813375 \n",
2111 | " 10 0.585192 0.776042 0.817375 \n",
2112 | " 11 0.576949 0.790247 0.8175 \n",
2113 | " 12 0.562295 0.796752 0.813875 \n",
2114 | " 13 0.56244 0.795689 0.788875 \n",
2115 | " 14 0.575644 0.806945 0.8145 \n",
2116 | " 15 0.58317 0.765245 0.81525 \n",
2117 | " 16 0.560825 0.782946 0.805875 \n",
2118 | " 17 0.562333 0.797728 0.79975 \n",
2119 | " 18 0.562816 0.761856 0.816375 \n",
2120 | " 19 0.553854 0.775206 0.8135 \n",
2121 | " 20 0.546581 0.714756 0.82125 \n",
2122 | " 21 0.541287 0.770479 0.8095 \n",
2123 | " 22 0.518532 0.805129 0.813625 \n",
2124 | " 23 0.521006 0.732549 0.806625 \n",
2125 | " 24 0.495842 0.853394 0.811375 \n",
2126 | " 25 0.493223 0.838361 0.81125 \n",
2127 | " 26 0.482846 0.831322 0.814375 \n",
2128 | " 27 0.483749 0.788495 0.820125 \n",
2129 | " 28 0.496672 0.792871 0.82125 \n",
2130 | " 29 0.497796 0.762518 0.823875 \n",
2131 | " 30 0.49775 0.79859 0.81975 \n",
2132 | " 31 0.491748 0.779934 0.822375 \n",
2133 | " 32 0.475641 0.790661 0.816625 \n",
2134 | " 33 0.4794 0.78795 0.82175 \n",
2135 | " 34 0.475329 0.764116 0.81825 \n",
2136 | " 35 0.455563 0.822148 0.815 \n",
2137 | " 36 0.451707 0.795527 0.821375 \n",
2138 | " 37 0.444007 0.772276 0.8155 \n",
2139 | " 38 0.435822 0.779683 0.817375 \n",
2140 | " 39 0.445946 0.780952 0.814 \n",
2141 | "\n",
2142 | "0.823875\n",
2143 | "fold_id1\n",
2144 | "5_fold_resnext50_128_scse_hcol_1\n"
2145 | ]
2146 | },
2147 | {
2148 | "data": {
2149 | "application/vnd.jupyter.widget-view+json": {
2150 | "model_id": "c2faf773816a409e9d909ded6b7f2313",
2151 | "version_major": 2,
2152 | "version_minor": 0
2153 | },
2154 | "text/plain": [
2155 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
2156 | ]
2157 | },
2158 | "metadata": {
2159 | "tags": []
2160 | },
2161 | "output_type": "display_data"
2162 | },
2163 | {
2164 | "name": "stdout",
2165 | "output_type": "stream",
2166 | "text": [
2167 | "epoch trn_loss val_loss my_eval \n",
2168 | " 0 0.565228 0.871326 0.785875 \n",
2169 | " 1 0.601723 0.882529 0.79425 \n",
2170 | " 2 0.642919 0.897614 0.78075 \n",
2171 | " 3 0.637186 0.878214 0.777875 \n",
2172 | " 4 0.609784 0.895229 0.79075 \n",
2173 | " 5 0.609371 0.926091 0.7865 \n",
2174 | " 6 0.601144 0.836737 0.788125 \n",
2175 | " 7 0.630633 0.871732 0.7835 \n",
2176 | " 8 0.611182 0.890666 0.7895 \n",
2177 | " 9 0.592919 0.831647 0.80075 \n",
2178 | " 10 0.567077 0.936684 0.786875 \n",
2179 | " 11 0.553323 0.966901 0.78875 \n",
2180 | " 12 0.559024 0.944055 0.77 \n",
2181 | " 13 0.590418 0.850013 0.796 \n",
2182 | " 14 0.573935 0.856166 0.7945 \n",
2183 | " 15 0.567616 0.860534 0.79075 \n",
2184 | " 16 0.569173 0.960924 0.781 \n",
2185 | " 17 0.558296 0.893143 0.791 \n",
2186 | " 18 0.542129 0.934519 0.794125 \n",
2187 | " 19 0.547132 0.855531 0.80075 \n",
2188 | " 20 0.53702 0.854191 0.796875 \n",
2189 | " 21 0.538296 0.912877 0.787625 \n",
2190 | " 22 0.526649 0.917397 0.800375 \n",
2191 | " 23 0.502509 0.896211 0.7885 \n",
2192 | " 24 0.521145 0.898083 0.796375 \n",
2193 | " 25 0.491927 0.923525 0.802625 \n",
2194 | " 26 0.501542 0.864817 0.794 \n",
2195 | " 27 0.494032 0.884823 0.796375 \n",
2196 | " 28 0.490451 0.899924 0.795875 \n",
2197 | " 29 0.481898 0.880166 0.784 \n",
2198 | " 30 0.472968 0.895765 0.797375 \n",
2199 | " 31 0.462385 0.878445 0.793625 \n",
2200 | " 32 0.455645 0.896859 0.795 \n",
2201 | " 33 0.469805 0.894238 0.7995 \n",
2202 | " 34 0.455159 0.913979 0.801375 \n",
2203 | " 35 0.457995 0.898335 0.8015 \n",
2204 | " 36 0.453348 0.900837 0.804875 \n",
2205 | " 37 0.446832 0.902242 0.793 \n",
2206 | " 38 0.453402 0.907094 0.801375 \n",
2207 | " 39 0.434806 0.921633 0.80875 \n",
2208 | "\n",
2209 | "0.80875\n",
2210 | "fold_id2\n",
2211 | "5_fold_resnext50_128_scse_hcol_2\n"
2212 | ]
2213 | },
2214 | {
2215 | "data": {
2216 | "application/vnd.jupyter.widget-view+json": {
2217 | "model_id": "2914de385caf4b7d879001cf1ca135af",
2218 | "version_major": 2,
2219 | "version_minor": 0
2220 | },
2221 | "text/plain": [
2222 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
2223 | ]
2224 | },
2225 | "metadata": {
2226 | "tags": []
2227 | },
2228 | "output_type": "display_data"
2229 | },
2230 | {
2231 | "name": "stdout",
2232 | "output_type": "stream",
2233 | "text": [
2234 | "epoch trn_loss val_loss my_eval \n",
2235 | " 0 0.577075 0.786943 0.80825 \n",
2236 | " 1 0.609053 0.768417 0.804375 \n",
2237 | " 2 0.656391 0.73859 0.812125 \n",
2238 | " 3 0.652841 0.757531 0.797 \n",
2239 | " 4 0.655581 0.736351 0.809875 \n",
2240 | " 5 0.623059 0.69868 0.812375 \n",
2241 | " 6 0.627167 0.77681 0.7915 \n",
2242 | " 7 0.623288 0.750636 0.802125 \n",
2243 | " 8 0.620909 0.76361 0.807125 \n",
2244 | " 9 0.628333 0.755546 0.793625 \n",
2245 | " 10 0.63122 0.744164 0.79475 \n",
2246 | " 11 0.63828 0.720994 0.8095 \n",
2247 | " 12 0.607952 0.712268 0.804125 \n",
2248 | " 13 0.588277 0.729939 0.80725 \n",
2249 | " 14 0.585946 0.705454 0.8255 \n",
2250 | " 15 0.58386 0.691177 0.81425 \n",
2251 | " 16 0.561116 0.713686 0.812625 \n",
2252 | " 17 0.553486 0.771892 0.818875 \n",
2253 | " 18 0.553116 0.707916 0.817875 \n",
2254 | " 19 0.565598 0.728426 0.815625 \n",
2255 | " 20 0.554347 0.766951 0.819625 \n",
2256 | " 21 0.536056 0.723741 0.820125 \n",
2257 | " 22 0.536135 0.724242 0.821625 \n",
2258 | " 23 0.554331 0.767052 0.8075 \n",
2259 | " 24 0.550846 0.718899 0.82025 \n",
2260 | " 25 0.539729 0.703877 0.81175 \n",
2261 | " 26 0.517849 0.694621 0.820125 \n",
2262 | " 27 0.504886 0.708674 0.821625 \n",
2263 | " 28 0.502828 0.726042 0.82175 \n",
2264 | " 29 0.498927 0.701111 0.81025 \n",
2265 | " 30 0.495155 0.715683 0.8185 \n",
2266 | " 31 0.507423 0.730976 0.82175 \n",
2267 | " 32 0.51031 0.709125 0.814125 \n",
2268 | " 33 0.504614 0.7094 0.81775 \n",
2269 | " 34 0.500099 0.694097 0.825375 \n",
2270 | " 35 0.482972 0.689517 0.824625 \n",
2271 | " 36 0.484635 0.679306 0.82225 \n",
2272 | " 37 0.481951 0.681534 0.828 \n",
2273 | " 38 0.475027 0.686403 0.825625 \n",
2274 | " 39 0.483564 0.695852 0.82175 \n",
2275 | "\n",
2276 | "0.828\n",
2277 | "fold_id3\n",
2278 | "5_fold_resnext50_128_scse_hcol_3\n"
2279 | ]
2280 | },
2281 | {
2282 | "data": {
2283 | "application/vnd.jupyter.widget-view+json": {
2284 | "model_id": "019ac79c6d2e4e56ac325d5d3dd9a97a",
2285 | "version_major": 2,
2286 | "version_minor": 0
2287 | },
2288 | "text/plain": [
2289 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
2290 | ]
2291 | },
2292 | "metadata": {
2293 | "tags": []
2294 | },
2295 | "output_type": "display_data"
2296 | },
2297 | {
2298 | "name": "stdout",
2299 | "output_type": "stream",
2300 | "text": [
2301 | "epoch trn_loss val_loss my_eval \n",
2302 | " 0 0.575313 0.695703 0.823625 \n",
2303 | " 1 0.624102 0.780891 0.813625 \n",
2304 | " 2 0.648184 0.697116 0.81825 \n",
2305 | " 3 0.651478 0.729888 0.80325 \n",
2306 | " 4 0.633732 0.689429 0.816875 \n",
2307 | " 5 0.6391 0.707463 0.821875 \n",
2308 | " 6 0.608808 0.734975 0.821 \n",
2309 | " 7 0.583592 0.727842 0.8085 \n",
2310 | " 8 0.607585 0.754664 0.806125 \n",
2311 | " 9 0.602211 0.745517 0.82925 \n",
2312 | " 10 0.58275 0.690118 0.826 \n",
2313 | " 11 0.577038 0.690472 0.816875 \n",
2314 | " 12 0.585021 0.703931 0.820875 \n",
2315 | " 13 0.573149 0.737506 0.822875 \n",
2316 | " 14 0.593045 0.701834 0.823125 \n",
2317 | " 15 0.576374 0.722945 0.8195 \n",
2318 | " 16 0.578671 0.718012 0.824125 \n",
2319 | " 17 0.576396 0.737476 0.825 \n",
2320 | " 18 0.562349 0.740334 0.82075 \n",
2321 | " 19 0.562596 0.703988 0.83375 \n",
2322 | " 20 0.545302 0.742709 0.824875 \n",
2323 | " 21 0.536242 0.730253 0.821375 \n",
2324 | " 22 0.538357 0.726912 0.834375 \n",
2325 | " 23 0.524991 0.705785 0.82975 \n",
2326 | " 24 0.539796 0.718887 0.821625 \n",
2327 | " 25 0.528741 0.74735 0.825125 \n",
2328 | " 26 0.52838 0.697965 0.833875 \n",
2329 | " 27 0.515824 0.80159 0.82975 \n",
2330 | " 28 0.519962 0.729874 0.813625 \n",
2331 | " 29 0.506576 0.729921 0.8205 \n",
2332 | " 30 0.488176 0.70798 0.823 \n",
2333 | " 31 0.487228 0.747718 0.824625 \n",
2334 | " 32 0.477002 0.720598 0.829625 \n",
2335 | " 33 0.468923 0.748624 0.833625 \n",
2336 | " 34 0.487961 0.757478 0.833125 \n",
2337 | " 35 0.470121 0.742304 0.813875 \n",
2338 | " 36 0.48037 0.745087 0.831875 \n",
2339 | " 37 0.476722 0.772283 0.8305 \n",
2340 | " 38 0.456916 0.735131 0.8195 \n",
2341 | " 39 0.461469 0.748882 0.8215 \n",
2342 | "\n",
2343 | "0.834375\n",
2344 | "fold_id4\n",
2345 | "5_fold_resnext50_128_scse_hcol_4\n"
2346 | ]
2347 | },
2348 | {
2349 | "data": {
2350 | "application/vnd.jupyter.widget-view+json": {
2351 | "model_id": "d6269952b5694a12a836f74645dd9bbe",
2352 | "version_major": 2,
2353 | "version_minor": 0
2354 | },
2355 | "text/plain": [
2356 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
2357 | ]
2358 | },
2359 | "metadata": {
2360 | "tags": []
2361 | },
2362 | "output_type": "display_data"
2363 | },
2364 | {
2365 | "name": "stdout",
2366 | "output_type": "stream",
2367 | "text": [
2368 | "epoch trn_loss val_loss my_eval \n",
2369 | " 0 0.546006 0.717454 0.8235 \n",
2370 | " 1 0.563938 0.751986 0.812375 \n",
2371 | " 2 0.598719 0.75113 0.815 \n",
2372 | " 3 0.588652 0.798206 0.822 \n",
2373 | " 4 0.599703 0.805366 0.8095 \n",
2374 | " 5 0.591598 0.734947 0.82075 \n",
2375 | " 6 0.596523 0.733508 0.811375 \n",
2376 | " 7 0.597403 0.742025 0.816125 \n",
2377 | " 8 0.588319 0.729897 0.80425 \n",
2378 | " 9 0.57466 0.743337 0.808375 \n",
2379 | " 10 0.590704 0.786838 0.8205 \n",
2380 | " 11 0.557619 0.687391 0.81425 \n",
2381 | " 12 0.539545 0.704461 0.82875 \n",
2382 | " 13 0.543599 0.810272 0.820625 \n",
2383 | " 14 0.549854 0.70434 0.819 \n",
2384 | " 15 0.541127 0.726442 0.820625 \n",
2385 | " 16 0.547006 0.710707 0.805625 \n",
2386 | " 17 0.556314 0.704044 0.825 \n",
2387 | " 18 0.533182 0.687741 0.81375 \n",
2388 | " 19 0.511729 0.7148 0.834 \n",
2389 | " 20 0.509799 0.690751 0.8375 \n",
2390 | " 21 0.496054 0.709561 0.83325 \n",
2391 | " 22 0.499565 0.745653 0.8235 \n",
2392 | " 23 0.506426 0.723755 0.824375 \n",
2393 | " 24 0.487424 0.73846 0.828125 \n",
2394 | " 25 0.485281 0.725011 0.822 \n",
2395 | " 26 0.466737 0.791306 0.827125 \n",
2396 | " 27 0.46954 0.759082 0.827625 \n",
2397 | " 28 0.479826 0.683398 0.8315 \n",
2398 | " 29 0.463535 0.724489 0.831 \n",
2399 | " 30 0.460917 0.712501 0.831 \n",
2400 | " 31 0.441873 0.723946 0.829 \n",
2401 | " 32 0.435788 0.724188 0.82975 \n",
2402 | " 33 0.438317 0.718923 0.821375 \n",
2403 | " 34 0.442253 0.724379 0.82475 \n",
2404 | " 35 0.425489 0.734597 0.824125 \n",
2405 | " 36 0.44013 0.728877 0.820125 \n",
2406 | " 37 0.443834 0.744366 0.828 \n",
2407 | " 38 0.426803 0.729525 0.827875 \n",
2408 | " 39 0.435825 0.735564 0.822 \n",
2409 | "\n",
2410 | "0.8375\n"
2411 | ]
2412 | }
2413 | ],
2414 | "source": [
2415 | "model = 'resnext50_128_scse_hcol'\n",
2416 | "arch = resnext50\n",
2417 | "bst_acc=[]\n",
2418 | "use_clr_min=20\n",
2419 | "use_clr_div=10\n",
2420 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
2421 | "\n",
2422 | "szs = [(128,64)]\n",
2423 | "for sz,bs in szs:\n",
2424 | " print([sz,bs])\n",
2425 | " for i in range(kf) :\n",
2426 | " print(f'fold_id{i}')\n",
2427 | " \n",
2428 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
2429 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
2430 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
2431 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
2432 | " \n",
2433 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
2434 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
2435 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
2436 | " denorm = md.trn_ds.denorm\n",
2437 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
2438 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
2439 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
2440 | " learn = get_tgs_model() \n",
2441 | " learn.opt_fn = optim.Adam\n",
2442 | " learn.metrics=[my_eval]\n",
2443 | " pa = f'{kf}_fold_{model}_{i}'\n",
2444 | " print(pa)\n",
2445 | "\n",
2446 | " lr=1e-2\n",
2447 | " wd=1e-7\n",
2448 | " lrs = np.array([lr/100,lr/10,lr])\n",
2449 | "\n",
2450 | " learn.unfreeze() \n",
2451 | " learn.crit = lovasz_hinge\n",
2452 | " learn.load(pa)\n",
2453 | "# learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n",
2454 | " learn.fit(lrs/4,1, wds=wd, cycle_len=40,use_clr=(20,32),best_save_name=pa)\n",
2455 | " \n",
2456 | " learn.load(pa) \n",
2457 | " #Calcuating mean iou score\n",
2458 | " v_targ = md.val_ds.ds[:][1]\n",
2459 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
2460 | " v_pred = learn.predict()\n",
2461 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
2462 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
2463 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
2464 | " print(bst_acc[-1])\n",
2465 | " \n",
2466 | " del learn\n",
2467 | " del md\n",
2468 | " gc.collect()\n",
2469 | "\n",
2470 | "save_model_weights()"
2471 | ]
2472 | },
2473 | {
2474 | "cell_type": "code",
2475 | "execution_count": 0,
2476 | "metadata": {
2477 | "colab": {
2478 | "base_uri": "https://localhost:8080/",
2479 | "height": 1996
2480 | },
2481 | "colab_type": "code",
2482 | "id": "H7YZnN_HiiuR",
2483 | "outputId": "ee40d675-4739-4c54-f4dd-a0a55119aad3"
2484 | },
2485 | "outputs": [
2486 | {
2487 | "name": "stdout",
2488 | "output_type": "stream",
2489 | "text": [
2490 | "[128, 64]\n",
2491 | "fold_id0\n",
2492 | "5_fold_resnext50_128_scse_hcol_0\n"
2493 | ]
2494 | },
2495 | {
2496 | "data": {
2497 | "application/vnd.jupyter.widget-view+json": {
2498 | "model_id": "56e79297e25847f0a15aff07aa418769",
2499 | "version_major": 2,
2500 | "version_minor": 0
2501 | },
2502 | "text/plain": [
2503 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
2504 | ]
2505 | },
2506 | "metadata": {
2507 | "tags": []
2508 | },
2509 | "output_type": "display_data"
2510 | },
2511 | {
2512 | "name": "stdout",
2513 | "output_type": "stream",
2514 | "text": [
2515 | "epoch trn_loss val_loss my_eval \n",
2516 | " 0 0.464316 0.82814 0.816375 \n",
2517 | " 1 0.503966 0.973039 0.808125 \n",
2518 | " 2 0.524512 0.875324 0.802125 \n",
2519 | " 3 0.552585 0.802356 0.8155 \n",
2520 | " 4 0.524045 0.898251 0.8165 \n",
2521 | " 5 0.519004 0.884085 0.815 \n",
2522 | " 6 0.512701 0.761209 0.816625 \n",
2523 | " 7 0.520394 0.781913 0.80575 \n",
2524 | " 8 0.501048 0.748156 0.815875 \n",
2525 | " 9 0.504658 0.800892 0.808875 \n",
2526 | " 10 0.491141 0.764409 0.813375 \n",
2527 | " 11 0.480288 0.845715 0.823 \n",
2528 | " 12 0.482994 0.864056 0.817125 \n",
2529 | " 13 0.474893 0.844127 0.818 \n",
2530 | " 14 0.475831 0.82016 0.822375 \n",
2531 | " 15 0.462766 0.80165 0.816 \n",
2532 | " 16 0.459729 0.801812 0.822625 \n",
2533 | " 17 0.452857 0.825043 0.821125 \n",
2534 | " 18 0.432667 0.798102 0.819125 \n",
2535 | " 19 0.448399 0.799831 0.82425 \n",
2536 | "\n",
2537 | "0.8242499999999999\n",
2538 | "fold_id1\n",
2539 | "5_fold_resnext50_128_scse_hcol_1\n"
2540 | ]
2541 | },
2542 | {
2543 | "data": {
2544 | "application/vnd.jupyter.widget-view+json": {
2545 | "model_id": "8aea5543371447dbafac57bda9d368dd",
2546 | "version_major": 2,
2547 | "version_minor": 0
2548 | },
2549 | "text/plain": [
2550 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
2551 | ]
2552 | },
2553 | "metadata": {
2554 | "tags": []
2555 | },
2556 | "output_type": "display_data"
2557 | },
2558 | {
2559 | "name": "stdout",
2560 | "output_type": "stream",
2561 | "text": [
2562 | "epoch trn_loss val_loss my_eval \n",
2563 | " 0 0.460871 0.923814 0.807875 \n",
2564 | " 1 0.463133 0.902192 0.804 \n",
2565 | " 2 0.487417 0.895001 0.79675 \n",
2566 | " 3 0.49072 0.918243 0.7825 \n",
2567 | " 4 0.51323 0.852118 0.80025 \n",
2568 | " 5 0.496024 0.872492 0.789875 \n",
2569 | " 6 0.506847 0.856714 0.802125 \n",
2570 | " 7 0.506169 0.917013 0.785 \n",
2571 | " 8 0.480601 0.872236 0.7935 \n",
2572 | " 9 0.478031 0.859447 0.79525 \n",
2573 | " 10 0.465379 0.893305 0.79625 \n",
2574 | " 11 0.455413 0.920751 0.802125 \n",
2575 | " 12 0.450254 0.921361 0.797125 \n",
2576 | " 13 0.437148 0.914842 0.802 \n",
2577 | " 14 0.445062 0.855681 0.797125 \n",
2578 | " 15 0.438345 0.867074 0.79825 \n",
2579 | " 16 0.445464 0.852298 0.80075 \n",
2580 | " 17 0.439749 0.874568 0.79525 \n",
2581 | " 18 0.434257 0.861378 0.809375 \n",
2582 | " 19 0.41873 0.868665 0.806125 \n",
2583 | "\n",
2584 | "0.809375\n",
2585 | "fold_id2\n",
2586 | "5_fold_resnext50_128_scse_hcol_2\n"
2587 | ]
2588 | },
2589 | {
2590 | "data": {
2591 | "application/vnd.jupyter.widget-view+json": {
2592 | "model_id": "758e5881ec094853afe6f93c9588adcd",
2593 | "version_major": 2,
2594 | "version_minor": 0
2595 | },
2596 | "text/plain": [
2597 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
2598 | ]
2599 | },
2600 | "metadata": {
2601 | "tags": []
2602 | },
2603 | "output_type": "display_data"
2604 | },
2605 | {
2606 | "name": "stdout",
2607 | "output_type": "stream",
2608 | "text": [
2609 | "epoch trn_loss val_loss my_eval \n",
2610 | " 0 0.46378 0.7404 0.81825 \n",
2611 | " 1 0.523088 0.696791 0.81175 \n",
2612 | " 2 0.561582 0.741509 0.806625 \n",
2613 | " 3 0.559633 0.743049 0.79175 \n",
2614 | " 70%|███████ | 35/50 [00:57<00:22, 1.49s/it, loss=0.537]"
2615 | ]
2616 | },
2617 | {
2618 | "ename": "KeyboardInterrupt",
2619 | "evalue": "ignored",
2620 | "output_type": "error",
2621 | "traceback": [
2622 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
2623 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
2624 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcycle_len\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0muse_clr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbest_save_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
2625 | "\u001b[0;32m/content/src/fastai/old/fastai/learner.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, lrs, n_cycle, wds, **kwargs)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0mlayer_opt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_gen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer_opt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_cycle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwarm_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
2626 | "\u001b[0;32m/content/src/fastai/old/fastai/learner.py\u001b[0m in \u001b[0;36mfit_gen\u001b[0;34m(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, use_clr_beta, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, use_swa, swa_start, swa_eval_freq, **kwargs)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreg_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp16\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 248\u001b[0m \u001b[0mswa_model\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mswa_model\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0muse_swa\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mswa_start\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mswa_start\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 249\u001b[0;31m swa_eval_freq=swa_eval_freq, **kwargs)\n\u001b[0m\u001b[1;32m 250\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_layer_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
2627 | "\u001b[0;32m/content/src/fastai/old/fastai/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, data, n_epochs, opt, crit, metrics, callbacks, stepper, swa_model, swa_start, swa_eval_freq, visualize, **kwargs)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0mbatch_num\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_stepper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mavg_mom\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mavg_mom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mdebias_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mavg_mom\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mbatch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
2628 | "\u001b[0;32m/content/src/fastai/old/fastai/model.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, xs, y, epoch)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_scale\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_scale\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxtra\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mupdate_fp32_grads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp32_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_scale\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
2629 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \"\"\"\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
2630 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 88\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 89\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
2631 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
2632 | ]
2633 | }
2634 | ],
2635 | "source": [
2636 | "model = 'resnext50_128_scse_hcol'\n",
2637 | "arch = resnext50\n",
2638 | "bst_acc=[]\n",
2639 | "use_clr_min=20\n",
2640 | "use_clr_div=10\n",
2641 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
2642 | "\n",
2643 | "szs = [(128,64)]\n",
2644 | "for sz,bs in szs:\n",
2645 | " print([sz,bs])\n",
2646 | " for i in range(kf) :\n",
2647 | " print(f'fold_id{i}')\n",
2648 | " \n",
2649 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
2650 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
2651 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
2652 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
2653 | " \n",
2654 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
2655 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
2656 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
2657 | " denorm = md.trn_ds.denorm\n",
2658 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
2659 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
2660 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
2661 | " learn = get_tgs_model() \n",
2662 | " learn.opt_fn = optim.Adam\n",
2663 | " learn.metrics=[my_eval]\n",
2664 | " pa = f'{kf}_fold_{model}_{i}'\n",
2665 | " print(pa)\n",
2666 | "\n",
2667 | " lr=1e-2\n",
2668 | " wd=1e-7\n",
2669 | " lrs = np.array([lr/100,lr/10,lr])\n",
2670 | "\n",
2671 | " learn.unfreeze() \n",
2672 | " learn.crit = lovasz_hinge\n",
2673 | " learn.load(pa)\n",
2674 | "# learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n",
2675 | " learn.fit(lrs/5,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n",
2676 | " \n",
2677 | " learn.load(pa) \n",
2678 | " #Calcuating mean iou score\n",
2679 | " v_targ = md.val_ds.ds[:][1]\n",
2680 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
2681 | " v_pred = learn.predict()\n",
2682 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
2683 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
2684 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
2685 | " print(bst_acc[-1])\n",
2686 | " \n",
2687 | " del learn\n",
2688 | " del md\n",
2689 | " gc.collect()\n",
2690 | "\n",
2691 | "# save_model_weights()"
2692 | ]
2693 | },
2694 | {
2695 | "cell_type": "code",
2696 | "execution_count": 0,
2697 | "metadata": {
2698 | "colab": {},
2699 | "colab_type": "code",
2700 | "id": "R2qbSFsNYJqO"
2701 | },
2702 | "outputs": [],
2703 | "source": [
2704 | "save_model_weights()"
2705 | ]
2706 | },
2707 | {
2708 | "cell_type": "code",
2709 | "execution_count": 0,
2710 | "metadata": {
2711 | "colab": {
2712 | "base_uri": "https://localhost:8080/",
2713 | "height": 34
2714 | },
2715 | "colab_type": "code",
2716 | "id": "TagiWKJ16V2z",
2717 | "outputId": "4c898c92-63b4-4bb0-af97-306f6113ac9a"
2718 | },
2719 | "outputs": [
2720 | {
2721 | "data": {
2722 | "text/plain": [
2723 | "([0.823875, 0.80875, 0.828, 0.834375, 0.8375], 0.8265)"
2724 | ]
2725 | },
2726 | "execution_count": 31,
2727 | "metadata": {
2728 | "tags": []
2729 | },
2730 | "output_type": "execute_result"
2731 | }
2732 | ],
2733 | "source": [
2734 | "bst_acc,np.mean(bst_acc)#With 128"
2735 | ]
2736 | },
2737 | {
2738 | "cell_type": "code",
2739 | "execution_count": 0,
2740 | "metadata": {
2741 | "colab": {},
2742 | "colab_type": "code",
2743 | "id": "FswF_Yw14xk2"
2744 | },
2745 | "outputs": [],
2746 | "source": [
2747 | "gc.collect()"
2748 | ]
2749 | },
2750 | {
2751 | "cell_type": "markdown",
2752 | "metadata": {
2753 | "colab_type": "text",
2754 | "id": "YxJSW5-QTk8b"
2755 | },
2756 | "source": [
2757 | "### Training: resnext50 + scSE"
2758 | ]
2759 | },
2760 | {
2761 | "cell_type": "code",
2762 | "execution_count": 0,
2763 | "metadata": {
2764 | "colab": {},
2765 | "colab_type": "code",
2766 | "id": "jXY7FrEcrLHj"
2767 | },
2768 | "outputs": [],
2769 | "source": [
2770 | "model = 'resnext50_128_scse'\n",
2771 | "arch = resnext50\n",
2772 | "bst_acc=[]\n",
2773 | "use_clr_min=20\n",
2774 | "use_clr_div=10\n",
2775 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
2776 | "\n",
2777 | "szs = [(128,64)]\n",
2778 | "for sz,bs in szs:\n",
2779 | " print([sz,bs])\n",
2780 | " for i in range(kf) :\n",
2781 | " print(f'fold_id{i}')\n",
2782 | " \n",
2783 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
2784 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
2785 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
2786 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
2787 | " \n",
2788 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
2789 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
2790 | " md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)\n",
2791 | " denorm = md.trn_ds.denorm\n",
2792 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
2793 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
2794 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
2795 | " learn = get_tgs_model() \n",
2796 | " learn.opt_fn = optim.Adam\n",
2797 | " learn.metrics=[my_eval]\n",
2798 | " pa = f'{kf}_fold_{model}_{i}'\n",
2799 | " print(pa)\n",
2800 | "\n",
2801 | " lr=1e-2\n",
2802 | " wd=1e-7\n",
2803 | " lrs = np.array([lr/100,lr/10,lr])\n",
2804 | "\n",
2805 | " learn.unfreeze() \n",
2806 | " learn.crit = nn.BCEWithLogitsLoss()\n",
2807 | " if os.path.exists(pa):\n",
2808 | " learn.load(pa)\n",
2809 | " learn.fit(lrs/2,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
2810 | " \n",
2811 | " learn.load(pa) \n",
2812 | " #Calcuating mean iou score\n",
2813 | " v_targ = md.val_ds.ds[:][1]\n",
2814 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
2815 | " v_pred = learn.predict()\n",
2816 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
2817 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
2818 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
2819 | " print(bst_acc[-1])"
2820 | ]
2821 | },
2822 | {
2823 | "cell_type": "code",
2824 | "execution_count": 0,
2825 | "metadata": {
2826 | "colab": {},
2827 | "colab_type": "code",
2828 | "id": "6mFbdBCuWbgc"
2829 | },
2830 | "outputs": [],
2831 | "source": [
2832 | "model = 'resnext50_128_scse'\n",
2833 | "arch = resnet34\n",
2834 | "bst_acc=[]\n",
2835 | "use_clr_min=20\n",
2836 | "use_clr_div=10\n",
2837 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
2838 | "\n",
2839 | "szs = [(128,64)]\n",
2840 | "for sz,bs in szs:\n",
2841 | " print([sz,bs])\n",
2842 | " for i in range(kf) :\n",
2843 | " print(f'fold_id{i}')\n",
2844 | " \n",
2845 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
2846 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
2847 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
2848 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
2849 | " \n",
2850 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
2851 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
2852 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
2853 | " denorm = md.trn_ds.denorm\n",
2854 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
2855 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
2856 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
2857 | " learn = get_tgs_model() \n",
2858 | " learn.opt_fn = optim.Adam\n",
2859 | " learn.metrics=[my_eval]\n",
2860 | " pa = f'{kf}_fold_{model}_{i}'\n",
2861 | " print(pa)\n",
2862 | "\n",
2863 | " lr=1e-2\n",
2864 | " wd=1e-7\n",
2865 | " lrs = np.array([lr/100,lr/10,lr])\n",
2866 | "\n",
2867 | " learn.unfreeze() \n",
2868 | " learn.crit = lovasz_hinge\n",
2869 | " learn.load(pa)\n",
2870 | "# learn.fit(lrs/3,2, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
2871 | " learn.fit(lrs/3,3, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
2872 | " \n",
2873 | " learn.load(pa) \n",
2874 | " #Calcuating mean iou score\n",
2875 | " v_targ = md.val_ds.ds[:][1]\n",
2876 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
2877 | " v_pred = learn.predict()\n",
2878 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
2879 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
2880 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
2881 | " print(bst_acc[-1])"
2882 | ]
2883 | },
2884 | {
2885 | "cell_type": "code",
2886 | "execution_count": 0,
2887 | "metadata": {
2888 | "colab": {},
2889 | "colab_type": "code",
2890 | "id": "UpWadtxBDMYv"
2891 | },
2892 | "outputs": [],
2893 | "source": [
2894 | "model = 'resnext50_128_scse'\n",
2895 | "arch = resnext50\n",
2896 | "bst_acc=[]\n",
2897 | "use_clr_min=20\n",
2898 | "use_clr_div=10\n",
2899 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
2900 | "\n",
2901 | "szs = [(128,64)]\n",
2902 | "for sz,bs in szs:\n",
2903 | " print([sz,bs])\n",
2904 | " for i in range(kf) :\n",
2905 | " print(f'fold_id{i}')\n",
2906 | " \n",
2907 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
2908 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
2909 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
2910 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
2911 | " \n",
2912 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
2913 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
2914 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
2915 | " denorm = md.trn_ds.denorm\n",
2916 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
2917 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
2918 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
2919 | " learn = get_tgs_model() \n",
2920 | " learn.opt_fn = optim.Adam\n",
2921 | " learn.metrics=[my_eval]\n",
2922 | " pa = f'{kf}_fold_{model}_{i}'\n",
2923 | " print(pa)\n",
2924 | "\n",
2925 | " lr=1e-2\n",
2926 | " wd=1e-7\n",
2927 | " lrs = np.array([lr/100,lr/10,lr])\n",
2928 | "\n",
2929 | " learn.unfreeze() \n",
2930 | " learn.crit = lovasz_hinge\n",
2931 | " learn.load(pa)\n",
2932 | "# learn.fit(lrs/3,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
2933 | " learn.fit(lrs/3,1, wds=wd, cycle_len=20,use_clr=(20,10),best_save_name=pa)\n",
2934 | " \n",
2935 | " learn.load(pa) \n",
2936 | " #Calcuating mean iou score\n",
2937 | " v_targ = md.val_ds.ds[:][1]\n",
2938 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
2939 | " v_pred = learn.predict()\n",
2940 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
2941 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
2942 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
2943 | " print(bst_acc[-1])"
2944 | ]
2945 | },
2946 | {
2947 | "cell_type": "code",
2948 | "execution_count": 0,
2949 | "metadata": {
2950 | "colab": {
2951 | "base_uri": "https://localhost:8080/",
2952 | "height": 5830
2953 | },
2954 | "colab_type": "code",
2955 | "id": "qrPUrDI6pieX",
2956 | "outputId": "8c65fb9e-b0e7-422a-f6ab-0e82cce4b87b"
2957 | },
2958 | "outputs": [
2959 | {
2960 | "name": "stdout",
2961 | "output_type": "stream",
2962 | "text": [
2963 | "[128, 64]\n",
2964 | "fold_id0\n",
2965 | "5_fold_resnext50_128_scse_0\n"
2966 | ]
2967 | },
2968 | {
2969 | "data": {
2970 | "application/vnd.jupyter.widget-view+json": {
2971 | "model_id": "9a222f50d2b1423db51268c7a6edcbcc",
2972 | "version_major": 2,
2973 | "version_minor": 0
2974 | },
2975 | "text/plain": [
2976 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
2977 | ]
2978 | },
2979 | "metadata": {
2980 | "tags": []
2981 | },
2982 | "output_type": "display_data"
2983 | },
2984 | {
2985 | "name": "stdout",
2986 | "output_type": "stream",
2987 | "text": [
2988 | "epoch trn_loss val_loss my_eval \n",
2989 | " 0 0.533707 0.76256 0.8175 \n",
2990 | " 1 0.594565 0.858352 0.798875 \n",
2991 | " 2 0.601546 0.788414 0.812125 \n",
2992 | " 3 0.602312 0.763777 0.816625 \n",
2993 | " 4 0.597194 0.756633 0.816125 \n",
2994 | " 5 0.56907 0.767038 0.81325 \n",
2995 | " 6 0.564023 0.746384 0.82175 \n",
2996 | " 7 0.579532 0.831695 0.81125 \n",
2997 | " 8 0.551373 0.760372 0.810875 \n",
2998 | " 9 0.54197 0.845872 0.809625 \n",
2999 | " 10 0.539457 0.726901 0.813625 \n",
3000 | " 11 0.528602 0.7233 0.81875 \n",
3001 | " 12 0.520925 0.772535 0.822875 \n",
3002 | " 13 0.518783 0.764549 0.8235 \n",
3003 | " 14 0.513134 0.735043 0.830125 \n",
3004 | " 15 0.510657 0.727242 0.826625 \n",
3005 | " 16 0.493573 0.704724 0.82725 \n",
3006 | " 17 0.481923 0.73212 0.829125 \n",
3007 | " 18 0.471987 0.698919 0.828 \n",
3008 | " 19 0.463161 0.711147 0.82725 \n",
3009 | "\n"
3010 | ]
3011 | },
3012 | {
3013 | "data": {
3014 | "application/vnd.jupyter.widget-view+json": {
3015 | "model_id": "076489da5da74916b79331f943e5b625",
3016 | "version_major": 2,
3017 | "version_minor": 0
3018 | },
3019 | "text/plain": [
3020 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
3021 | ]
3022 | },
3023 | "metadata": {
3024 | "tags": []
3025 | },
3026 | "output_type": "display_data"
3027 | },
3028 | {
3029 | "name": "stdout",
3030 | "output_type": "stream",
3031 | "text": [
3032 | "epoch trn_loss val_loss my_eval \n",
3033 | " 0 0.479967 0.827555 0.821125 \n",
3034 | " 1 0.506115 0.7792 0.807 \n",
3035 | " 2 0.521105 0.704616 0.823375 \n",
3036 | " 3 0.531661 0.759747 0.813625 \n",
3037 | " 4 0.527358 0.76719 0.8165 \n",
3038 | " 5 0.599692 0.78576 0.8095 \n",
3039 | " 6 0.579218 0.723072 0.814375 \n",
3040 | " 7 0.553373 0.735349 0.81025 \n",
3041 | " 8 0.522601 0.74302 0.82075 \n",
3042 | " 9 0.51634 0.725892 0.8255 \n",
3043 | " 10 0.507244 0.736419 0.810875 \n",
3044 | " 11 0.495299 0.747552 0.8235 \n",
3045 | " 12 0.517736 0.709832 0.8205 \n",
3046 | " 13 0.52503 0.702308 0.82 \n",
3047 | " 14 0.517949 0.763225 0.822875 \n",
3048 | " 15 0.481496 0.760829 0.825875 \n",
3049 | " 16 0.467406 0.75723 0.818 \n",
3050 | " 17 0.48397 0.719035 0.821125 \n",
3051 | " 18 0.480835 0.74827 0.81975 \n",
3052 | " 19 0.464327 0.758697 0.822625 \n",
3053 | " 20 0.460972 0.714269 0.8275 \n",
3054 | " 21 0.451309 0.763296 0.824375 \n",
3055 | " 22 0.443526 0.740101 0.82525 \n",
3056 | " 23 0.432715 0.74261 0.829875 \n",
3057 | " 24 0.459113 0.808962 0.8255 \n",
3058 | " 25 0.449143 0.793144 0.827 \n",
3059 | " 26 0.460752 0.774263 0.824 \n",
3060 | " 27 0.448464 0.768298 0.823625 \n",
3061 | " 28 0.429408 0.746062 0.826875 \n",
3062 | " 29 0.418083 0.767724 0.830375 \n",
3063 | " 30 0.420722 0.716629 0.82625 \n",
3064 | " 31 0.426644 0.73742 0.829625 \n",
3065 | " 32 0.425052 0.748321 0.82975 \n",
3066 | " 33 0.419258 0.759544 0.829125 \n",
3067 | " 34 0.411915 0.743406 0.82825 \n",
3068 | " 35 0.4213 0.755378 0.82675 \n",
3069 | " 36 0.404763 0.741769 0.831125 \n",
3070 | " 37 0.398306 0.767223 0.825625 \n",
3071 | " 38 0.396879 0.733644 0.825625 \n",
3072 | " 39 0.413759 0.745653 0.826 \n",
3073 | "\n",
3074 | "0.831125\n",
3075 | "fold_id1\n",
3076 | "5_fold_resnext50_128_scse_1\n"
3077 | ]
3078 | },
3079 | {
3080 | "data": {
3081 | "application/vnd.jupyter.widget-view+json": {
3082 | "model_id": "756be4b59566488e876b7880d74f2849",
3083 | "version_major": 2,
3084 | "version_minor": 0
3085 | },
3086 | "text/plain": [
3087 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
3088 | ]
3089 | },
3090 | "metadata": {
3091 | "tags": []
3092 | },
3093 | "output_type": "display_data"
3094 | },
3095 | {
3096 | "name": "stdout",
3097 | "output_type": "stream",
3098 | "text": [
3099 | "epoch trn_loss val_loss my_eval \n",
3100 | " 0 0.467102 0.833521 0.792125 \n",
3101 | " 1 0.514805 0.866801 0.770625 \n",
3102 | " 2 0.512858 0.888006 0.78125 \n",
3103 | " 3 0.523391 0.872697 0.785625 \n",
3104 | " 4 0.541148 0.89274 0.7915 \n",
3105 | " 5 0.534021 0.817045 0.794 \n",
3106 | " 6 0.511217 0.827669 0.80725 \n",
3107 | " 7 0.504867 0.870196 0.796375 \n",
3108 | " 8 0.482823 0.813677 0.80225 \n",
3109 | " 9 0.470734 0.859631 0.79925 \n",
3110 | " 10 0.50188 0.805922 0.790625 \n",
3111 | " 11 0.50007 0.865748 0.790125 \n",
3112 | " 12 0.490715 0.808949 0.79925 \n",
3113 | " 13 0.467356 0.822104 0.80175 \n",
3114 | " 14 0.459628 0.823243 0.8005 \n",
3115 | " 15 0.442727 0.81513 0.789125 \n",
3116 | " 16 0.433787 0.801445 0.803375 \n",
3117 | " 17 0.428531 0.821544 0.79725 \n",
3118 | " 18 0.427835 0.775848 0.80675 \n",
3119 | " 19 0.41361 0.812282 0.805 \n",
3120 | "\n"
3121 | ]
3122 | },
3123 | {
3124 | "data": {
3125 | "application/vnd.jupyter.widget-view+json": {
3126 | "model_id": "bc7d0f393faf4c2ea88f2dee53b90a55",
3127 | "version_major": 2,
3128 | "version_minor": 0
3129 | },
3130 | "text/plain": [
3131 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
3132 | ]
3133 | },
3134 | "metadata": {
3135 | "tags": []
3136 | },
3137 | "output_type": "display_data"
3138 | },
3139 | {
3140 | "name": "stdout",
3141 | "output_type": "stream",
3142 | "text": [
3143 | "epoch trn_loss val_loss my_eval \n",
3144 | " 0 0.443954 0.821174 0.795625 \n",
3145 | " 1 0.473285 0.845159 0.792 \n",
3146 | " 2 0.465905 0.985209 0.785 \n",
3147 | " 3 0.456353 0.974172 0.786875 \n",
3148 | " 4 0.467155 0.920761 0.79475 \n",
3149 | " 5 0.470211 0.900441 0.79725 \n",
3150 | " 6 0.455301 0.96447 0.792 \n",
3151 | " 7 0.47351 0.90089 0.78875 \n",
3152 | " 8 0.463646 0.856061 0.794875 \n",
3153 | " 9 0.443074 0.860522 0.792 \n",
3154 | " 10 0.429617 0.824053 0.798375 \n",
3155 | " 11 0.445916 0.875289 0.796375 \n",
3156 | " 12 0.453059 0.872382 0.798125 \n",
3157 | " 13 0.447237 0.872192 0.801125 \n",
3158 | " 14 0.44465 0.867648 0.800875 \n",
3159 | " 15 0.447324 0.889119 0.795 \n",
3160 | " 16 0.449389 0.843531 0.800875 \n",
3161 | " 17 0.438213 0.823923 0.8075 \n",
3162 | " 18 0.434376 0.847579 0.7975 \n",
3163 | " 19 0.443935 0.843054 0.801625 \n",
3164 | " 20 0.422225 0.849103 0.80025 \n",
3165 | " 21 0.414883 0.866637 0.7965 \n",
3166 | " 22 0.415798 0.935687 0.791625 \n",
3167 | " 23 0.401541 0.8712 0.800375 \n",
3168 | " 24 0.408011 0.873514 0.793375 \n",
3169 | " 25 0.38908 0.84675 0.798875 \n",
3170 | " 26 0.392715 0.857604 0.79125 \n",
3171 | " 27 0.376516 0.896894 0.792125 \n",
3172 | " 28 0.369762 0.973169 0.786875 \n",
3173 | " 29 0.385661 0.912236 0.798875 \n",
3174 | " 30 0.38374 0.864333 0.801625 \n",
3175 | " 31 0.36842 0.901064 0.806 \n",
3176 | " 32 0.378856 0.892412 0.800125 \n",
3177 | " 33 0.377577 0.88206 0.801 \n",
3178 | " 34 0.35965 0.856498 0.79675 \n",
3179 | " 35 0.365454 0.8875 0.798875 \n",
3180 | " 36 0.371453 0.905397 0.80125 \n",
3181 | " 37 0.365299 0.876433 0.7985 \n",
3182 | " 38 0.358871 0.897993 0.799625 \n",
3183 | " 39 0.360378 0.904567 0.797375 \n",
3184 | "\n",
3185 | "0.8075\n",
3186 | "fold_id2\n",
3187 | "5_fold_resnext50_128_scse_2\n"
3188 | ]
3189 | },
3190 | {
3191 | "data": {
3192 | "application/vnd.jupyter.widget-view+json": {
3193 | "model_id": "a25f7884982344d68ae68729bf31f33e",
3194 | "version_major": 2,
3195 | "version_minor": 0
3196 | },
3197 | "text/plain": [
3198 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
3199 | ]
3200 | },
3201 | "metadata": {
3202 | "tags": []
3203 | },
3204 | "output_type": "display_data"
3205 | },
3206 | {
3207 | "name": "stdout",
3208 | "output_type": "stream",
3209 | "text": [
3210 | "epoch trn_loss val_loss my_eval \n",
3211 | " 0 0.485156 0.756766 0.82075 \n",
3212 | " 1 0.517438 0.832651 0.814375 \n",
3213 | " 2 0.518396 0.735615 0.815 \n",
3214 | " 3 0.522302 0.761968 0.806625 \n",
3215 | " 4 0.540298 0.807242 0.81675 \n",
3216 | " 5 0.536279 0.714934 0.81575 \n",
3217 | " 6 0.510553 0.791743 0.797375 \n",
3218 | " 7 0.511985 0.739001 0.817625 \n",
3219 | " 8 0.50142 0.754067 0.81825 \n",
3220 | " 9 0.496233 0.749199 0.818 \n",
3221 | " 10 0.479415 0.739346 0.821375 \n",
3222 | " 11 0.4808 0.692525 0.826 \n",
3223 | " 12 0.473838 0.731722 0.823 \n",
3224 | " 13 0.467366 0.737328 0.82275 \n",
3225 | " 14 0.476735 0.719575 0.81975 \n",
3226 | " 15 0.448495 0.726288 0.820375 \n",
3227 | " 16 0.434496 0.749208 0.82025 \n",
3228 | " 17 0.426012 0.765359 0.825 \n",
3229 | " 18 0.435108 0.743209 0.8205 \n",
3230 | " 19 0.433077 0.765477 0.82475 \n",
3231 | "\n"
3232 | ]
3233 | },
3234 | {
3235 | "data": {
3236 | "application/vnd.jupyter.widget-view+json": {
3237 | "model_id": "354e58e006fe4d679dab82ffe007f623",
3238 | "version_major": 2,
3239 | "version_minor": 0
3240 | },
3241 | "text/plain": [
3242 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
3243 | ]
3244 | },
3245 | "metadata": {
3246 | "tags": []
3247 | },
3248 | "output_type": "display_data"
3249 | },
3250 | {
3251 | "name": "stdout",
3252 | "output_type": "stream",
3253 | "text": [
3254 | "epoch trn_loss val_loss my_eval \n",
3255 | " 0 0.443196 0.750643 0.826125 \n",
3256 | " 1 0.463696 0.795019 0.81775 \n",
3257 | " 2 0.475888 0.753367 0.819375 \n",
3258 | " 3 0.478167 0.75828 0.814875 \n",
3259 | " 4 0.471563 0.778656 0.813625 \n",
3260 | " 5 0.483763 0.742092 0.81475 \n",
3261 | " 6 0.479101 0.748993 0.81875 \n",
3262 | " 7 0.476277 0.780733 0.810875 \n",
3263 | " 8 0.498361 0.69888 0.82975 \n",
3264 | " 9 0.486142 0.739697 0.81925 \n",
3265 | " 10 0.457164 0.801588 0.825375 \n",
3266 | " 11 0.455916 0.745798 0.815875 \n",
3267 | " 12 0.457613 0.76608 0.818625 \n",
3268 | " 13 0.45103 0.773046 0.8205 \n",
3269 | " 14 0.447433 0.764471 0.813 \n",
3270 | " 15 0.436356 0.770631 0.82025 \n",
3271 | " 16 0.440452 0.798791 0.824875 \n",
3272 | " 17 0.43147 0.753223 0.825125 \n",
3273 | " 18 0.432604 0.828235 0.8245 \n",
3274 | " 19 0.435804 0.800897 0.81925 \n",
3275 | " 20 0.416318 0.782468 0.82325 \n",
3276 | " 21 0.425427 0.760889 0.82525 \n",
3277 | " 22 0.414975 0.779994 0.817375 \n",
3278 | " 23 0.41498 0.785758 0.82025 \n",
3279 | " 24 0.423222 0.784227 0.81475 \n",
3280 | " 25 0.417059 0.753324 0.823875 \n",
3281 | " 26 0.406137 0.769448 0.825125 \n",
3282 | " 27 0.398413 0.820555 0.824125 \n",
3283 | " 28 0.393823 0.790955 0.82025 \n",
3284 | " 29 0.404996 0.793106 0.827375 \n",
3285 | " 30 0.388835 0.761253 0.827 \n",
3286 | " 31 0.38426 0.748922 0.829625 \n",
3287 | " 32 0.380705 0.752186 0.829125 \n",
3288 | " 33 0.378552 0.722271 0.827 \n",
3289 | " 34 0.378998 0.761381 0.83075 \n",
3290 | " 35 0.380338 0.760387 0.828625 \n",
3291 | " 36 0.375267 0.79428 0.825375 \n",
3292 | " 37 0.37371 0.777205 0.830375 \n",
3293 | " 38 0.374267 0.77222 0.8245 \n",
3294 | " 39 0.372311 0.783963 0.824625 \n",
3295 | "\n",
3296 | "0.83075\n",
3297 | "fold_id3\n",
3298 | "5_fold_resnext50_128_scse_3\n"
3299 | ]
3300 | },
3301 | {
3302 | "data": {
3303 | "application/vnd.jupyter.widget-view+json": {
3304 | "model_id": "f568474ab13a402f984119d9c0490b8e",
3305 | "version_major": 2,
3306 | "version_minor": 0
3307 | },
3308 | "text/plain": [
3309 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
3310 | ]
3311 | },
3312 | "metadata": {
3313 | "tags": []
3314 | },
3315 | "output_type": "display_data"
3316 | },
3317 | {
3318 | "name": "stdout",
3319 | "output_type": "stream",
3320 | "text": [
3321 | "epoch trn_loss val_loss my_eval \n",
3322 | " 0 0.496033 0.709309 0.81325 \n",
3323 | " 1 0.521143 0.718663 0.82275 \n",
3324 | " 2 0.54441 0.835224 0.81875 \n",
3325 | " 3 0.549383 0.731907 0.821625 \n",
3326 | " 4 0.539594 0.64432 0.82425 \n",
3327 | " 5 0.534994 0.674456 0.82175 \n",
3328 | " 6 0.529376 0.71022 0.82975 \n",
3329 | " 7 0.502961 0.68506 0.824875 \n",
3330 | " 8 0.498022 0.686693 0.830625 \n",
3331 | " 9 0.500132 0.725364 0.822125 \n",
3332 | " 10 0.50281 0.675001 0.831875 \n",
3333 | " 11 0.512494 0.684257 0.836875 \n",
3334 | " 12 0.484919 0.696663 0.83475 \n",
3335 | " 13 0.469144 0.709017 0.837875 \n",
3336 | " 14 0.469601 0.659997 0.834375 \n",
3337 | " 15 0.46119 0.692171 0.834125 \n",
3338 | " 16 0.453164 0.69287 0.83025 \n",
3339 | " 17 0.443224 0.70075 0.834625 \n",
3340 | " 18 0.44307 0.694192 0.83 \n",
3341 | " 19 0.428427 0.728258 0.83575 \n",
3342 | "\n"
3343 | ]
3344 | },
3345 | {
3346 | "data": {
3347 | "application/vnd.jupyter.widget-view+json": {
3348 | "model_id": "5eaf0a7d85f94e4aac0efe9e6d732595",
3349 | "version_major": 2,
3350 | "version_minor": 0
3351 | },
3352 | "text/plain": [
3353 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
3354 | ]
3355 | },
3356 | "metadata": {
3357 | "tags": []
3358 | },
3359 | "output_type": "display_data"
3360 | },
3361 | {
3362 | "name": "stdout",
3363 | "output_type": "stream",
3364 | "text": [
3365 | "epoch trn_loss val_loss my_eval \n",
3366 | " 0 0.422124 0.726504 0.8295 \n",
3367 | " 1 0.462113 0.731205 0.8245 \n",
3368 | " 2 0.486406 0.728322 0.829125 \n",
3369 | " 3 0.487896 0.765298 0.828875 \n",
3370 | " 4 0.484002 0.75814 0.8245 \n",
3371 | " 5 0.491141 0.697839 0.827125 \n",
3372 | " 6 0.477768 0.709401 0.822625 \n",
3373 | " 7 0.483948 0.766132 0.82175 \n",
3374 | " 8 0.490673 0.695626 0.830375 \n",
3375 | " 9 0.472208 0.698568 0.81975 \n",
3376 | " 10 0.466838 0.67741 0.827375 \n",
3377 | " 11 0.463749 0.797842 0.81275 \n",
3378 | " 12 0.448588 0.679358 0.830125 \n",
3379 | " 13 0.469459 0.723771 0.828375 \n",
3380 | " 14 0.452238 0.672009 0.830875 \n",
3381 | " 15 0.447767 0.722987 0.830875 \n",
3382 | " 16 0.43008 0.666037 0.831 \n",
3383 | " 17 0.430931 0.738625 0.827625 \n",
3384 | " 18 0.427639 0.6799 0.83275 \n",
3385 | " 19 0.430004 0.645604 0.82625 \n",
3386 | " 20 0.440587 0.710382 0.825625 \n",
3387 | " 21 0.427687 0.758092 0.8225 \n",
3388 | " 22 0.424524 0.714262 0.830125 \n",
3389 | " 23 0.415856 0.734838 0.829375 \n",
3390 | " 24 0.420022 0.709158 0.834375 \n",
3391 | " 25 0.40909 0.71966 0.838125 \n",
3392 | " 26 0.392587 0.719539 0.836625 \n",
3393 | " 27 0.393285 0.732876 0.82375 \n",
3394 | " 28 0.389524 0.695115 0.83525 \n",
3395 | " 29 0.391245 0.734108 0.829375 \n",
3396 | " 30 0.384423 0.73317 0.833625 \n",
3397 | " 31 0.373081 0.731024 0.830875 \n",
3398 | " 32 0.389755 0.754126 0.83025 \n",
3399 | " 33 0.38323 0.752325 0.83575 \n",
3400 | " 34 0.375477 0.733968 0.83 \n",
3401 | " 35 0.371067 0.743891 0.83375 \n",
3402 | " 36 0.382864 0.747606 0.833875 \n",
3403 | " 37 0.370416 0.745533 0.83125 \n",
3404 | " 38 0.362293 0.751041 0.834 \n",
3405 | " 39 0.365368 0.740053 0.833625 \n",
3406 | "\n",
3407 | "0.8381250000000001\n",
3408 | "fold_id4\n",
3409 | "5_fold_resnext50_128_scse_4\n"
3410 | ]
3411 | },
3412 | {
3413 | "data": {
3414 | "application/vnd.jupyter.widget-view+json": {
3415 | "model_id": "f6ae02badebd42bcbf21aaabae7cd6f4",
3416 | "version_major": 2,
3417 | "version_minor": 0
3418 | },
3419 | "text/plain": [
3420 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
3421 | ]
3422 | },
3423 | "metadata": {
3424 | "tags": []
3425 | },
3426 | "output_type": "display_data"
3427 | },
3428 | {
3429 | "name": "stdout",
3430 | "output_type": "stream",
3431 | "text": [
3432 | "epoch trn_loss val_loss my_eval \n",
3433 | " 0 0.47414 0.722221 0.827 \n",
3434 | " 1 0.491114 0.712578 0.82425 \n",
3435 | " 2 0.51879 0.73661 0.815875 \n",
3436 | " 3 0.540288 0.692429 0.826375 \n",
3437 | " 4 0.534986 0.700633 0.827 \n",
3438 | " 5 0.518888 0.720056 0.826625 \n",
3439 | " 6 0.510556 0.819519 0.814 \n",
3440 | " 7 0.498047 0.740013 0.821125 \n",
3441 | " 8 0.467538 0.726275 0.821875 \n",
3442 | " 9 0.465198 0.740374 0.82125 \n",
3443 | " 10 0.454122 0.729831 0.822125 \n",
3444 | " 11 0.465724 0.69371 0.827625 \n",
3445 | " 12 0.459633 0.711312 0.82275 \n",
3446 | " 13 0.44714 0.709925 0.8275 \n",
3447 | " 14 0.450566 0.755889 0.81475 \n",
3448 | " 15 0.433707 0.689775 0.836125 \n",
3449 | " 16 0.419188 0.687338 0.8345 \n",
3450 | " 17 0.419773 0.675787 0.834 \n",
3451 | " 18 0.422193 0.681697 0.83225 \n",
3452 | " 19 0.414564 0.676884 0.833125 \n",
3453 | "\n"
3454 | ]
3455 | },
3456 | {
3457 | "data": {
3458 | "application/vnd.jupyter.widget-view+json": {
3459 | "model_id": "95d90582589d4934be7383e615f1dbd0",
3460 | "version_major": 2,
3461 | "version_minor": 0
3462 | },
3463 | "text/plain": [
3464 | "HBox(children=(IntProgress(value=0, description='Epoch', max=40), HTML(value='')))"
3465 | ]
3466 | },
3467 | "metadata": {
3468 | "tags": []
3469 | },
3470 | "output_type": "display_data"
3471 | },
3472 | {
3473 | "name": "stdout",
3474 | "output_type": "stream",
3475 | "text": [
3476 | "epoch trn_loss val_loss my_eval \n",
3477 | " 0 0.416673 0.743005 0.82225 \n",
3478 | " 1 0.451756 0.796281 0.822 \n",
3479 | " 2 0.47234 0.772019 0.824125 \n",
3480 | " 3 0.479158 0.759053 0.807875 \n",
3481 | " 4 0.46425 0.769957 0.815375 \n",
3482 | " 5 0.464162 0.772572 0.815875 \n",
3483 | " 6 0.460326 0.737922 0.820625 \n",
3484 | " 7 0.468473 0.750393 0.818375 \n",
3485 | " 8 0.458967 0.704921 0.831375 \n",
3486 | " 9 0.439696 0.740977 0.832625 \n",
3487 | " 10 0.445156 0.723011 0.824875 \n",
3488 | " 11 0.434864 0.725554 0.833625 \n",
3489 | " 12 0.422101 0.704498 0.828375 \n",
3490 | " 13 0.429801 0.699662 0.82975 \n",
3491 | " 14 0.432414 0.763216 0.81775 \n",
3492 | " 15 0.423432 0.72516 0.822875 \n",
3493 | " 16 0.425951 0.732884 0.832625 \n",
3494 | " 17 0.41282 0.699522 0.830125 \n",
3495 | " 18 0.40507 0.741123 0.8335 \n",
3496 | " 19 0.406175 0.729666 0.8355 \n",
3497 | " 20 0.423782 0.71638 0.82775 \n",
3498 | " 21 0.407595 0.724523 0.82375 \n",
3499 | " 22 0.400728 0.743059 0.82825 \n",
3500 | " 23 0.397718 0.712238 0.828375 \n",
3501 | " 24 0.400653 0.771733 0.81775 \n",
3502 | " 25 0.406241 0.706003 0.83425 \n",
3503 | " 26 0.384568 0.726671 0.830625 \n",
3504 | " 27 0.395504 0.695435 0.83125 \n",
3505 | " 28 0.392318 0.724178 0.831125 \n",
3506 | " 29 0.381853 0.684155 0.83 \n",
3507 | " 30 0.375725 0.688363 0.83775 \n",
3508 | " 31 0.385391 0.703829 0.836875 \n",
3509 | " 32 0.362482 0.684262 0.84275 \n",
3510 | " 33 0.354302 0.701925 0.83675 \n",
3511 | " 34 0.365955 0.715952 0.83425 \n",
3512 | " 35 0.367506 0.688223 0.8365 \n",
3513 | " 36 0.361231 0.706321 0.839 \n",
3514 | " 37 0.35691 0.693329 0.835625 \n",
3515 | " 38 0.363043 0.709944 0.832125 \n",
3516 | " 39 0.350134 0.699815 0.834625 \n",
3517 | "\n",
3518 | "0.8427500000000001\n"
3519 | ]
3520 | }
3521 | ],
3522 | "source": [
3523 | "model = 'resnext50_128_scse'\n",
3524 | "arch = resnext50\n",
3525 | "bst_acc=[]\n",
3526 | "use_clr_min=20\n",
3527 | "use_clr_div=10\n",
3528 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
3529 | "\n",
3530 | "szs = [(128,64)]\n",
3531 | "for sz,bs in szs:\n",
3532 | " print([sz,bs])\n",
3533 | " for i in range(kf) :\n",
3534 | " print(f'fold_id{i}')\n",
3535 | " \n",
3536 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
3537 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
3538 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
3539 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
3540 | " \n",
3541 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
3542 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
3543 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
3544 | " denorm = md.trn_ds.denorm\n",
3545 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
3546 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
3547 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
3548 | " learn = get_tgs_model() \n",
3549 | " learn.opt_fn = optim.Adam\n",
3550 | " learn.metrics=[my_eval]\n",
3551 | " pa = f'{kf}_fold_{model}_{i}'\n",
3552 | " print(pa)\n",
3553 | "\n",
3554 | " lr=1e-2\n",
3555 | " wd=1e-7\n",
3556 | " lrs = np.array([lr/100,lr/10,lr])\n",
3557 | "\n",
3558 | " learn.unfreeze() \n",
3559 | " learn.crit = lovasz_hinge\n",
3560 | " learn.load(pa)\n",
3561 | "# learn.fit(lrs/3,1, wds=wd, cycle_len=10,use_clr=(20,8),best_save_name=pa)\n",
3562 | " learn.fit(lrs/4,1, wds=wd, cycle_len=20,use_clr=(20,16),best_save_name=pa)\n",
3563 | " learn.fit(lrs/5,1, wds=wd, cycle_len=40,use_clr=(20,32),best_save_name=pa)\n",
3564 | " \n",
3565 | " learn.load(pa) \n",
3566 | " #Calcuating mean iou score\n",
3567 | " v_targ = md.val_ds.ds[:][1]\n",
3568 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
3569 | " v_pred = learn.predict()\n",
3570 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
3571 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
3572 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
3573 | " print(bst_acc[-1])\n",
3574 | "\n",
3575 | "save_model_weights()"
3576 | ]
3577 | },
3578 | {
3579 | "cell_type": "code",
3580 | "execution_count": 0,
3581 | "metadata": {
3582 | "colab": {
3583 | "base_uri": "https://localhost:8080/",
3584 | "height": 2218
3585 | },
3586 | "colab_type": "code",
3587 | "id": "ga4JYphvCPyd",
3588 | "outputId": "b589a96a-cf8d-40d5-b11f-a32c260d5aa5"
3589 | },
3590 | "outputs": [
3591 | {
3592 | "name": "stdout",
3593 | "output_type": "stream",
3594 | "text": [
3595 | "[128, 64]\n",
3596 | "fold_id0\n",
3597 | "5_fold_resnext50_128_scse_0\n"
3598 | ]
3599 | },
3600 | {
3601 | "data": {
3602 | "application/vnd.jupyter.widget-view+json": {
3603 | "model_id": "3f19b564db994e6687c8b43ab8f917ed",
3604 | "version_major": 2,
3605 | "version_minor": 0
3606 | },
3607 | "text/plain": [
3608 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3609 | ]
3610 | },
3611 | "metadata": {
3612 | "tags": []
3613 | },
3614 | "output_type": "display_data"
3615 | },
3616 | {
3617 | "name": "stdout",
3618 | "output_type": "stream",
3619 | "text": [
3620 | "epoch trn_loss val_loss my_eval \n",
3621 | " 0 0.352572 0.727342 0.824 \n",
3622 | " 1 0.369447 0.676443 0.83 \n",
3623 | " 2 0.387457 0.791503 0.825 \n",
3624 | " 3 0.351119 0.766687 0.827625 \n",
3625 | " 4 0.36598 0.731143 0.83275 \n",
3626 | " 5 0.364853 0.723793 0.8195 \n",
3627 | " 6 0.381143 0.841033 0.829125 \n",
3628 | " 7 0.364594 0.680827 0.83075 \n",
3629 | " 8 0.349329 0.828338 0.817375 \n",
3630 | " 9 0.35967 0.746246 0.834 \n",
3631 | " 10 0.355637 0.749717 0.83125 \n",
3632 | " 11 0.353103 0.816089 0.83225 \n",
3633 | " 12 0.33925 0.821926 0.83525 \n",
3634 | " 13 0.347948 0.751935 0.83625 \n",
3635 | " 14 0.325283 0.842732 0.833625 \n",
3636 | " 15 0.332459 0.769492 0.83075 \n",
3637 | " 16 0.329128 0.773248 0.83475 \n",
3638 | " 17 0.326993 0.837361 0.837 \n",
3639 | " 18 0.33641 0.736122 0.83775 \n",
3640 | " 19 0.32058 0.791872 0.8365 \n",
3641 | "\n",
3642 | "0.8377500000000001\n",
3643 | "fold_id1\n",
3644 | "5_fold_resnext50_128_scse_1\n"
3645 | ]
3646 | },
3647 | {
3648 | "data": {
3649 | "application/vnd.jupyter.widget-view+json": {
3650 | "model_id": "68bf61a771034be3aa0b6105e3a7b93f",
3651 | "version_major": 2,
3652 | "version_minor": 0
3653 | },
3654 | "text/plain": [
3655 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3656 | ]
3657 | },
3658 | "metadata": {
3659 | "tags": []
3660 | },
3661 | "output_type": "display_data"
3662 | },
3663 | {
3664 | "name": "stdout",
3665 | "output_type": "stream",
3666 | "text": [
3667 | "epoch trn_loss val_loss my_eval \n",
3668 | " 0 0.329473 0.895153 0.791 \n",
3669 | " 1 0.345151 0.857987 0.802125 \n",
3670 | " 2 0.346842 0.957962 0.787625 \n",
3671 | " 3 0.359232 0.914885 0.78725 \n",
3672 | " 4 0.35904 0.945004 0.79475 \n",
3673 | " 5 0.364206 0.90916 0.79025 \n",
3674 | " 6 0.352084 0.907089 0.803125 \n",
3675 | " 7 0.340385 0.884313 0.79975 \n",
3676 | " 8 0.331111 0.935045 0.80525 \n",
3677 | " 9 0.327103 0.912742 0.80725 \n",
3678 | " 10 0.334112 0.903979 0.799875 \n",
3679 | " 11 0.339094 0.997254 0.79625 \n",
3680 | " 12 0.336058 0.903867 0.802 \n",
3681 | " 13 0.330904 0.945125 0.79575 \n",
3682 | " 14 0.328544 0.888129 0.798625 \n",
3683 | " 15 0.322872 0.931454 0.7975 \n",
3684 | " 16 0.320054 0.917544 0.803375 \n",
3685 | " 17 0.318902 0.909964 0.8055 \n",
3686 | " 18 0.318903 0.864422 0.80375 \n",
3687 | " 19 0.319137 0.889656 0.803875 \n",
3688 | "\n",
3689 | "0.8072499999999999\n",
3690 | "fold_id2\n",
3691 | "5_fold_resnext50_128_scse_2\n"
3692 | ]
3693 | },
3694 | {
3695 | "data": {
3696 | "application/vnd.jupyter.widget-view+json": {
3697 | "model_id": "01b404522f664b47ad0a31f6acfa7b1a",
3698 | "version_major": 2,
3699 | "version_minor": 0
3700 | },
3701 | "text/plain": [
3702 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3703 | ]
3704 | },
3705 | "metadata": {
3706 | "tags": []
3707 | },
3708 | "output_type": "display_data"
3709 | },
3710 | {
3711 | "name": "stdout",
3712 | "output_type": "stream",
3713 | "text": [
3714 | "epoch trn_loss val_loss my_eval \n",
3715 | " 0 0.334564 0.907237 0.827 \n",
3716 | " 1 0.356609 0.77755 0.821125 \n",
3717 | " 2 0.359052 0.832501 0.83225 \n",
3718 | " 3 0.365679 0.80743 0.821 \n",
3719 | " 4 0.361891 0.790572 0.814625 \n",
3720 | " 5 0.343242 0.863644 0.823625 \n",
3721 | " 6 0.339013 0.821969 0.828 \n",
3722 | " 7 0.334319 0.807286 0.818625 \n",
3723 | " 8 0.327691 0.864193 0.828125 \n",
3724 | " 9 0.328807 0.806045 0.81325 \n",
3725 | " 10 0.34561 0.859754 0.82925 \n",
3726 | " 11 0.334322 0.871514 0.828 \n",
3727 | " 12 0.320915 0.854894 0.823 \n",
3728 | " 13 0.30964 0.863667 0.8175 \n",
3729 | " 14 0.3198 0.776783 0.81925 \n",
3730 | " 15 0.312285 0.829952 0.83175 \n",
3731 | " 16 0.315182 0.793502 0.8215 \n",
3732 | " 17 0.304829 0.833349 0.82575 \n",
3733 | " 18 0.305454 0.814216 0.833375 \n",
3734 | " 19 0.309406 0.807178 0.82525 \n",
3735 | "\n",
3736 | "0.8333750000000001\n",
3737 | "fold_id3\n",
3738 | "5_fold_resnext50_128_scse_3\n"
3739 | ]
3740 | },
3741 | {
3742 | "data": {
3743 | "application/vnd.jupyter.widget-view+json": {
3744 | "model_id": "7e9d4a962e6644e7a36ccbb7aa2958ef",
3745 | "version_major": 2,
3746 | "version_minor": 0
3747 | },
3748 | "text/plain": [
3749 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3750 | ]
3751 | },
3752 | "metadata": {
3753 | "tags": []
3754 | },
3755 | "output_type": "display_data"
3756 | },
3757 | {
3758 | "name": "stdout",
3759 | "output_type": "stream",
3760 | "text": [
3761 | "epoch trn_loss val_loss my_eval \n",
3762 | " 0 0.303919 0.798046 0.829125 \n",
3763 | " 1 0.330936 0.707061 0.83625 \n",
3764 | " 2 0.344996 0.750156 0.829125 \n",
3765 | " 3 0.344157 0.924753 0.833125 \n",
3766 | " 4 0.326378 0.760323 0.833875 \n",
3767 | " 5 0.328699 0.870832 0.825625 \n",
3768 | " 6 0.331057 0.820294 0.83625 \n",
3769 | " 7 0.34331 0.716414 0.838625 \n",
3770 | " 8 0.335372 0.924066 0.83275 \n",
3771 | " 9 0.323994 0.739209 0.83925 \n",
3772 | " 10 0.334121 0.761376 0.834 \n",
3773 | " 11 0.322592 0.740011 0.83975 \n",
3774 | " 12 0.330691 0.819215 0.837125 \n",
3775 | " 13 0.313002 0.78467 0.842875 \n",
3776 | " 14 0.318578 0.795953 0.83725 \n",
3777 | " 15 0.316326 0.765305 0.842625 \n",
3778 | " 16 0.310763 0.723363 0.84525 \n",
3779 | " 17 0.296692 0.825693 0.84075 \n",
3780 | " 18 0.29445 0.823472 0.842625 \n",
3781 | " 19 0.297867 0.751366 0.83925 \n",
3782 | "\n",
3783 | "0.8452500000000001\n",
3784 | "fold_id4\n",
3785 | "5_fold_resnext50_128_scse_4\n"
3786 | ]
3787 | },
3788 | {
3789 | "data": {
3790 | "application/vnd.jupyter.widget-view+json": {
3791 | "model_id": "7a025c00a503477aa9cfcba1f1ab3cd2",
3792 | "version_major": 2,
3793 | "version_minor": 0
3794 | },
3795 | "text/plain": [
3796 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3797 | ]
3798 | },
3799 | "metadata": {
3800 | "tags": []
3801 | },
3802 | "output_type": "display_data"
3803 | },
3804 | {
3805 | "name": "stdout",
3806 | "output_type": "stream",
3807 | "text": [
3808 | "epoch trn_loss val_loss my_eval \n",
3809 | " 0 0.32856 0.730436 0.834625 \n",
3810 | " 1 0.335845 0.899728 0.837125 \n",
3811 | " 2 0.348538 0.913843 0.8295 \n",
3812 | " 3 0.33758 0.769679 0.8235 \n",
3813 | " 4 0.346654 0.802165 0.835625 \n",
3814 | " 5 0.339364 0.746379 0.83825 \n",
3815 | " 6 0.343552 0.756106 0.828125 \n",
3816 | " 7 0.331363 0.792459 0.8335 \n",
3817 | " 8 0.31865 0.735218 0.836125 \n",
3818 | " 9 0.325964 0.810198 0.839 \n",
3819 | " 10 0.311348 0.795061 0.841875 \n",
3820 | " 11 0.321848 0.720627 0.83475 \n",
3821 | " 12 0.311088 0.884234 0.8375 \n",
3822 | " 13 0.304974 0.749801 0.836 \n",
3823 | " 14 0.305366 0.71613 0.83375 \n",
3824 | " 15 0.314715 0.693329 0.84125 \n",
3825 | " 16 0.318737 0.787573 0.83675 \n",
3826 | " 17 0.304981 0.747158 0.83425 \n",
3827 | " 18 0.302457 0.767533 0.834 \n",
3828 | " 19 0.299126 0.759485 0.833625 \n",
3829 | "\n",
3830 | "0.841875\n"
3831 | ]
3832 | }
3833 | ],
3834 | "source": [
3835 | "model = 'resnext50_128_scse'\n",
3836 | "arch = resnext50\n",
3837 | "bst_acc=[]\n",
3838 | "use_clr_min=20\n",
3839 | "use_clr_div=10\n",
3840 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
3841 | "\n",
3842 | "szs = [(128,64)]\n",
3843 | "for sz,bs in szs:\n",
3844 | " print([sz,bs])\n",
3845 | " for i in range(kf) :\n",
3846 | " print(f'fold_id{i}')\n",
3847 | " \n",
3848 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
3849 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
3850 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
3851 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
3852 | " \n",
3853 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
3854 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
3855 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
3856 | " denorm = md.trn_ds.denorm\n",
3857 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
3858 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
3859 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
3860 | " learn = get_tgs_model() \n",
3861 | " learn.opt_fn = optim.Adam\n",
3862 | " learn.metrics=[my_eval]\n",
3863 | " pa = f'{kf}_fold_{model}_{i}'\n",
3864 | " print(pa)\n",
3865 | "\n",
3866 | " lr=1e-2\n",
3867 | " wd=1e-7\n",
3868 | " lrs = np.array([lr/150,lr/20,lr])\n",
3869 | "\n",
3870 | " learn.unfreeze()\n",
3871 | " learn.bn_freeze(True)\n",
3872 | " learn.crit = lovasz_hinge\n",
3873 | " learn.load(pa)\n",
3874 | " learn.fit(lrs/10,1, wds=wd, cycle_len=20,use_clr=(20,20),best_save_name=pa)\n",
3875 | " \n",
3876 | " learn.load(pa) \n",
3877 | " #Calcuating mean iou score\n",
3878 | " v_targ = md.val_ds.ds[:][1]\n",
3879 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
3880 | " v_pred = learn.predict()\n",
3881 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
3882 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
3883 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
3884 | " print(bst_acc[-1])\n",
3885 | "\n",
3886 | " if i < kf - 1:\n",
3887 | " del learn\n",
3888 | " del md\n",
3889 | " gc.collect()\n",
3890 | " \n",
3891 | "save_model_weights()"
3892 | ]
3893 | },
3894 | {
3895 | "cell_type": "code",
3896 | "execution_count": 0,
3897 | "metadata": {
3898 | "colab": {
3899 | "base_uri": "https://localhost:8080/",
3900 | "height": 1730
3901 | },
3902 | "colab_type": "code",
3903 | "id": "ITXgqK7jjdZL",
3904 | "outputId": "0b712303-2fb6-4477-9985-1c67849c6563"
3905 | },
3906 | "outputs": [
3907 | {
3908 | "name": "stdout",
3909 | "output_type": "stream",
3910 | "text": [
3911 | "[128, 64]\n",
3912 | "fold_id0\n",
3913 | "5_fold_resnext50_128_scse_0\n"
3914 | ]
3915 | },
3916 | {
3917 | "data": {
3918 | "application/vnd.jupyter.widget-view+json": {
3919 | "model_id": "80b17c7b21dd4f91b3af5c1ad0da18a9",
3920 | "version_major": 2,
3921 | "version_minor": 0
3922 | },
3923 | "text/plain": [
3924 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3925 | ]
3926 | },
3927 | "metadata": {
3928 | "tags": []
3929 | },
3930 | "output_type": "display_data"
3931 | },
3932 | {
3933 | "name": "stdout",
3934 | "output_type": "stream",
3935 | "text": [
3936 | "epoch trn_loss val_loss my_eval \n",
3937 | " 0 0.352473 0.759049 0.829625 \n",
3938 | " 1 0.363571 0.880663 0.826 \n",
3939 | " 2 0.35344 0.807012 0.824875 \n",
3940 | " 3 0.344199 0.811331 0.82425 \n",
3941 | " 4 0.341503 0.780056 0.826375 \n",
3942 | " 5 0.349747 0.841056 0.820375 \n",
3943 | " 6 0.350928 0.883612 0.823375 \n",
3944 | " 7 0.346245 0.738591 0.819875 \n",
3945 | " 8 0.339464 0.790437 0.826375 \n",
3946 | " 9 0.339998 0.844214 0.826875 \n",
3947 | " 10 0.347131 0.780476 0.83 \n",
3948 | " 11 0.341404 0.798782 0.829 \n",
3949 | " 12 0.339125 0.789289 0.829875 \n",
3950 | " 13 0.334993 0.868688 0.827625 \n",
3951 | " 14 0.328118 0.880147 0.83275 \n",
3952 | " 15 0.306373 0.860532 0.831625 \n",
3953 | " 16 0.304504 0.829747 0.830875 \n",
3954 | " 17 0.300264 0.837326 0.832625 \n",
3955 | " 18 0.317846 0.774235 0.836375 \n",
3956 | " 19 0.310489 0.787925 0.833875 \n",
3957 | "\n",
3958 | "0.8363749999999999\n",
3959 | "fold_id1\n",
3960 | "5_fold_resnext50_128_scse_1\n"
3961 | ]
3962 | },
3963 | {
3964 | "data": {
3965 | "application/vnd.jupyter.widget-view+json": {
3966 | "model_id": "40209053452647b2858ae4b6ad13be2a",
3967 | "version_major": 2,
3968 | "version_minor": 0
3969 | },
3970 | "text/plain": [
3971 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
3972 | ]
3973 | },
3974 | "metadata": {
3975 | "tags": []
3976 | },
3977 | "output_type": "display_data"
3978 | },
3979 | {
3980 | "name": "stdout",
3981 | "output_type": "stream",
3982 | "text": [
3983 | "epoch trn_loss val_loss my_eval \n",
3984 | " 0 0.327861 0.839905 0.7995 \n",
3985 | " 1 0.332155 0.878513 0.80625 \n",
3986 | " 2 0.327904 0.842675 0.78325 \n",
3987 | " 3 0.323279 1.045407 0.800125 \n",
3988 | " 4 0.34445 0.838635 0.794375 \n",
3989 | " 5 0.336768 0.906733 0.8 \n",
3990 | " 6 0.325074 0.921411 0.804625 \n",
3991 | " 7 0.331317 0.867638 0.802125 \n",
3992 | " 8 0.336123 0.883447 0.796875 \n",
3993 | " 9 0.331685 0.851472 0.80425 \n",
3994 | " 10 0.315916 0.857878 0.790625 \n",
3995 | " 11 0.317591 1.001778 0.802125 \n",
3996 | " 12 0.322167 0.868535 0.79975 \n",
3997 | " 13 0.324085 0.90506 0.800625 \n",
3998 | " 14 0.309266 0.923913 0.795375 \n",
3999 | " 15 0.311752 0.973415 0.79975 \n",
4000 | " 16 0.300871 0.953171 0.8045 \n",
4001 | " 17 0.290202 0.962913 0.800125 \n",
4002 | " 18 0.297091 0.981357 0.80025 \n",
4003 | " 19 0.299129 0.941654 0.802625 \n",
4004 | "\n",
4005 | "0.80625\n",
4006 | "fold_id2\n",
4007 | "5_fold_resnext50_128_scse_2\n"
4008 | ]
4009 | },
4010 | {
4011 | "data": {
4012 | "application/vnd.jupyter.widget-view+json": {
4013 | "model_id": "91f55feb2d0541cd800bc6b78ff08643",
4014 | "version_major": 2,
4015 | "version_minor": 0
4016 | },
4017 | "text/plain": [
4018 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
4019 | ]
4020 | },
4021 | "metadata": {
4022 | "tags": []
4023 | },
4024 | "output_type": "display_data"
4025 | },
4026 | {
4027 | "name": "stdout",
4028 | "output_type": "stream",
4029 | "text": [
4030 | "epoch trn_loss val_loss my_eval \n",
4031 | " 0 0.337081 0.823228 0.83475 \n",
4032 | " 1 0.344549 0.913375 0.831625 \n",
4033 | " 2 0.329596 0.828047 0.831125 \n",
4034 | " 3 0.325009 0.777783 0.8285 \n",
4035 | " 4 0.33951 0.821626 0.808375 \n",
4036 | " 5 0.338317 0.940252 0.820875 \n",
4037 | " 6 0.330713 0.791693 0.83 \n",
4038 | " 7 0.332787 0.776885 0.82775 \n",
4039 | " 8 0.332073 0.806915 0.82675 \n",
4040 | " 9 0.329484 0.799763 0.817375 \n",
4041 | " 10 0.329214 0.797196 0.82475 \n",
4042 | " 11 0.31536 0.82193 0.832625 \n",
4043 | " 12 0.321162 0.843901 0.827375 \n",
4044 | " 13 0.310198 0.90338 0.828375 \n",
4045 | " 14 0.292332 0.859527 0.8295 \n",
4046 | " 15 0.291901 0.843471 0.8235 \n",
4047 | " 16 0.296143 0.85044 0.819875 \n",
4048 | " 17 0.302732 0.847719 0.827 \n",
4049 | " 18 0.303203 0.819016 0.827875 \n",
4050 | " 19 0.303959 0.815134 0.8255 \n",
4051 | "\n",
4052 | "0.8347500000000001\n",
4053 | "fold_id3\n",
4054 | "5_fold_resnext50_128_scse_3\n"
4055 | ]
4056 | },
4057 | {
4058 | "data": {
4059 | "application/vnd.jupyter.widget-view+json": {
4060 | "model_id": "28a3862558874de39a8aed9c8a9bf792",
4061 | "version_major": 2,
4062 | "version_minor": 0
4063 | },
4064 | "text/plain": [
4065 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…"
4066 | ]
4067 | },
4068 | "metadata": {
4069 | "tags": []
4070 | },
4071 | "output_type": "display_data"
4072 | },
4073 | {
4074 | "name": "stdout",
4075 | "output_type": "stream",
4076 | "text": [
4077 | "epoch trn_loss val_loss my_eval \n",
4078 | " 0 0.316421 0.866826 0.83525 \n",
4079 | " 1 0.337529 0.992894 0.828 \n",
4080 | " 2 0.355061 0.758108 0.83175 \n",
4081 | " 3 0.341312 0.793649 0.82925 \n",
4082 | " 4 0.3637 0.741804 0.83575 \n",
4083 | " 5 0.330458 0.743388 0.839375 \n",
4084 | " 6 0.334183 0.73421 0.841375 \n",
4085 | " 7 0.30901 0.724088 0.83675 \n",
4086 | " 8 0.322369 0.710755 0.830625 \n",
4087 | " 9 0.312626 0.907613 0.841 \n",
4088 | " 10 0.314347 0.75424 0.838875 \n",
4089 | " 11 0.287641 0.762639 0.836125 \n",
4090 | " 12 0.291472 0.738226 0.844875 \n",
4091 | " 13 0.294669 0.741112 0.841375 \n",
4092 | " 14 0.294856 0.747115 0.842375 \n",
4093 | " 15 0.295397 0.75156 0.843125 \n",
4094 | " 16 0.287345 0.807438 0.841875 \n",
4095 | " 17 0.295449 0.754707 0.843625 \n",
4096 | " 82%|████████▏ | 41/50 [00:54<00:10, 1.22s/it, loss=0.29] "
4097 | ]
4098 | }
4099 | ],
4100 | "source": [
4101 | "model = 'resnext50_128_scse'\n",
4102 | "arch = resnext50\n",
4103 | "bst_acc=[]\n",
4104 | "use_clr_min=20\n",
4105 | "use_clr_div=10\n",
4106 | "aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
4107 | "\n",
4108 | "szs = [(128,64)]\n",
4109 | "for sz,bs in szs:\n",
4110 | " print([sz,bs])\n",
4111 | " for i in range(kf) :\n",
4112 | " print(f'fold_id{i}')\n",
4113 | " \n",
4114 | " trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[i]])\n",
4115 | " trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[i]])\n",
4116 | " val_x = [f'{TRN_DN}/{o}' for o in val_folds[i]]\n",
4117 | " val_y = [f'{MASK_DN}/{o}' for o in val_folds[i]]\n",
4118 | " \n",
4119 | " tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
4120 | " datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
4121 | " md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
4122 | " denorm = md.trn_ds.denorm\n",
4123 | " md.trn_dl.dataset = DepthDataset(md.trn_ds,dpth_dict)\n",
4124 | " md.val_dl.dataset = DepthDataset(md.val_ds,dpth_dict)\n",
4125 | " md.test_dl.dataset = DepthDataset(md.test_ds,dpth_dict)\n",
4126 | " learn = get_tgs_model() \n",
4127 | " learn.opt_fn = optim.Adam\n",
4128 | " learn.metrics=[my_eval]\n",
4129 | " pa = f'{kf}_fold_{model}_{i}'\n",
4130 | " print(pa)\n",
4131 | "\n",
4132 | " lr=1e-2\n",
4133 | " wd=1e-7\n",
4134 | " lrs = np.array([lr/150,lr/20,lr])\n",
4135 | "\n",
4136 | " learn.unfreeze()\n",
4137 | " learn.bn_freeze(True)\n",
4138 | " learn.crit = lovasz_hinge\n",
4139 | " learn.load(pa)\n",
4140 | " learn.fit(lrs/10,1, wds=wd, cycle_len=20,use_clr=(20,20),best_save_name=pa)\n",
4141 | " \n",
4142 | " learn.load(pa) \n",
4143 | " #Calcuating mean iou score\n",
4144 | " v_targ = md.val_ds.ds[:][1]\n",
4145 | " v_preds = np.zeros((len(v_targ),sz,sz)) \n",
4146 | " v_pred = learn.predict()\n",
4147 | " v_pred = to_np(torch.sigmoid(torch.from_numpy(v_pred)))\n",
4148 | " p = ((v_pred)>0.5).astype(np.uint8)\n",
4149 | " bst_acc.append(intersection_over_union_thresholds(v_targ,p))\n",
4150 | " print(bst_acc[-1])\n",
4151 | "\n",
4152 | " if i < kf - 1:\n",
4153 | " del learn\n",
4154 | " del md\n",
4155 | " gc.collect()\n",
4156 | " \n",
4157 | "save_model_weights()"
4158 | ]
4159 | },
4160 | {
4161 | "cell_type": "code",
4162 | "execution_count": 0,
4163 | "metadata": {
4164 | "colab": {},
4165 | "colab_type": "code",
4166 | "id": "lTpda6ZVbF0b"
4167 | },
4168 | "outputs": [],
4169 | "source": [
4170 | "save_model_weights()"
4171 | ]
4172 | },
4173 | {
4174 | "cell_type": "code",
4175 | "execution_count": 0,
4176 | "metadata": {
4177 | "colab": {},
4178 | "colab_type": "code",
4179 | "id": "xP7Kn5wYs3NG"
4180 | },
4181 | "outputs": [],
4182 | "source": [
4183 | "load_model_weights(time_stamp='20181010')"
4184 | ]
4185 | },
4186 | {
4187 | "cell_type": "code",
4188 | "execution_count": 0,
4189 | "metadata": {
4190 | "colab": {
4191 | "base_uri": "https://localhost:8080/",
4192 | "height": 50
4193 | },
4194 | "colab_type": "code",
4195 | "id": "lz43bF8JWcKC",
4196 | "outputId": "3efd88a3-094c-448b-b5f2-5b6ae26e2a4e"
4197 | },
4198 | "outputs": [
4199 | {
4200 | "data": {
4201 | "text/plain": [
4202 | "([0.835, 0.8065000000000001, 0.8315, 0.8447499999999999, 0.8440000000000001],\n",
4203 | " 0.8323500000000001)"
4204 | ]
4205 | },
4206 | "execution_count": 74,
4207 | "metadata": {
4208 | "tags": []
4209 | },
4210 | "output_type": "execute_result"
4211 | }
4212 | ],
4213 | "source": [
4214 | "bst_acc,np.mean(bst_acc)#With 128"
4215 | ]
4216 | },
4217 | {
4218 | "cell_type": "markdown",
4219 | "metadata": {
4220 | "colab_type": "text",
4221 | "id": "rF8h3iawT6PI"
4222 | },
4223 | "source": [
4224 | "### Submit"
4225 | ]
4226 | },
4227 | {
4228 | "cell_type": "code",
4229 | "execution_count": 0,
4230 | "metadata": {
4231 | "colab": {},
4232 | "colab_type": "code",
4233 | "id": "LqbkAvBMMKR6"
4234 | },
4235 | "outputs": [],
4236 | "source": [
4237 | "# sz = 128\n",
4238 | "# bs = 64\n",
4239 | "# model = 'resnext50_128_scse'\n",
4240 | "\n",
4241 | "# # trn_x = np.array([f'{TRN_DN}/{o}' for o in train_folds[0]])\n",
4242 | "# # trn_y = np.array([f'{MASK_DN}/{o}' for o in train_folds[0]])\n",
4243 | "# # val_x = [f'{TRN_DN}/{o}' for o in val_folds[0]]\n",
4244 | "# # val_y = [f'{MASK_DN}/{o}' for o in val_folds[0]]\n",
4245 | " \n",
4246 | "# aug_tfms = [RandomFlip(tfm_y=TfmType.CLASS)]\n",
4247 | "# tfms = tfms_from_model(arch, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
4248 | "# datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
4249 | "# md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
4250 | "# # tfms = tfms_from_model(resnet34, sz=sz, pad=0,crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
4251 | "# # datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms,test=test_x,path=PATH)\n",
4252 | "# # md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
4253 | "\n",
4254 | "# learn = get_tgs_model() \n",
4255 | "# learn.opt_fn = optim.Adam\n",
4256 | "# learn.crit = lovasz_hinge\n",
4257 | "# learn.metrics=[my_eval]"
4258 | ]
4259 | },
4260 | {
4261 | "cell_type": "code",
4262 | "execution_count": 0,
4263 | "metadata": {
4264 | "colab": {
4265 | "base_uri": "https://localhost:8080/",
4266 | "height": 252
4267 | },
4268 | "colab_type": "code",
4269 | "id": "ms-zAHP2WcGb",
4270 | "outputId": "e1af7781-86a4-4154-82cf-4d6eb7e55897"
4271 | },
4272 | "outputs": [
4273 | {
4274 | "data": {
4275 | "application/vnd.jupyter.widget-view+json": {
4276 | "model_id": "ef25c63e5456401eb8edde4089530451",
4277 | "version_major": 2,
4278 | "version_minor": 0
4279 | },
4280 | "text/plain": [
4281 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
4282 | ]
4283 | },
4284 | "metadata": {
4285 | "tags": []
4286 | },
4287 | "output_type": "display_data"
4288 | },
4289 | {
4290 | "name": "stdout",
4291 | "output_type": "stream",
4292 | "text": [
4293 | "5_fold_resnext50_128_scse_0\n",
4294 | "5_fold_resnext50_128_scse_1\n",
4295 | "5_fold_resnext50_128_scse_2\n",
4296 | "5_fold_resnext50_128_scse_3\n",
4297 | "5_fold_resnext50_128_scse_4\n",
4298 | "\n"
4299 | ]
4300 | },
4301 | {
4302 | "data": {
4303 | "application/vnd.jupyter.widget-view+json": {
4304 | "model_id": "5c581f68713c40a090b48c2b02908785",
4305 | "version_major": 2,
4306 | "version_minor": 0
4307 | },
4308 | "text/plain": [
4309 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
4310 | ]
4311 | },
4312 | "metadata": {
4313 | "tags": []
4314 | },
4315 | "output_type": "display_data"
4316 | },
4317 | {
4318 | "name": "stdout",
4319 | "output_type": "stream",
4320 | "text": [
4321 | "5_fold_resnext50_128_scse_0\n",
4322 | "5_fold_resnext50_128_scse_1\n",
4323 | "5_fold_resnext50_128_scse_2\n",
4324 | "5_fold_resnext50_128_scse_3\n",
4325 | "5_fold_resnext50_128_scse_4\n",
4326 | "\n"
4327 | ]
4328 | }
4329 | ],
4330 | "source": [
4331 | "preds = np.zeros(shape = (18000,sz,sz))\n",
4332 | "for o in [True,False]:\n",
4333 | " md.test_dl.dataset = TestFilesDataset(test_x,test_x,tfms[1],flip=o,path=PATH)\n",
4334 | " md.test_dl.dataset = DepthDatasetV2(md.test_dl.dataset,dpth_dict)\n",
4335 | " \n",
4336 | " for i in tqdm_notebook(range(kf)):\n",
4337 | " pa = f'{kf}_fold_{model}_{i}'\n",
4338 | " print(pa)\n",
4339 | " learn.load(pa)\n",
4340 | " pred = learn.predict(is_test=True)\n",
4341 | " pred = to_np(torch.sigmoid(torch.from_numpy(pred))) \n",
4342 | " for im_idx,im in enumerate(pred):\n",
4343 | " preds[im_idx] += np.fliplr(im) if o else im\n",
4344 | " del pred"
4345 | ]
4346 | },
4347 | {
4348 | "cell_type": "code",
4349 | "execution_count": 0,
4350 | "metadata": {
4351 | "colab": {},
4352 | "colab_type": "code",
4353 | "id": "ZHKkmKrqWb7W"
4354 | },
4355 | "outputs": [],
4356 | "source": [
4357 | "# plt.imshow(((preds[16]/10)>0.5).astype(np.uint8))"
4358 | ]
4359 | },
4360 | {
4361 | "cell_type": "code",
4362 | "execution_count": 0,
4363 | "metadata": {
4364 | "colab": {},
4365 | "colab_type": "code",
4366 | "id": "pe59_TSGWb4K"
4367 | },
4368 | "outputs": [],
4369 | "source": [
4370 | "p = [cv2.resize(o/10,dsize=(101,101)) for o in preds]\n",
4371 | "p = [(o>0.5).astype(np.uint8) for o in p]"
4372 | ]
4373 | },
4374 | {
4375 | "cell_type": "code",
4376 | "execution_count": 0,
4377 | "metadata": {
4378 | "colab": {
4379 | "base_uri": "https://localhost:8080/",
4380 | "height": 34
4381 | },
4382 | "colab_type": "code",
4383 | "id": "qu8Gin-NTMef",
4384 | "outputId": "acaf9176-75b0-4b53-d2f5-ece8890b9862"
4385 | },
4386 | "outputs": [
4387 | {
4388 | "data": {
4389 | "text/plain": [
4390 | "(101, 101)"
4391 | ]
4392 | },
4393 | "execution_count": 67,
4394 | "metadata": {
4395 | "tags": []
4396 | },
4397 | "output_type": "execute_result"
4398 | }
4399 | ],
4400 | "source": [
4401 | "p[0].shape"
4402 | ]
4403 | },
4404 | {
4405 | "cell_type": "code",
4406 | "execution_count": 0,
4407 | "metadata": {
4408 | "colab": {
4409 | "base_uri": "https://localhost:8080/",
4410 | "height": 50
4411 | },
4412 | "colab_type": "code",
4413 | "id": "7ZRdo1lAWb08",
4414 | "outputId": "615f7553-793c-4f59-d6a1-22ccdedbbf6d"
4415 | },
4416 | "outputs": [
4417 | {
4418 | "data": {
4419 | "application/vnd.jupyter.widget-view+json": {
4420 | "model_id": "6a28697675af4cbc90c8dcd4afba37d2",
4421 | "version_major": 2,
4422 | "version_minor": 0
4423 | },
4424 | "text/plain": [
4425 | "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
4426 | ]
4427 | },
4428 | "metadata": {
4429 | "tags": []
4430 | },
4431 | "output_type": "display_data"
4432 | },
4433 | {
4434 | "name": "stdout",
4435 | "output_type": "stream",
4436 | "text": [
4437 | "\n"
4438 | ]
4439 | }
4440 | ],
4441 | "source": [
4442 | "pred_dict = {id_.name[:-4]:RLenc(p[i]) for i,id_ in tqdm_notebook(enumerate(test_x))}\n",
4443 | "sub = pd.DataFrame.from_dict(pred_dict,orient='index')\n",
4444 | "sub.index.names = ['id']\n",
4445 | "sub.columns = ['rle_mask']\n",
4446 | "sub.to_csv('submission.csv')"
4447 | ]
4448 | },
4449 | {
4450 | "cell_type": "code",
4451 | "execution_count": 0,
4452 | "metadata": {
4453 | "colab": {
4454 | "base_uri": "https://localhost:8080/",
4455 | "height": 34
4456 | },
4457 | "colab_type": "code",
4458 | "id": "et8aAjFVWbxh",
4459 | "outputId": "3af2bb8f-dcdc-4187-a48e-9dabfaebf9f4"
4460 | },
4461 | "outputs": [
4462 | {
4463 | "name": "stdout",
4464 | "output_type": "stream",
4465 | "text": [
4466 | "Successfully submitted to TGS Salt Identification Challenge"
4467 | ]
4468 | }
4469 | ],
4470 | "source": [
4471 | "!kaggle competitions submit -c tgs-salt-identification-challenge -f submission.csv -m \"\""
4472 | ]
4473 | },
4474 | {
4475 | "cell_type": "code",
4476 | "execution_count": null,
4477 | "metadata": {
4478 | "colab": {
4479 | "base_uri": "https://localhost:8080/",
4480 | "height": 386
4481 | },
4482 | "colab_type": "code",
4483 | "id": "38KczbM-WbuM",
4484 | "outputId": "0a1ab856-ec21-45ad-803e-fa3dc06ec9dc"
4485 | },
4486 | "outputs": [],
4487 | "source": [
4488 | "!kaggle competitions submissions -c tgs-salt-identification-challenge"
4489 | ]
4490 | },
4491 | {
4492 | "cell_type": "code",
4493 | "execution_count": 0,
4494 | "metadata": {
4495 | "colab": {},
4496 | "colab_type": "code",
4497 | "id": "7rbnFVohWbqd"
4498 | },
4499 | "outputs": [],
4500 | "source": [
4501 | "!fusermount -u drive"
4502 | ]
4503 | },
4504 | {
4505 | "cell_type": "code",
4506 | "execution_count": 0,
4507 | "metadata": {
4508 | "colab": {},
4509 | "colab_type": "code",
4510 | "id": "emlLQBDaj6fk"
4511 | },
4512 | "outputs": [],
4513 | "source": []
4514 | }
4515 | ],
4516 | "metadata": {
4517 | "accelerator": "GPU",
4518 | "colab": {
4519 | "collapsed_sections": [
4520 | "Dz4H7ksT0tMo"
4521 | ],
4522 | "name": "TGS_Salt_Unet50_5-Fold_scSE_score_8.42.ipynb",
4523 | "provenance": [],
4524 | "version": "0.3.2"
4525 | },
4526 | "kernelspec": {
4527 | "display_name": "Python 3",
4528 | "language": "python",
4529 | "name": "python3"
4530 | },
4531 | "language_info": {
4532 | "codemirror_mode": {
4533 | "name": "ipython",
4534 | "version": 3
4535 | },
4536 | "file_extension": ".py",
4537 | "mimetype": "text/x-python",
4538 | "name": "python",
4539 | "nbconvert_exporter": "python",
4540 | "pygments_lexer": "ipython3",
4541 | "version": "3.5.4"
4542 | }
4543 | },
4544 | "nbformat": 4,
4545 | "nbformat_minor": 1
4546 | }
4547 |
--------------------------------------------------------------------------------