├── README.md └── unet.py /README.md: -------------------------------------------------------------------------------- 1 | # U-Net for Keras 2 | 3 | This is an implementation of the U-Net model for Keras. When using the default parameters this will be the same as the original architecture, except that this code will use padded convolutions. Also, the original paper does not state the amount of Dropout used. Since 0.5 is a common choice in the industry, it was used as the default value. If you are aware of other differences, please contact me. 4 | 5 | U-Net: Convolutional Networks for Image Segmentation (https://arxiv.org/abs/1505.04597) 6 | 7 | Other than that, I added a few bells and whistles. This implementation supports transposed convolutions (deconvolutions) for upsampling, strided convolutions for downsampling, batchnormalization and residual connections. If you have further ideas, let me know. 8 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | from keras.models import Input, Model 2 | from keras.layers import Conv2D, Concatenate, MaxPooling2D 3 | from keras.layers import UpSampling2D, Dropout, BatchNormalization 4 | 5 | ''' 6 | U-Net: Convolutional Networks for Biomedical Image Segmentation 7 | (https://arxiv.org/abs/1505.04597) 8 | --- 9 | img_shape: (height, width, channels) 10 | out_ch: number of output channels 11 | start_ch: number of channels of the first conv 12 | depth: zero indexed depth of the U-structure 13 | inc_rate: rate at which the conv channels will increase 14 | activation: activation function after convolutions 15 | dropout: amount of dropout in the contracting part 16 | batchnorm: adds Batch Normalization if true 17 | maxpool: use strided conv instead of maxpooling if false 18 | upconv: use transposed conv instead of upsamping + conv if false 19 | residual: add residual connections around each conv block if true 20 | ''' 21 | 22 | def conv_block(m, dim, acti, bn, res, do=0): 23 | n = Conv2D(dim, 3, activation=acti, padding='same')(m) 24 | n = BatchNormalization()(n) if bn else n 25 | n = Dropout(do)(n) if do else n 26 | n = Conv2D(dim, 3, activation=acti, padding='same')(n) 27 | n = BatchNormalization()(n) if bn else n 28 | return Concatenate()([m, n]) if res else n 29 | 30 | def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res): 31 | if depth > 0: 32 | n = conv_block(m, dim, acti, bn, res) 33 | m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n) 34 | m = level_block(m, int(inc*dim), depth-1, inc, acti, do, bn, mp, up, res) 35 | if up: 36 | m = UpSampling2D()(m) 37 | m = Conv2D(dim, 2, activation=acti, padding='same')(m) 38 | else: 39 | m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m) 40 | n = Concatenate()([n, m]) 41 | m = conv_block(n, dim, acti, bn, res) 42 | else: 43 | m = conv_block(m, dim, acti, bn, res, do) 44 | return m 45 | 46 | def UNet(img_shape, out_ch=1, start_ch=64, depth=4, inc_rate=2., activation='relu', 47 | dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False): 48 | i = Input(shape=img_shape) 49 | o = level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual) 50 | o = Conv2D(out_ch, 1, activation='sigmoid')(o) 51 | return Model(inputs=i, outputs=o) 52 | --------------------------------------------------------------------------------