├── README.md ├── images ├── README.md ├── architecture.jpg ├── qualitative.jpg ├── result-1.png └── result-2.png ├── metrics.py ├── model.py ├── resnet.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DilatedSegNet: A Deep Dilated Segmentation Network for Polyp Segmentation 2 | 3 | ## 1. Abstract 4 |
5 | Colorectal cancer (CRC) is the second leading cause of cancer-related death worldwide. Excision of polyps during colonoscopy helps reduce mortality and morbidity for CRC. Powered by deep learning, computer-aided diagnosis (CAD) systems can detect regions in the colon overlooked by physicians during colonoscopy. Lacking high accuracy and real-time speed are the essential obstacles to be overcome for successful clinical integration of such systems. While literature is focused on improving accuracy, the speed parameter is often ignored. Toward this critical need, we intend to develop a novel real-time deep learning-based architecture, DilatedSegNet, to perform polyp segmentation on the fly. DilatedSegNet is an encoder-decoder network that uses pre-trained ResNet50 as the encoder from which we extract four levels of feature maps. Each of these feature maps is passed through a dilated convolution pooling (DCP) block. The outputs from the DCP blocks are concatenated and passed through a series of four decoder blocks that predicts the segmentation mask. The proposed method achieves a real-time operation speed of 33.68 frames per second with an average dice coefficient of 0.90 and mIoU of 0.83. Additionally, we also provide heatmap along with the qualitative results that shows the explanation for the polyp location, which increases the trustworthiness of the method. The results on the publicly available Kvasir-SEG and BKAI-IGH datasets suggest that DilatedSegNet can give real-time feedback while retaining a high dice coefficient, indicating high potential for using such models in real clinical settings in the near future. 6 |
7 | 8 | ## 2. Architecture 9 | 10 | 11 | ## 3. Implementation 12 | The proposed architecture is implemented using the PyTorch framework (1.9.0+cu111) with a single GeForce RTX 3090 GPU of 24 GB memory. 13 | 14 | ### 3.1 Dataset 15 | We have used the following datasets: 16 | - [Kvasir-SEG](https://datasets.simula.no/downloads/kvasir-seg.zip) 17 | - [BKAI](https://www.kaggle.com/competitions/bkai-igh-neopolyp/data) 18 | 19 | BKAI dataset follows an 80:10:10 split for training, validation and testing, while the Kvasir-SEG follows an official split of 880/120. 20 | 21 | ### 3.2 Weight file 22 | - [Kvasir-SEG](https://drive.google.com/file/d/1diYckKDMqDWSDD6O5Jm6InCxWEkU0GJC/view?usp=sharing) 23 | - [BKAI-IGH](https://drive.google.com/file/d/1ojGaQThD56mRhGQaVoJVpAw0oVwSzX8N/view?usp=sharing) 24 | 25 | ## 4. Results 26 | 27 | ### 4.1 Quantative Results: Same Dataset 28 | 29 | 30 | ### 4.2 Quantative Results: Different Dataset 31 | 32 | 33 | ### 4.3 Qualitative Results 34 | 35 | 36 | ## 5. Citation 37 | Updated soon. 38 | 39 | ## 6. License 40 | The source code is free for research and education use only. Any comercial use should receive a formal permission from the first author. 41 | 42 | ## 7. Contact 43 | Please contact nikhilroxtomar@gmail.com for any further questions. 44 | -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | # Images 2 | -------------------------------------------------------------------------------- /images/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/architecture.jpg -------------------------------------------------------------------------------- /images/qualitative.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/qualitative.jpg -------------------------------------------------------------------------------- /images/result-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/result-1.png -------------------------------------------------------------------------------- /images/result-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/DilatedSegNet/5b193e1dd20a925ec4aed28d08dd0ac6f495d8a2/images/result-2.png -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ Loss Functions -------------------------------------- """ 6 | class DiceLoss(nn.Module): 7 | def __init__(self, weight=None, size_average=True): 8 | super(DiceLoss, self).__init__() 9 | 10 | def forward(self, inputs, targets, smooth=1): 11 | inputs = torch.sigmoid(inputs) 12 | 13 | inputs = inputs.view(-1) 14 | targets = targets.view(-1) 15 | 16 | intersection = (inputs * targets).sum() 17 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 18 | 19 | return 1 - dice 20 | 21 | class DiceBCELoss(nn.Module): 22 | def __init__(self, weight=None, size_average=True): 23 | super(DiceBCELoss, self).__init__() 24 | 25 | def forward(self, inputs, targets, smooth=1): 26 | inputs = torch.sigmoid(inputs) 27 | 28 | inputs = inputs.view(-1) 29 | targets = targets.view(-1) 30 | 31 | intersection = (inputs * targets).sum() 32 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 33 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 34 | Dice_BCE = BCE + dice_loss 35 | 36 | return Dice_BCE 37 | 38 | """ Metrics ------------------------------------------ """ 39 | def precision(y_true, y_pred): 40 | intersection = (y_true * y_pred).sum() 41 | return (intersection + 1e-15) / (y_pred.sum() + 1e-15) 42 | 43 | def recall(y_true, y_pred): 44 | intersection = (y_true * y_pred).sum() 45 | return (intersection + 1e-15) / (y_true.sum() + 1e-15) 46 | 47 | def F2(y_true, y_pred, beta=2): 48 | p = precision(y_true,y_pred) 49 | r = recall(y_true, y_pred) 50 | return (1+beta**2.) *(p*r) / float(beta**2*p + r + 1e-15) 51 | 52 | def dice_score(y_true, y_pred): 53 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15) 54 | 55 | def jac_score(y_true, y_pred): 56 | intersection = (y_true * y_pred).sum() 57 | union = y_true.sum() + y_pred.sum() - intersection 58 | return (intersection + 1e-15) / (union + 1e-15) 59 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from resnet import resnet50 5 | import numpy as np 6 | import cv2 7 | 8 | def save_feats_mean(x, size=(256, 256)): 9 | b, c, h, w = x.shape 10 | with torch.no_grad(): 11 | x = x.detach().cpu().numpy() 12 | x = np.transpose(x[0], (1, 2, 0)) 13 | x = np.mean(x, axis=-1) 14 | x = x/np.max(x) 15 | x = x * 255.0 16 | x = x.astype(np.uint8) 17 | 18 | if h != size[1]: 19 | x = cv2.resize(x, size) 20 | 21 | x = cv2.applyColorMap(x, cv2.COLORMAP_JET) 22 | x = np.array(x, dtype=np.uint8) 23 | return x 24 | 25 | def get_mean_attention_map(x): 26 | x = torch.mean(x, axis=1) 27 | x = torch.unsqueeze(x, 1) 28 | x = x / torch.max(x) 29 | return x 30 | 31 | class ResidualBlock(nn.Module): 32 | def __init__(self, in_c, out_c): 33 | super().__init__() 34 | 35 | self.relu = nn.ReLU() 36 | self.conv = nn.Sequential( 37 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(out_c), 39 | nn.ReLU(), 40 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(out_c) 42 | ) 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_c, out_c, kernel_size=1, padding=0), 45 | nn.BatchNorm2d(out_c) 46 | ) 47 | 48 | def forward(self, inputs): 49 | x1 = self.conv(inputs) 50 | x2 = self.shortcut(inputs) 51 | x = self.relu(x1 + x2) 52 | return x 53 | 54 | class DilatedConv(nn.Module): 55 | def __init__(self, in_c, out_c): 56 | super().__init__() 57 | 58 | self.c1 = nn.Sequential( 59 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, dilation=1), 60 | nn.BatchNorm2d(out_c), 61 | nn.ReLU() 62 | ) 63 | 64 | self.c2 = nn.Sequential( 65 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=3, dilation=3), 66 | nn.BatchNorm2d(out_c), 67 | nn.ReLU() 68 | ) 69 | 70 | self.c3 = nn.Sequential( 71 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=6, dilation=6), 72 | nn.BatchNorm2d(out_c), 73 | nn.ReLU() 74 | ) 75 | 76 | self.c4 = nn.Sequential( 77 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=9, dilation=9), 78 | nn.BatchNorm2d(out_c), 79 | nn.ReLU() 80 | ) 81 | 82 | self.c5 = nn.Sequential( 83 | nn.Conv2d(out_c*4, out_c, kernel_size=1, padding=0), 84 | nn.BatchNorm2d(out_c), 85 | nn.ReLU() 86 | ) 87 | 88 | def forward(self, inputs): 89 | x1 = self.c1(inputs) 90 | x2 = self.c2(inputs) 91 | x3 = self.c3(inputs) 92 | x4 = self.c4(inputs) 93 | x = torch.cat([x1, x2, x3, x4], axis=1) 94 | x = self.c5(x) 95 | return x 96 | 97 | class ChannelAttention(nn.Module): 98 | def __init__(self, in_planes, ratio=16): 99 | super(ChannelAttention, self).__init__() 100 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 101 | self.max_pool = nn.AdaptiveMaxPool2d(1) 102 | 103 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 104 | self.relu1 = nn.ReLU() 105 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 106 | 107 | self.sigmoid = nn.Sigmoid() 108 | 109 | def forward(self, x): 110 | x0 = x 111 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 112 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 113 | out = avg_out + max_out 114 | return x0 * self.sigmoid(out) 115 | 116 | 117 | class SpatialAttention(nn.Module): 118 | def __init__(self, kernel_size=7): 119 | super(SpatialAttention, self).__init__() 120 | 121 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 122 | padding = 3 if kernel_size == 7 else 1 123 | 124 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 125 | self.sigmoid = nn.Sigmoid() 126 | 127 | def forward(self, x): 128 | x0 = x 129 | avg_out = torch.mean(x, dim=1, keepdim=True) 130 | max_out, _ = torch.max(x, dim=1, keepdim=True) 131 | x = torch.cat([avg_out, max_out], dim=1) 132 | x = self.conv1(x) 133 | return x0 * self.sigmoid(x) 134 | 135 | class DecoderBlock(nn.Module): 136 | def __init__(self, in_c, out_c): 137 | super().__init__() 138 | 139 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 140 | self.r1 = ResidualBlock(in_c[0]+in_c[1], out_c) 141 | self.r2 = ResidualBlock(out_c, out_c) 142 | 143 | self.ca = ChannelAttention(out_c) 144 | self.sa = SpatialAttention() 145 | 146 | def forward(self, x, s): 147 | x = self.up(x) 148 | x = torch.cat([x, s], axis=1) 149 | x = self.r1(x) 150 | x = self.r2(x) 151 | 152 | x = self.ca(x) 153 | x = self.sa(x) 154 | return x 155 | 156 | class RUPNet(nn.Module): 157 | def __init__(self): 158 | super().__init__() 159 | 160 | """ ResNet50 """ 161 | backbone = resnet50() 162 | self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu) 163 | self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1) 164 | self.layer2 = backbone.layer2 165 | self.layer3 = backbone.layer3 166 | 167 | """ Dilated Conv + Pooling """ 168 | self.r1 = nn.Sequential(DilatedConv(64, 64), nn.MaxPool2d((8, 8))) 169 | self.r2 = nn.Sequential(DilatedConv(256, 64), nn.MaxPool2d((4, 4))) 170 | self.r3 = nn.Sequential(DilatedConv(512, 64), nn.MaxPool2d((2, 2))) 171 | self.r4 = DilatedConv(1024, 64) 172 | 173 | """ Decoder """ 174 | self.d1 = DecoderBlock([256, 512], 256) 175 | self.d2 = DecoderBlock([256, 256], 128) 176 | self.d3 = DecoderBlock([128, 64], 64) 177 | self.d4 = DecoderBlock([64, 3], 32) 178 | 179 | """ """ 180 | 181 | """ Output """ 182 | self.y = nn.Conv2d(32, 1, kernel_size=1, padding=0) 183 | 184 | def forward(self, x, heatmap=None): 185 | """ ResNet50 """ 186 | s0 = x 187 | s1 = self.layer0(s0) ## [-1, 64, h/2, w/2] 188 | s2 = self.layer1(s1) ## [-1, 256, h/4, w/4] 189 | s3 = self.layer2(s2) ## [-1, 512, h/8, w/8] 190 | s4 = self.layer3(s3) ## [-1, 1024, h/16, w/16] 191 | 192 | """ Dilated Conv + Pooling """ 193 | r1 = self.r1(s1) 194 | r2 = self.r2(s2) 195 | r3 = self.r3(s3) 196 | r4 = self.r4(s4) 197 | 198 | rx = torch.cat([r1, r2, r3, r4], axis=1) 199 | 200 | """ Decoder """ 201 | d1 = self.d1(rx, s3) 202 | d2 = self.d2(d1, s2) 203 | d3 = self.d3(d2, s1) 204 | d4 = self.d4(d3, s0) 205 | 206 | y = self.y(d4) 207 | 208 | if heatmap != None: 209 | hmap = save_feats_mean(d4) 210 | return hmap, y 211 | else: 212 | return y 213 | 214 | if __name__ == "__main__": 215 | x = torch.randn((8, 3, 256, 256)) 216 | model = RUPNet() 217 | y = model(x) 218 | print(y.shape) 219 | 220 | from ptflops import get_model_complexity_info 221 | flops, params = get_model_complexity_info(model, input_res=(3, 256, 256), as_strings=True, print_per_layer_stat=False) 222 | print(' - Flops: ' + flops) 223 | print(' - Params: ' + params) 224 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from torchvision.models.utils import load_state_dict_from_url 4 | from torch.hub import load_state_dict_from_url 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 9 | 'wide_resnet50_2', 'wide_resnet101_2'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=dilation, groups=groups, bias=False, dilation=dilation) 29 | 30 | 31 | def conv1x1(in_planes, out_planes, stride=1): 32 | """1x1 convolution""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 80 | base_width=64, dilation=1, norm_layer=None): 81 | super(Bottleneck, self).__init__() 82 | if norm_layer is None: 83 | norm_layer = nn.BatchNorm2d 84 | width = int(planes * (base_width / 64.)) * groups 85 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 86 | self.conv1 = conv1x1(inplanes, width) 87 | self.bn1 = norm_layer(width) 88 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 89 | self.bn2 = norm_layer(width) 90 | self.conv3 = conv1x1(width, planes * self.expansion) 91 | self.bn3 = norm_layer(planes * self.expansion) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | identity = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | identity = self.downsample(x) 112 | 113 | out += identity 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class ResNet(nn.Module): 120 | 121 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 122 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 123 | norm_layer=None): 124 | super(ResNet, self).__init__() 125 | if norm_layer is None: 126 | norm_layer = nn.BatchNorm2d 127 | self._norm_layer = norm_layer 128 | 129 | self.inplanes = 64 130 | self.dilation = 1 131 | if replace_stride_with_dilation is None: 132 | # each element in the tuple indicates if we should replace 133 | # the 2x2 stride with a dilated convolution instead 134 | replace_stride_with_dilation = [False, False, False] 135 | if len(replace_stride_with_dilation) != 3: 136 | raise ValueError("replace_stride_with_dilation should be None " 137 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 138 | self.groups = groups 139 | self.base_width = width_per_group 140 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 141 | bias=False) 142 | self.bn1 = norm_layer(self.inplanes) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 145 | self.layer1 = self._make_layer(block, 64, layers[0]) 146 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 147 | dilate=replace_stride_with_dilation[0]) 148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 149 | dilate=replace_stride_with_dilation[1]) 150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 151 | dilate=replace_stride_with_dilation[2]) 152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | # Zero-initialize the last BN in each residual branch, 163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 165 | if zero_init_residual: 166 | for m in self.modules(): 167 | if isinstance(m, Bottleneck): 168 | nn.init.constant_(m.bn3.weight, 0) 169 | elif isinstance(m, BasicBlock): 170 | nn.init.constant_(m.bn2.weight, 0) 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 173 | norm_layer = self._norm_layer 174 | downsample = None 175 | previous_dilation = self.dilation 176 | if dilate: 177 | self.dilation *= stride 178 | stride = 1 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | conv1x1(self.inplanes, planes * block.expansion, stride), 182 | norm_layer(planes * block.expansion), 183 | ) 184 | 185 | layers = [] 186 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 187 | self.base_width, previous_dilation, norm_layer)) 188 | self.inplanes = planes * block.expansion 189 | for _ in range(1, blocks): 190 | layers.append(block(self.inplanes, planes, groups=self.groups, 191 | base_width=self.base_width, dilation=self.dilation, 192 | norm_layer=norm_layer)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | x = self.conv1(x) 198 | x = self.bn1(x) 199 | x = self.relu(x) 200 | x = self.maxpool(x) 201 | 202 | x = self.layer1(x) 203 | x = self.layer2(x) 204 | x = self.layer3(x) 205 | x = self.layer4(x) 206 | 207 | x = self.avgpool(x) 208 | x = torch.flatten(x, 1) 209 | x = self.fc(x) 210 | 211 | return x 212 | 213 | 214 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 215 | model = ResNet(block, layers, **kwargs) 216 | if pretrained: 217 | state_dict = load_state_dict_from_url(model_urls[arch], 218 | progress=progress) 219 | model.load_state_dict(state_dict) 220 | return model 221 | 222 | 223 | def resnet18(pretrained=False, progress=True, **kwargs): 224 | r"""ResNet-18 model from 225 | `"Deep Residual Learning for Image Recognition" `_ 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | progress (bool): If True, displays a progress bar of the download to stderr 230 | """ 231 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 232 | **kwargs) 233 | 234 | 235 | def resnet34(pretrained=False, progress=True, **kwargs): 236 | r"""ResNet-34 model from 237 | `"Deep Residual Learning for Image Recognition" `_ 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet50(pretrained=True, progress=True, **kwargs): 248 | r"""ResNet-50 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet101(pretrained=True, progress=True, **kwargs): 260 | r"""ResNet-101 model from 261 | `"Deep Residual Learning for Image Recognition" `_ 262 | 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet152(pretrained=False, progress=True, **kwargs): 272 | r"""ResNet-152 model from 273 | `"Deep Residual Learning for Image Recognition" `_ 274 | 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 280 | **kwargs) 281 | 282 | 283 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 284 | r"""ResNeXt-50 32x4d model from 285 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 286 | 287 | Args: 288 | pretrained (bool): If True, returns a model pre-trained on ImageNet 289 | progress (bool): If True, displays a progress bar of the download to stderr 290 | """ 291 | kwargs['groups'] = 32 292 | kwargs['width_per_group'] = 4 293 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 294 | pretrained, progress, **kwargs) 295 | 296 | 297 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 298 | r"""ResNeXt-101 32x8d model from 299 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 300 | 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | progress (bool): If True, displays a progress bar of the download to stderr 304 | """ 305 | kwargs['groups'] = 32 306 | kwargs['width_per_group'] = 8 307 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 308 | pretrained, progress, **kwargs) 309 | 310 | 311 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 312 | r"""Wide ResNet-50-2 model from 313 | `"Wide Residual Networks" `_ 314 | 315 | The model is the same as ResNet except for the bottleneck number of channels 316 | which is twice larger in every block. The number of channels in outer 1x1 317 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 318 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | kwargs['width_per_group'] = 64 * 2 325 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 326 | pretrained, progress, **kwargs) 327 | 328 | 329 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 330 | r"""Wide ResNet-101-2 model from 331 | `"Wide Residual Networks" `_ 332 | 333 | The model is the same as ResNet except for the bottleneck number of channels 334 | which is twice larger in every block. The number of channels in outer 1x1 335 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 336 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 337 | 338 | Args: 339 | pretrained (bool): If True, returns a model pre-trained on ImageNet 340 | progress (bool): If True, displays a progress bar of the download to stderr 341 | """ 342 | kwargs['width_per_group'] = 64 * 2 343 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 344 | pretrained, progress, **kwargs) 345 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os, time 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | from operator import add 5 | import numpy as np 6 | from glob import glob 7 | import cv2 8 | from tqdm import tqdm 9 | import imageio 10 | import torch 11 | from model import RUPNet 12 | from utils import create_dir, seeding 13 | from utils import calculate_metrics 14 | from train import load_data 15 | 16 | 17 | def evaluate(model, save_path, test_x, test_y, size): 18 | """ Loading other comparitive model masks """ 19 | comparison_path = "/media/nikhil/LAB/ML/ME/COMPARISON/Kvasir-SEG/" 20 | 21 | 22 | deeplabv3plus_mask = sorted(glob(os.path.join(comparison_path, "DeepLabV3+_50", "results", "Kvasir-SEG", "mask", "*"))) 23 | pranet_mask = sorted(glob(os.path.join(comparison_path, "PraNet", "results", "Kvasir-SEG", "mask", "*"))) 24 | 25 | metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 26 | time_taken = [] 27 | 28 | for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)): 29 | name = y.split("/")[-1].split(".")[0] 30 | 31 | """ Image """ 32 | image = cv2.imread(x, cv2.IMREAD_COLOR) 33 | image = cv2.resize(image, size) 34 | save_img = image 35 | image = np.transpose(image, (2, 0, 1)) 36 | image = image/255.0 37 | image = np.expand_dims(image, axis=0) 38 | image = image.astype(np.float32) 39 | image = torch.from_numpy(image) 40 | image = image.to(device) 41 | 42 | """ Mask """ 43 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 44 | mask = cv2.resize(mask, size) 45 | save_mask = mask 46 | save_mask = np.expand_dims(save_mask, axis=-1) 47 | save_mask = np.concatenate([save_mask, save_mask, save_mask], axis=2) 48 | mask = np.expand_dims(mask, axis=0) 49 | mask = mask/255.0 50 | mask = np.expand_dims(mask, axis=0) 51 | mask = mask.astype(np.float32) 52 | mask = torch.from_numpy(mask) 53 | mask = mask.to(device) 54 | 55 | with torch.no_grad(): 56 | """ FPS calculation """ 57 | start_time = time.time() 58 | heatmap, y_pred = model(image, heatmap=True) 59 | y_pred = torch.sigmoid(y_pred) 60 | end_time = time.time() - start_time 61 | time_taken.append(end_time) 62 | 63 | """ Evaluation metrics """ 64 | score = calculate_metrics(mask, y_pred) 65 | metrics_score = list(map(add, metrics_score, score)) 66 | 67 | """ Predicted Mask """ 68 | y_pred = y_pred[0].cpu().numpy() 69 | y_pred = np.squeeze(y_pred, axis=0) 70 | y_pred = y_pred > 0.5 71 | y_pred = y_pred.astype(np.int32) 72 | y_pred = y_pred * 255 73 | y_pred = np.array(y_pred, dtype=np.uint8) 74 | y_pred = np.expand_dims(y_pred, axis=-1) 75 | y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=2) 76 | 77 | """ Save the image - mask - pred """ 78 | line = np.ones((size[0], 10, 3)) * 255 79 | cat_images = np.concatenate([ 80 | save_img, line, 81 | save_mask, line, 82 | cv2.imread(deeplabv3plus_mask[i], cv2.IMREAD_COLOR), line, 83 | cv2.imread(pranet_mask[i], cv2.IMREAD_COLOR), line, 84 | y_pred, line, 85 | heatmap], axis=1) 86 | 87 | cv2.imwrite(f"{save_path}/joint/{name}.jpg", cat_images) 88 | cv2.imwrite(f"{save_path}/mask/{name}.jpg", y_pred) 89 | cv2.imwrite(f"{save_path}/heatmap/{name}.jpg", heatmap) 90 | 91 | jaccard = metrics_score[0]/len(test_x) 92 | f1 = metrics_score[1]/len(test_x) 93 | recall = metrics_score[2]/len(test_x) 94 | precision = metrics_score[3]/len(test_x) 95 | acc = metrics_score[4]/len(test_x) 96 | f2 = metrics_score[5]/len(test_x) 97 | 98 | print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - F2: {f2:1.4f}") 99 | 100 | mean_time_taken = np.mean(time_taken) 101 | mean_fps = 1/mean_time_taken 102 | print("Mean FPS: ", mean_fps) 103 | 104 | 105 | if __name__ == "__main__": 106 | """ Seeding """ 107 | seeding(42) 108 | 109 | """ Load the checkpoint """ 110 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 111 | model = RUPNet() 112 | model = model.to(device) 113 | checkpoint_path = "files/checkpoint.pth" 114 | model.load_state_dict(torch.load(checkpoint_path, map_location=device)) 115 | model.eval() 116 | 117 | """ Test dataset """ 118 | path = "/media/nikhil/Seagate Backup Plus Drive/ML_DATASET/Kvasir-SEG" 119 | (train_x, train_y), (test_x, test_y) = load_data(path) 120 | 121 | test_x = sorted(test_x) 122 | test_y = sorted(test_y) 123 | 124 | save_path = f"results/Kvasir-SEG" 125 | for item in ["mask", "joint", "heatmap"]: 126 | create_dir(f"{save_path}/{item}") 127 | 128 | size = (256, 256) 129 | create_dir(save_path) 130 | evaluate(model, save_path, test_x, test_y, size) 131 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import time 5 | import datetime 6 | import numpy as np 7 | import albumentations as A 8 | import cv2 9 | from glob import glob 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision import transforms 14 | from utils import seeding, create_dir, print_and_save, shuffling, epoch_time, calculate_metrics 15 | from model import RUPNet 16 | from metrics import DiceLoss, DiceBCELoss 17 | 18 | def load_names(path, file_path): 19 | f = open(file_path, "r") 20 | data = f.read().split("\n")[:-1] 21 | images = [os.path.join(path,"images", name) + ".jpg" for name in data] 22 | masks = [os.path.join(path,"masks", name) + ".jpg" for name in data] 23 | return images, masks 24 | 25 | def load_data(path): 26 | train_names_path = f"{path}/train.txt" 27 | valid_names_path = f"{path}/val.txt" 28 | 29 | train_x, train_y = load_names(path, train_names_path) 30 | valid_x, valid_y = load_names(path, valid_names_path) 31 | 32 | return (train_x, train_y), (valid_x, valid_y) 33 | 34 | class DATASET(Dataset): 35 | def __init__(self, images_path, masks_path, size, transform=None): 36 | super().__init__() 37 | 38 | self.images_path = images_path 39 | self.masks_path = masks_path 40 | self.transform = transform 41 | self.n_samples = len(images_path) 42 | 43 | def __getitem__(self, index): 44 | """ Image """ 45 | image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR) 46 | mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE) 47 | 48 | if self.transform is not None: 49 | augmentations = self.transform(image=image, mask=mask) 50 | image = augmentations["image"] 51 | mask = augmentations["mask"] 52 | 53 | image = cv2.resize(image, size) 54 | image = np.transpose(image, (2, 0, 1)) 55 | image = image/255.0 56 | 57 | mask = cv2.resize(mask, size) 58 | mask = np.expand_dims(mask, axis=0) 59 | mask = mask/255.0 60 | 61 | return image, mask 62 | 63 | def __len__(self): 64 | return self.n_samples 65 | 66 | def train(model, loader, optimizer, loss_fn, device): 67 | model.train() 68 | 69 | epoch_loss = 0.0 70 | epoch_jac = 0.0 71 | epoch_f1 = 0.0 72 | epoch_recall = 0.0 73 | epoch_precision = 0.0 74 | 75 | for i, (x, y) in enumerate(loader): 76 | x = x.to(device, dtype=torch.float32) 77 | y = y.to(device, dtype=torch.float32) 78 | 79 | optimizer.zero_grad() 80 | p1, p2, p3, p4 = model(x) 81 | loss = loss_fn(p1, y) + loss_fn(p2, y) + loss_fn(p3, y) + loss_fn(p4, y) 82 | loss.backward() 83 | optimizer.step() 84 | epoch_loss += loss.item() 85 | 86 | """ Calculate the metrics """ 87 | batch_jac = [] 88 | batch_f1 = [] 89 | batch_recall = [] 90 | batch_precision = [] 91 | 92 | for yt, yp in zip(y, p4): 93 | score = calculate_metrics(yt, yp) 94 | batch_jac.append(score[0]) 95 | batch_f1.append(score[1]) 96 | batch_recall.append(score[2]) 97 | batch_precision.append(score[3]) 98 | 99 | epoch_jac += np.mean(batch_jac) 100 | epoch_f1 += np.mean(batch_f1) 101 | epoch_recall += np.mean(batch_recall) 102 | epoch_precision += np.mean(batch_precision) 103 | 104 | epoch_loss = epoch_loss/len(loader) 105 | epoch_jac = epoch_jac/len(loader) 106 | epoch_f1 = epoch_f1/len(loader) 107 | epoch_recall = epoch_recall/len(loader) 108 | epoch_precision = epoch_precision/len(loader) 109 | 110 | return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision] 111 | 112 | def evaluate(model, loader, loss_fn, device): 113 | model.eval() 114 | 115 | epoch_loss = 0 116 | epoch_loss = 0.0 117 | epoch_jac = 0.0 118 | epoch_f1 = 0.0 119 | epoch_recall = 0.0 120 | epoch_precision = 0.0 121 | 122 | with torch.no_grad(): 123 | for i, (x, y) in enumerate(loader): 124 | x = x.to(device, dtype=torch.float32) 125 | y = y.to(device, dtype=torch.float32) 126 | 127 | p1, p2, p3, p4 = model(x) 128 | loss = loss_fn(p1, y) + loss_fn(p2, y) + loss_fn(p3, y) + loss_fn(p4, y) 129 | epoch_loss += loss.item() 130 | 131 | """ Calculate the metrics """ 132 | batch_jac = [] 133 | batch_f1 = [] 134 | batch_recall = [] 135 | batch_precision = [] 136 | 137 | for yt, yp in zip(y, p4): 138 | score = calculate_metrics(yt, yp) 139 | batch_jac.append(score[0]) 140 | batch_f1.append(score[1]) 141 | batch_recall.append(score[2]) 142 | batch_precision.append(score[3]) 143 | 144 | epoch_jac += np.mean(batch_jac) 145 | epoch_f1 += np.mean(batch_f1) 146 | epoch_recall += np.mean(batch_recall) 147 | epoch_precision += np.mean(batch_precision) 148 | 149 | epoch_loss = epoch_loss/len(loader) 150 | epoch_jac = epoch_jac/len(loader) 151 | epoch_f1 = epoch_f1/len(loader) 152 | epoch_recall = epoch_recall/len(loader) 153 | epoch_precision = epoch_precision/len(loader) 154 | 155 | return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision] 156 | 157 | if __name__ == "__main__": 158 | """ Seeding """ 159 | seeding(42) 160 | 161 | """ Directories """ 162 | create_dir("files") 163 | 164 | """ Training logfile """ 165 | train_log_path = "files/train_log.txt" 166 | if os.path.exists(train_log_path): 167 | print("Log file exists") 168 | else: 169 | train_log = open("files/train_log.txt", "w") 170 | train_log.write("\n") 171 | train_log.close() 172 | 173 | """ Record Date & Time """ 174 | datetime_object = str(datetime.datetime.now()) 175 | print_and_save(train_log_path, datetime_object) 176 | print("") 177 | 178 | """ Hyperparameters """ 179 | image_size = 256 180 | size = (image_size, image_size) 181 | batch_size = 16 182 | num_epochs = 500 183 | lr = 1e-4 184 | early_stopping_patience = 50 185 | checkpoint_path = "files/checkpoint.pth" 186 | path = "/media/nikhil/Seagate Backup Plus Drive/ML_DATASET/Kvasir-SEG" 187 | 188 | data_str = f"Image Size: {size}\nBatch Size: {batch_size}\nLR: {lr}\nEpochs: {num_epochs}\n" 189 | data_str += f"Early Stopping Patience: {early_stopping_patience}\n" 190 | print_and_save(train_log_path, data_str) 191 | 192 | """ Dataset """ 193 | (train_x, train_y), (valid_x, valid_y) = load_data(path) 194 | train_x, train_y = shuffling(train_x, train_y) 195 | data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n" 196 | print_and_save(train_log_path, data_str) 197 | 198 | """ Data augmentation: Transforms """ 199 | transform = A.Compose([ 200 | A.Rotate(limit=35, p=0.3), 201 | A.HorizontalFlip(p=0.3), 202 | A.VerticalFlip(p=0.3), 203 | A.CoarseDropout(p=0.3, max_holes=10, max_height=32, max_width=32) 204 | ]) 205 | 206 | """ Dataset and loader """ 207 | train_dataset = DATASET(train_x, train_y, size, transform=transform) 208 | valid_dataset = DATASET(valid_x, valid_y, size, transform=None) 209 | 210 | # create_dir("data") 211 | # for i, (x, y) in enumerate(train_dataset): 212 | # x = np.transpose(x, (1, 2, 0)) * 255 213 | # y = np.transpose(y, (1, 2, 0)) * 255 214 | # y = np.concatenate([y, y, y], axis=-1) 215 | # cv2.imwrite(f"data/{i}.png", np.concatenate([x, y], axis=1)) 216 | 217 | train_loader = DataLoader( 218 | dataset=train_dataset, 219 | batch_size=batch_size, 220 | shuffle=True, 221 | num_workers=2 222 | ) 223 | 224 | valid_loader = DataLoader( 225 | dataset=valid_dataset, 226 | batch_size=batch_size, 227 | shuffle=False, 228 | num_workers=2 229 | ) 230 | 231 | """ Model """ 232 | device = torch.device('cuda') 233 | model = RUPNet() 234 | model = model.to(device) 235 | 236 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 237 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True) 238 | loss_fn = DiceBCELoss() 239 | loss_name = "BCE Dice Loss" 240 | data_str = f"Optimizer: Adam\nLoss: {loss_name}\n" 241 | print_and_save(train_log_path, data_str) 242 | 243 | """ Training the model """ 244 | best_valid_metrics = 0.0 245 | early_stopping_count = 0 246 | 247 | for epoch in range(num_epochs): 248 | start_time = time.time() 249 | 250 | train_loss, train_metrics = train(model, train_loader, optimizer, loss_fn, device) 251 | valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, device) 252 | scheduler.step(valid_loss) 253 | 254 | if valid_metrics[1] > best_valid_metrics: 255 | data_str = f"Valid F1 improved from {best_valid_metrics:2.4f} to {valid_metrics[1]:2.4f}. Saving checkpoint: {checkpoint_path}" 256 | print_and_save(train_log_path, data_str) 257 | 258 | best_valid_metrics = valid_metrics[1] 259 | torch.save(model.state_dict(), checkpoint_path) 260 | early_stopping_count = 0 261 | 262 | elif valid_metrics[1] < best_valid_metrics: 263 | early_stopping_count += 1 264 | 265 | end_time = time.time() 266 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 267 | 268 | data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n" 269 | data_str += f"\tTrain Loss: {train_loss:.4f} - Jaccard: {train_metrics[0]:.4f} - F1: {train_metrics[1]:.4f} - Recall: {train_metrics[2]:.4f} - Precision: {train_metrics[3]:.4f}\n" 270 | data_str += f"\t Val. Loss: {valid_loss:.4f} - Jaccard: {valid_metrics[0]:.4f} - F1: {valid_metrics[1]:.4f} - Recall: {valid_metrics[2]:.4f} - Precision: {valid_metrics[3]:.4f}\n" 271 | print_and_save(train_log_path, data_str) 272 | 273 | if early_stopping_count == early_stopping_patience: 274 | data_str = f"Early stopping: validation loss stops improving from last {early_stopping_patience} continously.\n" 275 | print_and_save(train_log_path, data_str) 276 | break 277 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | import torch 8 | from sklearn.utils import shuffle 9 | from metrics import precision, recall, F2, dice_score, jac_score 10 | from sklearn.metrics import accuracy_score, confusion_matrix 11 | 12 | """ Seeding the randomness. """ 13 | def seeding(seed): 14 | random.seed(seed) 15 | os.environ["PYTHONHASHSEED"] = str(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | """ Create a directory """ 22 | def create_dir(path): 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | """ Shuffle the dataset. """ 27 | def shuffling(x, y): 28 | x, y = shuffle(x, y, random_state=42) 29 | return x, y 30 | 31 | def epoch_time(start_time, end_time): 32 | elapsed_time = end_time - start_time 33 | elapsed_mins = int(elapsed_time / 60) 34 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 35 | return elapsed_mins, elapsed_secs 36 | 37 | def print_and_save(file_path, data_str): 38 | print(data_str) 39 | with open(file_path, "a") as file: 40 | file.write(data_str) 41 | file.write("\n") 42 | 43 | def calculate_metrics(y_true, y_pred): 44 | y_true = y_true.detach().cpu().numpy() 45 | y_pred = y_pred.detach().cpu().numpy() 46 | 47 | y_pred = y_pred > 0.5 48 | y_pred = y_pred.reshape(-1) 49 | y_pred = y_pred.astype(np.uint8) 50 | 51 | y_true = y_true > 0.5 52 | y_true = y_true.reshape(-1) 53 | y_true = y_true.astype(np.uint8) 54 | 55 | ## Score 56 | score_jaccard = jac_score(y_true, y_pred) 57 | score_f1 = dice_score(y_true, y_pred) 58 | score_recall = recall(y_true, y_pred) 59 | score_precision = precision(y_true, y_pred) 60 | score_fbeta = F2(y_true, y_pred) 61 | score_acc = accuracy_score(y_true, y_pred) 62 | 63 | return [score_jaccard, score_f1, score_recall, score_precision, score_acc, score_fbeta] 64 | --------------------------------------------------------------------------------