├── Architecture.jpg ├── IntroTransNetR.png ├── IntroTransNetR1.png ├── Kvasir-seg.png ├── LICENSE ├── OOD.png ├── OOD_K.png ├── README.md ├── Residual-Transformer-Block.jpg ├── Test_PolypGen.jpg ├── bkai_crossdata.jpg ├── metrics.py ├── model.py ├── polypgen-samples.jpg ├── resnet.py ├── results.jpg ├── supplementry_C1.jpeg ├── supplementry_C6.jpg ├── test.py ├── train.py └── utils.py /Architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Architecture.jpg -------------------------------------------------------------------------------- /IntroTransNetR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/IntroTransNetR.png -------------------------------------------------------------------------------- /IntroTransNetR1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/IntroTransNetR1.png -------------------------------------------------------------------------------- /Kvasir-seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Kvasir-seg.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Debesh Jha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OOD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/OOD.png -------------------------------------------------------------------------------- /OOD_K.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/OOD_K.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransNetR: Transformer-based Residual Network for Polyp Segmentation with Multi-Center Out-of-Distribution Testing (MIDL 2023) 2 | 3 | Paper Link: https://arxiv.org/pdf/2303.07428.pdf 4 | 5 | TransNetR is an encoder decoder network which can be used for efficient biomedical image segmentation for both in-distribution and out-of-distribution datasets. 6 | 7 | ## In-distribution and Out-of-distributuion dataset 8 |

9 | 10 | 11 | Figure 1: Illustration of different scenarios expected to arise in real-world settings. The proposed work conducted both in-distribution and out-of-distribution validation process. C1 to C6 represent the different centers data present in PolypGen dataset width=50% height=50% 12 |

13 | 14 | # TransNetR 15 |

16 | 17 | 18 | Figure 2: Block diagram of TransNetR along with the Residual Transformer block 19 |

20 | 21 | 22 | ## Results (Qualitative results) 23 |

24 | 25 | 26 | Figure 3: Qualitative example showing polyp segmentation on Kvasir-SEG 27 |

28 | 29 | ## Results (Quantative results) 30 |

31 | 32 | Table 1: Quantitative results on the Kvasir-SEG test dataset. The parameters are in Mil- lions and Flops are in GMac. 33 |

34 | 35 |

36 | 37 |

38 | 39 |

40 | 41 |

42 | 43 |

44 | 45 |

46 | 47 | ## Results (Qualitative results) 48 |

49 | 50 | 51 | Figure 4: Cross-data result when models trained on Kvasir-SEG & tested on BKAI-IGH.LeakyReLU activation function. Finally, the output from the LeakyReLU is passed througha residual block which acts as the output of the residual transformer block. 52 |

53 | 54 | ## Results (Qualitative results) 55 |

56 | 57 | 58 | Figure 5: Center-wise example images from the PolypGen dataset. Here, the variabilityamong the dataset from different centers can be observed. There is a differencein image resolutions and sizes, shapes, colors, textures and appearances and col-lection protocols.Figure 6: Qualitative result when the TransNetR is trained on Kvasir-SEG and tested on(a) PolypGen (center 6 (C6)) and (b) PolypGen (center 1 (C1)).13 59 |

60 | 61 | 62 | ## Results (Samples of OOD (PolyGen-datasets from 6 different centers)) 63 |

64 | 65 | 66 | Figure 6: Qualitative result when the TransNetR is trained on Kvasir-SEG and tested on(a) PolypGen (center 6 (C6)) 67 |

68 | 69 | ## Qualitative results 70 |

71 | 72 | 73 | Figure 7: Qualitative result when the TransNetR is trained on Kvasir-SEG and tested on PolypGen (center 1 (C1)) 74 |

75 | 76 | ## Citation 77 | Please cite our paper if you find the work useful: 78 |
79 |   @INPROCEEDINGS{JhaTrans2023,
80 |   author={D.{Jha} and N.{Tomar} and  V.{Sharma} and U.{Bagci}}, 
81 |   booktitle={Proceedings of the Medical Imaging with Deep Learning}, 
82 |   title={TransNetR: Transformer-based Residual Network for Polyp Segmentation with Multi-Center Out-of-Distribution Testing}, 
83 |   year={2023}}
84 | 
85 | 86 | ## Contact 87 | Please contact debesh.jha@northwestern.edu and nikhilroxtomar@gmail.com for any further questions. 88 | -------------------------------------------------------------------------------- /Residual-Transformer-Block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Residual-Transformer-Block.jpg -------------------------------------------------------------------------------- /Test_PolypGen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Test_PolypGen.jpg -------------------------------------------------------------------------------- /bkai_crossdata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/bkai_crossdata.jpg -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.spatial.distance import directed_hausdorff 5 | 6 | """ Loss Functions -------------------------------------- """ 7 | class DiceLoss(nn.Module): 8 | def __init__(self, weight=None, size_average=True): 9 | super(DiceLoss, self).__init__() 10 | 11 | def forward(self, inputs, targets, smooth=1): 12 | inputs = torch.sigmoid(inputs) 13 | 14 | inputs = inputs.view(-1) 15 | targets = targets.view(-1) 16 | 17 | intersection = (inputs * targets).sum() 18 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 19 | 20 | return 1 - dice 21 | 22 | class DiceBCELoss(nn.Module): 23 | def __init__(self, weight=None, size_average=True): 24 | super(DiceBCELoss, self).__init__() 25 | 26 | def forward(self, inputs, targets, smooth=1): 27 | inputs = torch.sigmoid(inputs) 28 | 29 | inputs = inputs.view(-1) 30 | targets = targets.view(-1) 31 | 32 | intersection = (inputs * targets).sum() 33 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 34 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 35 | Dice_BCE = BCE + dice_loss 36 | 37 | return Dice_BCE 38 | 39 | """ Metrics ------------------------------------------ """ 40 | def precision(y_true, y_pred): 41 | intersection = (y_true * y_pred).sum() 42 | return (intersection + 1e-15) / (y_pred.sum() + 1e-15) 43 | 44 | def recall(y_true, y_pred): 45 | intersection = (y_true * y_pred).sum() 46 | return (intersection + 1e-15) / (y_true.sum() + 1e-15) 47 | 48 | def F2(y_true, y_pred, beta=2): 49 | p = precision(y_true,y_pred) 50 | r = recall(y_true, y_pred) 51 | return (1+beta**2.) *(p*r) / float(beta**2*p + r + 1e-15) 52 | 53 | def dice_score(y_true, y_pred): 54 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15) 55 | 56 | def jac_score(y_true, y_pred): 57 | intersection = (y_true * y_pred).sum() 58 | union = y_true.sum() + y_pred.sum() - intersection 59 | return (intersection + 1e-15) / (union + 1e-15) 60 | 61 | ## https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/319452 62 | def hd_dist(preds, targets): 63 | haussdorf_dist = directed_hausdorff(preds, targets)[0] 64 | return haussdorf_dist 65 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from resnet import resnet50, resnet34 5 | 6 | class Conv2D(nn.Module): 7 | def __init__(self, in_c, out_c, kernel_size=3, padding=1, stride=1, dilation=1, bias=True, act=True): 8 | super().__init__() 9 | 10 | self.act = act 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, stride=stride, bias=bias), 13 | nn.BatchNorm2d(out_c) 14 | ) 15 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | if self.act == True: 20 | x = self.relu(x) 21 | return x 22 | 23 | class residual_block(nn.Module): 24 | def __init__(self, in_c, out_c): 25 | super().__init__() 26 | 27 | self.conv = nn.Sequential( 28 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 29 | nn.BatchNorm2d(out_c), 30 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 31 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(out_c) 33 | ) 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_c, out_c, kernel_size=1, padding=0), 36 | nn.BatchNorm2d(out_c) 37 | ) 38 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 39 | 40 | def forward(self, inputs): 41 | x = self.conv(inputs) 42 | s = self.shortcut(inputs) 43 | return self.relu(x + s) 44 | 45 | class residual_transformer_block(nn.Module): 46 | def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None): 47 | super().__init__() 48 | 49 | self.ps = patch_size 50 | self.c1 = Conv2D(in_c, out_c) 51 | 52 | encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads) 53 | self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 54 | 55 | self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False) 56 | self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False) 57 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 58 | self.r1 = residual_block(out_c, out_c) 59 | 60 | def forward(self, inputs): 61 | x = self.c1(inputs) 62 | 63 | b, c, h, w = x.shape 64 | num_patches = (h*w)//(self.ps**2) 65 | x = torch.reshape(x, (b, (self.ps**2)*c, num_patches)) 66 | x = self.te(x) 67 | x = torch.reshape(x, (b, c, h, w)) 68 | 69 | x = self.c2(x) 70 | s = self.c3(inputs) 71 | x = self.relu(x + s) 72 | x = self.r1(x) 73 | return x 74 | 75 | class Model(nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | 79 | """ Encoder """ 80 | backbone = resnet50() 81 | self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu) 82 | self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1) 83 | self.layer2 = backbone.layer2 84 | self.layer3 = backbone.layer3 85 | self.layer4 = backbone.layer4 86 | 87 | self.e1 = Conv2D(64, 64, kernel_size=1, padding=0) 88 | self.e2 = Conv2D(256, 64, kernel_size=1, padding=0) 89 | self.e3 = Conv2D(512, 64, kernel_size=1, padding=0) 90 | self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0) 91 | 92 | 93 | """ Decoder """ 94 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 95 | self.r1 = residual_transformer_block(64+64, 64, dim=64) 96 | self.r2 = residual_transformer_block(64+64, 64, dim=256) 97 | self.r3 = residual_block(64+64, 64) 98 | 99 | """ Classifier """ 100 | self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0) 101 | 102 | def forward(self, inputs): 103 | """ Encoder """ 104 | x0 = inputs 105 | x1 = self.layer0(x0) ## [-1, 64, h/2, w/2] 106 | x2 = self.layer1(x1) ## [-1, 256, h/4, w/4] 107 | x3 = self.layer2(x2) ## [-1, 512, h/8, w/8] 108 | x4 = self.layer3(x3) ## [-1, 1024, h/16, w/16] 109 | # print(x1.shape, x2.shape, x3.shape, x4.shape) 110 | 111 | e1 = self.e1(x1) 112 | e2 = self.e2(x2) 113 | e3 = self.e3(x3) 114 | e4 = self.e4(x4) 115 | 116 | """ Decoder """ 117 | x = self.up(e4) 118 | x = torch.cat([x, e3], axis=1) 119 | x = self.r1(x) 120 | 121 | x = self.up(x) 122 | x = torch.cat([x, e2], axis=1) 123 | x = self.r2(x) 124 | 125 | x = self.up(x) 126 | x = torch.cat([x, e1], axis=1) 127 | x = self.r3(x) 128 | 129 | x = self.up(x) 130 | 131 | """ Classifier """ 132 | outputs = self.outputs(x) 133 | return outputs 134 | 135 | if __name__ == "__main__": 136 | x = torch.randn((4, 3, 256, 256)) 137 | model = Model() 138 | y = model(x) 139 | print(y.shape) 140 | -------------------------------------------------------------------------------- /polypgen-samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/polypgen-samples.jpg -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 | base_width=64, dilation=1, norm_layer=None): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(Bottleneck, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | width = int(planes * (base_width / 64.)) * groups 84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = conv1x1(inplanes, width) 86 | self.bn1 = norm_layer(width) 87 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 88 | self.bn2 = norm_layer(width) 89 | self.conv3 = conv1x1(width, planes * self.expansion) 90 | self.bn3 = norm_layer(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | identity = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | identity = self.downsample(x) 111 | 112 | out += identity 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 121 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 122 | norm_layer=None): 123 | super(ResNet, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(self.inplanes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 146 | dilate=replace_stride_with_dilation[0]) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 148 | dilate=replace_stride_with_dilation[1]) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 150 | dilate=replace_stride_with_dilation[2]) 151 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | # Zero-initialize the last BN in each residual branch, 162 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 163 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 164 | if zero_init_residual: 165 | for m in self.modules(): 166 | if isinstance(m, Bottleneck): 167 | nn.init.constant_(m.bn3.weight, 0) 168 | elif isinstance(m, BasicBlock): 169 | nn.init.constant_(m.bn2.weight, 0) 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 172 | norm_layer = self._norm_layer 173 | downsample = None 174 | previous_dilation = self.dilation 175 | if dilate: 176 | self.dilation *= stride 177 | stride = 1 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, previous_dilation, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for _ in range(1, blocks): 189 | layers.append(block(self.inplanes, planes, groups=self.groups, 190 | base_width=self.base_width, dilation=self.dilation, 191 | norm_layer=norm_layer)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | x = torch.flatten(x, 1) 208 | x = self.fc(x) 209 | 210 | return x 211 | 212 | 213 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 214 | model = ResNet(block, layers, **kwargs) 215 | if pretrained: 216 | state_dict = load_state_dict_from_url(model_urls[arch], 217 | progress=progress) 218 | model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 231 | **kwargs) 232 | 233 | 234 | def resnet34(pretrained=False, progress=True, **kwargs): 235 | r"""ResNet-34 model from 236 | `"Deep Residual Learning for Image Recognition" `_ 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | progress (bool): If True, displays a progress bar of the download to stderr 241 | """ 242 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 243 | **kwargs) 244 | 245 | 246 | def resnet50(pretrained=True, progress=True, **kwargs): 247 | r"""ResNet-50 model from 248 | `"Deep Residual Learning for Image Recognition" `_ 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | r"""ResNet-101 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet152(pretrained=False, progress=True, **kwargs): 271 | r"""ResNet-152 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | 274 | Args: 275 | pretrained (bool): If True, returns a model pre-trained on ImageNet 276 | progress (bool): If True, displays a progress bar of the download to stderr 277 | """ 278 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 279 | **kwargs) 280 | 281 | 282 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 283 | r"""ResNeXt-50 32x4d model from 284 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | kwargs['groups'] = 32 291 | kwargs['width_per_group'] = 4 292 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 293 | pretrained, progress, **kwargs) 294 | 295 | 296 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-101 32x8d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | kwargs['groups'] = 32 305 | kwargs['width_per_group'] = 8 306 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 307 | pretrained, progress, **kwargs) 308 | 309 | 310 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 311 | r"""Wide ResNet-50-2 model from 312 | `"Wide Residual Networks" `_ 313 | 314 | The model is the same as ResNet except for the bottleneck number of channels 315 | which is twice larger in every block. The number of channels in outer 1x1 316 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 317 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | kwargs['width_per_group'] = 64 * 2 324 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 325 | pretrained, progress, **kwargs) 326 | 327 | 328 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 329 | r"""Wide ResNet-101-2 model from 330 | `"Wide Residual Networks" `_ 331 | 332 | The model is the same as ResNet except for the bottleneck number of channels 333 | which is twice larger in every block. The number of channels in outer 1x1 334 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 335 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 336 | 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | progress (bool): If True, displays a progress bar of the download to stderr 340 | """ 341 | kwargs['width_per_group'] = 64 * 2 342 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 343 | pretrained, progress, **kwargs) 344 | -------------------------------------------------------------------------------- /results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/results.jpg -------------------------------------------------------------------------------- /supplementry_C1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/supplementry_C1.jpeg -------------------------------------------------------------------------------- /supplementry_C6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/supplementry_C6.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os, time 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | from operator import add 5 | import numpy as np 6 | from glob import glob 7 | import cv2 8 | from tqdm import tqdm 9 | import imageio 10 | import torch 11 | from model import Model 12 | from utils import create_dir, seeding 13 | from utils import calculate_metrics 14 | from train import load_data 15 | from metrics import hd_dist 16 | 17 | def calculate_hd(y_true, y_pred): 18 | y_true = y_true[0][0].detach().cpu().numpy() 19 | y_pred = y_pred[0][0].detach().cpu().numpy() 20 | 21 | y_pred = y_pred > 0.5 22 | y_pred = y_pred.astype(np.uint8) 23 | 24 | y_true = y_true > 0.5 25 | y_true = y_true.astype(np.uint8) 26 | 27 | return hd_dist(y_true, y_pred) 28 | 29 | 30 | def evaluate(model, save_path, test_x, test_y, size): 31 | metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 32 | time_taken = [] 33 | 34 | for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)): 35 | name = y.split("/")[-1].split(".")[0] 36 | 37 | """ Image """ 38 | image = cv2.imread(x, cv2.IMREAD_COLOR) 39 | image = cv2.resize(image, size) 40 | save_img = image 41 | image = np.transpose(image, (2, 0, 1)) 42 | image = image/255.0 43 | image = np.expand_dims(image, axis=0) 44 | image = image.astype(np.float32) 45 | image = torch.from_numpy(image) 46 | image = image.to(device) 47 | 48 | """ Mask """ 49 | mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) 50 | mask = cv2.resize(mask, size) 51 | save_mask = mask 52 | save_mask = np.expand_dims(save_mask, axis=-1) 53 | save_mask = np.concatenate([save_mask, save_mask, save_mask], axis=2) 54 | mask = np.expand_dims(mask, axis=0) 55 | mask = mask/255.0 56 | mask = np.expand_dims(mask, axis=0) 57 | mask = mask.astype(np.float32) 58 | mask = torch.from_numpy(mask) 59 | mask = mask.to(device) 60 | 61 | with torch.no_grad(): 62 | """ FPS calculation """ 63 | start_time = time.time() 64 | y_pred = model(image) 65 | y_pred = torch.sigmoid(y_pred) 66 | end_time = time.time() - start_time 67 | time_taken.append(end_time) 68 | 69 | """ Evaluation metrics """ 70 | score = calculate_metrics(mask, y_pred) 71 | hd = calculate_hd(mask, y_pred) 72 | score.append(hd) 73 | metrics_score = list(map(add, metrics_score, score)) 74 | 75 | """ Predicted Mask """ 76 | y_pred = y_pred[0].cpu().numpy() 77 | y_pred = np.squeeze(y_pred, axis=0) 78 | y_pred = y_pred > 0.5 79 | y_pred = y_pred.astype(np.int32) 80 | y_pred = y_pred * 255 81 | y_pred = np.array(y_pred, dtype=np.uint8) 82 | y_pred = np.expand_dims(y_pred, axis=-1) 83 | y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=2) 84 | 85 | """ Save the image - mask - pred """ 86 | line = np.ones((size[0], 10, 3)) * 255 87 | cat_images = np.concatenate([save_img, line, save_mask, line, y_pred], axis=1) 88 | cv2.imwrite(f"{save_path}/joint/{name}.jpg", cat_images) 89 | cv2.imwrite(f"{save_path}/mask/{name}.jpg", y_pred) 90 | 91 | jaccard = metrics_score[0]/len(test_x) 92 | f1 = metrics_score[1]/len(test_x) 93 | recall = metrics_score[2]/len(test_x) 94 | precision = metrics_score[3]/len(test_x) 95 | acc = metrics_score[4]/len(test_x) 96 | f2 = metrics_score[5]/len(test_x) 97 | hd = metrics_score[6]/len(test_x) 98 | 99 | print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - F2: {f2:1.4f} - HD: {hd:2.4f}") 100 | 101 | mean_time_taken = np.mean(time_taken) 102 | mean_fps = 1/mean_time_taken 103 | print("Mean FPS: ", mean_fps) 104 | 105 | 106 | if __name__ == "__main__": 107 | """ Seeding """ 108 | seeding(42) 109 | 110 | """ Load the checkpoint """ 111 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 112 | model = Model() 113 | model = model.to(device) 114 | checkpoint_path = "files/checkpoint.pth" 115 | model.load_state_dict(torch.load(checkpoint_path, map_location=device)) 116 | model.eval() 117 | 118 | """ Test dataset """ 119 | path = "/../Kvasir-SEG" 120 | (train_x, train_y), (test_x, test_y) = load_data(path) 121 | 122 | save_path = f"results/Kvasir-SEG" 123 | for item in ["mask", "joint"]: 124 | create_dir(f"{save_path}/{item}") 125 | 126 | size = (256, 256) 127 | create_dir(save_path) 128 | evaluate(model, save_path, test_x, test_y, size) 129 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import time 5 | import datetime 6 | import numpy as np 7 | import albumentations as A 8 | import cv2 9 | from PIL import Image 10 | from glob import glob 11 | import torch 12 | import torch.nn as nn 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import transforms 15 | from utils import seeding, create_dir, print_and_save, shuffling, epoch_time, calculate_metrics 16 | from model import Model 17 | from metrics import DiceLoss, DiceBCELoss 18 | 19 | def load_names(path, file_path): 20 | f = open(file_path, "r") 21 | data = f.read().split("\n")[:-1] 22 | images = [os.path.join(path,"images", name) + ".jpg" for name in data] 23 | masks = [os.path.join(path,"masks", name) + ".jpg" for name in data] 24 | return images, masks 25 | 26 | def load_data(path): 27 | train_names_path = f"{path}/train.txt" 28 | valid_names_path = f"{path}/val.txt" 29 | 30 | train_x, train_y = load_names(path, train_names_path) 31 | valid_x, valid_y = load_names(path, valid_names_path) 32 | 33 | return (train_x, train_y), (valid_x, valid_y) 34 | 35 | class DATASET(Dataset): 36 | def __init__(self, images_path, masks_path, size, transform=None): 37 | super().__init__() 38 | 39 | self.images_path = images_path 40 | self.masks_path = masks_path 41 | self.transform = transform 42 | self.n_samples = len(images_path) 43 | 44 | def __getitem__(self, index): 45 | """ Image """ 46 | image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR) 47 | mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE) 48 | 49 | if self.transform is not None: 50 | augmentations = self.transform(image=image, mask=mask) 51 | image = augmentations["image"] 52 | mask = augmentations["mask"] 53 | 54 | image = cv2.resize(image, size) 55 | image = np.transpose(image, (2, 0, 1)) 56 | image = image/255.0 57 | 58 | mask = cv2.resize(mask, size) 59 | mask = np.expand_dims(mask, axis=0) 60 | mask = mask/255.0 61 | 62 | return image, mask 63 | 64 | def __len__(self): 65 | return self.n_samples 66 | 67 | def train(model, loader, optimizer, loss_fn, device): 68 | model.train() 69 | 70 | epoch_loss = 0.0 71 | epoch_jac = 0.0 72 | epoch_f1 = 0.0 73 | epoch_recall = 0.0 74 | epoch_precision = 0.0 75 | 76 | for i, (x, y) in enumerate(loader): 77 | x = x.to(device, dtype=torch.float32) 78 | y = y.to(device, dtype=torch.float32) 79 | 80 | optimizer.zero_grad() 81 | y_pred = model(x) 82 | loss = loss_fn(y_pred, y) 83 | loss.backward() 84 | optimizer.step() 85 | epoch_loss += loss.item() 86 | 87 | """ Calculate the metrics """ 88 | batch_jac = [] 89 | batch_f1 = [] 90 | batch_recall = [] 91 | batch_precision = [] 92 | 93 | for yt, yp in zip(y, y_pred): 94 | score = calculate_metrics(yt, yp) 95 | batch_jac.append(score[0]) 96 | batch_f1.append(score[1]) 97 | batch_recall.append(score[2]) 98 | batch_precision.append(score[3]) 99 | 100 | epoch_jac += np.mean(batch_jac) 101 | epoch_f1 += np.mean(batch_f1) 102 | epoch_recall += np.mean(batch_recall) 103 | epoch_precision += np.mean(batch_precision) 104 | 105 | epoch_loss = epoch_loss/len(loader) 106 | epoch_jac = epoch_jac/len(loader) 107 | epoch_f1 = epoch_f1/len(loader) 108 | epoch_recall = epoch_recall/len(loader) 109 | epoch_precision = epoch_precision/len(loader) 110 | 111 | return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision] 112 | 113 | def evaluate(model, loader, loss_fn, device): 114 | model.eval() 115 | 116 | epoch_loss = 0 117 | epoch_loss = 0.0 118 | epoch_jac = 0.0 119 | epoch_f1 = 0.0 120 | epoch_recall = 0.0 121 | epoch_precision = 0.0 122 | 123 | with torch.no_grad(): 124 | for i, (x, y) in enumerate(loader): 125 | x = x.to(device, dtype=torch.float32) 126 | y = y.to(device, dtype=torch.float32) 127 | 128 | y_pred = model(x) 129 | loss = loss_fn(y_pred, y) 130 | epoch_loss += loss.item() 131 | 132 | """ Calculate the metrics """ 133 | batch_jac = [] 134 | batch_f1 = [] 135 | batch_recall = [] 136 | batch_precision = [] 137 | 138 | for yt, yp in zip(y, y_pred): 139 | score = calculate_metrics(yt, yp) 140 | batch_jac.append(score[0]) 141 | batch_f1.append(score[1]) 142 | batch_recall.append(score[2]) 143 | batch_precision.append(score[3]) 144 | 145 | epoch_jac += np.mean(batch_jac) 146 | epoch_f1 += np.mean(batch_f1) 147 | epoch_recall += np.mean(batch_recall) 148 | epoch_precision += np.mean(batch_precision) 149 | 150 | epoch_loss = epoch_loss/len(loader) 151 | epoch_jac = epoch_jac/len(loader) 152 | epoch_f1 = epoch_f1/len(loader) 153 | epoch_recall = epoch_recall/len(loader) 154 | epoch_precision = epoch_precision/len(loader) 155 | 156 | return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision] 157 | 158 | if __name__ == "__main__": 159 | """ Seeding """ 160 | seeding(42) 161 | 162 | """ Directories """ 163 | create_dir("files") 164 | 165 | """ Training logfile """ 166 | train_log_path = "files/train_log.txt" 167 | if os.path.exists(train_log_path): 168 | print("Log file exists") 169 | else: 170 | train_log = open("files/train_log.txt", "w") 171 | train_log.write("\n") 172 | train_log.close() 173 | 174 | """ Record Date & Time """ 175 | datetime_object = str(datetime.datetime.now()) 176 | print_and_save(train_log_path, datetime_object) 177 | print("") 178 | 179 | """ Hyperparameters """ 180 | image_size = 256 181 | size = (image_size, image_size) 182 | batch_size = 8 183 | num_epochs = 500 184 | lr = 1e-4 185 | early_stopping_patience = 50 186 | checkpoint_path = "files/checkpoint.pth" 187 | path = "/../Kvasir-SEG" 188 | 189 | 190 | data_str = f"Image Size: {size}\nBatch Size: {batch_size}\nLR: {lr}\nEpochs: {num_epochs}\n" 191 | data_str += f"Early Stopping Patience: {early_stopping_patience}\n" 192 | print_and_save(train_log_path, data_str) 193 | 194 | """ Dataset """ 195 | (train_x, train_y), (valid_x, valid_y) = load_data(path) 196 | train_x, train_y = shuffling(train_x, train_y) 197 | data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n" 198 | print_and_save(train_log_path, data_str) 199 | 200 | """ Data augmentation: Transforms """ 201 | transform = A.Compose([ 202 | A.Rotate(limit=35, p=0.3), 203 | A.HorizontalFlip(p=0.3), 204 | A.VerticalFlip(p=0.3), 205 | A.CoarseDropout(p=0.3, max_holes=10, max_height=32, max_width=32) 206 | ]) 207 | 208 | """ Dataset and loader """ 209 | train_dataset = DATASET(train_x, train_y, size, transform=transform) 210 | valid_dataset = DATASET(valid_x, valid_y, size, transform=None) 211 | 212 | train_loader = DataLoader( 213 | dataset=train_dataset, 214 | batch_size=batch_size, 215 | shuffle=True, 216 | num_workers=2 217 | ) 218 | 219 | valid_loader = DataLoader( 220 | dataset=valid_dataset, 221 | batch_size=batch_size, 222 | shuffle=False, 223 | num_workers=2 224 | ) 225 | 226 | """ Model """ 227 | device = torch.device('cuda') 228 | model = Model() 229 | model = model.to(device) 230 | 231 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 232 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True) 233 | loss_fn = DiceBCELoss() 234 | loss_name = "BCE Dice Loss" 235 | data_str = f"Optimizer: Adam\nLoss: {loss_name}\n" 236 | print_and_save(train_log_path, data_str) 237 | 238 | """ Training the model """ 239 | best_valid_metrics = 0.0 240 | early_stopping_count = 0 241 | 242 | for epoch in range(num_epochs): 243 | start_time = time.time() 244 | 245 | train_loss, train_metrics = train(model, train_loader, optimizer, loss_fn, device) 246 | valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, device) 247 | scheduler.step(valid_loss) 248 | 249 | if valid_metrics[1] > best_valid_metrics: 250 | data_str = f"Valid F1 improved from {best_valid_metrics:2.4f} to {valid_metrics[1]:2.4f}. Saving checkpoint: {checkpoint_path}" 251 | print_and_save(train_log_path, data_str) 252 | 253 | best_valid_metrics = valid_metrics[1] 254 | torch.save(model.state_dict(), checkpoint_path) 255 | early_stopping_count = 0 256 | 257 | elif valid_metrics[1] < best_valid_metrics: 258 | early_stopping_count += 1 259 | 260 | end_time = time.time() 261 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 262 | 263 | data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n" 264 | data_str += f"\tTrain Loss: {train_loss:.4f} - Jaccard: {train_metrics[0]:.4f} - F1: {train_metrics[1]:.4f} - Recall: {train_metrics[2]:.4f} - Precision: {train_metrics[3]:.4f}\n" 265 | data_str += f"\t Val. Loss: {valid_loss:.4f} - Jaccard: {valid_metrics[0]:.4f} - F1: {valid_metrics[1]:.4f} - Recall: {valid_metrics[2]:.4f} - Precision: {valid_metrics[3]:.4f}\n" 266 | print_and_save(train_log_path, data_str) 267 | 268 | if early_stopping_count == early_stopping_patience: 269 | data_str = f"Early stopping: validation loss stops improving from last {early_stopping_patience} continously.\n" 270 | print_and_save(train_log_path, data_str) 271 | break 272 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | import torch 8 | from sklearn.utils import shuffle 9 | from metrics import precision, recall, F2, dice_score, jac_score, hd_dist 10 | from sklearn.metrics import accuracy_score, confusion_matrix 11 | 12 | """ Seeding the randomness. """ 13 | def seeding(seed): 14 | random.seed(seed) 15 | os.environ["PYTHONHASHSEED"] = str(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | """ Create a directory """ 22 | def create_dir(path): 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | """ Shuffle the dataset. """ 27 | def shuffling(x, y): 28 | x, y = shuffle(x, y, random_state=42) 29 | return x, y 30 | 31 | def epoch_time(start_time, end_time): 32 | elapsed_time = end_time - start_time 33 | elapsed_mins = int(elapsed_time / 60) 34 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 35 | return elapsed_mins, elapsed_secs 36 | 37 | def print_and_save(file_path, data_str): 38 | print(data_str) 39 | with open(file_path, "a") as file: 40 | file.write(data_str) 41 | file.write("\n") 42 | 43 | def calculate_metrics(y_true, y_pred): 44 | y_true = y_true.detach().cpu().numpy() 45 | y_pred = y_pred.detach().cpu().numpy() 46 | 47 | y_pred = y_pred > 0.5 48 | y_pred = y_pred.reshape(-1) 49 | y_pred = y_pred.astype(np.uint8) 50 | 51 | y_true = y_true > 0.5 52 | y_true = y_true.reshape(-1) 53 | y_true = y_true.astype(np.uint8) 54 | 55 | ## Score 56 | score_jaccard = jac_score(y_true, y_pred) 57 | score_f1 = dice_score(y_true, y_pred) 58 | score_recall = recall(y_true, y_pred) 59 | score_precision = precision(y_true, y_pred) 60 | score_fbeta = F2(y_true, y_pred) 61 | score_acc = accuracy_score(y_true, y_pred) 62 | 63 | return [score_jaccard, score_f1, score_recall, score_precision, score_acc, score_fbeta] 64 | --------------------------------------------------------------------------------