├── Architecture.jpg ├── IntroTransNetR.png ├── IntroTransNetR1.png ├── Kvasir-seg.png ├── LICENSE ├── OOD.png ├── OOD_K.png ├── README.md ├── Residual-Transformer-Block.jpg ├── Test_PolypGen.jpg ├── bkai_crossdata.jpg ├── metrics.py ├── model.py ├── polypgen-samples.jpg ├── resnet.py ├── results.jpg ├── supplementry_C1.jpeg ├── supplementry_C6.jpg ├── test.py ├── train.py └── utils.py /Architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Architecture.jpg -------------------------------------------------------------------------------- /IntroTransNetR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/IntroTransNetR.png -------------------------------------------------------------------------------- /IntroTransNetR1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/IntroTransNetR1.png -------------------------------------------------------------------------------- /Kvasir-seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Kvasir-seg.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Debesh Jha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OOD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/OOD.png -------------------------------------------------------------------------------- /OOD_K.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/OOD_K.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransNetR: Transformer-based Residual Network for Polyp Segmentation with Multi-Center Out-of-Distribution Testing (MIDL 2023) 2 | 3 | Paper Link: https://arxiv.org/pdf/2303.07428.pdf 4 | 5 | TransNetR is an encoder decoder network which can be used for efficient biomedical image segmentation for both in-distribution and out-of-distribution datasets. 6 | 7 | ## In-distribution and Out-of-distributuion dataset 8 |
9 |
10 |
11 | Figure 1: Illustration of different scenarios expected to arise in real-world settings. The proposed work conducted both in-distribution and out-of-distribution validation process. C1 to C6 represent the different centers data present in PolypGen dataset width=50% height=50%
12 |
16 |
17 |
18 | Figure 2: Block diagram of TransNetR along with the Residual Transformer block
19 |
24 |
25 |
26 | Figure 3: Qualitative example showing polyp segmentation on Kvasir-SEG
27 |
31 |
32 | Table 1: Quantitative results on the Kvasir-SEG test dataset. The parameters are in Mil- lions and Flops are in GMac.
33 |
36 |
37 |
40 |
41 |
44 |
45 |
49 |
50 |
51 | Figure 4: Cross-data result when models trained on Kvasir-SEG & tested on BKAI-IGH.LeakyReLU activation function. Finally, the output from the LeakyReLU is passed througha residual block which acts as the output of the residual transformer block.
52 |
56 |
57 |
58 | Figure 5: Center-wise example images from the PolypGen dataset. Here, the variabilityamong the dataset from different centers can be observed. There is a differencein image resolutions and sizes, shapes, colors, textures and appearances and col-lection protocols.Figure 6: Qualitative result when the TransNetR is trained on Kvasir-SEG and tested on(a) PolypGen (center 6 (C6)) and (b) PolypGen (center 1 (C1)).13
59 |
64 |
65 |
66 | Figure 6: Qualitative result when the TransNetR is trained on Kvasir-SEG and tested on(a) PolypGen (center 6 (C6))
67 |
71 |
72 |
73 | Figure 7: Qualitative result when the TransNetR is trained on Kvasir-SEG and tested on PolypGen (center 1 (C1))
74 |
79 | @INPROCEEDINGS{JhaTrans2023, 80 | author={D.{Jha} and N.{Tomar} and V.{Sharma} and U.{Bagci}}, 81 | booktitle={Proceedings of the Medical Imaging with Deep Learning}, 82 | title={TransNetR: Transformer-based Residual Network for Polyp Segmentation with Multi-Center Out-of-Distribution Testing}, 83 | year={2023}} 84 |85 | 86 | ## Contact 87 | Please contact debesh.jha@northwestern.edu and nikhilroxtomar@gmail.com for any further questions. 88 | -------------------------------------------------------------------------------- /Residual-Transformer-Block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Residual-Transformer-Block.jpg -------------------------------------------------------------------------------- /Test_PolypGen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/Test_PolypGen.jpg -------------------------------------------------------------------------------- /bkai_crossdata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/bkai_crossdata.jpg -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.spatial.distance import directed_hausdorff 5 | 6 | """ Loss Functions -------------------------------------- """ 7 | class DiceLoss(nn.Module): 8 | def __init__(self, weight=None, size_average=True): 9 | super(DiceLoss, self).__init__() 10 | 11 | def forward(self, inputs, targets, smooth=1): 12 | inputs = torch.sigmoid(inputs) 13 | 14 | inputs = inputs.view(-1) 15 | targets = targets.view(-1) 16 | 17 | intersection = (inputs * targets).sum() 18 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 19 | 20 | return 1 - dice 21 | 22 | class DiceBCELoss(nn.Module): 23 | def __init__(self, weight=None, size_average=True): 24 | super(DiceBCELoss, self).__init__() 25 | 26 | def forward(self, inputs, targets, smooth=1): 27 | inputs = torch.sigmoid(inputs) 28 | 29 | inputs = inputs.view(-1) 30 | targets = targets.view(-1) 31 | 32 | intersection = (inputs * targets).sum() 33 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 34 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 35 | Dice_BCE = BCE + dice_loss 36 | 37 | return Dice_BCE 38 | 39 | """ Metrics ------------------------------------------ """ 40 | def precision(y_true, y_pred): 41 | intersection = (y_true * y_pred).sum() 42 | return (intersection + 1e-15) / (y_pred.sum() + 1e-15) 43 | 44 | def recall(y_true, y_pred): 45 | intersection = (y_true * y_pred).sum() 46 | return (intersection + 1e-15) / (y_true.sum() + 1e-15) 47 | 48 | def F2(y_true, y_pred, beta=2): 49 | p = precision(y_true,y_pred) 50 | r = recall(y_true, y_pred) 51 | return (1+beta**2.) *(p*r) / float(beta**2*p + r + 1e-15) 52 | 53 | def dice_score(y_true, y_pred): 54 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15) 55 | 56 | def jac_score(y_true, y_pred): 57 | intersection = (y_true * y_pred).sum() 58 | union = y_true.sum() + y_pred.sum() - intersection 59 | return (intersection + 1e-15) / (union + 1e-15) 60 | 61 | ## https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/319452 62 | def hd_dist(preds, targets): 63 | haussdorf_dist = directed_hausdorff(preds, targets)[0] 64 | return haussdorf_dist 65 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from resnet import resnet50, resnet34 5 | 6 | class Conv2D(nn.Module): 7 | def __init__(self, in_c, out_c, kernel_size=3, padding=1, stride=1, dilation=1, bias=True, act=True): 8 | super().__init__() 9 | 10 | self.act = act 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, stride=stride, bias=bias), 13 | nn.BatchNorm2d(out_c) 14 | ) 15 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | if self.act == True: 20 | x = self.relu(x) 21 | return x 22 | 23 | class residual_block(nn.Module): 24 | def __init__(self, in_c, out_c): 25 | super().__init__() 26 | 27 | self.conv = nn.Sequential( 28 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 29 | nn.BatchNorm2d(out_c), 30 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 31 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(out_c) 33 | ) 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_c, out_c, kernel_size=1, padding=0), 36 | nn.BatchNorm2d(out_c) 37 | ) 38 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 39 | 40 | def forward(self, inputs): 41 | x = self.conv(inputs) 42 | s = self.shortcut(inputs) 43 | return self.relu(x + s) 44 | 45 | class residual_transformer_block(nn.Module): 46 | def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None): 47 | super().__init__() 48 | 49 | self.ps = patch_size 50 | self.c1 = Conv2D(in_c, out_c) 51 | 52 | encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads) 53 | self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 54 | 55 | self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False) 56 | self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False) 57 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 58 | self.r1 = residual_block(out_c, out_c) 59 | 60 | def forward(self, inputs): 61 | x = self.c1(inputs) 62 | 63 | b, c, h, w = x.shape 64 | num_patches = (h*w)//(self.ps**2) 65 | x = torch.reshape(x, (b, (self.ps**2)*c, num_patches)) 66 | x = self.te(x) 67 | x = torch.reshape(x, (b, c, h, w)) 68 | 69 | x = self.c2(x) 70 | s = self.c3(inputs) 71 | x = self.relu(x + s) 72 | x = self.r1(x) 73 | return x 74 | 75 | class Model(nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | 79 | """ Encoder """ 80 | backbone = resnet50() 81 | self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu) 82 | self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1) 83 | self.layer2 = backbone.layer2 84 | self.layer3 = backbone.layer3 85 | self.layer4 = backbone.layer4 86 | 87 | self.e1 = Conv2D(64, 64, kernel_size=1, padding=0) 88 | self.e2 = Conv2D(256, 64, kernel_size=1, padding=0) 89 | self.e3 = Conv2D(512, 64, kernel_size=1, padding=0) 90 | self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0) 91 | 92 | 93 | """ Decoder """ 94 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 95 | self.r1 = residual_transformer_block(64+64, 64, dim=64) 96 | self.r2 = residual_transformer_block(64+64, 64, dim=256) 97 | self.r3 = residual_block(64+64, 64) 98 | 99 | """ Classifier """ 100 | self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0) 101 | 102 | def forward(self, inputs): 103 | """ Encoder """ 104 | x0 = inputs 105 | x1 = self.layer0(x0) ## [-1, 64, h/2, w/2] 106 | x2 = self.layer1(x1) ## [-1, 256, h/4, w/4] 107 | x3 = self.layer2(x2) ## [-1, 512, h/8, w/8] 108 | x4 = self.layer3(x3) ## [-1, 1024, h/16, w/16] 109 | # print(x1.shape, x2.shape, x3.shape, x4.shape) 110 | 111 | e1 = self.e1(x1) 112 | e2 = self.e2(x2) 113 | e3 = self.e3(x3) 114 | e4 = self.e4(x4) 115 | 116 | """ Decoder """ 117 | x = self.up(e4) 118 | x = torch.cat([x, e3], axis=1) 119 | x = self.r1(x) 120 | 121 | x = self.up(x) 122 | x = torch.cat([x, e2], axis=1) 123 | x = self.r2(x) 124 | 125 | x = self.up(x) 126 | x = torch.cat([x, e1], axis=1) 127 | x = self.r3(x) 128 | 129 | x = self.up(x) 130 | 131 | """ Classifier """ 132 | outputs = self.outputs(x) 133 | return outputs 134 | 135 | if __name__ == "__main__": 136 | x = torch.randn((4, 3, 256, 256)) 137 | model = Model() 138 | y = model(x) 139 | print(y.shape) 140 | -------------------------------------------------------------------------------- /polypgen-samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebeshJha/TransNetR/57567bb27fabe767af81c33001e38284a53555e9/polypgen-samples.jpg -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 | base_width=64, dilation=1, norm_layer=None): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(Bottleneck, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | width = int(planes * (base_width / 64.)) * groups 84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = conv1x1(inplanes, width) 86 | self.bn1 = norm_layer(width) 87 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 88 | self.bn2 = norm_layer(width) 89 | self.conv3 = conv1x1(width, planes * self.expansion) 90 | self.bn3 = norm_layer(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | identity = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | identity = self.downsample(x) 111 | 112 | out += identity 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 121 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 122 | norm_layer=None): 123 | super(ResNet, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(self.inplanes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 146 | dilate=replace_stride_with_dilation[0]) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 148 | dilate=replace_stride_with_dilation[1]) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 150 | dilate=replace_stride_with_dilation[2]) 151 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | # Zero-initialize the last BN in each residual branch, 162 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 163 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 164 | if zero_init_residual: 165 | for m in self.modules(): 166 | if isinstance(m, Bottleneck): 167 | nn.init.constant_(m.bn3.weight, 0) 168 | elif isinstance(m, BasicBlock): 169 | nn.init.constant_(m.bn2.weight, 0) 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 172 | norm_layer = self._norm_layer 173 | downsample = None 174 | previous_dilation = self.dilation 175 | if dilate: 176 | self.dilation *= stride 177 | stride = 1 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, previous_dilation, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for _ in range(1, blocks): 189 | layers.append(block(self.inplanes, planes, groups=self.groups, 190 | base_width=self.base_width, dilation=self.dilation, 191 | norm_layer=norm_layer)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | x = torch.flatten(x, 1) 208 | x = self.fc(x) 209 | 210 | return x 211 | 212 | 213 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 214 | model = ResNet(block, layers, **kwargs) 215 | if pretrained: 216 | state_dict = load_state_dict_from_url(model_urls[arch], 217 | progress=progress) 218 | model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition"