├── README.rst └── network ├── deeplabv3_3d.py ├── aspp.py └── resnet.py /README.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | Pytorch Deeplabv3+ 3D 3 | ====================== 4 | This is PyTorch implementation of 3D `Deeplabv3+ `_ 5 | 6 | - Reference 1. https://github.com/jfzhang95/pytorch-deeplab-xception 7 | - Reference 2. https://github.com/fregu856/deeplabv3 8 | 9 | --------------- 10 | How to use 11 | --------------- 12 | .. code-block:: python 13 | 14 | 15 | from network.deeplabv3_3d import DeepLabV3_3D 16 | 17 | num_classes = 10 # Number of classes. (= number of output channel) 18 | input_channels = 3 # Number of input channel 19 | resnet = 'resnet18_os16' # Base resnet architecture ('resnet18_os16', 'resnet34_os16', 'resnet50_os16', 'resnet101_os16', 'resnet152_os16', 'resnet18_os8', 'resnet34_os18') 20 | last_activation = 'softmax' # 'softmax', 'sigmoid' or None 21 | 22 | model = DeepLabV3_3D(num_classes = num_classes, input_channels = input_channels, resnet = resnet, last_activation = last_activation) 23 | 24 | -------------------------------------------------------------------------------- /network/deeplabv3_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import os 6 | 7 | from network.resnet import ResNet18_OS16, ResNet34_OS16, ResNet50_OS16, ResNet101_OS16, ResNet152_OS16, ResNet18_OS8, ResNet34_OS8 8 | from network.aspp import ASPP, ASPP_Bottleneck 9 | 10 | class DeepLabV3_3D(nn.Module): 11 | def __init__(self, num_classes, input_channels, resnet, last_activation = None): 12 | super(DeepLabV3_3D, self).__init__() 13 | self.num_classes = num_classes 14 | self.last_activation = last_activation 15 | 16 | if resnet.lower() == 'resnet18_os16': 17 | self.resnet = ResNet18_OS16(input_channels) 18 | 19 | elif resnet.lower() == 'resnet34_os16': 20 | self.resnet = ResNet34_OS16(input_channels) 21 | 22 | elif resnet.lower() == 'resnet50_os16': 23 | self.resnet = ResNet50_OS16(input_channels) 24 | 25 | elif resnet.lower() == 'resnet101_os16': 26 | self.resnet = ResNet101_OS16(input_channels) 27 | 28 | elif resnet.lower() == 'resnet152_os16': 29 | self.resnet = ResNet152_OS16(input_channels) 30 | 31 | elif resnet.lower() == 'resnet18_os8': 32 | self.resnet = ResNet18_OS8(input_channels) 33 | 34 | elif resnet.lower() == 'resnet34_os8': 35 | self.resnet = ResNet34_OS8(input_channels) 36 | 37 | if resnet.lower() in ['resnet50_os16', 'resnet101_os16', 'resnet152_os16']: 38 | self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) 39 | else: 40 | self.aspp = ASPP(num_classes=self.num_classes) 41 | 42 | def forward(self, x): 43 | 44 | h = x.size()[2] 45 | w = x.size()[3] 46 | c = x.size()[4] 47 | 48 | feature_map = self.resnet(x) 49 | 50 | output = self.aspp(feature_map) 51 | 52 | output = F.interpolate(output, size=(h, w, c), mode='trilinear', align_corners=True) 53 | 54 | if self.last_activation.lower() == 'sigmoid': 55 | output = nn.Sigmoid()(output) 56 | 57 | elif self.last_activation.lower() == 'softmax': 58 | output = nn.Softmax()(output) 59 | 60 | return output 61 | -------------------------------------------------------------------------------- /network/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ASPP(nn.Module): 6 | def __init__(self, num_classes): 7 | super(ASPP, self).__init__() 8 | 9 | self.conv_1x1_1 = nn.Conv3d(512, 256, kernel_size=1) 10 | self.bn_conv_1x1_1 = nn.BatchNorm3d(256) 11 | 12 | self.conv_3x3_1 = nn.Conv3d(512, 256, kernel_size=3, stride=1, padding=6, dilation=6) 13 | self.bn_conv_3x3_1 = nn.BatchNorm3d(256) 14 | 15 | self.conv_3x3_2 = nn.Conv3d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12) 16 | self.bn_conv_3x3_2 = nn.BatchNorm3d(256) 17 | 18 | self.conv_3x3_3 = nn.Conv3d(512, 256, kernel_size=3, stride=1, padding=18, dilation=18) 19 | self.bn_conv_3x3_3 = nn.BatchNorm3d(256) 20 | 21 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 22 | 23 | self.conv_1x1_2 = nn.Conv3d(512, 256, kernel_size=1) 24 | self.bn_conv_1x1_2 = nn.BatchNorm3d(256) 25 | 26 | self.conv_1x1_3 = nn.Conv3d(1280, 256, kernel_size=1) 27 | self.bn_conv_1x1_3 = nn.BatchNorm3d(256) 28 | 29 | self.conv_1x1_4 = nn.Conv3d(256, num_classes, kernel_size=1) 30 | 31 | def forward(self, feature_map): 32 | feature_map_h = feature_map.size()[2] 33 | feature_map_w = feature_map.size()[3] 34 | feature_map_c = feature_map.size()[4] 35 | 36 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) 37 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) 38 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) 39 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) 40 | 41 | out_img = self.avg_pool(feature_map) 42 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) 43 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w, feature_map_c), mode='trilinear', align_corners=True) 44 | 45 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) 46 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) 47 | out = self.conv_1x1_4(out) 48 | 49 | return out 50 | 51 | class ASPP_Bottleneck(nn.Module): 52 | def __init__(self, num_classes): 53 | super(ASPP_Bottleneck, self).__init__() 54 | 55 | self.conv_1x1_1 = nn.Conv3d(4*512, 256, kernel_size=1) 56 | self.bn_conv_1x1_1 = nn.BatchNorm3d(256) 57 | 58 | self.conv_3x3_1 = nn.Conv3d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6) 59 | self.bn_conv_3x3_1 = nn.BatchNorm3d(256) 60 | 61 | self.conv_3x3_2 = nn.Conv3d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12) 62 | self.bn_conv_3x3_2 = nn.BatchNorm3d(256) 63 | 64 | self.conv_3x3_3 = nn.Conv3d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18) 65 | self.bn_conv_3x3_3 = nn.BatchNorm3d(256) 66 | 67 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 68 | 69 | self.conv_1x1_2 = nn.Conv3d(4*512, 256, kernel_size=1) 70 | self.bn_conv_1x1_2 = nn.BatchNorm3d(256) 71 | 72 | self.conv_1x1_3 = nn.Conv3d(1280, 256, kernel_size=1) # (1280 = 5*256) 73 | self.bn_conv_1x1_3 = nn.BatchNorm3d(256) 74 | 75 | self.conv_1x1_4 = nn.Conv3d(256, num_classes, kernel_size=1) 76 | 77 | def forward(self, feature_map): 78 | feature_map_h = feature_map.size()[2] 79 | feature_map_w = feature_map.size()[3] 80 | feature_map_c = feature_map.size()[4] 81 | 82 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) 83 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) 84 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) 85 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) 86 | 87 | out_img = self.avg_pool(feature_map) 88 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) 89 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w, feature_map_c), mode='trilinear', align_corners=True) 90 | 91 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) 92 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) 93 | out = self.conv_1x1_4(out) 94 | 95 | return out 96 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | # NOTE! OS: output stride, the ratio of input image resolution to final output resolution 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | import math 9 | from functools import partial 10 | 11 | # Reference (1) : https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18 12 | # Reference (2) : https://github.com/fregu856/deeplabv3 13 | 14 | 15 | 16 | 17 | 18 | # ----------------------------------------------- 3D Resnet ----------------------------------------------- 19 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 20 | 'resnet152'] 21 | 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm3d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = nn.BatchNorm3d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None): 71 | super(Bottleneck, self).__init__() 72 | self.conv1 = conv1x1(inplanes, planes) 73 | self.bn1 = nn.BatchNorm3d(planes) 74 | self.conv2 = conv3x3(planes, planes, stride) 75 | self.bn2 = nn.BatchNorm3d(planes) 76 | self.conv3 = conv1x1(planes, planes * self.expansion) 77 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | identity = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | identity = self.downsample(x) 98 | 99 | out += identity 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class ResNet(nn.Module): 106 | 107 | def __init__(self, input_channels, block, layers, num_classes=1, zero_init_residual=False): 108 | super(ResNet, self).__init__() 109 | self.inplanes = 64 110 | self.conv1 = nn.Conv3d(input_channels, 64, kernel_size=7, stride=2, padding=3, 111 | bias=False) 112 | self.bn1 = nn.BatchNorm3d(64) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 115 | self.layer1 = self._make_layer(block, 64, layers[0]) 116 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 119 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1)) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv3d): 124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 125 | elif isinstance(m, nn.BatchNorm3d): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | 129 | # Zero-initialize the last BN in each residual branch, 130 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 131 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, Bottleneck): 135 | nn.init.constant_(m.bn3.weight, 0) 136 | elif isinstance(m, BasicBlock): 137 | nn.init.constant_(m.bn2.weight, 0) 138 | 139 | def _make_layer(self, block, planes, blocks, stride=1): 140 | downsample = None 141 | if stride != 1 or self.inplanes != planes * block.expansion: 142 | downsample = nn.Sequential( 143 | conv1x1(self.inplanes, planes * block.expansion, stride), 144 | nn.BatchNorm3d(planes * block.expansion), 145 | ) 146 | 147 | layers = [] 148 | layers.append(block(self.inplanes, planes, stride, downsample)) 149 | self.inplanes = planes * block.expansion 150 | for _ in range(1, blocks): 151 | layers.append(block(self.inplanes, planes)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def forward(self, x): 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.relu(x) 159 | x = self.maxpool(x) 160 | 161 | x = self.layer1(x) 162 | x = self.layer2(x) 163 | x = self.layer3(x) 164 | x = self.layer4(x) 165 | 166 | x = self.avgpool(x) 167 | x = x.view(x.size(0), -1) 168 | x = self.fc(x) 169 | 170 | return x 171 | 172 | 173 | def resnet18(input_channels, **kwargs): 174 | model = ResNet(input_channels, BasicBlock, [2, 2, 2, 2], **kwargs) 175 | return model 176 | 177 | 178 | def resnet34(input_channels, **kwargs): 179 | model = ResNet(input_channels, BasicBlock, [3, 4, 6, 3], **kwargs) 180 | return model 181 | 182 | 183 | def resnet50(input_channels, **kwargs): 184 | model = ResNet(input_channels, Bottleneck, [3, 4, 6, 3], **kwargs) 185 | return model 186 | 187 | 188 | def resnet101(input_channels, **kwargs): 189 | model = ResNet(input_channels, Bottleneck, [3, 4, 23, 3], **kwargs) 190 | return model 191 | 192 | 193 | def resnet152(input_channels, **kwargs): 194 | model = ResNet(input_channels, Bottleneck, [3, 8, 36, 3], **kwargs) 195 | return model 196 | 197 | 198 | # -------------------------------------- Resnet for Deeplab -------------------------------------- 199 | def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1): 200 | strides = [stride] + [1]*(num_blocks - 1) 201 | 202 | blocks = [] 203 | for stride in strides: 204 | blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation)) 205 | in_channels = block.expansion*channels 206 | 207 | layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments) 208 | 209 | return layer 210 | 211 | class BasicBlock(nn.Module): 212 | expansion = 1 213 | 214 | def __init__(self, in_channels, channels, stride=1, dilation=1): 215 | super(BasicBlock, self).__init__() 216 | 217 | out_channels = self.expansion*channels 218 | 219 | if type(dilation) != type(1): 220 | dilation = 1 221 | 222 | self.conv1 = nn.Conv3d(in_channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 223 | self.bn1 = nn.BatchNorm3d(channels) 224 | 225 | self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False) 226 | self.bn2 = nn.BatchNorm3d(channels) 227 | 228 | if (stride != 1) or (in_channels != out_channels): 229 | conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 230 | bn = nn.BatchNorm3d(out_channels) 231 | self.downsample = nn.Sequential(conv, bn) 232 | else: 233 | self.downsample = nn.Sequential() 234 | 235 | def forward(self, x): 236 | out = F.relu(self.bn1(self.conv1(x))) 237 | out = self.bn2(self.conv2(out)) 238 | 239 | out = out + self.downsample(x) 240 | 241 | out = F.relu(out) 242 | 243 | return out 244 | 245 | class Bottleneck(nn.Module): 246 | expansion = 4 247 | 248 | def __init__(self, in_channels, channels, stride=1, dilation=1): 249 | super(Bottleneck, self).__init__() 250 | 251 | out_channels = self.expansion*channels 252 | 253 | self.conv1 = nn.Conv3d(in_channels, channels, kernel_size=1, bias=False) 254 | self.bn1 = nn.BatchNorm3d(channels) 255 | 256 | self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 257 | self.bn2 = nn.BatchNorm3d(channels) 258 | 259 | self.conv3 = nn.Conv3d(channels, out_channels, kernel_size=1, bias=False) 260 | self.bn3 = nn.BatchNorm3d(out_channels) 261 | 262 | if (stride != 1) or (in_channels != out_channels): 263 | conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 264 | bn = nn.BatchNorm3d(out_channels) 265 | self.downsample = nn.Sequential(conv, bn) 266 | else: 267 | self.downsample = nn.Sequential() 268 | 269 | def forward(self, x): 270 | out = F.relu(self.bn1(self.conv1(x))) 271 | out = F.relu(self.bn2(self.conv2(out))) 272 | out = self.bn3(self.conv3(out)) 273 | 274 | out = out + self.downsample(x) 275 | 276 | out = F.relu(out) 277 | 278 | return out 279 | 280 | class ResNet_Bottleneck_OS16(nn.Module): 281 | def __init__(self, num_layers, input_channels): 282 | super(ResNet_Bottleneck_OS16, self).__init__() 283 | 284 | if num_layers == 50: 285 | resnet = resnet50(input_channels) 286 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 287 | elif num_layers == 101: 288 | resnet = resnet101(input_channels) 289 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 290 | elif num_layers == 152: 291 | resnet = resnet152(input_channels) 292 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 293 | else: 294 | raise Exception("num_layers must be in {50, 101, 152}!") 295 | 296 | self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2) 297 | 298 | def forward(self, x): 299 | c4 = self.resnet(x) 300 | 301 | output = self.layer5(c4) 302 | 303 | return output 304 | 305 | class ResNet_BasicBlock_OS16(nn.Module): 306 | def __init__(self, num_layers, input_channels): 307 | super(ResNet_BasicBlock_OS16, self).__init__() 308 | 309 | if num_layers == 18: 310 | resnet = resnet18(input_channels) 311 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 312 | 313 | num_blocks = 2 314 | 315 | elif num_layers == 34: 316 | resnet = resnet34(input_channels) 317 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 318 | 319 | num_blocks = 3 320 | else: 321 | raise Exception("num_layers must be in {18, 34}!") 322 | 323 | self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks, stride=1, dilation=2) 324 | 325 | def forward(self, x): 326 | c4 = self.resnet(x) 327 | 328 | output = self.layer5(c4) 329 | 330 | return output 331 | 332 | class ResNet_BasicBlock_OS8(nn.Module): 333 | def __init__(self, num_layers, input_channels): 334 | super(ResNet_BasicBlock_OS8, self).__init__() 335 | 336 | if num_layers == 18: 337 | resnet = resnet18(input_channels) 338 | 339 | self.resnet = nn.Sequential(*list(resnet.children())[:-4]) 340 | 341 | num_blocks_layer_4 = 2 342 | num_blocks_layer_5 = 2 343 | 344 | elif num_layers == 34: 345 | resnet = resnet34(input_channels) 346 | 347 | self.resnet = nn.Sequential(*list(resnet.children())[:-4]) 348 | 349 | num_blocks_layer_4 = 6 350 | num_blocks_layer_5 = 3 351 | else: 352 | raise Exception("num_layers must be in {18, 34}!") 353 | 354 | self.layer4 = make_layer(BasicBlock, in_channels=128, channels=256, num_blocks=num_blocks_layer_4, stride=1, dilation=2) 355 | 356 | self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks_layer_5, stride=1, dilation=4) 357 | 358 | def forward(self, x): 359 | c3 = self.resnet(x) 360 | 361 | output = self.layer4(c3) 362 | output = self.layer5(output) 363 | 364 | return output 365 | 366 | def ResNet18_OS16(input_channels): 367 | return ResNet_BasicBlock_OS16(num_layers=18, input_channels=input_channels) 368 | 369 | def ResNet50_OS16(input_channels): 370 | return ResNet_Bottleneck_OS16(num_layers=50, input_channels=input_channels) 371 | 372 | def ResNet101_OS16(input_channels): 373 | return ResNet_Bottleneck_OS16(num_layers=101, input_channels=input_channels) 374 | 375 | def ResNet152_OS16(input_channels): 376 | return ResNet_Bottleneck_OS16(num_layers=152, input_channels=input_channels) 377 | 378 | def ResNet34_OS16(input_channels): 379 | return ResNet_BasicBlock_OS16(num_layers=34, input_channels=input_channels) 380 | 381 | def ResNet18_OS8(input_channels): 382 | return ResNet_BasicBlock_OS8(num_layers=18, input_channels=input_channels) 383 | 384 | def ResNet34_OS8(input_channels): 385 | return ResNet_BasicBlock_OS8(num_layers=34, input_channels=input_channels) --------------------------------------------------------------------------------