├── .gitignore ├── LICENSE ├── README.md ├── model.py └── unet-architecture.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jackson Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net implementation in PyTorch 2 | 3 | The U-Net is an encoder-decoder neural network used for **semantic segmentation**. The implementation in this repository is a modified version of the U-Net proposed in [this paper](https://arxiv.org/abs/1505.04597). 4 | 5 | ![U-Net Architecture](unet-architecture.png) 6 | 7 | ## Features 8 | 9 | 1. **You can alter the U-Net's depth.** 10 | The original U-Net uses a depth of 5, as depicted in the diagram above. The word "depth" specifically 11 | refers to the number of *different* spatially-sized convolutional outputs. With this U-Net implementation, you can easily vary the depth. 12 | 13 | 2. **You can merge decoder and encoder pathways in two ways.** 14 | In the original U-Net, the decoder and encoder activations are merged by concatenating channels. 15 | I've implemented a ResNet-style merging of the decoder and encoder activations by adding 16 | these activations. This was easy to code up, but it may not make sense theoretically and has not been tested. 17 | 18 | ## Pixel-wise loss for semantic segmentation 19 | I had some trouble getting the pixel-wise loss working correctly for a semantic segmentation task. 20 | Here's how I got it working in the end. 21 | 22 | ```python 23 | from model import UNet 24 | 25 | model = UNet() 26 | 27 | # set up dataloaders, etc. 28 | 29 | output = model(some_input_data) 30 | 31 | # permute is like np.transpose: (N, C, H, W) => (H, W, N, C) 32 | # contiguous is required because of this issue: https://github.com/pytorch/pytorch/issues/764 33 | # view: reshapes the output tensor so that we have (H * W * N, num_class) 34 | # NOTE: num_class == C (number of output channels) 35 | output = output.permute(2, 3, 0, 1).contiguous().view(-1, num_classes) 36 | ``` 37 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | import numpy as np 8 | 9 | def conv3x3(in_channels, out_channels, stride=1, 10 | padding=1, bias=True, groups=1): 11 | return nn.Conv2d( 12 | in_channels, 13 | out_channels, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=padding, 17 | bias=bias, 18 | groups=groups) 19 | 20 | def upconv2x2(in_channels, out_channels, mode='transpose'): 21 | if mode == 'transpose': 22 | return nn.ConvTranspose2d( 23 | in_channels, 24 | out_channels, 25 | kernel_size=2, 26 | stride=2) 27 | else: 28 | # out_channels is always going to be the same 29 | # as in_channels 30 | return nn.Sequential( 31 | nn.Upsample(mode='bilinear', scale_factor=2), 32 | conv1x1(in_channels, out_channels)) 33 | 34 | def conv1x1(in_channels, out_channels, groups=1): 35 | return nn.Conv2d( 36 | in_channels, 37 | out_channels, 38 | kernel_size=1, 39 | groups=groups, 40 | stride=1) 41 | 42 | 43 | class DownConv(nn.Module): 44 | """ 45 | A helper Module that performs 2 convolutions and 1 MaxPool. 46 | A ReLU activation follows each convolution. 47 | """ 48 | def __init__(self, in_channels, out_channels, pooling=True): 49 | super(DownConv, self).__init__() 50 | 51 | self.in_channels = in_channels 52 | self.out_channels = out_channels 53 | self.pooling = pooling 54 | 55 | self.conv1 = conv3x3(self.in_channels, self.out_channels) 56 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 57 | 58 | if self.pooling: 59 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 60 | 61 | def forward(self, x): 62 | x = F.relu(self.conv1(x)) 63 | x = F.relu(self.conv2(x)) 64 | before_pool = x 65 | if self.pooling: 66 | x = self.pool(x) 67 | return x, before_pool 68 | 69 | 70 | class UpConv(nn.Module): 71 | """ 72 | A helper Module that performs 2 convolutions and 1 UpConvolution. 73 | A ReLU activation follows each convolution. 74 | """ 75 | def __init__(self, in_channels, out_channels, 76 | merge_mode='concat', up_mode='transpose'): 77 | super(UpConv, self).__init__() 78 | 79 | self.in_channels = in_channels 80 | self.out_channels = out_channels 81 | self.merge_mode = merge_mode 82 | self.up_mode = up_mode 83 | 84 | self.upconv = upconv2x2(self.in_channels, self.out_channels, 85 | mode=self.up_mode) 86 | 87 | if self.merge_mode == 'concat': 88 | self.conv1 = conv3x3( 89 | 2*self.out_channels, self.out_channels) 90 | else: 91 | # num of input channels to conv2 is same 92 | self.conv1 = conv3x3(self.out_channels, self.out_channels) 93 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 94 | 95 | 96 | def forward(self, from_down, from_up): 97 | """ Forward pass 98 | Arguments: 99 | from_down: tensor from the encoder pathway 100 | from_up: upconv'd tensor from the decoder pathway 101 | """ 102 | from_up = self.upconv(from_up) 103 | if self.merge_mode == 'concat': 104 | x = torch.cat((from_up, from_down), 1) 105 | else: 106 | x = from_up + from_down 107 | x = F.relu(self.conv1(x)) 108 | x = F.relu(self.conv2(x)) 109 | return x 110 | 111 | 112 | class UNet(nn.Module): 113 | """ `UNet` class is based on https://arxiv.org/abs/1505.04597 114 | 115 | The U-Net is a convolutional encoder-decoder neural network. 116 | Contextual spatial information (from the decoding, 117 | expansive pathway) about an input tensor is merged with 118 | information representing the localization of details 119 | (from the encoding, compressive pathway). 120 | 121 | Modifications to the original paper: 122 | (1) padding is used in 3x3 convolutions to prevent loss 123 | of border pixels 124 | (2) merging outputs does not require cropping due to (1) 125 | (3) residual connections can be used by specifying 126 | UNet(merge_mode='add') 127 | (4) if non-parametric upsampling is used in the decoder 128 | pathway (specified by upmode='upsample'), then an 129 | additional 1x1 2d convolution occurs after upsampling 130 | to reduce channel dimensionality by a factor of 2. 131 | This channel halving happens with the convolution in 132 | the tranpose convolution (specified by upmode='transpose') 133 | """ 134 | 135 | def __init__(self, num_classes, in_channels=3, depth=5, 136 | start_filts=64, up_mode='transpose', 137 | merge_mode='concat'): 138 | """ 139 | Arguments: 140 | in_channels: int, number of channels in the input tensor. 141 | Default is 3 for RGB images. 142 | depth: int, number of MaxPools in the U-Net. 143 | start_filts: int, number of convolutional filters for the 144 | first conv. 145 | up_mode: string, type of upconvolution. Choices: 'transpose' 146 | for transpose convolution or 'upsample' for nearest neighbour 147 | upsampling. 148 | """ 149 | super(UNet, self).__init__() 150 | 151 | if up_mode in ('transpose', 'upsample'): 152 | self.up_mode = up_mode 153 | else: 154 | raise ValueError("\"{}\" is not a valid mode for " 155 | "upsampling. Only \"transpose\" and " 156 | "\"upsample\" are allowed.".format(up_mode)) 157 | 158 | if merge_mode in ('concat', 'add'): 159 | self.merge_mode = merge_mode 160 | else: 161 | raise ValueError("\"{}\" is not a valid mode for" 162 | "merging up and down paths. " 163 | "Only \"concat\" and " 164 | "\"add\" are allowed.".format(up_mode)) 165 | 166 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' 167 | if self.up_mode == 'upsample' and self.merge_mode == 'add': 168 | raise ValueError("up_mode \"upsample\" is incompatible " 169 | "with merge_mode \"add\" at the moment " 170 | "because it doesn't make sense to use " 171 | "nearest neighbour to reduce " 172 | "depth channels (by half).") 173 | 174 | self.num_classes = num_classes 175 | self.in_channels = in_channels 176 | self.start_filts = start_filts 177 | self.depth = depth 178 | 179 | self.down_convs = [] 180 | self.up_convs = [] 181 | 182 | # create the encoder pathway and add to a list 183 | for i in range(depth): 184 | ins = self.in_channels if i == 0 else outs 185 | outs = self.start_filts*(2**i) 186 | pooling = True if i < depth-1 else False 187 | 188 | down_conv = DownConv(ins, outs, pooling=pooling) 189 | self.down_convs.append(down_conv) 190 | 191 | # create the decoder pathway and add to a list 192 | # - careful! decoding only requires depth-1 blocks 193 | for i in range(depth-1): 194 | ins = outs 195 | outs = ins // 2 196 | up_conv = UpConv(ins, outs, up_mode=up_mode, 197 | merge_mode=merge_mode) 198 | self.up_convs.append(up_conv) 199 | 200 | self.conv_final = conv1x1(outs, self.num_classes) 201 | 202 | # add the list of modules to current module 203 | self.down_convs = nn.ModuleList(self.down_convs) 204 | self.up_convs = nn.ModuleList(self.up_convs) 205 | 206 | self.reset_params() 207 | 208 | @staticmethod 209 | def weight_init(m): 210 | if isinstance(m, nn.Conv2d): 211 | init.xavier_normal(m.weight) 212 | init.constant(m.bias, 0) 213 | 214 | 215 | def reset_params(self): 216 | for i, m in enumerate(self.modules()): 217 | self.weight_init(m) 218 | 219 | 220 | def forward(self, x): 221 | encoder_outs = [] 222 | 223 | # encoder pathway, save outputs for merging 224 | for i, module in enumerate(self.down_convs): 225 | x, before_pool = module(x) 226 | encoder_outs.append(before_pool) 227 | 228 | for i, module in enumerate(self.up_convs): 229 | before_pool = encoder_outs[-(i+2)] 230 | x = module(before_pool, x) 231 | 232 | # No softmax is used. This means you need to use 233 | # nn.CrossEntropyLoss is your training script, 234 | # as this module includes a softmax already. 235 | x = self.conv_final(x) 236 | return x 237 | 238 | if __name__ == "__main__": 239 | """ 240 | testing 241 | """ 242 | model = UNet(3, depth=5, merge_mode='concat') 243 | x = Variable(torch.FloatTensor(np.random.random((1, 3, 320, 320)))) 244 | out = model(x) 245 | loss = torch.sum(out) 246 | loss.backward() 247 | -------------------------------------------------------------------------------- /unet-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaxony/unet-pytorch/2bb0e44a2d5c7c04b589d5d942673aee52bcff58/unet-architecture.png --------------------------------------------------------------------------------