├── README.md
├── images
├── README.md
├── architecture.jpg
├── qualitative.jpg
├── result-1.png
└── result-2.png
├── metrics.py
├── model.py
├── resnet.py
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # DilatedSegNet: A Deep Dilated Segmentation Network for Polyp Segmentation
2 |
3 | ## 1. Abstract
4 |
5 | Colorectal cancer (CRC) is the second leading cause of cancer-related death worldwide. Excision of polyps during colonoscopy helps reduce mortality and morbidity for CRC. Powered by deep learning, computer-aided diagnosis (CAD) systems can detect regions in the colon overlooked by physicians during colonoscopy. Lacking high accuracy and real-time speed are the essential obstacles to be overcome for successful clinical integration of such systems. While literature is focused on improving accuracy, the speed parameter is often ignored. Toward this critical need, we intend to develop a novel real-time deep learning-based architecture, DilatedSegNet, to perform polyp segmentation on the fly. DilatedSegNet is an encoder-decoder network that uses pre-trained ResNet50 as the encoder from which we extract four levels of feature maps. Each of these feature maps is passed through a dilated convolution pooling (DCP) block. The outputs from the DCP blocks are concatenated and passed through a series of four decoder blocks that predicts the segmentation mask. The proposed method achieves a real-time operation speed of 33.68 frames per second with an average dice coefficient of 0.90 and mIoU of 0.83. Additionally, we also provide heatmap along with the qualitative results that shows the explanation for the polyp location, which increases the trustworthiness of the method. The results on the publicly available Kvasir-SEG and BKAI-IGH datasets suggest that DilatedSegNet can give real-time feedback while retaining a high dice coefficient, indicating high potential for using such models in real clinical settings in the near future.
6 |
7 |
8 | ## 2. Architecture
9 |
10 |
11 | ## 3. Implementation
12 | The proposed architecture is implemented using the PyTorch framework (1.9.0+cu111) with a single GeForce RTX 3090 GPU of 24 GB memory.
13 |
14 | ### 3.1 Dataset
15 | We have used the following datasets:
16 | - [Kvasir-SEG](https://datasets.simula.no/downloads/kvasir-seg.zip)
17 | - [BKAI](https://www.kaggle.com/competitions/bkai-igh-neopolyp/data)
18 |
19 | BKAI dataset follows an 80:10:10 split for training, validation and testing, while the Kvasir-SEG follows an official split of 880/120.
20 |
21 | ### 3.2 Weight file
22 | - [Kvasir-SEG](https://drive.google.com/file/d/1diYckKDMqDWSDD6O5Jm6InCxWEkU0GJC/view?usp=sharing)
23 | - [BKAI-IGH](https://drive.google.com/file/d/1ojGaQThD56mRhGQaVoJVpAw0oVwSzX8N/view?usp=sharing)
24 |
25 | ## 4. Results
26 |
27 | ### 4.1 Quantative Results: Same Dataset
28 |
29 |
30 | ### 4.2 Quantative Results: Different Dataset
31 |
32 |
33 | ### 4.3 Qualitative Results
34 |
35 |
36 | ## 5. Citation
37 | Updated soon.
38 |
39 | ## 6. License
40 | The source code is free for research and education use only. Any comercial use should receive a formal permission from the first author.
41 |
42 | ## 7. Contact
43 | Please contact nikhilroxtomar@gmail.com for any further questions.
44 |
--------------------------------------------------------------------------------
/images/README.md:
--------------------------------------------------------------------------------
1 | # Images
2 |
--------------------------------------------------------------------------------
/images/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/architecture.jpg
--------------------------------------------------------------------------------
/images/qualitative.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/qualitative.jpg
--------------------------------------------------------------------------------
/images/result-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/result-1.png
--------------------------------------------------------------------------------
/images/result-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/result-2.png
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """ Loss Functions -------------------------------------- """
6 | class DiceLoss(nn.Module):
7 | def __init__(self, weight=None, size_average=True):
8 | super(DiceLoss, self).__init__()
9 |
10 | def forward(self, inputs, targets, smooth=1):
11 | inputs = torch.sigmoid(inputs)
12 |
13 | inputs = inputs.view(-1)
14 | targets = targets.view(-1)
15 |
16 | intersection = (inputs * targets).sum()
17 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
18 |
19 | return 1 - dice
20 |
21 | class DiceBCELoss(nn.Module):
22 | def __init__(self, weight=None, size_average=True):
23 | super(DiceBCELoss, self).__init__()
24 |
25 | def forward(self, inputs, targets, smooth=1):
26 | inputs = torch.sigmoid(inputs)
27 |
28 | inputs = inputs.view(-1)
29 | targets = targets.view(-1)
30 |
31 | intersection = (inputs * targets).sum()
32 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
33 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
34 | Dice_BCE = BCE + dice_loss
35 |
36 | return Dice_BCE
37 |
38 | """ Metrics ------------------------------------------ """
39 | def precision(y_true, y_pred):
40 | intersection = (y_true * y_pred).sum()
41 | return (intersection + 1e-15) / (y_pred.sum() + 1e-15)
42 |
43 | def recall(y_true, y_pred):
44 | intersection = (y_true * y_pred).sum()
45 | return (intersection + 1e-15) / (y_true.sum() + 1e-15)
46 |
47 | def F2(y_true, y_pred, beta=2):
48 | p = precision(y_true,y_pred)
49 | r = recall(y_true, y_pred)
50 | return (1+beta**2.) *(p*r) / float(beta**2*p + r + 1e-15)
51 |
52 | def dice_score(y_true, y_pred):
53 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15)
54 |
55 | def jac_score(y_true, y_pred):
56 | intersection = (y_true * y_pred).sum()
57 | union = y_true.sum() + y_pred.sum() - intersection
58 | return (intersection + 1e-15) / (union + 1e-15)
59 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from resnet import resnet50
5 | import numpy as np
6 | import cv2
7 |
8 | def save_feats_mean(x, size=(256, 256)):
9 | b, c, h, w = x.shape
10 | with torch.no_grad():
11 | x = x.detach().cpu().numpy()
12 | x = np.transpose(x[0], (1, 2, 0))
13 | x = np.mean(x, axis=-1)
14 | x = x/np.max(x)
15 | x = x * 255.0
16 | x = x.astype(np.uint8)
17 |
18 | if h != size[1]:
19 | x = cv2.resize(x, size)
20 |
21 | x = cv2.applyColorMap(x, cv2.COLORMAP_JET)
22 | x = np.array(x, dtype=np.uint8)
23 | return x
24 |
25 | def get_mean_attention_map(x):
26 | x = torch.mean(x, axis=1)
27 | x = torch.unsqueeze(x, 1)
28 | x = x / torch.max(x)
29 | return x
30 |
31 | class ResidualBlock(nn.Module):
32 | def __init__(self, in_c, out_c):
33 | super().__init__()
34 |
35 | self.relu = nn.ReLU()
36 | self.conv = nn.Sequential(
37 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
38 | nn.BatchNorm2d(out_c),
39 | nn.ReLU(),
40 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
41 | nn.BatchNorm2d(out_c)
42 | )
43 | self.shortcut = nn.Sequential(
44 | nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),
45 | nn.BatchNorm2d(out_c)
46 | )
47 |
48 | def forward(self, inputs):
49 | x1 = self.conv(inputs)
50 | x2 = self.shortcut(inputs)
51 | x = self.relu(x1 + x2)
52 | return x
53 |
54 | class DilatedConv(nn.Module):
55 | def __init__(self, in_c, out_c):
56 | super().__init__()
57 |
58 | self.c1 = nn.Sequential(
59 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, dilation=1),
60 | nn.BatchNorm2d(out_c),
61 | nn.ReLU()
62 | )
63 |
64 | self.c2 = nn.Sequential(
65 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=3, dilation=3),
66 | nn.BatchNorm2d(out_c),
67 | nn.ReLU()
68 | )
69 |
70 | self.c3 = nn.Sequential(
71 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=6, dilation=6),
72 | nn.BatchNorm2d(out_c),
73 | nn.ReLU()
74 | )
75 |
76 | self.c4 = nn.Sequential(
77 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=9, dilation=9),
78 | nn.BatchNorm2d(out_c),
79 | nn.ReLU()
80 | )
81 |
82 | self.c5 = nn.Sequential(
83 | nn.Conv2d(out_c*4, out_c, kernel_size=1, padding=0),
84 | nn.BatchNorm2d(out_c),
85 | nn.ReLU()
86 | )
87 |
88 | def forward(self, inputs):
89 | x1 = self.c1(inputs)
90 | x2 = self.c2(inputs)
91 | x3 = self.c3(inputs)
92 | x4 = self.c4(inputs)
93 | x = torch.cat([x1, x2, x3, x4], axis=1)
94 | x = self.c5(x)
95 | return x
96 |
97 | class ChannelAttention(nn.Module):
98 | def __init__(self, in_planes, ratio=16):
99 | super(ChannelAttention, self).__init__()
100 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
101 | self.max_pool = nn.AdaptiveMaxPool2d(1)
102 |
103 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
104 | self.relu1 = nn.ReLU()
105 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
106 |
107 | self.sigmoid = nn.Sigmoid()
108 |
109 | def forward(self, x):
110 | x0 = x
111 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
112 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
113 | out = avg_out + max_out
114 | return x0 * self.sigmoid(out)
115 |
116 |
117 | class SpatialAttention(nn.Module):
118 | def __init__(self, kernel_size=7):
119 | super(SpatialAttention, self).__init__()
120 |
121 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
122 | padding = 3 if kernel_size == 7 else 1
123 |
124 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
125 | self.sigmoid = nn.Sigmoid()
126 |
127 | def forward(self, x):
128 | x0 = x
129 | avg_out = torch.mean(x, dim=1, keepdim=True)
130 | max_out, _ = torch.max(x, dim=1, keepdim=True)
131 | x = torch.cat([avg_out, max_out], dim=1)
132 | x = self.conv1(x)
133 | return x0 * self.sigmoid(x)
134 |
135 | class DecoderBlock(nn.Module):
136 | def __init__(self, in_c, out_c):
137 | super().__init__()
138 |
139 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
140 | self.r1 = ResidualBlock(in_c[0]+in_c[1], out_c)
141 | self.r2 = ResidualBlock(out_c, out_c)
142 |
143 | self.ca = ChannelAttention(out_c)
144 | self.sa = SpatialAttention()
145 |
146 | def forward(self, x, s):
147 | x = self.up(x)
148 | x = torch.cat([x, s], axis=1)
149 | x = self.r1(x)
150 | x = self.r2(x)
151 |
152 | x = self.ca(x)
153 | x = self.sa(x)
154 | return x
155 |
156 | class RUPNet(nn.Module):
157 | def __init__(self):
158 | super().__init__()
159 |
160 | """ ResNet50 """
161 | backbone = resnet50()
162 | self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
163 | self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)
164 | self.layer2 = backbone.layer2
165 | self.layer3 = backbone.layer3
166 |
167 | """ Dilated Conv + Pooling """
168 | self.r1 = nn.Sequential(DilatedConv(64, 64), nn.MaxPool2d((8, 8)))
169 | self.r2 = nn.Sequential(DilatedConv(256, 64), nn.MaxPool2d((4, 4)))
170 | self.r3 = nn.Sequential(DilatedConv(512, 64), nn.MaxPool2d((2, 2)))
171 | self.r4 = DilatedConv(1024, 64)
172 |
173 | """ Decoder """
174 | self.d1 = DecoderBlock([256, 512], 256)
175 | self.d2 = DecoderBlock([256, 256], 128)
176 | self.d3 = DecoderBlock([128, 64], 64)
177 | self.d4 = DecoderBlock([64, 3], 32)
178 |
179 | """ """
180 |
181 | """ Output """
182 | self.y = nn.Conv2d(32, 1, kernel_size=1, padding=0)
183 |
184 | def forward(self, x, heatmap=None):
185 | """ ResNet50 """
186 | s0 = x
187 | s1 = self.layer0(s0) ## [-1, 64, h/2, w/2]
188 | s2 = self.layer1(s1) ## [-1, 256, h/4, w/4]
189 | s3 = self.layer2(s2) ## [-1, 512, h/8, w/8]
190 | s4 = self.layer3(s3) ## [-1, 1024, h/16, w/16]
191 |
192 | """ Dilated Conv + Pooling """
193 | r1 = self.r1(s1)
194 | r2 = self.r2(s2)
195 | r3 = self.r3(s3)
196 | r4 = self.r4(s4)
197 |
198 | rx = torch.cat([r1, r2, r3, r4], axis=1)
199 |
200 | """ Decoder """
201 | d1 = self.d1(rx, s3)
202 | d2 = self.d2(d1, s2)
203 | d3 = self.d3(d2, s1)
204 | d4 = self.d4(d3, s0)
205 |
206 | y = self.y(d4)
207 |
208 | if heatmap != None:
209 | hmap = save_feats_mean(d4)
210 | return hmap, y
211 | else:
212 | return y
213 |
214 | if __name__ == "__main__":
215 | x = torch.randn((8, 3, 256, 256))
216 | model = RUPNet()
217 | y = model(x)
218 | print(y.shape)
219 |
220 | from ptflops import get_model_complexity_info
221 | flops, params = get_model_complexity_info(model, input_res=(3, 256, 256), as_strings=True, print_per_layer_stat=False)
222 | print(' - Flops: ' + flops)
223 | print(' - Params: ' + params)
224 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from torchvision.models.utils import load_state_dict_from_url
4 | from torch.hub import load_state_dict_from_url
5 |
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
9 | 'wide_resnet50_2', 'wide_resnet101_2']
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
22 | }
23 |
24 |
25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28 | padding=dilation, groups=groups, bias=False, dilation=dilation)
29 |
30 |
31 | def conv1x1(in_planes, out_planes, stride=1):
32 | """1x1 convolution"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
34 |
35 |
36 | class BasicBlock(nn.Module):
37 | expansion = 1
38 |
39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
40 | base_width=64, dilation=1, norm_layer=None):
41 | super(BasicBlock, self).__init__()
42 | if norm_layer is None:
43 | norm_layer = nn.BatchNorm2d
44 | if groups != 1 or base_width != 64:
45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
46 | if dilation > 1:
47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49 | self.conv1 = conv3x3(inplanes, planes, stride)
50 | self.bn1 = norm_layer(planes)
51 | self.relu = nn.ReLU(inplace=True)
52 | self.conv2 = conv3x3(planes, planes)
53 | self.bn2 = norm_layer(planes)
54 | self.downsample = downsample
55 | self.stride = stride
56 |
57 | def forward(self, x):
58 | identity = x
59 |
60 | out = self.conv1(x)
61 | out = self.bn1(out)
62 | out = self.relu(out)
63 |
64 | out = self.conv2(out)
65 | out = self.bn2(out)
66 |
67 | if self.downsample is not None:
68 | identity = self.downsample(x)
69 |
70 | out += identity
71 | out = self.relu(out)
72 |
73 | return out
74 |
75 |
76 | class Bottleneck(nn.Module):
77 | expansion = 4
78 |
79 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
80 | base_width=64, dilation=1, norm_layer=None):
81 | super(Bottleneck, self).__init__()
82 | if norm_layer is None:
83 | norm_layer = nn.BatchNorm2d
84 | width = int(planes * (base_width / 64.)) * groups
85 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
86 | self.conv1 = conv1x1(inplanes, width)
87 | self.bn1 = norm_layer(width)
88 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
89 | self.bn2 = norm_layer(width)
90 | self.conv3 = conv1x1(width, planes * self.expansion)
91 | self.bn3 = norm_layer(planes * self.expansion)
92 | self.relu = nn.ReLU(inplace=True)
93 | self.downsample = downsample
94 | self.stride = stride
95 |
96 | def forward(self, x):
97 | identity = x
98 |
99 | out = self.conv1(x)
100 | out = self.bn1(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv2(out)
104 | out = self.bn2(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv3(out)
108 | out = self.bn3(out)
109 |
110 | if self.downsample is not None:
111 | identity = self.downsample(x)
112 |
113 | out += identity
114 | out = self.relu(out)
115 |
116 | return out
117 |
118 |
119 | class ResNet(nn.Module):
120 |
121 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
122 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
123 | norm_layer=None):
124 | super(ResNet, self).__init__()
125 | if norm_layer is None:
126 | norm_layer = nn.BatchNorm2d
127 | self._norm_layer = norm_layer
128 |
129 | self.inplanes = 64
130 | self.dilation = 1
131 | if replace_stride_with_dilation is None:
132 | # each element in the tuple indicates if we should replace
133 | # the 2x2 stride with a dilated convolution instead
134 | replace_stride_with_dilation = [False, False, False]
135 | if len(replace_stride_with_dilation) != 3:
136 | raise ValueError("replace_stride_with_dilation should be None "
137 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
138 | self.groups = groups
139 | self.base_width = width_per_group
140 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
141 | bias=False)
142 | self.bn1 = norm_layer(self.inplanes)
143 | self.relu = nn.ReLU(inplace=True)
144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
145 | self.layer1 = self._make_layer(block, 64, layers[0])
146 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
147 | dilate=replace_stride_with_dilation[0])
148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
149 | dilate=replace_stride_with_dilation[1])
150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
151 | dilate=replace_stride_with_dilation[2])
152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
153 | self.fc = nn.Linear(512 * block.expansion, num_classes)
154 |
155 | for m in self.modules():
156 | if isinstance(m, nn.Conv2d):
157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
159 | nn.init.constant_(m.weight, 1)
160 | nn.init.constant_(m.bias, 0)
161 |
162 | # Zero-initialize the last BN in each residual branch,
163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
165 | if zero_init_residual:
166 | for m in self.modules():
167 | if isinstance(m, Bottleneck):
168 | nn.init.constant_(m.bn3.weight, 0)
169 | elif isinstance(m, BasicBlock):
170 | nn.init.constant_(m.bn2.weight, 0)
171 |
172 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
173 | norm_layer = self._norm_layer
174 | downsample = None
175 | previous_dilation = self.dilation
176 | if dilate:
177 | self.dilation *= stride
178 | stride = 1
179 | if stride != 1 or self.inplanes != planes * block.expansion:
180 | downsample = nn.Sequential(
181 | conv1x1(self.inplanes, planes * block.expansion, stride),
182 | norm_layer(planes * block.expansion),
183 | )
184 |
185 | layers = []
186 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
187 | self.base_width, previous_dilation, norm_layer))
188 | self.inplanes = planes * block.expansion
189 | for _ in range(1, blocks):
190 | layers.append(block(self.inplanes, planes, groups=self.groups,
191 | base_width=self.base_width, dilation=self.dilation,
192 | norm_layer=norm_layer))
193 |
194 | return nn.Sequential(*layers)
195 |
196 | def forward(self, x):
197 | x = self.conv1(x)
198 | x = self.bn1(x)
199 | x = self.relu(x)
200 | x = self.maxpool(x)
201 |
202 | x = self.layer1(x)
203 | x = self.layer2(x)
204 | x = self.layer3(x)
205 | x = self.layer4(x)
206 |
207 | x = self.avgpool(x)
208 | x = torch.flatten(x, 1)
209 | x = self.fc(x)
210 |
211 | return x
212 |
213 |
214 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
215 | model = ResNet(block, layers, **kwargs)
216 | if pretrained:
217 | state_dict = load_state_dict_from_url(model_urls[arch],
218 | progress=progress)
219 | model.load_state_dict(state_dict)
220 | return model
221 |
222 |
223 | def resnet18(pretrained=False, progress=True, **kwargs):
224 | r"""ResNet-18 model from
225 | `"Deep Residual Learning for Image Recognition" `_
226 |
227 | Args:
228 | pretrained (bool): If True, returns a model pre-trained on ImageNet
229 | progress (bool): If True, displays a progress bar of the download to stderr
230 | """
231 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
232 | **kwargs)
233 |
234 |
235 | def resnet34(pretrained=False, progress=True, **kwargs):
236 | r"""ResNet-34 model from
237 | `"Deep Residual Learning for Image Recognition" `_
238 |
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on ImageNet
241 | progress (bool): If True, displays a progress bar of the download to stderr
242 | """
243 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
244 | **kwargs)
245 |
246 |
247 | def resnet50(pretrained=True, progress=True, **kwargs):
248 | r"""ResNet-50 model from
249 | `"Deep Residual Learning for Image Recognition" `_
250 |
251 | Args:
252 | pretrained (bool): If True, returns a model pre-trained on ImageNet
253 | progress (bool): If True, displays a progress bar of the download to stderr
254 | """
255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
256 | **kwargs)
257 |
258 |
259 | def resnet101(pretrained=True, progress=True, **kwargs):
260 | r"""ResNet-101 model from
261 | `"Deep Residual Learning for Image Recognition" `_
262 |
263 | Args:
264 | pretrained (bool): If True, returns a model pre-trained on ImageNet
265 | progress (bool): If True, displays a progress bar of the download to stderr
266 | """
267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
268 | **kwargs)
269 |
270 |
271 | def resnet152(pretrained=False, progress=True, **kwargs):
272 | r"""ResNet-152 model from
273 | `"Deep Residual Learning for Image Recognition" `_
274 |
275 | Args:
276 | pretrained (bool): If True, returns a model pre-trained on ImageNet
277 | progress (bool): If True, displays a progress bar of the download to stderr
278 | """
279 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
280 | **kwargs)
281 |
282 |
283 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
284 | r"""ResNeXt-50 32x4d model from
285 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
286 |
287 | Args:
288 | pretrained (bool): If True, returns a model pre-trained on ImageNet
289 | progress (bool): If True, displays a progress bar of the download to stderr
290 | """
291 | kwargs['groups'] = 32
292 | kwargs['width_per_group'] = 4
293 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
294 | pretrained, progress, **kwargs)
295 |
296 |
297 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
298 | r"""ResNeXt-101 32x8d model from
299 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
300 |
301 | Args:
302 | pretrained (bool): If True, returns a model pre-trained on ImageNet
303 | progress (bool): If True, displays a progress bar of the download to stderr
304 | """
305 | kwargs['groups'] = 32
306 | kwargs['width_per_group'] = 8
307 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
308 | pretrained, progress, **kwargs)
309 |
310 |
311 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
312 | r"""Wide ResNet-50-2 model from
313 | `"Wide Residual Networks" `_
314 |
315 | The model is the same as ResNet except for the bottleneck number of channels
316 | which is twice larger in every block. The number of channels in outer 1x1
317 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
318 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
319 |
320 | Args:
321 | pretrained (bool): If True, returns a model pre-trained on ImageNet
322 | progress (bool): If True, displays a progress bar of the download to stderr
323 | """
324 | kwargs['width_per_group'] = 64 * 2
325 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
326 | pretrained, progress, **kwargs)
327 |
328 |
329 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
330 | r"""Wide ResNet-101-2 model from
331 | `"Wide Residual Networks" `_
332 |
333 | The model is the same as ResNet except for the bottleneck number of channels
334 | which is twice larger in every block. The number of channels in outer 1x1
335 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
336 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
337 |
338 | Args:
339 | pretrained (bool): If True, returns a model pre-trained on ImageNet
340 | progress (bool): If True, displays a progress bar of the download to stderr
341 | """
342 | kwargs['width_per_group'] = 64 * 2
343 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
344 | pretrained, progress, **kwargs)
345 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import os, time
3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
4 | from operator import add
5 | import numpy as np
6 | from glob import glob
7 | import cv2
8 | from tqdm import tqdm
9 | import imageio
10 | import torch
11 | from model import RUPNet
12 | from utils import create_dir, seeding
13 | from utils import calculate_metrics
14 | from train import load_data
15 |
16 |
17 | def evaluate(model, save_path, test_x, test_y, size):
18 | """ Loading other comparitive model masks """
19 | comparison_path = "/media/nikhil/LAB/ML/ME/COMPARISON/Kvasir-SEG/"
20 |
21 |
22 | deeplabv3plus_mask = sorted(glob(os.path.join(comparison_path, "DeepLabV3+_50", "results", "Kvasir-SEG", "mask", "*")))
23 | pranet_mask = sorted(glob(os.path.join(comparison_path, "PraNet", "results", "Kvasir-SEG", "mask", "*")))
24 |
25 | metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
26 | time_taken = []
27 |
28 | for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
29 | name = y.split("/")[-1].split(".")[0]
30 |
31 | """ Image """
32 | image = cv2.imread(x, cv2.IMREAD_COLOR)
33 | image = cv2.resize(image, size)
34 | save_img = image
35 | image = np.transpose(image, (2, 0, 1))
36 | image = image/255.0
37 | image = np.expand_dims(image, axis=0)
38 | image = image.astype(np.float32)
39 | image = torch.from_numpy(image)
40 | image = image.to(device)
41 |
42 | """ Mask """
43 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
44 | mask = cv2.resize(mask, size)
45 | save_mask = mask
46 | save_mask = np.expand_dims(save_mask, axis=-1)
47 | save_mask = np.concatenate([save_mask, save_mask, save_mask], axis=2)
48 | mask = np.expand_dims(mask, axis=0)
49 | mask = mask/255.0
50 | mask = np.expand_dims(mask, axis=0)
51 | mask = mask.astype(np.float32)
52 | mask = torch.from_numpy(mask)
53 | mask = mask.to(device)
54 |
55 | with torch.no_grad():
56 | """ FPS calculation """
57 | start_time = time.time()
58 | heatmap, y_pred = model(image, heatmap=True)
59 | y_pred = torch.sigmoid(y_pred)
60 | end_time = time.time() - start_time
61 | time_taken.append(end_time)
62 |
63 | """ Evaluation metrics """
64 | score = calculate_metrics(mask, y_pred)
65 | metrics_score = list(map(add, metrics_score, score))
66 |
67 | """ Predicted Mask """
68 | y_pred = y_pred[0].cpu().numpy()
69 | y_pred = np.squeeze(y_pred, axis=0)
70 | y_pred = y_pred > 0.5
71 | y_pred = y_pred.astype(np.int32)
72 | y_pred = y_pred * 255
73 | y_pred = np.array(y_pred, dtype=np.uint8)
74 | y_pred = np.expand_dims(y_pred, axis=-1)
75 | y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=2)
76 |
77 | """ Save the image - mask - pred """
78 | line = np.ones((size[0], 10, 3)) * 255
79 | cat_images = np.concatenate([
80 | save_img, line,
81 | save_mask, line,
82 | cv2.imread(deeplabv3plus_mask[i], cv2.IMREAD_COLOR), line,
83 | cv2.imread(pranet_mask[i], cv2.IMREAD_COLOR), line,
84 | y_pred, line,
85 | heatmap], axis=1)
86 |
87 | cv2.imwrite(f"{save_path}/joint/{name}.jpg", cat_images)
88 | cv2.imwrite(f"{save_path}/mask/{name}.jpg", y_pred)
89 | cv2.imwrite(f"{save_path}/heatmap/{name}.jpg", heatmap)
90 |
91 | jaccard = metrics_score[0]/len(test_x)
92 | f1 = metrics_score[1]/len(test_x)
93 | recall = metrics_score[2]/len(test_x)
94 | precision = metrics_score[3]/len(test_x)
95 | acc = metrics_score[4]/len(test_x)
96 | f2 = metrics_score[5]/len(test_x)
97 |
98 | print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - F2: {f2:1.4f}")
99 |
100 | mean_time_taken = np.mean(time_taken)
101 | mean_fps = 1/mean_time_taken
102 | print("Mean FPS: ", mean_fps)
103 |
104 |
105 | if __name__ == "__main__":
106 | """ Seeding """
107 | seeding(42)
108 |
109 | """ Load the checkpoint """
110 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
111 | model = RUPNet()
112 | model = model.to(device)
113 | checkpoint_path = "files/checkpoint.pth"
114 | model.load_state_dict(torch.load(checkpoint_path, map_location=device))
115 | model.eval()
116 |
117 | """ Test dataset """
118 | path = "/media/nikhil/Seagate Backup Plus Drive/ML_DATASET/Kvasir-SEG"
119 | (train_x, train_y), (test_x, test_y) = load_data(path)
120 |
121 | test_x = sorted(test_x)
122 | test_y = sorted(test_y)
123 |
124 | save_path = f"results/Kvasir-SEG"
125 | for item in ["mask", "joint", "heatmap"]:
126 | create_dir(f"{save_path}/{item}")
127 |
128 | size = (256, 256)
129 | create_dir(save_path)
130 | evaluate(model, save_path, test_x, test_y, size)
131 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import time
5 | import datetime
6 | import numpy as np
7 | import albumentations as A
8 | import cv2
9 | from glob import glob
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import Dataset, DataLoader
13 | from torchvision import transforms
14 | from utils import seeding, create_dir, print_and_save, shuffling, epoch_time, calculate_metrics
15 | from model import RUPNet
16 | from metrics import DiceLoss, DiceBCELoss
17 |
18 | def load_names(path, file_path):
19 | f = open(file_path, "r")
20 | data = f.read().split("\n")[:-1]
21 | images = [os.path.join(path,"images", name) + ".jpg" for name in data]
22 | masks = [os.path.join(path,"masks", name) + ".jpg" for name in data]
23 | return images, masks
24 |
25 | def load_data(path):
26 | train_names_path = f"{path}/train.txt"
27 | valid_names_path = f"{path}/val.txt"
28 |
29 | train_x, train_y = load_names(path, train_names_path)
30 | valid_x, valid_y = load_names(path, valid_names_path)
31 |
32 | return (train_x, train_y), (valid_x, valid_y)
33 |
34 | class DATASET(Dataset):
35 | def __init__(self, images_path, masks_path, size, transform=None):
36 | super().__init__()
37 |
38 | self.images_path = images_path
39 | self.masks_path = masks_path
40 | self.transform = transform
41 | self.n_samples = len(images_path)
42 |
43 | def __getitem__(self, index):
44 | """ Image """
45 | image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
46 | mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
47 |
48 | if self.transform is not None:
49 | augmentations = self.transform(image=image, mask=mask)
50 | image = augmentations["image"]
51 | mask = augmentations["mask"]
52 |
53 | image = cv2.resize(image, size)
54 | image = np.transpose(image, (2, 0, 1))
55 | image = image/255.0
56 |
57 | mask = cv2.resize(mask, size)
58 | mask = np.expand_dims(mask, axis=0)
59 | mask = mask/255.0
60 |
61 | return image, mask
62 |
63 | def __len__(self):
64 | return self.n_samples
65 |
66 | def train(model, loader, optimizer, loss_fn, device):
67 | model.train()
68 |
69 | epoch_loss = 0.0
70 | epoch_jac = 0.0
71 | epoch_f1 = 0.0
72 | epoch_recall = 0.0
73 | epoch_precision = 0.0
74 |
75 | for i, (x, y) in enumerate(loader):
76 | x = x.to(device, dtype=torch.float32)
77 | y = y.to(device, dtype=torch.float32)
78 |
79 | optimizer.zero_grad()
80 | p1, p2, p3, p4 = model(x)
81 | loss = loss_fn(p1, y) + loss_fn(p2, y) + loss_fn(p3, y) + loss_fn(p4, y)
82 | loss.backward()
83 | optimizer.step()
84 | epoch_loss += loss.item()
85 |
86 | """ Calculate the metrics """
87 | batch_jac = []
88 | batch_f1 = []
89 | batch_recall = []
90 | batch_precision = []
91 |
92 | for yt, yp in zip(y, p4):
93 | score = calculate_metrics(yt, yp)
94 | batch_jac.append(score[0])
95 | batch_f1.append(score[1])
96 | batch_recall.append(score[2])
97 | batch_precision.append(score[3])
98 |
99 | epoch_jac += np.mean(batch_jac)
100 | epoch_f1 += np.mean(batch_f1)
101 | epoch_recall += np.mean(batch_recall)
102 | epoch_precision += np.mean(batch_precision)
103 |
104 | epoch_loss = epoch_loss/len(loader)
105 | epoch_jac = epoch_jac/len(loader)
106 | epoch_f1 = epoch_f1/len(loader)
107 | epoch_recall = epoch_recall/len(loader)
108 | epoch_precision = epoch_precision/len(loader)
109 |
110 | return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]
111 |
112 | def evaluate(model, loader, loss_fn, device):
113 | model.eval()
114 |
115 | epoch_loss = 0
116 | epoch_loss = 0.0
117 | epoch_jac = 0.0
118 | epoch_f1 = 0.0
119 | epoch_recall = 0.0
120 | epoch_precision = 0.0
121 |
122 | with torch.no_grad():
123 | for i, (x, y) in enumerate(loader):
124 | x = x.to(device, dtype=torch.float32)
125 | y = y.to(device, dtype=torch.float32)
126 |
127 | p1, p2, p3, p4 = model(x)
128 | loss = loss_fn(p1, y) + loss_fn(p2, y) + loss_fn(p3, y) + loss_fn(p4, y)
129 | epoch_loss += loss.item()
130 |
131 | """ Calculate the metrics """
132 | batch_jac = []
133 | batch_f1 = []
134 | batch_recall = []
135 | batch_precision = []
136 |
137 | for yt, yp in zip(y, p4):
138 | score = calculate_metrics(yt, yp)
139 | batch_jac.append(score[0])
140 | batch_f1.append(score[1])
141 | batch_recall.append(score[2])
142 | batch_precision.append(score[3])
143 |
144 | epoch_jac += np.mean(batch_jac)
145 | epoch_f1 += np.mean(batch_f1)
146 | epoch_recall += np.mean(batch_recall)
147 | epoch_precision += np.mean(batch_precision)
148 |
149 | epoch_loss = epoch_loss/len(loader)
150 | epoch_jac = epoch_jac/len(loader)
151 | epoch_f1 = epoch_f1/len(loader)
152 | epoch_recall = epoch_recall/len(loader)
153 | epoch_precision = epoch_precision/len(loader)
154 |
155 | return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]
156 |
157 | if __name__ == "__main__":
158 | """ Seeding """
159 | seeding(42)
160 |
161 | """ Directories """
162 | create_dir("files")
163 |
164 | """ Training logfile """
165 | train_log_path = "files/train_log.txt"
166 | if os.path.exists(train_log_path):
167 | print("Log file exists")
168 | else:
169 | train_log = open("files/train_log.txt", "w")
170 | train_log.write("\n")
171 | train_log.close()
172 |
173 | """ Record Date & Time """
174 | datetime_object = str(datetime.datetime.now())
175 | print_and_save(train_log_path, datetime_object)
176 | print("")
177 |
178 | """ Hyperparameters """
179 | image_size = 256
180 | size = (image_size, image_size)
181 | batch_size = 16
182 | num_epochs = 500
183 | lr = 1e-4
184 | early_stopping_patience = 50
185 | checkpoint_path = "files/checkpoint.pth"
186 | path = "/media/nikhil/Seagate Backup Plus Drive/ML_DATASET/Kvasir-SEG"
187 |
188 | data_str = f"Image Size: {size}\nBatch Size: {batch_size}\nLR: {lr}\nEpochs: {num_epochs}\n"
189 | data_str += f"Early Stopping Patience: {early_stopping_patience}\n"
190 | print_and_save(train_log_path, data_str)
191 |
192 | """ Dataset """
193 | (train_x, train_y), (valid_x, valid_y) = load_data(path)
194 | train_x, train_y = shuffling(train_x, train_y)
195 | data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
196 | print_and_save(train_log_path, data_str)
197 |
198 | """ Data augmentation: Transforms """
199 | transform = A.Compose([
200 | A.Rotate(limit=35, p=0.3),
201 | A.HorizontalFlip(p=0.3),
202 | A.VerticalFlip(p=0.3),
203 | A.CoarseDropout(p=0.3, max_holes=10, max_height=32, max_width=32)
204 | ])
205 |
206 | """ Dataset and loader """
207 | train_dataset = DATASET(train_x, train_y, size, transform=transform)
208 | valid_dataset = DATASET(valid_x, valid_y, size, transform=None)
209 |
210 | # create_dir("data")
211 | # for i, (x, y) in enumerate(train_dataset):
212 | # x = np.transpose(x, (1, 2, 0)) * 255
213 | # y = np.transpose(y, (1, 2, 0)) * 255
214 | # y = np.concatenate([y, y, y], axis=-1)
215 | # cv2.imwrite(f"data/{i}.png", np.concatenate([x, y], axis=1))
216 |
217 | train_loader = DataLoader(
218 | dataset=train_dataset,
219 | batch_size=batch_size,
220 | shuffle=True,
221 | num_workers=2
222 | )
223 |
224 | valid_loader = DataLoader(
225 | dataset=valid_dataset,
226 | batch_size=batch_size,
227 | shuffle=False,
228 | num_workers=2
229 | )
230 |
231 | """ Model """
232 | device = torch.device('cuda')
233 | model = RUPNet()
234 | model = model.to(device)
235 |
236 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
237 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
238 | loss_fn = DiceBCELoss()
239 | loss_name = "BCE Dice Loss"
240 | data_str = f"Optimizer: Adam\nLoss: {loss_name}\n"
241 | print_and_save(train_log_path, data_str)
242 |
243 | """ Training the model """
244 | best_valid_metrics = 0.0
245 | early_stopping_count = 0
246 |
247 | for epoch in range(num_epochs):
248 | start_time = time.time()
249 |
250 | train_loss, train_metrics = train(model, train_loader, optimizer, loss_fn, device)
251 | valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, device)
252 | scheduler.step(valid_loss)
253 |
254 | if valid_metrics[1] > best_valid_metrics:
255 | data_str = f"Valid F1 improved from {best_valid_metrics:2.4f} to {valid_metrics[1]:2.4f}. Saving checkpoint: {checkpoint_path}"
256 | print_and_save(train_log_path, data_str)
257 |
258 | best_valid_metrics = valid_metrics[1]
259 | torch.save(model.state_dict(), checkpoint_path)
260 | early_stopping_count = 0
261 |
262 | elif valid_metrics[1] < best_valid_metrics:
263 | early_stopping_count += 1
264 |
265 | end_time = time.time()
266 | epoch_mins, epoch_secs = epoch_time(start_time, end_time)
267 |
268 | data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n"
269 | data_str += f"\tTrain Loss: {train_loss:.4f} - Jaccard: {train_metrics[0]:.4f} - F1: {train_metrics[1]:.4f} - Recall: {train_metrics[2]:.4f} - Precision: {train_metrics[3]:.4f}\n"
270 | data_str += f"\t Val. Loss: {valid_loss:.4f} - Jaccard: {valid_metrics[0]:.4f} - F1: {valid_metrics[1]:.4f} - Recall: {valid_metrics[2]:.4f} - Precision: {valid_metrics[3]:.4f}\n"
271 | print_and_save(train_log_path, data_str)
272 |
273 | if early_stopping_count == early_stopping_patience:
274 | data_str = f"Early stopping: validation loss stops improving from last {early_stopping_patience} continously.\n"
275 | print_and_save(train_log_path, data_str)
276 | break
277 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import numpy as np
5 | import cv2
6 | from tqdm import tqdm
7 | import torch
8 | from sklearn.utils import shuffle
9 | from metrics import precision, recall, F2, dice_score, jac_score
10 | from sklearn.metrics import accuracy_score, confusion_matrix
11 |
12 | """ Seeding the randomness. """
13 | def seeding(seed):
14 | random.seed(seed)
15 | os.environ["PYTHONHASHSEED"] = str(seed)
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 | torch.cuda.manual_seed(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | """ Create a directory """
22 | def create_dir(path):
23 | if not os.path.exists(path):
24 | os.makedirs(path)
25 |
26 | """ Shuffle the dataset. """
27 | def shuffling(x, y):
28 | x, y = shuffle(x, y, random_state=42)
29 | return x, y
30 |
31 | def epoch_time(start_time, end_time):
32 | elapsed_time = end_time - start_time
33 | elapsed_mins = int(elapsed_time / 60)
34 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
35 | return elapsed_mins, elapsed_secs
36 |
37 | def print_and_save(file_path, data_str):
38 | print(data_str)
39 | with open(file_path, "a") as file:
40 | file.write(data_str)
41 | file.write("\n")
42 |
43 | def calculate_metrics(y_true, y_pred):
44 | y_true = y_true.detach().cpu().numpy()
45 | y_pred = y_pred.detach().cpu().numpy()
46 |
47 | y_pred = y_pred > 0.5
48 | y_pred = y_pred.reshape(-1)
49 | y_pred = y_pred.astype(np.uint8)
50 |
51 | y_true = y_true > 0.5
52 | y_true = y_true.reshape(-1)
53 | y_true = y_true.astype(np.uint8)
54 |
55 | ## Score
56 | score_jaccard = jac_score(y_true, y_pred)
57 | score_f1 = dice_score(y_true, y_pred)
58 | score_recall = recall(y_true, y_pred)
59 | score_precision = precision(y_true, y_pred)
60 | score_fbeta = F2(y_true, y_pred)
61 | score_acc = accuracy_score(y_true, y_pred)
62 |
63 | return [score_jaccard, score_f1, score_recall, score_precision, score_acc, score_fbeta]
64 |
--------------------------------------------------------------------------------