├── .gitignore
├── Figures
├── fig_inference_time.png
├── fig_prediction_comparison.png
└── fig_unet++.png
├── LICENSE
├── README.md
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/Figures/fig_inference_time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CarryHJR/Nested-UNet/e477bf8ef4d37e849d11e81cfbb991c7fa2f48d0/Figures/fig_inference_time.png
--------------------------------------------------------------------------------
/Figures/fig_prediction_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CarryHJR/Nested-UNet/e477bf8ef4d37e849d11e81cfbb991c7fa2f48d0/Figures/fig_prediction_comparison.png
--------------------------------------------------------------------------------
/Figures/fig_unet++.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CarryHJR/Nested-UNet/e477bf8ef4d37e849d11e81cfbb991c7fa2f48d0/Figures/fig_unet++.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Zongwei Zhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNet++: A Nested U-Net Architecture for Medical Image Segmentation
2 |
3 | This is an implementation of ["UNet++: A Nested U-Net Architecture for Medical Image Segmentation"](https://arxiv.org/pdf/1807.10165.pdf) in Python and powered by the Keras deep learning framework (Tensorflow as backend). For the first time, a new architecture, called **UNet++** (nested U-Net architecture), is proposed for a more precise segmentation. We introduced the intermediate layers to U-Nets, which naturally form multiple new up-sampling expanding paths of different depths, resulting in an ensemble of U-Nets with a partially shared contracting path.
4 |
5 |
6 |
7 |
8 |
9 | ## License
10 |
11 | Detectron is released under the [MIT](https://github.com/MrGiovanni/Nested-UNet/blob/master/LICENSE).
12 |
13 | ## Citing UNet++
14 |
15 | If you use UNet++ in your research, please consider the following BibTeX entry.
16 |
17 | ```
18 | @inproceedings{zhou2018nest,
19 | title={UNet++: A Nested U-Net Architecture for Medical Image Segmentation},
20 | author={Zongwei Zhou, Md Mahfuzur Rahman Siddiquee, Nima Tajbakhsh and Jianming Liang},
21 | booktitle={Deep Learning in Medical Image Analysis},
22 | year={2018}
23 | }
24 | ```
25 |
26 | ## Contacts (Maintainers)
27 |
28 | * Zongwei Zhou, homepage: [zongweiz.com](https://www.zongweiz.com)
29 | * Md Mahfuzur Rahman Siddiquee, github: [mahfuzmohammad](https://github.com/mahfuzmohammad)
30 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | '''
2 |
3 | '''
4 |
5 |
6 | import keras
7 | import tensorflow as tf
8 | from keras.models import Model
9 | from keras import backend as K
10 | from keras.layers import Input, merge, Conv2D, ZeroPadding2D, UpSampling2D, Dense, concatenate, Conv2DTranspose
11 | from keras.layers.pooling import MaxPooling2D, GlobalAveragePooling2D, MaxPooling2D
12 | from keras.layers.core import Dense, Dropout, Activation
13 | from keras.layers import BatchNormalization, Dropout, Flatten, Lambda
14 | from keras.layers.advanced_activations import ELU, LeakyReLU
15 | from keras.optimizers import Adam, RMSprop, SGD
16 | from keras.regularizers import l2
17 | from keras.layers.noise import GaussianDropout
18 |
19 | import numpy as np
20 |
21 | smooth = 1.
22 | dropout_rate = 0.5
23 |
24 | def mean_iou(y_true, y_pred):
25 | prec = []
26 | for t in np.arange(0.5, 1.0, 0.05):
27 | y_pred_ = tf.to_int32(y_pred > t)
28 | score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, 2)
29 | K.get_session().run(tf.local_variables_initializer())
30 | with tf.control_dependencies([up_opt]):
31 | score = tf.identity(score)
32 | prec.append(score)
33 | return K.mean(K.stack(prec), axis=0)
34 |
35 | # Custom loss function
36 | def dice_coef(y_true, y_pred):
37 | smooth = 1.
38 | y_true_f = K.flatten(y_true)
39 | y_pred_f = K.flatten(y_pred)
40 | intersection = K.sum(y_true_f * y_pred_f)
41 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
42 |
43 | def bce_dice_loss(y_true, y_pred):
44 | return 0.5 * keras.losses.binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)
45 |
46 |
47 | ########################################
48 | # 2D Standard
49 | ########################################
50 |
51 | def standard_unit(input_tensor, stage, nb_filter, kernel_size=3):
52 |
53 | act = 'elu'
54 |
55 | x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor)
56 | x = Dropout(dropout_rate, name='dp'+stage+'_1')(x)
57 | x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x)
58 | x = Dropout(dropout_rate, name='dp'+stage+'_2')(x)
59 |
60 | return x
61 |
62 | ########################################
63 |
64 | """
65 | Standard U-Net [Ronneberger et.al, 2015]
66 | Total params: 7,759,521
67 | """
68 | def U_Net(img_rows, img_cols, color_type=1, num_class=1):
69 |
70 | nb_filter = [32,64,128,256,512]
71 | act = 'elu'
72 |
73 | # Handle Dimension Ordering for different backends
74 | global bn_axis
75 | if K.image_dim_ordering() == 'tf':
76 | bn_axis = 3
77 | img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
78 | else:
79 | bn_axis = 1
80 | img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')
81 |
82 | conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
83 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)
84 |
85 | conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
86 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)
87 |
88 | conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
89 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)
90 |
91 | conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
92 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)
93 |
94 | conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])
95 |
96 | up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
97 | conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
98 | conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])
99 |
100 | up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
101 | conv3_3 = concatenate([up3_3, conv3_1], name='merge33', axis=bn_axis)
102 | conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])
103 |
104 | up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
105 | conv2_4 = concatenate([up2_4, conv2_1], name='merge24', axis=bn_axis)
106 | conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])
107 |
108 | up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
109 | conv1_5 = concatenate([up1_5, conv1_1], name='merge15', axis=bn_axis)
110 | conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])
111 |
112 | unet_output = Conv2D(num_class, (1, 1), activation='sigmoid', name='output', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)
113 |
114 | model = Model(input=img_input, output=unet_output)
115 |
116 | return model
117 |
118 | """
119 | wU-Net for comparison
120 | Total params: 9,282,246
121 | """
122 | def wU_Net(img_rows, img_cols, color_type=1, num_class=1):
123 |
124 | # nb_filter = [32,64,128,256,512]
125 | nb_filter = [35,70,140,280,560]
126 | act = 'elu'
127 |
128 | # Handle Dimension Ordering for different backends
129 | global bn_axis
130 | if K.image_dim_ordering() == 'tf':
131 | bn_axis = 3
132 | img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
133 | else:
134 | bn_axis = 1
135 | img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')
136 |
137 | conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
138 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)
139 |
140 | conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
141 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)
142 |
143 | conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
144 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)
145 |
146 | conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
147 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)
148 |
149 | conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])
150 |
151 | up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
152 | conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
153 | conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])
154 |
155 | up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
156 | conv3_3 = concatenate([up3_3, conv3_1], name='merge33', axis=bn_axis)
157 | conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])
158 |
159 | up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
160 | conv2_4 = concatenate([up2_4, conv2_1], name='merge24', axis=bn_axis)
161 | conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])
162 |
163 | up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
164 | conv1_5 = concatenate([up1_5, conv1_1], name='merge15', axis=bn_axis)
165 | conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])
166 |
167 | unet_output = Conv2D(num_class, (1, 1), activation='sigmoid', name='output', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)
168 |
169 | model = Model(input=img_input, output=unet_output)
170 |
171 | return model
172 |
173 | """
174 | Standard UNet++ [Zhou et.al, 2018]
175 | Total params: 9,041,601
176 | """
177 | def Nest_Net(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False):
178 |
179 | nb_filter = [32,64,128,256,512]
180 | act = 'elu'
181 |
182 | # Handle Dimension Ordering for different backends
183 | global bn_axis
184 | if K.image_dim_ordering() == 'tf':
185 | bn_axis = 3
186 | img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
187 | else:
188 | bn_axis = 1
189 | img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')
190 |
191 | conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
192 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)
193 |
194 | conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
195 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)
196 |
197 | up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
198 | conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
199 | conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0])
200 |
201 | conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
202 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)
203 |
204 | up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
205 | conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
206 | conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1])
207 |
208 | up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
209 | conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
210 | conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0])
211 |
212 | conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
213 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)
214 |
215 | up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
216 | conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
217 | conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2])
218 |
219 | up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
220 | conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
221 | conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1])
222 |
223 | up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
224 | conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
225 | conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0])
226 |
227 | conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])
228 |
229 | up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
230 | conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
231 | conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])
232 |
233 | up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
234 | conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
235 | conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])
236 |
237 | up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
238 | conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
239 | conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])
240 |
241 | up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
242 | conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
243 | conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])
244 |
245 | nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2)
246 | nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3)
247 | nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4)
248 | nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)
249 |
250 | if deep_supervision:
251 | model = Model(input=img_input, output=[nestnet_output_1,
252 | nestnet_output_2,
253 | nestnet_output_3,
254 | nestnet_output_4])
255 | else:
256 | model = Model(input=img_input, output=[nestnet_output_4])
257 |
258 | return model
259 |
260 |
261 | if __name__ == '__main__':
262 |
263 | model = U_Net(96,96,1)
264 | model.summary()
265 |
266 | model = wU_Net(96,96,1)
267 | model.summary()
268 |
269 | model = Nest_Net(96,96,1)
270 | model.summary()
--------------------------------------------------------------------------------