├── .editorconfig ├── .gitignore ├── README.md ├── mnet_deep_cdr ├── .ipynb_checkpoints │ ├── Demo_Full-checkpoint.ipynb │ └── Demo_ipython-checkpoint.ipynb ├── Model_DiscSeg.py ├── Model_MNet.py ├── REFUGE_result │ ├── REFUGE_result.zip │ └── ROC_curve.png ├── Step_1_Disc_Crop.py ├── Step_2_MNet_train.py ├── Step_3_MNet_test.py ├── Step_4_CDR_output.m ├── __init__.py ├── _data │ └── logging.yml ├── deep_model │ ├── Model_DiscSeg_ORIGA.h5 │ └── Model_MNet_REFUGE.h5 ├── mat_scr.zip ├── mnet_utils.py ├── result │ └── .gitkeep └── test_img │ ├── CS50041_R.jpg │ └── CS50106_R.jpg ├── requirements.txt └── setup.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | insert_final_newline = true 7 | indent_style = space 8 | indent_size = 4 9 | 10 | [*.py] 11 | max_line_length = 119 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Jetbrains IDEs 4 | .idea 5 | out/ 6 | 7 | # File-based project format 8 | *.iws 9 | 10 | # JIRA plugin 11 | atlassian-ide-plugin.xml 12 | 13 | # Crashlytics plugin (for Android Studio and IntelliJ) 14 | com_crashlytics_export_strings.xml 15 | crashlytics.properties 16 | crashlytics-build.properties 17 | fabric.properties 18 | 19 | ### Python ### 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook (maybe ignore?) 96 | ## mnet_deep_cdr/.ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # celery beat schedule file 113 | celerybeat-schedule 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | mnet_deep_cdr 2 | ============= 3 | ![Python version range](https://img.shields.io/badge/python-2.7%E2%80%933.6+-blue.svg) 4 | **Code for TMI 2018 "Joint Optic Disc and Cup Segmentation Based on Multi-label Deep Network and Polar Transformation"** 5 | 6 | Project homepage:http://hzfu.github.io/proj_glaucoma_fundus.html 7 | 8 | ## Install dependencies 9 | 10 | pip install -r requirements.txt 11 | 12 | ## Install package 13 | 14 | pip install . 15 | 16 | OpenCV will need to be installed separately. 17 | 18 | --- 19 | 20 | 1. The code is based on: *TensorFlow 1.14 (with Keras) + Matlab* 21 | 2. The deep output is raw segmentation result without ellipse fitting. The Matlab code is the ellipse fitting and CDR calculation (by using PDollar toolbox: https://pdollar.github.io/toolbox/). 22 | 3. You can run the 'Step\_3\_MNet\_test.py' for testing any new image directly. 23 | 4. We also provided the validation and test results on [REFUGE dataset](https://refuge.grand-challenge.org/home/) in 'REFUGE\_result' fold. 24 | 5. **Note: Due to the 'scipy.misc.imresize' in SciPy 1.0.0 has been removed in SciPy 1.3.0, the original trained model 'Model\_MNet\_REFUGE.h5' is not suitable. If you want to segment disc/cup from fundus image, you can consider our newest methods: CE-Net and AG-Net, which obtain the better performances and are also released in:** 25 | - CE-Net: [https://github.com/Guzaiwang/CE-Net](https://github.com/Guzaiwang/CE-Net) 26 | - AG-Net: [https://github.com/HzFu/AGNet](https://github.com/HzFu/AGNet) 27 | 6. A pytorch implementation of M-Net could be found in **AG-Net**: [https://github.com/HzFu/AGNet](https://github.com/HzFu/AGNet) 28 | 29 | 30 | --- 31 | 32 | **Main files:** 33 | 34 | 1. 'Step\_1\_Disc\_Crop.py': The disc detection code for whole funuds image. 35 | 2. 'Step\_2\_MNet\_train.py': The M-Net training code. 36 | 3. 'Step\_3\_MNet\_test.py': The M-Net testing code. 37 | 4. 'Step\_4\_CDR\_output.m': The ellipse fitting for disc and cup, and CDR calculation. 38 | 39 | --- 40 | 41 | **If you use this code, please cite the following papers:** 42 | 43 | 1. Huazhu Fu, Jun Cheng, Yanwu Xu, Damon Wing Kee Wong, Jiang Liu, and Xiaochun Cao, "Joint Optic Disc and Cup Segmentation Based on Multi-label Deep Network and Polar Transformation", IEEE Transactions on Medical Imaging (TMI), vol. 37, no. 7, pp. 1597–1605, 2018. [[PDF]](https://arxiv.org/abs/1801.00926) 44 | 2. Huazhu Fu, Jun Cheng, Yanwu Xu, Changqing Zhang, Damon Wing Kee Wong, Jiang Liu, and Xiaochun Cao, "Disc-aware Ensemble Network for Glaucoma Screening from Fundus Image", IEEE Transactions on Medical Imaging (TMI), vol. 37, no. 11, pp. 2493–2501, 2018. [[PDF]](http://arxiv.org/abs/1805.07549) 45 | 46 | 47 | **There are also some related works for medical image segmentation for your reference:** 48 | 49 | 1. "Attention Guided Network for Retinal Image Segmentation," in MICCAI, 2019. [[PDF]](http://arxiv.org/abs/1907.12930) [[Github Code]](https://github.com/Guzaiwang/CE-Net) 50 | 2. “CE-Net: Context Encoder Network for 2D Medical Image Segmentation,” IEEE TMI, 2019. [[PDF]](https://arxiv.org/abs/1903.02740) [[Github Code]](https://github.com/HzFu/AGNet) 51 | 52 | --- 53 | 54 | **Note: for ORIGA and SCES datasets** 55 | 56 | Unfortunately, the ORIGA and SCES datasets cannot be released due to the clinical policy. 57 | But, here is an other glaucoma challenge, [**Retinal Fundus Glaucoma Challenge (REFUGE)**](https://refuge.grand-challenge.org/home/), including disc/cup segmentation, glaucoma screening, and localization of Fovea. 58 | 59 | --- 60 | **License** 61 | 62 | The code is released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for NonCommercial use only. Any commercial use should get formal permission first. 63 | 64 | --- 65 | 66 | Update log: 67 | 68 | - 19.01.22: Added training code, and uploaded the results on REFUGE dataset. 69 | - 18.06.30: Added ellipse fitting code (based on Matlab), and Fixed the bug for macular center fundus. 70 | - 18.06.29: Added disc detection code (based on U-Net). 71 | - 18.02.26: Added CDR calculation code (based on Matlab). 72 | - 18.02.24: Released the code. 73 | -------------------------------------------------------------------------------- /mnet_deep_cdr/Model_DiscSeg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | from __future__ import print_function 5 | 6 | from tensorflow.python.keras.layers import (Input, concatenate, Conv2D, MaxPooling2D, 7 | Conv2DTranspose, UpSampling2D, average) 8 | from tensorflow.python.keras.models import Model 9 | 10 | 11 | def DeepModel(size_set=640): 12 | img_input = Input(shape=(size_set, size_set, 3)) 13 | 14 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input) 15 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block1_conv2')(conv1) 16 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 17 | 18 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block2_conv1')(pool1) 19 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block2_conv2')(conv2) 20 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 21 | 22 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block3_conv1')(pool2) 23 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block3_conv2')(conv3) 24 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 25 | 26 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block4_conv1')(pool3) 27 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block4_conv2')(conv4) 28 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 29 | 30 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(pool4) 31 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(conv5) 32 | 33 | up6 = concatenate( 34 | [Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', name='block6_dconv')(conv5), conv4], 35 | axis=3) 36 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block6_conv1')(up6) 37 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block6_conv2')(conv6) 38 | 39 | up7 = concatenate( 40 | [Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', name='block7_dconv')(conv6), conv3], 41 | axis=3) 42 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block7_conv1')(up7) 43 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block7_conv2')(conv7) 44 | 45 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', name='block8_dconv')(conv7), conv2], 46 | axis=3) 47 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block8_conv1')(up8) 48 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block8_conv2')(conv8) 49 | 50 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same', name='block9_dconv')(conv8), conv1], 51 | axis=3) 52 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block9_conv1')(up9) 53 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same', name='block9_conv2')(conv9) 54 | 55 | side6 = UpSampling2D(size=(8, 8))(conv6) 56 | side7 = UpSampling2D(size=(4, 4))(conv7) 57 | side8 = UpSampling2D(size=(2, 2))(conv8) 58 | out6 = Conv2D(1, (1, 1), activation='sigmoid', name='side_6')(side6) 59 | out7 = Conv2D(1, (1, 1), activation='sigmoid', name='side_7')(side7) 60 | out8 = Conv2D(1, (1, 1), activation='sigmoid', name='side_8')(side8) 61 | out9 = Conv2D(1, (1, 1), activation='sigmoid', name='side_9')(conv9) 62 | 63 | out10 = average([out6, out7, out8, out9]) 64 | # out10 = Conv2D(1, (1, 1), activation='sigmoid', name='side_10')(out10) 65 | 66 | return Model(inputs=[img_input], outputs=[out10]) 67 | -------------------------------------------------------------------------------- /mnet_deep_cdr/Model_MNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | from __future__ import print_function 5 | 6 | from tensorflow.python.keras.layers import (Input, concatenate, Conv2D, MaxPooling2D, AveragePooling2D, 7 | Conv2DTranspose, UpSampling2D, average) 8 | from tensorflow.python.keras.models import Model 9 | 10 | 11 | def DeepModel(size_set=800): 12 | img_input = Input(shape=(size_set, size_set, 3)) 13 | 14 | scale_img_2 = AveragePooling2D(pool_size=(2, 2))(img_input) 15 | scale_img_3 = AveragePooling2D(pool_size=(2, 2))(scale_img_2) 16 | scale_img_4 = AveragePooling2D(pool_size=(2, 2))(scale_img_3) 17 | 18 | conv1 = Conv2D(32, (3, 3), padding='same', activation='relu', name='block1_conv1')(img_input) 19 | conv1 = Conv2D(32, (3, 3), padding='same', activation='relu', name='block1_conv2')(conv1) 20 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 21 | 22 | input2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='block2_input1')(scale_img_2) 23 | input2 = concatenate([input2, pool1], axis=3) 24 | conv2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='block2_conv1')(input2) 25 | conv2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='block2_conv2')(conv2) 26 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 27 | 28 | input3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='block3_input1')(scale_img_3) 29 | input3 = concatenate([input3, pool2], axis=3) 30 | conv3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='block3_conv1')(input3) 31 | conv3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='block3_conv2')(conv3) 32 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 33 | 34 | input4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='block4_input1')(scale_img_4) 35 | input4 = concatenate([input4, pool3], axis=3) 36 | conv4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='block4_conv1')(input4) 37 | conv4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='block4_conv2')(conv4) 38 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 39 | 40 | conv5 = Conv2D(512, (3, 3), padding='same', activation='relu', name='block5_conv1')(pool4) 41 | conv5 = Conv2D(512, (3, 3), padding='same', activation='relu', name='block5_conv2')(conv5) 42 | 43 | up6 = concatenate( 44 | [Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', name='block6_dconv')(conv5), conv4], 45 | axis=3) 46 | conv6 = Conv2D(256, (3, 3), padding='same', activation='relu', name='block6_conv1')(up6) 47 | conv6 = Conv2D(256, (3, 3), padding='same', activation='relu', name='block6_conv2')(conv6) 48 | 49 | up7 = concatenate( 50 | [Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', name='block7_dconv')(conv6), conv3], 51 | axis=3) 52 | conv7 = Conv2D(128, (3, 3), padding='same', activation='relu', name='block7_conv1')(up7) 53 | conv7 = Conv2D(128, (3, 3), padding='same', activation='relu', name='block7_conv2')(conv7) 54 | 55 | up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', name='block8_dconv')(conv7), conv2], 56 | axis=3) 57 | conv8 = Conv2D(64, (3, 3), padding='same', activation='relu', name='block8_conv1')(up8) 58 | conv8 = Conv2D(64, (3, 3), padding='same', activation='relu', name='block8_conv2')(conv8) 59 | 60 | up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same', name='block9_dconv')(conv8), conv1], 61 | axis=3) 62 | conv9 = Conv2D(32, (3, 3), padding='same', activation='relu', name='block9_conv1')(up9) 63 | conv9 = Conv2D(32, (3, 3), padding='same', activation='relu', name='block9_conv2')(conv9) 64 | 65 | side6 = UpSampling2D(size=(8, 8))(conv6) 66 | side7 = UpSampling2D(size=(4, 4))(conv7) 67 | side8 = UpSampling2D(size=(2, 2))(conv8) 68 | out6 = Conv2D(2, (1, 1), activation='sigmoid', name='side_63')(side6) 69 | out7 = Conv2D(2, (1, 1), activation='sigmoid', name='side_73')(side7) 70 | out8 = Conv2D(2, (1, 1), activation='sigmoid', name='side_83')(side8) 71 | out9 = Conv2D(2, (1, 1), activation='sigmoid', name='side_93')(conv9) 72 | 73 | out10 = average([out6, out7, out8, out9]) 74 | 75 | return Model(inputs=[img_input], outputs=[out6, out7, out8, out9, out10]) 76 | -------------------------------------------------------------------------------- /mnet_deep_cdr/REFUGE_result/REFUGE_result.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/REFUGE_result/REFUGE_result.zip -------------------------------------------------------------------------------- /mnet_deep_cdr/REFUGE_result/ROC_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/REFUGE_result/ROC_curve.png -------------------------------------------------------------------------------- /mnet_deep_cdr/Step_1_Disc_Crop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | from os import path 6 | from sys import modules 7 | 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from PIL import Image 12 | from pkg_resources import resource_filename 13 | from skimage.measure import label, regionprops 14 | from skimage.transform import rotate, resize 15 | from tensorflow.python.keras.preprocessing import image 16 | 17 | from mnet_deep_cdr import Model_DiscSeg as DiscModel 18 | from mnet_deep_cdr.mnet_utils import BW_img, disc_crop, mk_dir, files_with_ext 19 | 20 | disc_list = [400, 500, 600, 700, 800] 21 | DiscROI_size = 800 22 | DiscSeg_size = 640 23 | CDRSeg_size = 400 24 | 25 | data_type = '.jpg' 26 | parent_dir = path.dirname(resource_filename(modules[__name__].__name__, '__init__.py')) 27 | data_img_path = path.abspath(path.join(parent_dir, 'data', 'REFUGE-Training400', 'Training400', 'Glaucoma')) 28 | label_img_path = path.abspath(path.join(parent_dir, 'data', 'Annotation-Training400', 29 | 'Annotation-Training400', 'Disc_Cup_Masks', 'Glaucoma')) 30 | 31 | data_save_path = mk_dir(path.join(parent_dir, 'training_crop', 'data')) 32 | label_save_path = mk_dir(path.join(parent_dir, 'training_crop', 'label')) 33 | 34 | file_test_list = files_with_ext(data_img_path, data_type) 35 | 36 | DiscSeg_model = DiscModel.DeepModel(size_set=DiscSeg_size) 37 | DiscSeg_model.load_weights(path.join(parent_dir, 'deep_model', 'Model_DiscSeg_ORIGA.h5')) 38 | 39 | Disc_flat = None 40 | 41 | for lineIdx, temp_txt in enumerate(file_test_list): 42 | print('Processing Img {idx}: {temp_txt}'.format(idx=lineIdx + 1, temp_txt=temp_txt)) 43 | 44 | # load image 45 | org_img = np.asarray(image.load_img(path.join(data_img_path, temp_txt))) 46 | 47 | # load label 48 | org_label = np.asarray(image.load_img(path.join(label_img_path, temp_txt[:-4] + '.bmp')))[:, :, 0] 49 | new_label = np.zeros(np.shape(org_label) + (3,), dtype=np.uint8) 50 | new_label[org_label < 200, 0] = 255 51 | new_label[org_label < 100, 1] = 255 52 | 53 | # Disc region detection by U-Net 54 | temp_img = resize(org_img, (DiscSeg_size, DiscSeg_size, 3)) * 255 55 | temp_img = np.reshape(temp_img, (1,) + temp_img.shape) 56 | disc_map = DiscSeg_model.predict([temp_img]) 57 | 58 | disc_map = BW_img(np.reshape(disc_map, (DiscSeg_size, DiscSeg_size)), 0.5) 59 | 60 | regions = regionprops(label(disc_map)) 61 | C_x = int(regions[0].centroid[0] * org_img.shape[0] / DiscSeg_size) 62 | C_y = int(regions[0].centroid[1] * org_img.shape[1] / DiscSeg_size) 63 | 64 | for disc_idx, DiscROI_size in enumerate(disc_list): 65 | disc_region, err_coord, crop_coord = disc_crop(org_img, DiscROI_size, C_x, C_y) 66 | label_region, _, _ = disc_crop(new_label, DiscROI_size, C_x, C_y) 67 | Disc_flat = rotate(cv2.linearPolar(disc_region, (DiscROI_size / 2, DiscROI_size / 2), DiscROI_size / 2, 68 | cv2.INTER_NEAREST + cv2.WARP_FILL_OUTLIERS), -90) 69 | Label_flat = rotate(cv2.linearPolar(label_region, (DiscROI_size / 2, DiscROI_size / 2), DiscROI_size / 2, 70 | cv2.INTER_NEAREST + cv2.WARP_FILL_OUTLIERS), -90) 71 | 72 | disc_result = Image.fromarray((Disc_flat * 255).astype(np.uint8)) 73 | filename = '{}_{}.png'.format(temp_txt[:-4], DiscROI_size) 74 | disc_result.save(path.join(data_save_path, filename)) 75 | label_result = Image.fromarray((Label_flat * 255).astype(np.uint8)) 76 | label_result.save(path.join(label_save_path, filename)) 77 | 78 | plt.imshow(Disc_flat) 79 | plt.show() 80 | -------------------------------------------------------------------------------- /mnet_deep_cdr/Step_2_MNet_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | from os import path 7 | from sys import modules 8 | 9 | from pkg_resources import resource_filename 10 | from tensorflow.python.keras.optimizers import SGD 11 | 12 | from mnet_deep_cdr import Model_MNet as DeepModel 13 | from mnet_deep_cdr.mnet_utils import dice_coef_loss, train_loader, mk_dir, files_with_ext 14 | 15 | parent_dir = path.dirname(resource_filename(modules[__name__].__name__, '__init__.py')) 16 | result_path = mk_dir(path.join(parent_dir, 'deep_model')) 17 | pre_model_file = path.join(parent_dir, 'deep_model', 'Model_MNet_REFUGE.h5') 18 | save_model_file = path.join(parent_dir, 'deep_model', 'Model_MNet_REFUGE_v2.h5') 19 | 20 | root_path = path.join(parent_dir, 'training_crop') 21 | train_data_path = path.join(root_path, 'data') 22 | train_mask_path = path.join(root_path, 'label') 23 | 24 | val_data_path = path.join(root_path, 'val_data', 'data') 25 | val_mask_path = path.join(root_path, 'val_data', 'label') 26 | 27 | # load training data 28 | train_list = files_with_ext(train_data_path, '.png') 29 | val_list = files_with_ext(val_data_path, '.png') 30 | 31 | Total_iter = 100 32 | nb_epoch_setting = 3 33 | input_size = 400 34 | optimizer_setting = SGD(lr=0.0001, momentum=0.9) 35 | 36 | my_model = DeepModel.DeepModel(size_set=input_size) 37 | my_model.load_weights(pre_model_file, by_name=True) 38 | 39 | my_model.compile(optimizer=optimizer_setting, loss=dice_coef_loss, loss_weights=[0.1, 0.1, 0.1, 0.1, 0.6]) 40 | 41 | loss_max = 10000 42 | 43 | for idx_iter in range(Total_iter): 44 | random.shuffle(train_list) 45 | model_return = my_model.fit_generator( 46 | generator=train_loader(train_list, train_data_path, train_mask_path, input_size), 47 | steps_per_epoch=len(train_list), 48 | validation_data=train_loader(val_list, val_data_path, val_mask_path, input_size), 49 | validation_steps=len(train_list), 50 | verbose=0 51 | ) 52 | val_loss = model_return.history['val_loss'][0] 53 | train_loss = model_return.history['loss'][0] 54 | if val_loss < loss_max: 55 | my_model.save(save_model_file) 56 | loss_max = val_loss 57 | print('[Save] training iter: {idx}, train_loss: {train_loss}, val_loss: {val_loss}'.format( 58 | idx=idx_iter + 1, train_loss=train_loss, val_loss=val_loss) 59 | ) 60 | else: 61 | print('[None] training iter: {idx}, train_loss: {train_loss}, val_loss: {val_loss}'.format( 62 | idx=idx_iter + 1, train_loss=train_loss, val_loss=val_loss) 63 | ) 64 | -------------------------------------------------------------------------------- /mnet_deep_cdr/Step_3_MNet_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | from os import path 6 | from sys import modules 7 | from time import time 8 | 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | from pkg_resources import resource_filename 13 | from skimage.measure import label, regionprops 14 | from skimage.transform import rotate, resize 15 | from tensorflow.python.keras.preprocessing import image 16 | 17 | from mnet_deep_cdr import Model_DiscSeg as DiscModel, Model_MNet as MNetModel 18 | from mnet_deep_cdr.mnet_utils import pro_process, BW_img, disc_crop, mk_dir, files_with_ext 19 | 20 | DiscROI_size = 600 21 | DiscSeg_size = 640 22 | CDRSeg_size = 400 23 | 24 | parent_dir = path.dirname(resource_filename(modules[__name__].__name__, '__init__.py')) 25 | 26 | test_data_path = path.join(parent_dir, 'test_img') 27 | data_save_path = mk_dir(path.join(parent_dir, 'test_img')) 28 | 29 | file_test_list = files_with_ext(test_data_path, '.jpg') 30 | 31 | DiscSeg_model = DiscModel.DeepModel(size_set=DiscSeg_size) 32 | DiscSeg_model.load_weights(path.join(parent_dir, 'deep_model', 'Model_DiscSeg_ORIGA.h5')) 33 | 34 | CDRSeg_model = MNetModel.DeepModel(size_set=CDRSeg_size) 35 | CDRSeg_model.load_weights(path.join(parent_dir, 'deep_model', 'Model_MNet_REFUGE.h5')) 36 | 37 | for lineIdx, temp_txt in enumerate(file_test_list): 38 | # load image 39 | org_img = np.asarray(image.load_img(path.join(test_data_path, temp_txt))) 40 | # Disc region detection by U-Net 41 | temp_img = resize(org_img, (DiscSeg_size, DiscSeg_size, 3)) * 255 42 | temp_img = np.reshape(temp_img, (1,) + temp_img.shape) 43 | disc_map = DiscSeg_model.predict([temp_img]) 44 | disc_map = BW_img(np.reshape(disc_map, (DiscSeg_size, DiscSeg_size)), 0.5) 45 | 46 | regions = regionprops(label(disc_map)) 47 | C_x = int(regions[0].centroid[0] * org_img.shape[0] / DiscSeg_size) 48 | C_y = int(regions[0].centroid[1] * org_img.shape[1] / DiscSeg_size) 49 | disc_region, err_xy, crop_xy = disc_crop(org_img, DiscROI_size, C_x, C_y) 50 | 51 | # Disc and Cup segmentation by M-Net 52 | run_start = time() 53 | Disc_flat = rotate(cv2.linearPolar(disc_region, (DiscROI_size / 2, DiscROI_size / 2), 54 | DiscROI_size / 2, cv2.WARP_FILL_OUTLIERS), -90) 55 | 56 | temp_img = pro_process(Disc_flat, CDRSeg_size) 57 | temp_img = np.reshape(temp_img, (1,) + temp_img.shape) 58 | [_, _, _, _, prob_10] = CDRSeg_model.predict(temp_img) 59 | run_end = time() 60 | 61 | # Extract mask 62 | prob_map = np.reshape(prob_10, (prob_10.shape[1], prob_10.shape[2], prob_10.shape[3])) 63 | disc_map = np.array(Image.fromarray(prob_map[:, :, 0]).resize((DiscROI_size, DiscROI_size))) 64 | cup_map = np.array(Image.fromarray(prob_map[:, :, 1]).resize((DiscROI_size, DiscROI_size))) 65 | disc_map[-round(DiscROI_size / 3):, :] = 0 66 | cup_map[-round(DiscROI_size / 2):, :] = 0 67 | De_disc_map = cv2.linearPolar(rotate(disc_map, 90), (DiscROI_size / 2, DiscROI_size / 2), 68 | DiscROI_size / 2, cv2.WARP_FILL_OUTLIERS + cv2.WARP_INVERSE_MAP) 69 | De_cup_map = cv2.linearPolar(rotate(cup_map, 90), (DiscROI_size / 2, DiscROI_size / 2), 70 | DiscROI_size / 2, cv2.WARP_FILL_OUTLIERS + cv2.WARP_INVERSE_MAP) 71 | 72 | De_disc_map = np.array(BW_img(De_disc_map, 0.5), dtype=int) 73 | De_cup_map = np.array(BW_img(De_cup_map, 0.5), dtype=int) 74 | 75 | print('Processing Img {idx}: {temp_txt}, running time: {running_time}'.format( 76 | idx=lineIdx + 1, temp_txt=temp_txt, running_time=run_end - run_start 77 | )) 78 | 79 | # Save raw mask 80 | ROI_result = np.array(BW_img(De_disc_map, 0.5), dtype=int) + np.array(BW_img(De_cup_map, 0.5), dtype=int) 81 | Img_result = np.zeros((org_img.shape[0], org_img.shape[1]), dtype=np.int8) 82 | Img_result[crop_xy[0]:crop_xy[1], crop_xy[2]:crop_xy[3], ] = ROI_result[err_xy[0]:err_xy[1], err_xy[2]:err_xy[3], ] 83 | save_result = Image.fromarray((Img_result * 127).astype(np.uint8)) 84 | save_result.save(path.join(data_save_path, temp_txt[:-4] + '.png')) 85 | -------------------------------------------------------------------------------- /mnet_deep_cdr/Step_4_CDR_output.m: -------------------------------------------------------------------------------- 1 | % You'll need to extract 'mat_scr.zip' for this to work. 2 | clear; close all; 3 | addpath(genpath(['mat_scr' filesep])); 4 | 5 | raw_result = 'result_refuge'; 6 | img_list = dir([raw_result filesep ' *.png']); 7 | 8 | img_num = size(img_list, 1); 9 | 10 | for idx = 1 : img_num 11 | img_name = img_list(idx).name; 12 | img_map = imread([raw_result filesep img_name]); 13 | [img_h, img_w, img_c] = size(img_map); 14 | 15 | Disc_map = fun_Ell_Fit(img_map > 100, img_h, img_w, 1); 16 | Cup_map = fun_Ell_Fit(img_map > 200, img_h, img_w, 1); 17 | CDR_value = fun_CalCDR(Disc_map.fit_map, Cup_map.fit_map); 18 | 19 | Seg_map = Disc_map.fit_map + Cup_map.fit_map; 20 | Seg_map(Seg_map == 0) = 255; 21 | Seg_map(Seg_map == 1) = 128; 22 | Seg_map(Seg_map == 2) = 0; 23 | 24 | save(['final_result' filesep img_name(1 : end - 4) '.mat'], 'CDR_value'); 25 | imwrite(uint8(Seg_map), ['final_result' filesep img_name(1 : end - 4) '.bmp']); 26 | end 27 | -------------------------------------------------------------------------------- /mnet_deep_cdr/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | from logging.config import dictConfig as _dictConfig 6 | from os import path 7 | 8 | import yaml 9 | 10 | __author__ = 'HzFu ' 11 | __version__ = '0.0.2' 12 | 13 | 14 | def get_logger(name=None): 15 | with open(path.join(path.dirname(__file__), '_data', 'logging.yml'), 'rt') as f: 16 | data = yaml.load(f) 17 | _dictConfig(data) 18 | return logging.getLogger(name=name) 19 | 20 | 21 | root_logger = get_logger() 22 | -------------------------------------------------------------------------------- /mnet_deep_cdr/_data/logging.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 5 | datefmt: '%Y-%m-%d %H:%M:%S' 6 | handlers: 7 | console: 8 | class: logging.StreamHandler 9 | level: DEBUG 10 | formatter: simple 11 | stream: ext://sys.stdout 12 | loggers: 13 | simpleExample: 14 | level: DEBUG 15 | handlers: [console] 16 | propagate: no 17 | root: 18 | level: DEBUG 19 | handlers: [console] 20 | -------------------------------------------------------------------------------- /mnet_deep_cdr/deep_model/Model_DiscSeg_ORIGA.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/deep_model/Model_DiscSeg_ORIGA.h5 -------------------------------------------------------------------------------- /mnet_deep_cdr/deep_model/Model_MNet_REFUGE.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/deep_model/Model_MNet_REFUGE.h5 -------------------------------------------------------------------------------- /mnet_deep_cdr/mat_scr.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/mat_scr.zip -------------------------------------------------------------------------------- /mnet_deep_cdr/mnet_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.ndimage import binary_fill_holes 10 | from skimage.measure import label, regionprops 11 | from tensorflow.python.keras import backend as K 12 | from tensorflow.python.keras.preprocessing import image 13 | 14 | 15 | def pro_process(temp_img, input_size): 16 | img = np.asarray(temp_img).astype('float32') 17 | # img = np.array(Image.fromarray(img).resize((input_size, input_size)).convert(3)) 18 | img = np.array(Image.fromarray(img, mode='RGB').resize((input_size, input_size))) 19 | return img 20 | 21 | 22 | def train_loader(data_list, data_path, mask_path, input_size): 23 | while 1: 24 | for lineIdx, temp_txt in enumerate(data_list): 25 | train_img = np.asarray(image.load_img(os.path.join(data_path, temp_txt), 26 | target_size=(input_size, input_size, 3)) 27 | ).astype('float32') 28 | img_mask = np.asarray( 29 | image.load_img(os.path.join(mask_path, temp_txt), 30 | target_size=(input_size, input_size, 3)) 31 | ) / 255.0 32 | 33 | train_img = np.reshape(train_img, (1,) + train_img.shape) 34 | img_mask = np.reshape(img_mask, (1,) + img_mask.shape) 35 | yield ([train_img], [img_mask, img_mask, img_mask, img_mask, img_mask]) 36 | 37 | 38 | def BW_img(input, thresholding): 39 | if input.max() > thresholding: 40 | binary = input > thresholding 41 | else: 42 | binary = input > input.max() / 2.0 43 | 44 | label_image = label(binary) 45 | regions = regionprops(label_image) 46 | area_list = [region.area for region in regions] 47 | if area_list: 48 | idx_max = np.argmax(area_list) 49 | binary[label_image != idx_max + 1] = 0 50 | return binary_fill_holes(np.asarray(binary).astype(int)) 51 | 52 | 53 | def dice_coef(y_true, y_pred): 54 | smooth = 1. 55 | y_true_f = K.flatten(y_true) 56 | y_pred_f = K.flatten(y_pred) 57 | intersection = K.sum(y_true_f * y_pred_f) 58 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 59 | 60 | 61 | def dice_coef2(y_true, y_pred): 62 | score0 = dice_coef(y_true[:, :, :, 0], y_pred[:, :, :, 0]) 63 | score1 = dice_coef(y_true[:, :, :, 1], y_pred[:, :, :, 1]) 64 | score = 0.5 * score0 + 0.5 * score1 65 | 66 | return score 67 | 68 | 69 | def dice_coef_loss(y_true, y_pred): 70 | return -dice_coef2(y_true, y_pred) 71 | 72 | 73 | def disc_crop(org_img, DiscROI_size, C_x, C_y): 74 | tmp_size = int(DiscROI_size / 2) 75 | disc_region = np.zeros((DiscROI_size, DiscROI_size, 3), dtype=org_img.dtype) 76 | crop_coord = np.array([C_x - tmp_size, C_x + tmp_size, C_y - tmp_size, C_y + tmp_size], dtype=int) 77 | err_coord = [0, DiscROI_size, 0, DiscROI_size] 78 | 79 | if crop_coord[0] < 0: 80 | err_coord[0] = abs(crop_coord[0]) 81 | crop_coord[0] = 0 82 | 83 | if crop_coord[2] < 0: 84 | err_coord[2] = abs(crop_coord[2]) 85 | crop_coord[2] = 0 86 | 87 | if crop_coord[1] > org_img.shape[0]: 88 | err_coord[1] = err_coord[1] - (crop_coord[1] - org_img.shape[0]) 89 | crop_coord[1] = org_img.shape[0] 90 | 91 | if crop_coord[3] > org_img.shape[1]: 92 | err_coord[3] = err_coord[3] - (crop_coord[3] - org_img.shape[1]) 93 | crop_coord[3] = org_img.shape[1] 94 | 95 | disc_region[err_coord[0]:err_coord[1], err_coord[2]:err_coord[3], ] = org_img[ 96 | crop_coord[0]:crop_coord[1], 97 | crop_coord[2]:crop_coord[3], 98 | ] 99 | 100 | return disc_region, err_coord, crop_coord 101 | 102 | 103 | def mk_dir(dir_path): 104 | if not os.path.exists(dir_path): 105 | os.makedirs(dir_path) 106 | return dir_path 107 | 108 | 109 | def files_with_ext(data_path, data_type): 110 | file_list = [file for file in os.listdir(data_path) if file.lower().endswith(data_type)] 111 | print(len(file_list)) 112 | return file_list 113 | -------------------------------------------------------------------------------- /mnet_deep_cdr/result/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/result/.gitkeep -------------------------------------------------------------------------------- /mnet_deep_cdr/test_img/CS50041_R.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/test_img/CS50041_R.jpg -------------------------------------------------------------------------------- /mnet_deep_cdr/test_img/CS50106_R.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HzFu/MNet_DeepCDR/e094023d5390ffc1606aba682e48eacf272fdba9/mnet_deep_cdr/test_img/CS50106_R.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | pillow 4 | scikit-image 5 | scipy 6 | tensorflow==1.15.2 7 | pyyaml 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from ast import parse 4 | from distutils.sysconfig import get_python_lib 5 | from functools import partial 6 | from os import path, listdir 7 | from platform import python_version_tuple 8 | 9 | from setuptools import setup, find_packages 10 | 11 | if python_version_tuple()[0] == '3': 12 | imap = map 13 | ifilter = filter 14 | else: 15 | from itertools import imap, ifilter 16 | 17 | if __name__ == '__main__': 18 | package_name = 'mnet_deep_cdr' 19 | 20 | with open(path.join(package_name, '__init__.py')) as f: 21 | __author__, __version__ = imap( 22 | lambda buf: next(imap(lambda e: e.value.s, parse(buf).body)), 23 | ifilter(lambda line: line.startswith('__version__') or line.startswith('__author__'), f) 24 | ) 25 | 26 | to_funcs = lambda *paths: (partial(path.join, path.dirname(__file__), package_name, *paths), 27 | partial(path.join, get_python_lib(prefix=''), package_name, *paths)) 28 | _data_join, _data_install_dir = to_funcs('_data') 29 | ipynb_checkpoints_join, ipynb_checkpoints_install_dir = to_funcs('.ipynb_checkpoints') 30 | deep_model_join, deep_model_install_dir = to_funcs('deep_model') 31 | # mat_scr_join, mat_scr_install_dir = to_funcs('mat_scr') 32 | REFUGE_result_join, REFUGE_result_install_dir = to_funcs('REFUGE_result') 33 | result_join, result_install_dir = to_funcs('result') 34 | test_img_join, test_img_install_dir = to_funcs('test_img') 35 | 36 | setup( 37 | name=package_name, 38 | author=__author__, 39 | version=__version__, 40 | install_requires=['pyyaml'], 41 | test_suite=package_name + '.tests', 42 | packages=find_packages(), 43 | package_dir={package_name: package_name}, 44 | data_files=[ 45 | (_data_install_dir(), list(imap(_data_join, listdir(_data_join())))), 46 | (ipynb_checkpoints_install_dir(), list(imap(ipynb_checkpoints_join, listdir(ipynb_checkpoints_join())))), 47 | (deep_model_install_dir(), list(imap(deep_model_join, listdir(deep_model_join())))), 48 | # (mat_scr_install_dir(), list(imap(mat_scr_join, listdir(mat_scr_join())))), 49 | (REFUGE_result_install_dir(), list(imap(REFUGE_result_join, listdir(REFUGE_result_join())))), 50 | (result_install_dir(), list(imap(result_join, listdir(result_join())))), 51 | (test_img_install_dir(), list(imap(test_img_join, listdir(test_img_join())))) 52 | ] 53 | ) 54 | --------------------------------------------------------------------------------