├── README.md ├── LICENSE.md └── u_net_resnet_50_encoder.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-unet-resnet-50-encoder 2 | 3 | This model is a U-Net with a pretrained Resnet50 encoder. For most segmentation tasks that I've encountered using a pretrained encoder yields better results than training everything from scratch, though extracting the bottleneck layer from the PyTorch's implementation of Resnet is a bit of hassle so hopefully this will help someone! 4 | 5 | You will need PyTorch version >= 0.3.0 and TorchVision version >= 0.2.0 6 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kevin Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /u_net_resnet_50_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | resnet = torchvision.models.resnet.resnet50(pretrained=True) 5 | 6 | 7 | class ConvBlock(nn.Module): 8 | """ 9 | Helper module that consists of a Conv -> BN -> ReLU 10 | """ 11 | 12 | def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): 13 | super().__init__() 14 | self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) 15 | self.bn = nn.BatchNorm2d(out_channels) 16 | self.relu = nn.ReLU() 17 | self.with_nonlinearity = with_nonlinearity 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.bn(x) 22 | if self.with_nonlinearity: 23 | x = self.relu(x) 24 | return x 25 | 26 | 27 | class Bridge(nn.Module): 28 | """ 29 | This is the middle layer of the UNet which just consists of some 30 | """ 31 | 32 | def __init__(self, in_channels, out_channels): 33 | super().__init__() 34 | self.bridge = nn.Sequential( 35 | ConvBlock(in_channels, out_channels), 36 | ConvBlock(out_channels, out_channels) 37 | ) 38 | 39 | def forward(self, x): 40 | return self.bridge(x) 41 | 42 | 43 | class UpBlockForUNetWithResNet50(nn.Module): 44 | """ 45 | Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock 46 | """ 47 | 48 | def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, 49 | upsampling_method="conv_transpose"): 50 | super().__init__() 51 | 52 | if up_conv_in_channels == None: 53 | up_conv_in_channels = in_channels 54 | if up_conv_out_channels == None: 55 | up_conv_out_channels = out_channels 56 | 57 | if upsampling_method == "conv_transpose": 58 | self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) 59 | elif upsampling_method == "bilinear": 60 | self.upsample = nn.Sequential( 61 | nn.Upsample(mode='bilinear', scale_factor=2), 62 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) 63 | ) 64 | self.conv_block_1 = ConvBlock(in_channels, out_channels) 65 | self.conv_block_2 = ConvBlock(out_channels, out_channels) 66 | 67 | def forward(self, up_x, down_x): 68 | """ 69 | 70 | :param up_x: this is the output from the previous up block 71 | :param down_x: this is the output from the down block 72 | :return: upsampled feature map 73 | """ 74 | x = self.upsample(up_x) 75 | x = torch.cat([x, down_x], 1) 76 | x = self.conv_block_1(x) 77 | x = self.conv_block_2(x) 78 | return x 79 | 80 | 81 | class UNetWithResnet50Encoder(nn.Module): 82 | DEPTH = 6 83 | 84 | def __init__(self, n_classes=2): 85 | super().__init__() 86 | resnet = torchvision.models.resnet.resnet50(pretrained=True) 87 | down_blocks = [] 88 | up_blocks = [] 89 | self.input_block = nn.Sequential(*list(resnet.children()))[:3] 90 | self.input_pool = list(resnet.children())[3] 91 | for bottleneck in list(resnet.children()): 92 | if isinstance(bottleneck, nn.Sequential): 93 | down_blocks.append(bottleneck) 94 | self.down_blocks = nn.ModuleList(down_blocks) 95 | self.bridge = Bridge(2048, 2048) 96 | up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) 97 | up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) 98 | up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) 99 | up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, 100 | up_conv_in_channels=256, up_conv_out_channels=128)) 101 | up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, 102 | up_conv_in_channels=128, up_conv_out_channels=64)) 103 | 104 | self.up_blocks = nn.ModuleList(up_blocks) 105 | 106 | self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) 107 | 108 | def forward(self, x, with_output_feature_map=False): 109 | pre_pools = dict() 110 | pre_pools[f"layer_0"] = x 111 | x = self.input_block(x) 112 | pre_pools[f"layer_1"] = x 113 | x = self.input_pool(x) 114 | 115 | for i, block in enumerate(self.down_blocks, 2): 116 | x = block(x) 117 | if i == (UNetWithResnet50Encoder.DEPTH - 1): 118 | continue 119 | pre_pools[f"layer_{i}"] = x 120 | 121 | x = self.bridge(x) 122 | 123 | for i, block in enumerate(self.up_blocks, 1): 124 | key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" 125 | x = block(x, pre_pools[key]) 126 | output_feature_map = x 127 | x = self.out(x) 128 | del pre_pools 129 | if with_output_feature_map: 130 | return x, output_feature_map 131 | else: 132 | return x 133 | 134 | model = UNetWithResnet50Encoder().cuda() 135 | inp = torch.rand((2, 3, 512, 512)).cuda() 136 | out = model(inp) 137 | --------------------------------------------------------------------------------