├── .gitignore ├── Figures ├── fig_inference_time.png ├── fig_prediction_comparison.png └── fig_unet++.png ├── LICENSE ├── README.md └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /Figures/fig_inference_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarryHJR/Nested-UNet/e477bf8ef4d37e849d11e81cfbb991c7fa2f48d0/Figures/fig_inference_time.png -------------------------------------------------------------------------------- /Figures/fig_prediction_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarryHJR/Nested-UNet/e477bf8ef4d37e849d11e81cfbb991c7fa2f48d0/Figures/fig_prediction_comparison.png -------------------------------------------------------------------------------- /Figures/fig_unet++.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarryHJR/Nested-UNet/e477bf8ef4d37e849d11e81cfbb991c7fa2f48d0/Figures/fig_unet++.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zongwei Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNet++: A Nested U-Net Architecture for Medical Image Segmentation 2 | 3 | This is an implementation of ["UNet++: A Nested U-Net Architecture for Medical Image Segmentation"](https://arxiv.org/pdf/1807.10165.pdf) in Python and powered by the Keras deep learning framework (Tensorflow as backend). For the first time, a new architecture, called **UNet++** (nested U-Net architecture), is proposed for a more precise segmentation. We introduced the intermediate layers to U-Nets, which naturally form multiple new up-sampling expanding paths of different depths, resulting in an ensemble of U-Nets with a partially shared contracting path. 4 | 5 |

6 | 7 |

8 | 9 | ## License 10 | 11 | Detectron is released under the [MIT](https://github.com/MrGiovanni/Nested-UNet/blob/master/LICENSE). 12 | 13 | ## Citing UNet++ 14 | 15 | If you use UNet++ in your research, please consider the following BibTeX entry. 16 | 17 | ``` 18 | @inproceedings{zhou2018nest, 19 | title={UNet++: A Nested U-Net Architecture for Medical Image Segmentation}, 20 | author={Zongwei Zhou, Md Mahfuzur Rahman Siddiquee, Nima Tajbakhsh and Jianming Liang}, 21 | booktitle={Deep Learning in Medical Image Analysis}, 22 | year={2018} 23 | } 24 | ``` 25 | 26 | ## Contacts (Maintainers) 27 | 28 | * Zongwei Zhou, homepage: [zongweiz.com](https://www.zongweiz.com) 29 | * Md Mahfuzur Rahman Siddiquee, github: [mahfuzmohammad](https://github.com/mahfuzmohammad) 30 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | ''' 4 | 5 | 6 | import keras 7 | import tensorflow as tf 8 | from keras.models import Model 9 | from keras import backend as K 10 | from keras.layers import Input, merge, Conv2D, ZeroPadding2D, UpSampling2D, Dense, concatenate, Conv2DTranspose 11 | from keras.layers.pooling import MaxPooling2D, GlobalAveragePooling2D, MaxPooling2D 12 | from keras.layers.core import Dense, Dropout, Activation 13 | from keras.layers import BatchNormalization, Dropout, Flatten, Lambda 14 | from keras.layers.advanced_activations import ELU, LeakyReLU 15 | from keras.optimizers import Adam, RMSprop, SGD 16 | from keras.regularizers import l2 17 | from keras.layers.noise import GaussianDropout 18 | 19 | import numpy as np 20 | 21 | smooth = 1. 22 | dropout_rate = 0.5 23 | 24 | def mean_iou(y_true, y_pred): 25 | prec = [] 26 | for t in np.arange(0.5, 1.0, 0.05): 27 | y_pred_ = tf.to_int32(y_pred > t) 28 | score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, 2) 29 | K.get_session().run(tf.local_variables_initializer()) 30 | with tf.control_dependencies([up_opt]): 31 | score = tf.identity(score) 32 | prec.append(score) 33 | return K.mean(K.stack(prec), axis=0) 34 | 35 | # Custom loss function 36 | def dice_coef(y_true, y_pred): 37 | smooth = 1. 38 | y_true_f = K.flatten(y_true) 39 | y_pred_f = K.flatten(y_pred) 40 | intersection = K.sum(y_true_f * y_pred_f) 41 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 42 | 43 | def bce_dice_loss(y_true, y_pred): 44 | return 0.5 * keras.losses.binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred) 45 | 46 | 47 | ######################################## 48 | # 2D Standard 49 | ######################################## 50 | 51 | def standard_unit(input_tensor, stage, nb_filter, kernel_size=3): 52 | 53 | act = 'elu' 54 | 55 | x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor) 56 | x = Dropout(dropout_rate, name='dp'+stage+'_1')(x) 57 | x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x) 58 | x = Dropout(dropout_rate, name='dp'+stage+'_2')(x) 59 | 60 | return x 61 | 62 | ######################################## 63 | 64 | """ 65 | Standard U-Net [Ronneberger et.al, 2015] 66 | Total params: 7,759,521 67 | """ 68 | def U_Net(img_rows, img_cols, color_type=1, num_class=1): 69 | 70 | nb_filter = [32,64,128,256,512] 71 | act = 'elu' 72 | 73 | # Handle Dimension Ordering for different backends 74 | global bn_axis 75 | if K.image_dim_ordering() == 'tf': 76 | bn_axis = 3 77 | img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input') 78 | else: 79 | bn_axis = 1 80 | img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input') 81 | 82 | conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0]) 83 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1) 84 | 85 | conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1]) 86 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1) 87 | 88 | conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2]) 89 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1) 90 | 91 | conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3]) 92 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1) 93 | 94 | conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4]) 95 | 96 | up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1) 97 | conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis) 98 | conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3]) 99 | 100 | up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2) 101 | conv3_3 = concatenate([up3_3, conv3_1], name='merge33', axis=bn_axis) 102 | conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2]) 103 | 104 | up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3) 105 | conv2_4 = concatenate([up2_4, conv2_1], name='merge24', axis=bn_axis) 106 | conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1]) 107 | 108 | up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4) 109 | conv1_5 = concatenate([up1_5, conv1_1], name='merge15', axis=bn_axis) 110 | conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0]) 111 | 112 | unet_output = Conv2D(num_class, (1, 1), activation='sigmoid', name='output', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5) 113 | 114 | model = Model(input=img_input, output=unet_output) 115 | 116 | return model 117 | 118 | """ 119 | wU-Net for comparison 120 | Total params: 9,282,246 121 | """ 122 | def wU_Net(img_rows, img_cols, color_type=1, num_class=1): 123 | 124 | # nb_filter = [32,64,128,256,512] 125 | nb_filter = [35,70,140,280,560] 126 | act = 'elu' 127 | 128 | # Handle Dimension Ordering for different backends 129 | global bn_axis 130 | if K.image_dim_ordering() == 'tf': 131 | bn_axis = 3 132 | img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input') 133 | else: 134 | bn_axis = 1 135 | img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input') 136 | 137 | conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0]) 138 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1) 139 | 140 | conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1]) 141 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1) 142 | 143 | conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2]) 144 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1) 145 | 146 | conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3]) 147 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1) 148 | 149 | conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4]) 150 | 151 | up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1) 152 | conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis) 153 | conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3]) 154 | 155 | up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2) 156 | conv3_3 = concatenate([up3_3, conv3_1], name='merge33', axis=bn_axis) 157 | conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2]) 158 | 159 | up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3) 160 | conv2_4 = concatenate([up2_4, conv2_1], name='merge24', axis=bn_axis) 161 | conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1]) 162 | 163 | up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4) 164 | conv1_5 = concatenate([up1_5, conv1_1], name='merge15', axis=bn_axis) 165 | conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0]) 166 | 167 | unet_output = Conv2D(num_class, (1, 1), activation='sigmoid', name='output', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5) 168 | 169 | model = Model(input=img_input, output=unet_output) 170 | 171 | return model 172 | 173 | """ 174 | Standard UNet++ [Zhou et.al, 2018] 175 | Total params: 9,041,601 176 | """ 177 | def Nest_Net(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False): 178 | 179 | nb_filter = [32,64,128,256,512] 180 | act = 'elu' 181 | 182 | # Handle Dimension Ordering for different backends 183 | global bn_axis 184 | if K.image_dim_ordering() == 'tf': 185 | bn_axis = 3 186 | img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input') 187 | else: 188 | bn_axis = 1 189 | img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input') 190 | 191 | conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0]) 192 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1) 193 | 194 | conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1]) 195 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1) 196 | 197 | up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1) 198 | conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis) 199 | conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0]) 200 | 201 | conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2]) 202 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1) 203 | 204 | up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1) 205 | conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis) 206 | conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1]) 207 | 208 | up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2) 209 | conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis) 210 | conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0]) 211 | 212 | conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3]) 213 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1) 214 | 215 | up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1) 216 | conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis) 217 | conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2]) 218 | 219 | up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2) 220 | conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis) 221 | conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1]) 222 | 223 | up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3) 224 | conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis) 225 | conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0]) 226 | 227 | conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4]) 228 | 229 | up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1) 230 | conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis) 231 | conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3]) 232 | 233 | up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2) 234 | conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis) 235 | conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2]) 236 | 237 | up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3) 238 | conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis) 239 | conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1]) 240 | 241 | up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4) 242 | conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis) 243 | conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0]) 244 | 245 | nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2) 246 | nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3) 247 | nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4) 248 | nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5) 249 | 250 | if deep_supervision: 251 | model = Model(input=img_input, output=[nestnet_output_1, 252 | nestnet_output_2, 253 | nestnet_output_3, 254 | nestnet_output_4]) 255 | else: 256 | model = Model(input=img_input, output=[nestnet_output_4]) 257 | 258 | return model 259 | 260 | 261 | if __name__ == '__main__': 262 | 263 | model = U_Net(96,96,1) 264 | model.summary() 265 | 266 | model = wU_Net(96,96,1) 267 | model.summary() 268 | 269 | model = Nest_Net(96,96,1) 270 | model.summary() --------------------------------------------------------------------------------